llama_cpp 0.5.1 → 0.5.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +15 -3
- data/examples/prompt_jp.txt +1 -1
- data/ext/llama_cpp/extconf.rb +1 -1
- data/ext/llama_cpp/llama_cpp.cpp +32 -2
- data/ext/llama_cpp/src/ggml-alloc.c +6 -11
- data/ext/llama_cpp/src/ggml-cuda.cu +1108 -699
- data/ext/llama_cpp/src/ggml-metal.m +93 -24
- data/ext/llama_cpp/src/ggml-metal.metal +407 -174
- data/ext/llama_cpp/src/ggml-opencl.cpp +3 -3
- data/ext/llama_cpp/src/ggml.c +75 -43
- data/ext/llama_cpp/src/ggml.h +42 -32
- data/ext/llama_cpp/src/k_quants.c +4 -1
- data/ext/llama_cpp/src/llama.cpp +1040 -201
- data/ext/llama_cpp/src/llama.h +13 -7
- data/lib/llama_cpp/version.rb +2 -2
- data/lib/llama_cpp.rb +1 -1
- data/sig/llama_cpp.rbs +4 -0
- metadata +2 -2
| @@ -38,7 +38,7 @@ kernel void kernel_add_row( | |
| 38 38 | 
             
                    device const float4 * src0,
         | 
| 39 39 | 
             
                    device const float4 * src1,
         | 
| 40 40 | 
             
                    device       float4 * dst,
         | 
| 41 | 
            -
                    constant | 
| 41 | 
            +
                    constant    int64_t & nb,
         | 
| 42 42 | 
             
                    uint tpig[[thread_position_in_grid]]) {
         | 
| 43 43 | 
             
                dst[tpig] = src0[tpig] + src1[tpig % nb];
         | 
| 44 44 | 
             
            }
         | 
| @@ -63,18 +63,18 @@ kernel void kernel_mul_row( | |
| 63 63 | 
             
            }
         | 
| 64 64 |  | 
| 65 65 | 
             
            kernel void kernel_scale(
         | 
| 66 | 
            -
                    device const  | 
| 67 | 
            -
                    device        | 
| 66 | 
            +
                    device const float4 * src0,
         | 
| 67 | 
            +
                    device       float4 * dst,
         | 
| 68 68 | 
             
                    constant     float & scale,
         | 
| 69 69 | 
             
                    uint tpig[[thread_position_in_grid]]) {
         | 
| 70 70 | 
             
                dst[tpig] = src0[tpig] * scale;
         | 
| 71 71 | 
             
            }
         | 
| 72 72 |  | 
| 73 73 | 
             
            kernel void kernel_silu(
         | 
| 74 | 
            -
                    device const  | 
| 75 | 
            -
                    device        | 
| 74 | 
            +
                    device const float4 * src0,
         | 
| 75 | 
            +
                    device       float4 * dst,
         | 
| 76 76 | 
             
                    uint tpig[[thread_position_in_grid]]) {
         | 
| 77 | 
            -
                 | 
| 77 | 
            +
                device const float4 & x = src0[tpig];
         | 
| 78 78 | 
             
                dst[tpig] = x / (1.0f + exp(-x));
         | 
| 79 79 | 
             
            }
         | 
| 80 80 |  | 
| @@ -89,10 +89,10 @@ constant float GELU_COEF_A    = 0.044715f; | |
| 89 89 | 
             
            constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
         | 
| 90 90 |  | 
| 91 91 | 
             
            kernel void kernel_gelu(
         | 
| 92 | 
            -
                device const  | 
| 93 | 
            -
                device        | 
| 92 | 
            +
                device const float4 * src0,
         | 
| 93 | 
            +
                device       float4 * dst,
         | 
| 94 94 | 
             
                uint tpig[[thread_position_in_grid]]) {
         | 
| 95 | 
            -
                 | 
| 95 | 
            +
                device const float4 & x = src0[tpig];
         | 
| 96 96 |  | 
| 97 97 | 
             
                // BEWARE !!!
         | 
| 98 98 | 
             
                // Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
         | 
| @@ -107,7 +107,6 @@ kernel void kernel_soft_max( | |
| 107 107 | 
             
                    constant   int64_t & ne00,
         | 
| 108 108 | 
             
                    constant   int64_t & ne01,
         | 
| 109 109 | 
             
                    constant   int64_t & ne02,
         | 
| 110 | 
            -
                    threadgroup float  * buf [[threadgroup(0)]],
         | 
| 111 110 | 
             
                    uint3 tgpig[[threadgroup_position_in_grid]],
         | 
| 112 111 | 
             
                    uint3 tpitg[[thread_position_in_threadgroup]],
         | 
| 113 112 | 
             
                    uint3   ntg[[threads_per_threadgroup]]) {
         | 
| @@ -119,61 +118,67 @@ kernel void kernel_soft_max( | |
| 119 118 | 
             
                device       float * pdst  = dst  + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
         | 
| 120 119 |  | 
| 121 120 | 
             
                // parallel max
         | 
| 122 | 
            -
                 | 
| 123 | 
            -
                for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
         | 
| 124 | 
            -
                     | 
| 121 | 
            +
                float lmax = tpitg[0] < ne00 ? psrc0[tpitg[0]] : -INFINITY;
         | 
| 122 | 
            +
                for (int i00 = tpitg[0] + ntg[0]; i00 < ne00; i00 += ntg[0]) {
         | 
| 123 | 
            +
                    lmax = MAX(lmax, psrc0[i00]);
         | 
| 125 124 | 
             
                }
         | 
| 126 | 
            -
             | 
| 127 | 
            -
                // reduce
         | 
| 128 | 
            -
                threadgroup_barrier(mem_flags::mem_threadgroup);
         | 
| 129 | 
            -
                for (uint i = ntg[0]/2; i > 0; i /= 2) {
         | 
| 130 | 
            -
                    if (tpitg[0] < i) {
         | 
| 131 | 
            -
                        buf[tpitg[0]] = MAX(buf[tpitg[0]], buf[tpitg[0] + i]);
         | 
| 132 | 
            -
                    }
         | 
| 133 | 
            -
                    threadgroup_barrier(mem_flags::mem_threadgroup);
         | 
| 134 | 
            -
                }
         | 
| 135 | 
            -
             | 
| 136 | 
            -
                //// broadcast - not needed. There is a threadgroup barrier above in the last iteration of
         | 
| 137 | 
            -
                //               the loop, and when that is done, buf[0] has the correct (synchronized) value
         | 
| 138 | 
            -
                //if (tpitg[0] == 0) {
         | 
| 139 | 
            -
                //    buf[0] = buf[0];
         | 
| 140 | 
            -
                //}
         | 
| 141 | 
            -
             | 
| 142 | 
            -
                //threadgroup_barrier(mem_flags::mem_threadgroup);
         | 
| 143 | 
            -
             | 
| 144 | 
            -
                const float max = buf[0];
         | 
| 125 | 
            +
                const float max = simd_max(lmax);
         | 
| 145 126 |  | 
| 146 127 | 
             
                // parallel sum
         | 
| 147 | 
            -
                 | 
| 128 | 
            +
                float lsum = 0.0f;
         | 
| 148 129 | 
             
                for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
         | 
| 149 130 | 
             
                    const float exp_psrc0 = exp(psrc0[i00] - max);
         | 
| 150 | 
            -
                     | 
| 131 | 
            +
                    lsum += exp_psrc0;
         | 
| 151 132 | 
             
                    // Remember the result of exp here. exp is expensive, so we really do not
         | 
| 152 133 | 
             
                    // whish to compute it twice.
         | 
| 153 134 | 
             
                    pdst[i00] = exp_psrc0;
         | 
| 154 135 | 
             
                }
         | 
| 155 136 |  | 
| 156 | 
            -
                 | 
| 157 | 
            -
             | 
| 158 | 
            -
                for ( | 
| 159 | 
            -
                     | 
| 160 | 
            -
             | 
| 161 | 
            -
             | 
| 162 | 
            -
             | 
| 137 | 
            +
                const float sum = simd_sum(lsum);
         | 
| 138 | 
            +
             | 
| 139 | 
            +
                for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
         | 
| 140 | 
            +
                    pdst[i00] /= sum;
         | 
| 141 | 
            +
                }
         | 
| 142 | 
            +
            }
         | 
| 143 | 
            +
             | 
| 144 | 
            +
            kernel void kernel_soft_max_4(
         | 
| 145 | 
            +
                    device const float * src0,
         | 
| 146 | 
            +
                    device       float * dst,
         | 
| 147 | 
            +
                    constant   int64_t & ne00,
         | 
| 148 | 
            +
                    constant   int64_t & ne01,
         | 
| 149 | 
            +
                    constant   int64_t & ne02,
         | 
| 150 | 
            +
                    uint3 tgpig[[threadgroup_position_in_grid]],
         | 
| 151 | 
            +
                    uint3 tpitg[[thread_position_in_threadgroup]],
         | 
| 152 | 
            +
                    uint3   ntg[[threads_per_threadgroup]]) {
         | 
| 153 | 
            +
                const int64_t i03 = tgpig[2];
         | 
| 154 | 
            +
                const int64_t i02 = tgpig[1];
         | 
| 155 | 
            +
                const int64_t i01 = tgpig[0];
         | 
| 156 | 
            +
             | 
| 157 | 
            +
                device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
         | 
| 158 | 
            +
                device       float4 * pdst4 = (device       float4 *)(dst  + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
         | 
| 159 | 
            +
             | 
| 160 | 
            +
                // parallel max
         | 
| 161 | 
            +
                float4 lmax4 = tpitg[0] < ne00/4 ? psrc4[tpitg[0]] : -INFINITY;
         | 
| 162 | 
            +
                for (int i00 = tpitg[0] + ntg[0]; i00 < ne00/4; i00 += ntg[0]) {
         | 
| 163 | 
            +
                    lmax4 = fmax(lmax4, psrc4[i00]);
         | 
| 163 164 | 
             
                }
         | 
| 165 | 
            +
                float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
         | 
| 164 166 |  | 
| 165 | 
            -
                 | 
| 166 | 
            -
                //// broadcast
         | 
| 167 | 
            -
                //if (tpitg[0] == 0) {
         | 
| 168 | 
            -
                //    buf[0] = buf[0];
         | 
| 169 | 
            -
                //}
         | 
| 167 | 
            +
                const float max = simd_max(lmax);
         | 
| 170 168 |  | 
| 171 | 
            -
                // | 
| 169 | 
            +
                // parallel sum
         | 
| 170 | 
            +
                float4 lsum4 = 0.0f;
         | 
| 171 | 
            +
                for (int i00 = tpitg[0]; i00 < ne00/4; i00 += ntg[0]) {
         | 
| 172 | 
            +
                    const float4 exp_psrc4 = exp(psrc4[i00] - max);
         | 
| 173 | 
            +
                    lsum4 += exp_psrc4;
         | 
| 174 | 
            +
                    pdst4[i00] = exp_psrc4;
         | 
| 175 | 
            +
                }
         | 
| 176 | 
            +
                float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
         | 
| 172 177 |  | 
| 173 | 
            -
                const float sum =  | 
| 178 | 
            +
                const float sum = simd_sum(lsum);
         | 
| 174 179 |  | 
| 175 | 
            -
                for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
         | 
| 176 | 
            -
                     | 
| 180 | 
            +
                for (int i00 = tpitg[0]; i00 < ne00/4; i00 += ntg[0]) {
         | 
| 181 | 
            +
                    pdst4[i00] /= sum;
         | 
| 177 182 | 
             
                }
         | 
| 178 183 | 
             
            }
         | 
| 179 184 |  | 
| @@ -192,6 +197,33 @@ kernel void kernel_diag_mask_inf( | |
| 192 197 | 
             
                    dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY;
         | 
| 193 198 | 
             
                } else {
         | 
| 194 199 | 
             
                    dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00];
         | 
| 200 | 
            +
                 }
         | 
| 201 | 
            +
            }
         | 
| 202 | 
            +
             | 
| 203 | 
            +
            kernel void kernel_diag_mask_inf_8(
         | 
| 204 | 
            +
                    device const float4 * src0,
         | 
| 205 | 
            +
                    device       float4 * dst,
         | 
| 206 | 
            +
                    constant    int64_t & ne00,
         | 
| 207 | 
            +
                    constant    int64_t & ne01,
         | 
| 208 | 
            +
                    constant        int & n_past,
         | 
| 209 | 
            +
                    uint3 tpig[[thread_position_in_grid]]) {
         | 
| 210 | 
            +
             | 
| 211 | 
            +
                const int64_t i = 2*tpig[0];
         | 
| 212 | 
            +
             | 
| 213 | 
            +
                dst[i+0] = src0[i+0];
         | 
| 214 | 
            +
                dst[i+1] = src0[i+1];
         | 
| 215 | 
            +
                int64_t i4 = 4*i;
         | 
| 216 | 
            +
                const int64_t i02 = i4/(ne00*ne01); i4 -= i02*ne00*ne01;
         | 
| 217 | 
            +
                const int64_t i01 = i4/(ne00);      i4 -= i01*ne00;
         | 
| 218 | 
            +
                const int64_t i00 = i4;
         | 
| 219 | 
            +
                for (int k = 3; k >= 0; --k) {
         | 
| 220 | 
            +
                    if (i00 + 4 + k <= n_past + i01) {
         | 
| 221 | 
            +
                        break;
         | 
| 222 | 
            +
                    }
         | 
| 223 | 
            +
                    dst[i+1][k] = -INFINITY;
         | 
| 224 | 
            +
                    if (i00 + k > n_past + i01) {
         | 
| 225 | 
            +
                        dst[i][k] = -INFINITY;
         | 
| 226 | 
            +
                    }
         | 
| 195 227 | 
             
                }
         | 
| 196 228 | 
             
            }
         | 
| 197 229 |  | 
| @@ -491,6 +523,79 @@ kernel void kernel_mul_mat_q8_0_f32( | |
| 491 523 | 
             
                }
         | 
| 492 524 | 
             
            }
         | 
| 493 525 |  | 
| 526 | 
            +
            #define N_F32_F32 4
         | 
| 527 | 
            +
             | 
| 528 | 
            +
            kernel void kernel_mul_mat_f32_f32(
         | 
| 529 | 
            +
                    device const  char * src0,
         | 
| 530 | 
            +
                    device const  char * src1,
         | 
| 531 | 
            +
                    device       float * dst,
         | 
| 532 | 
            +
                    constant   int64_t & ne00,
         | 
| 533 | 
            +
                    constant   int64_t & ne01,
         | 
| 534 | 
            +
                    constant   int64_t & ne02,
         | 
| 535 | 
            +
                    constant  uint64_t & nb00,
         | 
| 536 | 
            +
                    constant  uint64_t & nb01,
         | 
| 537 | 
            +
                    constant  uint64_t & nb02,
         | 
| 538 | 
            +
                    constant   int64_t & ne10,
         | 
| 539 | 
            +
                    constant   int64_t & ne11,
         | 
| 540 | 
            +
                    constant   int64_t & ne12,
         | 
| 541 | 
            +
                    constant  uint64_t & nb10,
         | 
| 542 | 
            +
                    constant  uint64_t & nb11,
         | 
| 543 | 
            +
                    constant  uint64_t & nb12,
         | 
| 544 | 
            +
                    constant   int64_t & ne0,
         | 
| 545 | 
            +
                    constant   int64_t & ne1,
         | 
| 546 | 
            +
                    uint3 tgpig[[threadgroup_position_in_grid]],
         | 
| 547 | 
            +
                    uint tiisg[[thread_index_in_simdgroup]]) {
         | 
| 548 | 
            +
             | 
| 549 | 
            +
                const int64_t r0 = tgpig.x;
         | 
| 550 | 
            +
                const int64_t rb = tgpig.y*N_F32_F32;
         | 
| 551 | 
            +
                const int64_t im = tgpig.z;
         | 
| 552 | 
            +
             | 
| 553 | 
            +
                device const float * x = (device const float *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
         | 
| 554 | 
            +
             | 
| 555 | 
            +
                if (ne00 < 128) {
         | 
| 556 | 
            +
                    for (int row = 0; row < N_F32_F32; ++row) {
         | 
| 557 | 
            +
                        int r1 = rb + row;
         | 
| 558 | 
            +
                        if (r1 >= ne11) {
         | 
| 559 | 
            +
                            break;
         | 
| 560 | 
            +
                        }
         | 
| 561 | 
            +
             | 
| 562 | 
            +
                        device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
         | 
| 563 | 
            +
             | 
| 564 | 
            +
                        float sumf = 0;
         | 
| 565 | 
            +
                        for (int i = tiisg; i < ne00; i += 32) {
         | 
| 566 | 
            +
                            sumf += (float) x[i] * (float) y[i];
         | 
| 567 | 
            +
                        }
         | 
| 568 | 
            +
             | 
| 569 | 
            +
                        float all_sum = simd_sum(sumf);
         | 
| 570 | 
            +
                        if (tiisg == 0) {
         | 
| 571 | 
            +
                            dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
         | 
| 572 | 
            +
                        }
         | 
| 573 | 
            +
                    }
         | 
| 574 | 
            +
                } else {
         | 
| 575 | 
            +
                    device const float4 * x4 = (device const float4 *)x;
         | 
| 576 | 
            +
                    for (int row = 0; row < N_F32_F32; ++row) {
         | 
| 577 | 
            +
                        int r1 = rb + row;
         | 
| 578 | 
            +
                        if (r1 >= ne11) {
         | 
| 579 | 
            +
                            break;
         | 
| 580 | 
            +
                        }
         | 
| 581 | 
            +
             | 
| 582 | 
            +
                        device const float  * y  = (device const float  *) (src1 + r1*nb11 + im*nb12);
         | 
| 583 | 
            +
                        device const float4 * y4 = (device const float4 *) y;
         | 
| 584 | 
            +
             | 
| 585 | 
            +
                        float sumf = 0;
         | 
| 586 | 
            +
                        for (int i = tiisg; i < ne00/4; i += 32) {
         | 
| 587 | 
            +
                            for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k];
         | 
| 588 | 
            +
                        }
         | 
| 589 | 
            +
             | 
| 590 | 
            +
                        float all_sum = simd_sum(sumf);
         | 
| 591 | 
            +
                        if (tiisg == 0) {
         | 
| 592 | 
            +
                            for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i];
         | 
| 593 | 
            +
                            dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
         | 
| 594 | 
            +
                        }
         | 
| 595 | 
            +
                    }
         | 
| 596 | 
            +
                }
         | 
| 597 | 
            +
            }
         | 
| 598 | 
            +
             | 
| 494 599 | 
             
            kernel void kernel_mul_mat_f16_f32_1row(
         | 
| 495 600 | 
             
                    device const  char * src0,
         | 
| 496 601 | 
             
                    device const  char * src1,
         | 
| @@ -616,6 +721,49 @@ kernel void kernel_mul_mat_f16_f32( | |
| 616 721 | 
             
                }
         | 
| 617 722 | 
             
            }
         | 
| 618 723 |  | 
| 724 | 
            +
            // Assumes row size (ne00) is a multiple of 4
         | 
| 725 | 
            +
            kernel void kernel_mul_mat_f16_f32_l4(
         | 
| 726 | 
            +
                    device const  char * src0,
         | 
| 727 | 
            +
                    device const  char * src1,
         | 
| 728 | 
            +
                    device       float * dst,
         | 
| 729 | 
            +
                    constant   int64_t & ne00,
         | 
| 730 | 
            +
                    constant   int64_t & ne01,
         | 
| 731 | 
            +
                    constant   int64_t & ne02,
         | 
| 732 | 
            +
                    constant  uint64_t & nb00,
         | 
| 733 | 
            +
                    constant  uint64_t & nb01,
         | 
| 734 | 
            +
                    constant  uint64_t & nb02,
         | 
| 735 | 
            +
                    constant   int64_t & ne10,
         | 
| 736 | 
            +
                    constant   int64_t & ne11,
         | 
| 737 | 
            +
                    constant   int64_t & ne12,
         | 
| 738 | 
            +
                    constant  uint64_t & nb10,
         | 
| 739 | 
            +
                    constant  uint64_t & nb11,
         | 
| 740 | 
            +
                    constant  uint64_t & nb12,
         | 
| 741 | 
            +
                    constant   int64_t & ne0,
         | 
| 742 | 
            +
                    constant   int64_t & ne1,
         | 
| 743 | 
            +
                    uint3 tgpig[[threadgroup_position_in_grid]],
         | 
| 744 | 
            +
                    uint tiisg[[thread_index_in_simdgroup]]) {
         | 
| 745 | 
            +
             | 
| 746 | 
            +
                const int nrows = ne11;
         | 
| 747 | 
            +
                const int64_t r0 = tgpig.x;
         | 
| 748 | 
            +
                const int64_t im = tgpig.z;
         | 
| 749 | 
            +
             | 
| 750 | 
            +
                device const half4 * x4 = (device const half4 *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
         | 
| 751 | 
            +
             | 
| 752 | 
            +
                for (int r1 = 0; r1 < nrows; ++r1) {
         | 
| 753 | 
            +
                    device const float4 * y4 = (device const float4 *) (src1 + r1*nb11 + im*nb12);
         | 
| 754 | 
            +
             | 
| 755 | 
            +
                    float sumf = 0;
         | 
| 756 | 
            +
                    for (int i = tiisg; i < ne00/4; i += 32) {
         | 
| 757 | 
            +
                        for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k];
         | 
| 758 | 
            +
                    }
         | 
| 759 | 
            +
             | 
| 760 | 
            +
                    float all_sum = simd_sum(sumf);
         | 
| 761 | 
            +
                    if (tiisg == 0) {
         | 
| 762 | 
            +
                        dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
         | 
| 763 | 
            +
                    }
         | 
| 764 | 
            +
                }
         | 
| 765 | 
            +
            }
         | 
| 766 | 
            +
             | 
| 619 767 | 
             
            kernel void kernel_alibi_f32(
         | 
| 620 768 | 
             
                    device const float * src0,
         | 
| 621 769 | 
             
                    device       float * dst,
         | 
| @@ -1123,31 +1271,40 @@ kernel void kernel_mul_mat_q3_K_f32( | |
| 1123 1271 | 
             
                device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + offset0;
         | 
| 1124 1272 | 
             
                device const float     * yy = (device const float      *) src1 + r1*ne10 + r2*ne00*ne1;
         | 
| 1125 1273 |  | 
| 1126 | 
            -
                float yl[ | 
| 1274 | 
            +
                float yl[32];
         | 
| 1127 1275 |  | 
| 1128 | 
            -
                const uint16_t kmask1 =  | 
| 1276 | 
            +
                const uint16_t kmask1 = 0x3030;
         | 
| 1129 1277 | 
             
                const uint16_t kmask2 = 0x0f0f;
         | 
| 1130 1278 |  | 
| 1131 | 
            -
                const int tid = tiisg/ | 
| 1132 | 
            -
                const int ix  = tiisg% | 
| 1133 | 
            -
                const int ip  = tid/ | 
| 1134 | 
            -
                const int il  = tid/2 | 
| 1279 | 
            +
                const int tid = tiisg/4;
         | 
| 1280 | 
            +
                const int ix  = tiisg%4;
         | 
| 1281 | 
            +
                const int ip  = tid/4;          // 0 or 1
         | 
| 1282 | 
            +
                const int il  = 2*((tid%4)/2);  // 0 or 2
         | 
| 1135 1283 | 
             
                const int ir  = tid%2;
         | 
| 1136 1284 | 
             
                const int n   = 8;
         | 
| 1137 1285 | 
             
                const int l0  = n*ir;
         | 
| 1138 1286 |  | 
| 1139 | 
            -
                 | 
| 1140 | 
            -
                 | 
| 1287 | 
            +
                // One would think that the Metal compiler would figure out that ip and il can only have
         | 
| 1288 | 
            +
                // 4 possible states, and optimize accordingly. Well, no. It needs help, and we do it
         | 
| 1289 | 
            +
                // with these two tales.
         | 
| 1290 | 
            +
                //
         | 
| 1291 | 
            +
                // Possible masks for the high bit
         | 
| 1292 | 
            +
                const ushort4 mm[4] = {{0x0001, 0x0100, 0x0002, 0x0200},  // ip = 0, il = 0
         | 
| 1293 | 
            +
                                       {0x0004, 0x0400, 0x0008, 0x0800},  // ip = 0, il = 2
         | 
| 1294 | 
            +
                                       {0x0010, 0x1000, 0x0020, 0x2000},  // ip = 1, il = 0
         | 
| 1295 | 
            +
                                       {0x0040, 0x4000, 0x0080, 0x8000}}; // ip = 1, il = 2
         | 
| 1296 | 
            +
             | 
| 1297 | 
            +
                // Possible masks for the low 2 bits
         | 
| 1298 | 
            +
                const int4 qm[2] = {{0x0003, 0x0300, 0x000c, 0x0c00}, {0x0030, 0x3000, 0x00c0, 0xc000}};
         | 
| 1299 | 
            +
             | 
| 1300 | 
            +
                const ushort4 hm = mm[2*ip + il/2];
         | 
| 1141 1301 |  | 
| 1142 1302 | 
             
                const int shift = 2*il;
         | 
| 1143 | 
            -
                const  | 
| 1144 | 
            -
                const  | 
| 1145 | 
            -
                const int32_t v1 = 4 << shift;
         | 
| 1146 | 
            -
                const int32_t v2 = 1024 << shift;
         | 
| 1303 | 
            +
                const float    v1 = il == 0 ? 4.f : 64.f;
         | 
| 1304 | 
            +
                const float    v2 = 4.f * v1;
         | 
| 1147 1305 |  | 
| 1148 1306 | 
             
                const uint16_t s_shift1 = 4*ip;
         | 
| 1149 | 
            -
                const uint16_t s_shift2 = s_shift1 +  | 
| 1150 | 
            -
                const int ik = 4 + (il%2);
         | 
| 1307 | 
            +
                const uint16_t s_shift2 = s_shift1 + il;
         | 
| 1151 1308 |  | 
| 1152 1309 | 
             
                const int q_offset = 32*ip + l0;
         | 
| 1153 1310 | 
             
                const int y_offset = 128*ip + 32*il + l0;
         | 
| @@ -1156,12 +1313,19 @@ kernel void kernel_mul_mat_q3_K_f32( | |
| 1156 1313 |  | 
| 1157 1314 | 
             
                device const float * y1 = yy + ix*QK_K + y_offset;
         | 
| 1158 1315 |  | 
| 1159 | 
            -
                 | 
| 1160 | 
            -
                 | 
| 1316 | 
            +
                uint32_t scales32, aux32;
         | 
| 1317 | 
            +
                thread uint16_t * scales16 = (thread uint16_t *)&scales32;
         | 
| 1318 | 
            +
                thread const int8_t * scales = (thread const int8_t *)&scales32;
         | 
| 1319 | 
            +
             | 
| 1320 | 
            +
                float sumf1[2] = {0.f};
         | 
| 1321 | 
            +
                float sumf2[2] = {0.f};
         | 
| 1322 | 
            +
                for (int i = ix; i < nb; i += 4) {
         | 
| 1161 1323 |  | 
| 1162 1324 | 
             
                    for (int l = 0; l < 8; ++l) {
         | 
| 1163 | 
            -
                        yl[l+0] = y1[l+ 0];
         | 
| 1164 | 
            -
                        yl[l+8] = y1[l+16];
         | 
| 1325 | 
            +
                        yl[l+ 0] = y1[l+ 0];
         | 
| 1326 | 
            +
                        yl[l+ 8] = y1[l+16];
         | 
| 1327 | 
            +
                        yl[l+16] = y1[l+32];
         | 
| 1328 | 
            +
                        yl[l+24] = y1[l+48];
         | 
| 1165 1329 | 
             
                    }
         | 
| 1166 1330 |  | 
| 1167 1331 | 
             
                    device const uint16_t * q = (device const uint16_t *)(x[i].qs + q_offset);
         | 
| @@ -1172,27 +1336,43 @@ kernel void kernel_mul_mat_q3_K_f32( | |
| 1172 1336 | 
             
                    for (int row = 0; row < 2; ++row) {
         | 
| 1173 1337 |  | 
| 1174 1338 | 
             
                        const float d_all = (float)dh[0];
         | 
| 1175 | 
            -
                        const char2 scales = as_type<char2>((uint16_t)(((a[il] >> s_shift1) & kmask2) | (((a[ik] >> s_shift2) & kmask1) << 4)));
         | 
| 1176 1339 |  | 
| 1177 | 
            -
                         | 
| 1340 | 
            +
                        scales16[0] = a[4];
         | 
| 1341 | 
            +
                        scales16[1] = a[5];
         | 
| 1342 | 
            +
                        aux32 = ((scales32 >> s_shift2) << 4) & 0x30303030;
         | 
| 1343 | 
            +
                        scales16[0] = a[il+0];
         | 
| 1344 | 
            +
                        scales16[1] = a[il+1];
         | 
| 1345 | 
            +
                        scales32 = ((scales32 >> s_shift1) & 0x0f0f0f0f) | aux32;
         | 
| 1346 | 
            +
             | 
| 1347 | 
            +
                        float s1 = 0, s2 = 0, s3 = 0, s4 = 0, s5 = 0, s6 = 0;
         | 
| 1178 1348 | 
             
                        for (int l = 0; l < n; l += 2) {
         | 
| 1179 | 
            -
                            const  | 
| 1180 | 
            -
                            s1 += yl[l+0] * ( | 
| 1181 | 
            -
                            s2 += yl[l+1] * ( | 
| 1349 | 
            +
                            const int32_t qs = q[l/2];
         | 
| 1350 | 
            +
                            s1 += yl[l+0] * (qs & qm[il/2][0]);
         | 
| 1351 | 
            +
                            s2 += yl[l+1] * (qs & qm[il/2][1]);
         | 
| 1352 | 
            +
                            s3 += ((h[l/2] & hm[0]) ? 0.f : yl[l+0]) + ((h[l/2] & hm[1]) ? 0.f : yl[l+1]);
         | 
| 1353 | 
            +
                            s4 += yl[l+16] * (qs & qm[il/2][2]);
         | 
| 1354 | 
            +
                            s5 += yl[l+17] * (qs & qm[il/2][3]);
         | 
| 1355 | 
            +
                            s6 += ((h[l/2] & hm[2]) ? 0.f : yl[l+16]) + ((h[l/2] & hm[3]) ? 0.f : yl[l+17]);
         | 
| 1182 1356 | 
             
                        }
         | 
| 1183 | 
            -
                        float  | 
| 1184 | 
            -
                         | 
| 1185 | 
            -
                         | 
| 1357 | 
            +
                        float d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1);
         | 
| 1358 | 
            +
                        float d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2);
         | 
| 1359 | 
            +
                        sumf1[row] += d1 * (scales[0] - 32);
         | 
| 1360 | 
            +
                        sumf2[row] += d2 * (scales[2] - 32);
         | 
| 1186 1361 |  | 
| 1187 | 
            -
                        s1 = s2 = 0;
         | 
| 1362 | 
            +
                        s1 = s2 = s3 = s4 = s5 = s6 = 0;
         | 
| 1188 1363 | 
             
                        for (int l = 0; l < n; l += 2) {
         | 
| 1189 | 
            -
                            const  | 
| 1190 | 
            -
                            s1 += yl[l+8] * ( | 
| 1191 | 
            -
                            s2 += yl[l+9] * ( | 
| 1364 | 
            +
                            const int32_t qs = q[l/2+8];
         | 
| 1365 | 
            +
                            s1 += yl[l+8] * (qs & qm[il/2][0]);
         | 
| 1366 | 
            +
                            s2 += yl[l+9] * (qs & qm[il/2][1]);
         | 
| 1367 | 
            +
                            s3 += ((h[l/2+8] & hm[0]) ? 0.f : yl[l+8]) + ((h[l/2+8] & hm[1]) ? 0.f : yl[l+9]);
         | 
| 1368 | 
            +
                            s4 += yl[l+24] * (qs & qm[il/2][2]);
         | 
| 1369 | 
            +
                            s5 += yl[l+25] * (qs & qm[il/2][3]);
         | 
| 1370 | 
            +
                            s6 += ((h[l/2+8] & hm[2]) ? 0.f : yl[l+24]) + ((h[l/2+8] & hm[3]) ? 0.f : yl[l+25]);
         | 
| 1192 1371 | 
             
                        }
         | 
| 1193 | 
            -
                         | 
| 1194 | 
            -
                         | 
| 1195 | 
            -
                         | 
| 1372 | 
            +
                        d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1);
         | 
| 1373 | 
            +
                        d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2);
         | 
| 1374 | 
            +
                        sumf1[row] += d1 * (scales[1] - 32);
         | 
| 1375 | 
            +
                        sumf2[row] += d2 * (scales[3] - 32);
         | 
| 1196 1376 |  | 
| 1197 1377 | 
             
                        q  += step;
         | 
| 1198 1378 | 
             
                        h  += step;
         | 
| @@ -1201,15 +1381,17 @@ kernel void kernel_mul_mat_q3_K_f32( | |
| 1201 1381 |  | 
| 1202 1382 | 
             
                    }
         | 
| 1203 1383 |  | 
| 1204 | 
            -
                    y1 +=  | 
| 1384 | 
            +
                    y1 += 4 * QK_K;
         | 
| 1205 1385 |  | 
| 1206 1386 | 
             
                }
         | 
| 1207 1387 |  | 
| 1208 1388 | 
             
                for (int row = 0; row < 2; ++row) {
         | 
| 1209 | 
            -
                    const float sumf = (sumf1[row]  | 
| 1210 | 
            -
                     | 
| 1211 | 
            -
             | 
| 1212 | 
            -
             | 
| 1389 | 
            +
                    const float sumf = (sumf1[row] + 0.25f * sumf2[row]) / (1 << shift);
         | 
| 1390 | 
            +
                    sumf1[row] = simd_sum(sumf);
         | 
| 1391 | 
            +
                }
         | 
| 1392 | 
            +
                if (tiisg == 0) {
         | 
| 1393 | 
            +
                    for (int row = 0; row < 2; ++row) {
         | 
| 1394 | 
            +
                        dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = sumf1[row];
         | 
| 1213 1395 | 
             
                    }
         | 
| 1214 1396 | 
             
                }
         | 
| 1215 1397 | 
             
            }
         | 
| @@ -1290,13 +1472,13 @@ kernel void kernel_mul_mat_q4_K_f32( | |
| 1290 1472 | 
             
                    device const float * src1,
         | 
| 1291 1473 | 
             
                    device       float * dst,
         | 
| 1292 1474 | 
             
                    constant   int64_t & ne00,
         | 
| 1293 | 
            -
                    constant   int64_t & ne01[[buffer(4)]],
         | 
| 1294 | 
            -
                    constant   int64_t & ne02[[buffer(5)]],
         | 
| 1295 | 
            -
                    constant   int64_t & ne10[[buffer(9)]],
         | 
| 1296 | 
            -
                    constant   int64_t & ne12[[buffer(11)]],
         | 
| 1297 | 
            -
                    constant   int64_t & ne0[[buffer(15)]],
         | 
| 1298 | 
            -
                    constant   int64_t & ne1[[buffer(16)]],
         | 
| 1299 | 
            -
                    constant   uint    & gqa[[buffer(17)]],
         | 
| 1475 | 
            +
                    constant   int64_t & ne01 [[buffer(4)]],
         | 
| 1476 | 
            +
                    constant   int64_t & ne02 [[buffer(5)]],
         | 
| 1477 | 
            +
                    constant   int64_t & ne10 [[buffer(9)]],
         | 
| 1478 | 
            +
                    constant   int64_t & ne12 [[buffer(11)]],
         | 
| 1479 | 
            +
                    constant   int64_t & ne0  [[buffer(15)]],
         | 
| 1480 | 
            +
                    constant   int64_t & ne1  [[buffer(16)]],
         | 
| 1481 | 
            +
                    constant   uint    & gqa  [[buffer(17)]],
         | 
| 1300 1482 | 
             
                    uint3 tgpig[[threadgroup_position_in_grid]],
         | 
| 1301 1483 | 
             
                    uint tiisg[[thread_index_in_simdgroup]],
         | 
| 1302 1484 | 
             
                    uint sgitg[[simdgroup_index_in_threadgroup]]) {
         | 
| @@ -1564,17 +1746,25 @@ kernel void kernel_mul_mat_q5_K_f32( | |
| 1564 1746 | 
             
                        sc16[2] = ((a[4] >> 0) & kmask2) | ((a[0] & kmask3) >> 2);
         | 
| 1565 1747 | 
             
                        sc16[3] = ((a[4] >> 4) & kmask2) | ((a[2] & kmask3) >> 2);
         | 
| 1566 1748 |  | 
| 1567 | 
            -
                        float4  | 
| 1749 | 
            +
                        float4 acc1 = {0.f};
         | 
| 1750 | 
            +
                        float4 acc2 = {0.f};
         | 
| 1568 1751 | 
             
                        for (int l = 0; l < n; ++l) {
         | 
| 1569 1752 | 
             
                            uint8_t h = qh[l];
         | 
| 1570 | 
            -
                             | 
| 1571 | 
            -
                             | 
| 1572 | 
            -
                             | 
| 1573 | 
            -
                             | 
| 1753 | 
            +
                            acc1[0] += yl[l+0] * (q1[l] & 0x0F);
         | 
| 1754 | 
            +
                            acc1[1] += yl[l+8] * (q1[l] & 0xF0);
         | 
| 1755 | 
            +
                            acc1[2] += yh[l+0] * (q2[l] & 0x0F);
         | 
| 1756 | 
            +
                            acc1[3] += yh[l+8] * (q2[l] & 0xF0);
         | 
| 1757 | 
            +
                            acc2[0] += h & hm1 ? yl[l+0] : 0.f;
         | 
| 1758 | 
            +
                            acc2[1] += h & hm2 ? yl[l+8] : 0.f;
         | 
| 1759 | 
            +
                            acc2[2] += h & hm3 ? yh[l+0] : 0.f;
         | 
| 1760 | 
            +
                            acc2[3] += h & hm4 ? yh[l+8] : 0.f;
         | 
| 1574 1761 | 
             
                        }
         | 
| 1575 1762 | 
             
                        const float dall = dh[0];
         | 
| 1576 1763 | 
             
                        const float dmin = dh[1];
         | 
| 1577 | 
            -
                        sumf[row] += dall * ( | 
| 1764 | 
            +
                        sumf[row] += dall * (sc8[0] * (acc1[0] +  16.f*acc2[0]) +
         | 
| 1765 | 
            +
                                             sc8[1] * (acc1[1]/16.f + 16.f*acc2[1]) +
         | 
| 1766 | 
            +
                                             sc8[4] * (acc1[2] +  16.f*acc2[2]) +
         | 
| 1767 | 
            +
                                             sc8[5] * (acc1[3]/16.f + 16.f*acc2[3])) -
         | 
| 1578 1768 | 
             
                                     dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
         | 
| 1579 1769 |  | 
| 1580 1770 | 
             
                        q1 += step;
         | 
| @@ -1747,6 +1937,15 @@ kernel void kernel_mul_mat_q6_K_f32( | |
| 1747 1937 |  | 
| 1748 1938 | 
             
            //============================= templates and their specializations =============================
         | 
| 1749 1939 |  | 
| 1940 | 
            +
            // NOTE: this is not dequantizing - we are simply fitting the template
         | 
| 1941 | 
            +
            template <typename type4x4>
         | 
| 1942 | 
            +
            void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
         | 
| 1943 | 
            +
                float4x4 temp = *(((device float4x4 *)src));
         | 
| 1944 | 
            +
                for (int i = 0; i < 16; i++){
         | 
| 1945 | 
            +
                    reg[i/4][i%4] = temp[i/4][i%4];
         | 
| 1946 | 
            +
                }
         | 
| 1947 | 
            +
            }
         | 
| 1948 | 
            +
             | 
| 1750 1949 | 
             
            template <typename type4x4>
         | 
| 1751 1950 | 
             
            void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) {
         | 
| 1752 1951 | 
             
                half4x4 temp = *(((device half4x4 *)src));
         | 
| @@ -1758,28 +1957,30 @@ void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) | |
| 1758 1957 | 
             
            template <typename type4x4>
         | 
| 1759 1958 | 
             
            void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
         | 
| 1760 1959 | 
             
                device const uint16_t * qs = ((device const uint16_t *)xb + 1);
         | 
| 1761 | 
            -
                const  | 
| 1762 | 
            -
                const  | 
| 1960 | 
            +
                const float d1 = il ? (xb->d / 16.h) : xb->d;
         | 
| 1961 | 
            +
                const float d2 = d1 / 256.f;
         | 
| 1962 | 
            +
                const float md = -8.h * xb->d;
         | 
| 1763 1963 | 
             
                const ushort mask0 = il ? 0x00F0 : 0x000F;
         | 
| 1764 | 
            -
                const ushort mask1 =  | 
| 1964 | 
            +
                const ushort mask1 = mask0 << 8;
         | 
| 1765 1965 |  | 
| 1766 1966 | 
             
                for (int i=0;i<8;i++) {
         | 
| 1767 | 
            -
                    reg[i/2][2*(i%2)] | 
| 1768 | 
            -
                    reg[i/2][2*(i%2)+1] = ( | 
| 1967 | 
            +
                    reg[i/2][2*(i%2)+0] = d1 * (qs[i] & mask0) + md;
         | 
| 1968 | 
            +
                    reg[i/2][2*(i%2)+1] = d2 * (qs[i] & mask1) + md;
         | 
| 1769 1969 | 
             
                }
         | 
| 1770 1970 | 
             
            }
         | 
| 1771 1971 |  | 
| 1772 1972 | 
             
            template <typename type4x4>
         | 
| 1773 1973 | 
             
            void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) {
         | 
| 1774 1974 | 
             
                device const uint16_t * qs = ((device const uint16_t *)xb + 2);
         | 
| 1775 | 
            -
                const  | 
| 1776 | 
            -
                const  | 
| 1975 | 
            +
                const float d1 = il ? (xb->d / 16.h) : xb->d;
         | 
| 1976 | 
            +
                const float d2 = d1 / 256.f;
         | 
| 1977 | 
            +
                const float  m = xb->m;
         | 
| 1777 1978 | 
             
                const ushort mask0 = il ? 0x00F0 : 0x000F;
         | 
| 1778 | 
            -
                const ushort mask1 =  | 
| 1979 | 
            +
                const ushort mask1 = mask0 << 8;
         | 
| 1779 1980 |  | 
| 1780 1981 | 
             
                for (int i=0;i<8;i++) {
         | 
| 1781 | 
            -
                    reg[i/2][2*(i%2)] | 
| 1782 | 
            -
                    reg[i/2][2*(i%2)+1] = (( | 
| 1982 | 
            +
                    reg[i/2][2*(i%2)+0] = ((qs[i] & mask0) * d1) + m;
         | 
| 1983 | 
            +
                    reg[i/2][2*(i%2)+1] = ((qs[i] & mask1) * d2) + m;
         | 
| 1783 1984 | 
             
                }
         | 
| 1784 1985 | 
             
            }
         | 
| 1785 1986 |  | 
| @@ -1815,7 +2016,7 @@ void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg | |
| 1815 2016 |  | 
| 1816 2017 | 
             
            template <typename type4x4>
         | 
| 1817 2018 | 
             
            void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg) {
         | 
| 1818 | 
            -
                const  | 
| 2019 | 
            +
                const half d_all = xb->d;
         | 
| 1819 2020 | 
             
                device const uint8_t * q = (device const uint8_t *)xb->qs;
         | 
| 1820 2021 | 
             
                device const uint8_t * h = (device const uint8_t *)xb->hmask;
         | 
| 1821 2022 | 
             
                device const int8_t * scales = (device const int8_t *)xb->scales;
         | 
| @@ -1828,16 +2029,18 @@ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg | |
| 1828 2029 | 
             
                                             ((il/4)>0 ? 12  : 3);
         | 
| 1829 2030 | 
             
                uint16_t kmask2 = il/8 ? 0xF0 : 0x0F;
         | 
| 1830 2031 | 
             
                uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4];
         | 
| 1831 | 
            -
                int16_t  dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2) | 
| 1832 | 
            -
             | 
| 1833 | 
            -
                 | 
| 2032 | 
            +
                int16_t  dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2)
         | 
| 2033 | 
            +
                                           : (scale_2&kmask2) | ((scale_1&kmask1) << 4);
         | 
| 2034 | 
            +
                half dl = il<8 ? d_all * (dl_int - 32.h) : d_all * (dl_int / 16.h - 32.h);
         | 
| 2035 | 
            +
                const half ml = 4.h * dl;
         | 
| 1834 2036 |  | 
| 1835 | 
            -
                il = (il/2) | 
| 1836 | 
            -
                 | 
| 1837 | 
            -
                uint8_t mask = il>1 ? (il>2 ? 192    : 48)     : (il>0 ? 12    : 3);
         | 
| 2037 | 
            +
                il = (il/2) & 3;
         | 
| 2038 | 
            +
                const half    coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
         | 
| 2039 | 
            +
                const uint8_t mask = il>1 ? (il>2 ? 192    : 48)     : (il>0 ? 12    : 3);
         | 
| 2040 | 
            +
                dl *= coef;
         | 
| 1838 2041 |  | 
| 1839 2042 | 
             
                for (int i = 0; i < 16; ++i) {
         | 
| 1840 | 
            -
                    reg[i/4][i%4] =  | 
| 2043 | 
            +
                    reg[i/4][i%4] = dl * (q[i] & mask) - (h[i] & m ? 0 : ml);
         | 
| 1841 2044 | 
             
                }
         | 
| 1842 2045 | 
             
            #else
         | 
| 1843 2046 | 
             
                float    kcoef = il&1 ? 1.f/16.f : 1.f;
         | 
| @@ -1852,26 +2055,31 @@ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg | |
| 1852 2055 | 
             
            #endif
         | 
| 1853 2056 | 
             
            }
         | 
| 1854 2057 |  | 
| 2058 | 
            +
            static inline uchar2 get_scale_min_k4_just2(int j, int k, device const uchar * q) {
         | 
| 2059 | 
            +
                return j < 4 ? uchar2{uchar(q[j+0+k] & 63), uchar(q[j+4+k] & 63)}
         | 
| 2060 | 
            +
                             : uchar2{uchar((q[j+4+k] & 0xF) | ((q[j-4+k] & 0xc0) >> 2)), uchar((q[j+4+k] >> 4) | ((q[j-0+k] & 0xc0) >> 2))};
         | 
| 2061 | 
            +
            }
         | 
| 2062 | 
            +
             | 
| 1855 2063 | 
             
            template <typename type4x4>
         | 
| 1856 2064 | 
             
            void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg) {
         | 
| 1857 | 
            -
                device const  | 
| 2065 | 
            +
                device const uchar * q = xb->qs;
         | 
| 1858 2066 |  | 
| 1859 2067 | 
             
            #if QK_K == 256
         | 
| 1860 | 
            -
                const float d = (float)(xb->d);
         | 
| 1861 | 
            -
                const float min = (float)(xb->dmin);
         | 
| 1862 2068 | 
             
                short is = (il/4) * 2;
         | 
| 1863 2069 | 
             
                q = q + (il/4) * 32 + 16 * (il&1);
         | 
| 1864 | 
            -
                il = il | 
| 1865 | 
            -
                const  | 
| 1866 | 
            -
                const  | 
| 1867 | 
            -
                const  | 
| 2070 | 
            +
                il = il & 3;
         | 
| 2071 | 
            +
                const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
         | 
| 2072 | 
            +
                const half d   = il < 2 ? xb->d : xb->d / 16.h;
         | 
| 2073 | 
            +
                const half min = xb->dmin;
         | 
| 2074 | 
            +
                const half dl = d * sc[0];
         | 
| 2075 | 
            +
                const half ml = min * sc[1];
         | 
| 1868 2076 | 
             
            #else
         | 
| 1869 2077 | 
             
                q = q + 16 * (il&1);
         | 
| 1870 2078 | 
             
                device const uint8_t * s = xb->scales;
         | 
| 1871 2079 | 
             
                device const half2 * dh = (device const half2 *)xb->d;
         | 
| 1872 2080 | 
             
                const float2 d = (float2)dh[0];
         | 
| 1873 2081 | 
             
                const float dl = il<2 ? d[0] * (s[0]&0xF) : d[0] * (s[1]&0xF)/16.h;
         | 
| 1874 | 
            -
                const float ml = il<2 ? d[1] * (s[0]>>4)  : d[1  | 
| 2082 | 
            +
                const float ml = il<2 ? d[1] * (s[0]>>4)  : d[1] * (s[1]>>4);
         | 
| 1875 2083 | 
             
            #endif
         | 
| 1876 2084 | 
             
                const ushort mask = il<2 ? 0x0F : 0xF0;
         | 
| 1877 2085 | 
             
                for (int i = 0; i < 16; ++i) {
         | 
| @@ -1885,19 +2093,19 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg | |
| 1885 2093 | 
             
                device const uint8_t * qh = xb->qh;
         | 
| 1886 2094 |  | 
| 1887 2095 | 
             
            #if QK_K == 256
         | 
| 1888 | 
            -
                const float d = (float)(xb->d);
         | 
| 1889 | 
            -
                const float min = (float)(xb->dmin);
         | 
| 1890 2096 | 
             
                short is = (il/4) * 2;
         | 
| 1891 2097 | 
             
                q  = q + 32 * (il/4) + 16 * (il&1);
         | 
| 1892 2098 | 
             
                qh = qh + 16 * (il&1);
         | 
| 1893 2099 | 
             
                uint8_t ul = 1 << (il/2);
         | 
| 1894 | 
            -
                il = il | 
| 1895 | 
            -
                const  | 
| 1896 | 
            -
                const  | 
| 1897 | 
            -
                const  | 
| 2100 | 
            +
                il = il & 3;
         | 
| 2101 | 
            +
                const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
         | 
| 2102 | 
            +
                const half d = il < 2 ? xb->d : xb->d / 16.h;
         | 
| 2103 | 
            +
                const half min = xb->dmin;
         | 
| 2104 | 
            +
                const half dl = d * sc[0];
         | 
| 2105 | 
            +
                const half ml = min * sc[1];
         | 
| 1898 2106 |  | 
| 1899 | 
            -
                const ushort mask | 
| 1900 | 
            -
                const  | 
| 2107 | 
            +
                const ushort mask = il<2 ? 0x0F : 0xF0;
         | 
| 2108 | 
            +
                const half qh_val = il<2 ? 16.h : 256.h;
         | 
| 1901 2109 | 
             
                for (int i = 0; i < 16; ++i) {
         | 
| 1902 2110 | 
             
                    reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
         | 
| 1903 2111 | 
             
                }
         | 
| @@ -1916,7 +2124,7 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg | |
| 1916 2124 |  | 
| 1917 2125 | 
             
            template <typename type4x4>
         | 
| 1918 2126 | 
             
            void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg) {
         | 
| 1919 | 
            -
                const  | 
| 2127 | 
            +
                const half d_all = xb->d;
         | 
| 1920 2128 | 
             
                device const uint8_t * ql = (device const uint8_t *)xb->ql;
         | 
| 1921 2129 | 
             
                device const uint8_t * qh = (device const uint8_t *)xb->qh;
         | 
| 1922 2130 | 
             
                device const int8_t * scales = (device const int8_t *)xb->scales;
         | 
| @@ -1924,19 +2132,21 @@ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg | |
| 1924 2132 | 
             
            #if QK_K == 256
         | 
| 1925 2133 | 
             
                ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1);
         | 
| 1926 2134 | 
             
                qh = qh + 32*(il/8) + 16*(il&1);
         | 
| 1927 | 
            -
                 | 
| 1928 | 
            -
                il = (il/2) | 
| 2135 | 
            +
                half sc = scales[(il%2) + 2 * ((il/2))];
         | 
| 2136 | 
            +
                il = (il/2) & 3;
         | 
| 1929 2137 | 
             
            #else
         | 
| 1930 2138 | 
             
                ql = ql + 16 * (il&1);
         | 
| 1931 | 
            -
                 | 
| 2139 | 
            +
                half sc = scales[il];
         | 
| 1932 2140 | 
             
            #endif
         | 
| 2141 | 
            +
                const uint16_t  kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
         | 
| 2142 | 
            +
                const uint16_t  kmask2 = il>1 ? 0xF0              : 0x0F;
         | 
| 2143 | 
            +
                const half        coef = il>1 ? 1.f/16.h          : 1.h;
         | 
| 2144 | 
            +
                const half ml = d_all * sc * 32.h;
         | 
| 2145 | 
            +
                const half dl = d_all * sc * coef;
         | 
| 1933 2146 | 
             
                for (int i = 0; i < 16; ++i) {
         | 
| 1934 | 
            -
                     | 
| 1935 | 
            -
             | 
| 1936 | 
            -
                     | 
| 1937 | 
            -
                    float q = il&1 ? ((ql[i]&kmask2)|((qh[i]&kmask1)<<2)) - 32.f/coef : \
         | 
| 1938 | 
            -
                                     ((ql[i]&kmask2)|((qh[i]&kmask1)<<4)) - 32.f/coef;
         | 
| 1939 | 
            -
                    reg[i/4][i%4] = d_all * sc * q * coef;
         | 
| 2147 | 
            +
                    const half q = il&1 ? ((ql[i] & kmask2) | ((qh[i] & kmask1) << 2))
         | 
| 2148 | 
            +
                                        : ((ql[i] & kmask2) | ((qh[i] & kmask1) << 4));
         | 
| 2149 | 
            +
                    reg[i/4][i%4] = dl * q - ml;
         | 
| 1940 2150 | 
             
                }
         | 
| 1941 2151 | 
             
            }
         | 
| 1942 2152 |  | 
| @@ -1976,22 +2186,25 @@ kernel void kernel_get_rows( | |
| 1976 2186 | 
             
            // each block_q contains 16*nl weights
         | 
| 1977 2187 | 
             
            template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
         | 
| 1978 2188 | 
             
            kernel void kernel_mul_mm(device const  uchar * src0,
         | 
| 1979 | 
            -
             | 
| 1980 | 
            -
             | 
| 1981 | 
            -
             | 
| 1982 | 
            -
             | 
| 1983 | 
            -
             | 
| 1984 | 
            -
             | 
| 1985 | 
            -
             | 
| 1986 | 
            -
             | 
| 1987 | 
            -
             | 
| 1988 | 
            -
             | 
| 1989 | 
            -
             | 
| 1990 | 
            -
             | 
| 1991 | 
            -
             | 
| 1992 | 
            -
             | 
| 1993 | 
            -
             | 
| 1994 | 
            -
             | 
| 2189 | 
            +
                                      device const  uchar * src1,
         | 
| 2190 | 
            +
                                      device        float * dst,
         | 
| 2191 | 
            +
                                      constant    int64_t & ne00,
         | 
| 2192 | 
            +
                                      constant    int64_t & ne02,
         | 
| 2193 | 
            +
                                      constant    int64_t & nb01,
         | 
| 2194 | 
            +
                                      constant    int64_t & nb02,
         | 
| 2195 | 
            +
                                      constant    int64_t & ne12,
         | 
| 2196 | 
            +
                                      constant    int64_t & nb10,
         | 
| 2197 | 
            +
                                      constant    int64_t & nb11,
         | 
| 2198 | 
            +
                                      constant    int64_t & nb12,
         | 
| 2199 | 
            +
                                      constant    int64_t & ne0,
         | 
| 2200 | 
            +
                                      constant    int64_t & ne1,
         | 
| 2201 | 
            +
                                      constant       uint & gqa,
         | 
| 2202 | 
            +
                                      threadgroup   uchar * shared_memory [[threadgroup(0)]],
         | 
| 2203 | 
            +
                                      uint3                 tgpig[[threadgroup_position_in_grid]],
         | 
| 2204 | 
            +
                                      uint                  tiitg[[thread_index_in_threadgroup]],
         | 
| 2205 | 
            +
                                      uint                  sgitg[[simdgroup_index_in_threadgroup]]) {
         | 
| 2206 | 
            +
             | 
| 2207 | 
            +
                threadgroup half  * sa = (threadgroup half  *)(shared_memory);
         | 
| 1995 2208 | 
             
                threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
         | 
| 1996 2209 |  | 
| 1997 2210 | 
             
                const uint r0 = tgpig.y;
         | 
| @@ -2004,7 +2217,7 @@ kernel void kernel_mul_mm(device const  uchar * src0, | |
| 2004 2217 | 
             
                short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
         | 
| 2005 2218 | 
             
                short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
         | 
| 2006 2219 |  | 
| 2007 | 
            -
                simdgroup_half8x8 | 
| 2220 | 
            +
                simdgroup_half8x8  ma[4];
         | 
| 2008 2221 | 
             
                simdgroup_float8x8 mb[2];
         | 
| 2009 2222 | 
             
                simdgroup_float8x8 c_res[8];
         | 
| 2010 2223 | 
             
                for (int i = 0; i < 8; i++){
         | 
| @@ -2012,10 +2225,15 @@ kernel void kernel_mul_mm(device const  uchar * src0, | |
| 2012 2225 | 
             
                }
         | 
| 2013 2226 |  | 
| 2014 2227 | 
             
                short il = (tiitg % THREAD_PER_ROW);
         | 
| 2015 | 
            -
             | 
| 2016 | 
            -
                 | 
| 2017 | 
            -
                 | 
| 2018 | 
            -
             | 
| 2228 | 
            +
             | 
| 2229 | 
            +
                uint   offset0 = im/gqa*nb02;
         | 
| 2230 | 
            +
                ushort offset1 = il/nl;
         | 
| 2231 | 
            +
             | 
| 2232 | 
            +
                device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
         | 
| 2233 | 
            +
                device const float   * y = (device const float   *)(src1
         | 
| 2234 | 
            +
                    + nb12 * im
         | 
| 2235 | 
            +
                    + nb11 * (r1 * BLOCK_SIZE_N + thread_col)
         | 
| 2236 | 
            +
                    + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
         | 
| 2019 2237 |  | 
| 2020 2238 | 
             
                for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
         | 
| 2021 2239 | 
             
                    //load data and store to threadgroup memory
         | 
| @@ -2095,6 +2313,7 @@ kernel void kernel_mul_mm(device const  uchar * src0, | |
| 2095 2313 | 
             
            typedef void (get_rows_t)(device const void *, device const int *, device float *, constant int64_t &, \
         | 
| 2096 2314 | 
             
                                      constant uint64_t &, constant uint64_t &, uint, uint, uint);
         | 
| 2097 2315 |  | 
| 2316 | 
            +
            template [[host_name("kernel_get_rows_f32")]]  kernel get_rows_t kernel_get_rows<float4x4,   1, dequantize_f32>;
         | 
| 2098 2317 | 
             
            template [[host_name("kernel_get_rows_f16")]]  kernel get_rows_t kernel_get_rows<half4x4,    1, dequantize_f16>;
         | 
| 2099 2318 | 
             
            template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
         | 
| 2100 2319 | 
             
            template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
         | 
| @@ -2105,14 +2324,28 @@ template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows | |
| 2105 2324 | 
             
            template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows<block_q5_K, QK_NL, dequantize_q5_K>;
         | 
| 2106 2325 | 
             
            template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>;
         | 
| 2107 2326 |  | 
| 2108 | 
            -
            typedef void (mat_mm_t)( | 
| 2109 | 
            -
             | 
| 2110 | 
            -
             | 
| 2111 | 
            -
             | 
| 2112 | 
            -
             | 
| 2113 | 
            -
             | 
| 2114 | 
            -
             | 
| 2115 | 
            -
             | 
| 2327 | 
            +
            typedef void (mat_mm_t)(
         | 
| 2328 | 
            +
                    device const  uchar * src0,
         | 
| 2329 | 
            +
                    device const  uchar * src1,
         | 
| 2330 | 
            +
                    device        float * dst,
         | 
| 2331 | 
            +
                    constant    int64_t & ne00,
         | 
| 2332 | 
            +
                    constant    int64_t & ne02,
         | 
| 2333 | 
            +
                    constant    int64_t & nb01,
         | 
| 2334 | 
            +
                    constant    int64_t & nb02,
         | 
| 2335 | 
            +
                    constant    int64_t & ne12,
         | 
| 2336 | 
            +
                    constant    int64_t & nb10,
         | 
| 2337 | 
            +
                    constant    int64_t & nb11,
         | 
| 2338 | 
            +
                    constant    int64_t & nb12,
         | 
| 2339 | 
            +
                    constant    int64_t & ne0,
         | 
| 2340 | 
            +
                    constant    int64_t & ne1,
         | 
| 2341 | 
            +
                    constant       uint & gqa,
         | 
| 2342 | 
            +
                    threadgroup uchar *, uint3, uint, uint);
         | 
| 2343 | 
            +
             | 
| 2344 | 
            +
            template [[host_name("kernel_mul_mm_f32_f32")]]  kernel mat_mm_t kernel_mul_mm<float4x4,   1,     dequantize_f32>;
         | 
| 2345 | 
            +
            template [[host_name("kernel_mul_mm_f16_f32")]]  kernel mat_mm_t kernel_mul_mm<half4x4,    1,     dequantize_f16>;
         | 
| 2346 | 
            +
            template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2,     dequantize_q4_0>;
         | 
| 2347 | 
            +
            template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2,     dequantize_q4_1>;
         | 
| 2348 | 
            +
            template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2,     dequantize_q8_0>;
         | 
| 2116 2349 | 
             
            template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>;
         | 
| 2117 2350 | 
             
            template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>;
         | 
| 2118 2351 | 
             
            template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;
         |