llama_cpp 0.3.8 → 0.5.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +19 -0
- data/README.md +1 -1
- data/examples/chat.rb +4 -6
- data/ext/llama_cpp/extconf.rb +3 -3
- data/ext/llama_cpp/llama_cpp.cpp +129 -124
- data/ext/llama_cpp/src/ggml-alloc.c +90 -113
- data/ext/llama_cpp/src/ggml-alloc.h +1 -1
- data/ext/llama_cpp/src/ggml-cuda.cu +350 -77
- data/ext/llama_cpp/src/ggml-cuda.h +13 -0
- data/ext/llama_cpp/src/ggml-metal.h +4 -0
- data/ext/llama_cpp/src/ggml-metal.m +226 -121
- data/ext/llama_cpp/src/ggml-metal.metal +157 -35
- data/ext/llama_cpp/src/ggml.c +2724 -584
- data/ext/llama_cpp/src/ggml.h +282 -31
- data/ext/llama_cpp/src/k_quants.c +112 -56
- data/ext/llama_cpp/src/llama.cpp +4857 -2986
- data/ext/llama_cpp/src/llama.h +180 -126
- data/lib/llama_cpp/version.rb +2 -2
- data/lib/llama_cpp.rb +2 -2
- data/sig/llama_cpp.rbs +12 -11
- metadata +2 -2
| @@ -18,10 +18,16 @@ typedef struct { | |
| 18 18 | 
             
                uint8_t qs[QK4_1 / 2];  // nibbles / quants
         | 
| 19 19 | 
             
            } block_q4_1;
         | 
| 20 20 |  | 
| 21 | 
            +
            #define QK8_0 32
         | 
| 22 | 
            +
            typedef struct {
         | 
| 23 | 
            +
                half    d;         // delta
         | 
| 24 | 
            +
                int8_t  qs[QK8_0]; // quants
         | 
| 25 | 
            +
            } block_q8_0;
         | 
| 26 | 
            +
             | 
| 21 27 | 
             
            kernel void kernel_add(
         | 
| 22 | 
            -
                    device const  | 
| 23 | 
            -
                    device const  | 
| 24 | 
            -
                    device        | 
| 28 | 
            +
                    device const float4 * src0,
         | 
| 29 | 
            +
                    device const float4 * src1,
         | 
| 30 | 
            +
                    device       float4 * dst,
         | 
| 25 31 | 
             
                    uint tpig[[thread_position_in_grid]]) {
         | 
| 26 32 | 
             
                dst[tpig] = src0[tpig] + src1[tpig];
         | 
| 27 33 | 
             
            }
         | 
| @@ -29,18 +35,18 @@ kernel void kernel_add( | |
| 29 35 | 
             
            // assumption: src1 is a row
         | 
| 30 36 | 
             
            // broadcast src1 into src0
         | 
| 31 37 | 
             
            kernel void kernel_add_row(
         | 
| 32 | 
            -
                    device const  | 
| 33 | 
            -
                    device const  | 
| 34 | 
            -
                    device        | 
| 35 | 
            -
                    constant   int64_t &  | 
| 38 | 
            +
                    device const float4 * src0,
         | 
| 39 | 
            +
                    device const float4 * src1,
         | 
| 40 | 
            +
                    device       float4 * dst,
         | 
| 41 | 
            +
                    constant   int64_t & nb,
         | 
| 36 42 | 
             
                    uint tpig[[thread_position_in_grid]]) {
         | 
| 37 | 
            -
                dst[tpig] = src0[tpig] + src1[tpig %  | 
| 43 | 
            +
                dst[tpig] = src0[tpig] + src1[tpig % nb];
         | 
| 38 44 | 
             
            }
         | 
| 39 45 |  | 
| 40 46 | 
             
            kernel void kernel_mul(
         | 
| 41 | 
            -
                    device const  | 
| 42 | 
            -
                    device const  | 
| 43 | 
            -
                    device        | 
| 47 | 
            +
                    device const float4 * src0,
         | 
| 48 | 
            +
                    device const float4 * src1,
         | 
| 49 | 
            +
                    device       float4 * dst,
         | 
| 44 50 | 
             
                    uint tpig[[thread_position_in_grid]]) {
         | 
| 45 51 | 
             
                dst[tpig] = src0[tpig] * src1[tpig];
         | 
| 46 52 | 
             
            }
         | 
| @@ -48,12 +54,12 @@ kernel void kernel_mul( | |
| 48 54 | 
             
            // assumption: src1 is a row
         | 
| 49 55 | 
             
            // broadcast src1 into src0
         | 
| 50 56 | 
             
            kernel void kernel_mul_row(
         | 
| 51 | 
            -
                    device const  | 
| 52 | 
            -
                    device const  | 
| 53 | 
            -
                    device        | 
| 54 | 
            -
                    constant | 
| 57 | 
            +
                    device const float4 * src0,
         | 
| 58 | 
            +
                    device const float4 * src1,
         | 
| 59 | 
            +
                    device       float4 * dst,
         | 
| 60 | 
            +
                    constant    int64_t & nb,
         | 
| 55 61 | 
             
                    uint tpig[[thread_position_in_grid]]) {
         | 
| 56 | 
            -
                dst[tpig] = src0[tpig] * src1[tpig %  | 
| 62 | 
            +
                dst[tpig] = src0[tpig] * src1[tpig % nb];
         | 
| 57 63 | 
             
            }
         | 
| 58 64 |  | 
| 59 65 | 
             
            kernel void kernel_scale(
         | 
| @@ -87,7 +93,12 @@ kernel void kernel_gelu( | |
| 87 93 | 
             
                device       float * dst,
         | 
| 88 94 | 
             
                uint tpig[[thread_position_in_grid]]) {
         | 
| 89 95 | 
             
                float x = src0[tpig];
         | 
| 90 | 
            -
             | 
| 96 | 
            +
             | 
| 97 | 
            +
                // BEWARE !!!
         | 
| 98 | 
            +
                // Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
         | 
| 99 | 
            +
                // This was observed with Falcon 7B and 40B models
         | 
| 100 | 
            +
                //
         | 
| 101 | 
            +
                dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
         | 
| 91 102 | 
             
            }
         | 
| 92 103 |  | 
| 93 104 | 
             
            kernel void kernel_soft_max(
         | 
| @@ -352,7 +363,7 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device | |
| 352 363 | 
             
                const int first_row = (r0 * nsg + sgitg) * nr;
         | 
| 353 364 | 
             
                const uint offset0 = first_row * nb + im/gqa*(nb*ne0);
         | 
| 354 365 | 
             
                device const block_q_type * x = (device const block_q_type *) src0 + offset0;
         | 
| 355 | 
            -
                device const float | 
| 366 | 
            +
                device const float        * y = (device const float        *) src1 + r1*ne10 + im*ne00*ne1;
         | 
| 356 367 | 
             
                float yl[16];       // src1 vector cache
         | 
| 357 368 | 
             
                float sumf[nr]={0.f};
         | 
| 358 369 |  | 
| @@ -424,6 +435,68 @@ kernel void kernel_mul_mat_q4_1_f32( | |
| 424 435 | 
             
                 mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
         | 
| 425 436 | 
             
            }
         | 
| 426 437 |  | 
| 438 | 
            +
            kernel void kernel_mul_mat_q8_0_f32(
         | 
| 439 | 
            +
                    device const  void * src0,
         | 
| 440 | 
            +
                    device const float * src1,
         | 
| 441 | 
            +
                    device       float * dst,
         | 
| 442 | 
            +
                    constant   int64_t & ne00,
         | 
| 443 | 
            +
                    constant   int64_t & ne01[[buffer(4)]],
         | 
| 444 | 
            +
                    constant   int64_t & ne02[[buffer(5)]],
         | 
| 445 | 
            +
                    constant   int64_t & ne10[[buffer(9)]],
         | 
| 446 | 
            +
                    constant   int64_t & ne12[[buffer(11)]],
         | 
| 447 | 
            +
                    constant   int64_t & ne0[[buffer(15)]],
         | 
| 448 | 
            +
                    constant   int64_t & ne1[[buffer(16)]],
         | 
| 449 | 
            +
                    constant   uint    & gqa[[buffer(17)]],
         | 
| 450 | 
            +
                    uint3 tgpig[[threadgroup_position_in_grid]],
         | 
| 451 | 
            +
                    uint tiisg[[thread_index_in_simdgroup]],
         | 
| 452 | 
            +
                    uint sgitg[[simdgroup_index_in_threadgroup]]) {
         | 
| 453 | 
            +
                const int nr  = N_DST;
         | 
| 454 | 
            +
                const int nsg = N_SIMDGROUP;
         | 
| 455 | 
            +
                const int nw  = N_SIMDWIDTH;
         | 
| 456 | 
            +
             | 
| 457 | 
            +
                const int nb = ne00/QK8_0;
         | 
| 458 | 
            +
                const int r0 = tgpig.x;
         | 
| 459 | 
            +
                const int r1 = tgpig.y;
         | 
| 460 | 
            +
                const int im = tgpig.z;
         | 
| 461 | 
            +
                const int first_row = (r0 * nsg + sgitg) * nr;
         | 
| 462 | 
            +
                const uint offset0 = first_row * nb + im/gqa*(nb*ne0);
         | 
| 463 | 
            +
                device const block_q8_0 * x = (device const block_q8_0 *) src0 + offset0;
         | 
| 464 | 
            +
                device const float      * y = (device const float      *) src1 + r1*ne10 + im*ne00*ne1;
         | 
| 465 | 
            +
             | 
| 466 | 
            +
                float yl[16];
         | 
| 467 | 
            +
                float sumf[nr]={0.f};
         | 
| 468 | 
            +
             | 
| 469 | 
            +
                const int ix = tiisg/2;
         | 
| 470 | 
            +
                const int il = tiisg%2;
         | 
| 471 | 
            +
             | 
| 472 | 
            +
                device const float * yb = y + ix * QK8_0 + 16*il;
         | 
| 473 | 
            +
             | 
| 474 | 
            +
                // each thread in a SIMD group deals with half a block.
         | 
| 475 | 
            +
                for (int ib = ix; ib < nb; ib += nw/2) {
         | 
| 476 | 
            +
                    for (int i = 0; i < 16; ++i) {
         | 
| 477 | 
            +
                        yl[i] = yb[i];
         | 
| 478 | 
            +
                    }
         | 
| 479 | 
            +
             | 
| 480 | 
            +
                    for (int row = 0; row < nr; row++) {
         | 
| 481 | 
            +
                        device const int8_t * qs = x[ib+row*nb].qs + 16*il;
         | 
| 482 | 
            +
                        float sumq = 0.f;
         | 
| 483 | 
            +
                        for (int iq = 0; iq < 16; ++iq) {
         | 
| 484 | 
            +
                            sumq += qs[iq] * yl[iq];
         | 
| 485 | 
            +
                        }
         | 
| 486 | 
            +
                        sumf[row] += sumq*x[ib+row*nb].d;
         | 
| 487 | 
            +
                    }
         | 
| 488 | 
            +
             | 
| 489 | 
            +
                    yb += QK8_0 * 16;
         | 
| 490 | 
            +
                }
         | 
| 491 | 
            +
             | 
| 492 | 
            +
                for (int row = 0; row < nr; ++row) {
         | 
| 493 | 
            +
                    const float tot = simd_sum(sumf[row]);
         | 
| 494 | 
            +
                    if (tiisg == 0 && first_row + row < ne01) {
         | 
| 495 | 
            +
                        dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot;
         | 
| 496 | 
            +
                    }
         | 
| 497 | 
            +
                }
         | 
| 498 | 
            +
            }
         | 
| 499 | 
            +
             | 
| 427 500 | 
             
            kernel void kernel_mul_mat_f16_f32(
         | 
| 428 501 | 
             
                    device const  char * src0,
         | 
| 429 502 | 
             
                    device const  char * src1,
         | 
| @@ -455,26 +528,43 @@ kernel void kernel_mul_mat_f16_f32( | |
| 455 528 | 
             
                device const half  * x = (device const half  *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
         | 
| 456 529 | 
             
                device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
         | 
| 457 530 |  | 
| 458 | 
            -
                 | 
| 531 | 
            +
                uint ith = tpitg.x;
         | 
| 532 | 
            +
                uint nth = tptg.x;
         | 
| 459 533 |  | 
| 460 | 
            -
                 | 
| 461 | 
            -
             | 
| 534 | 
            +
                sum[ith] = 0.0f;
         | 
| 535 | 
            +
             | 
| 536 | 
            +
                for (int i = ith; i < ne00; i += nth) {
         | 
| 537 | 
            +
                    sum[ith] += (float) x[i] * (float) y[i];
         | 
| 462 538 | 
             
                }
         | 
| 463 539 |  | 
| 464 540 | 
             
                // accumulate the sum from all threads in the threadgroup
         | 
| 465 541 | 
             
                threadgroup_barrier(mem_flags::mem_threadgroup);
         | 
| 466 | 
            -
                 | 
| 467 | 
            -
                     | 
| 468 | 
            -
                        sum[tpitg.x] += sum[tpitg.x + i];
         | 
| 469 | 
            -
                    }
         | 
| 470 | 
            -
                    threadgroup_barrier(mem_flags::mem_threadgroup);
         | 
| 542 | 
            +
                if (ith%4 == 0) {
         | 
| 543 | 
            +
                    for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
         | 
| 471 544 | 
             
                }
         | 
| 472 | 
            -
             | 
| 473 | 
            -
                if ( | 
| 545 | 
            +
                threadgroup_barrier(mem_flags::mem_threadgroup);
         | 
| 546 | 
            +
                if (ith%16 == 0) {
         | 
| 547 | 
            +
                    for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
         | 
| 548 | 
            +
                }
         | 
| 549 | 
            +
                threadgroup_barrier(mem_flags::mem_threadgroup);
         | 
| 550 | 
            +
                if (ith == 0) {
         | 
| 551 | 
            +
                    for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
         | 
| 474 552 | 
             
                    dst[im*ne1*ne0 + r1*ne0 + r0] = sum[0];
         | 
| 475 553 | 
             
                }
         | 
| 476 | 
            -
            }
         | 
| 477 554 |  | 
| 555 | 
            +
                // Original implementation. Left behind commented out for now
         | 
| 556 | 
            +
                //threadgroup_barrier(mem_flags::mem_threadgroup);
         | 
| 557 | 
            +
                //for (uint i = tptg.x/2; i > 0; i /= 2) {
         | 
| 558 | 
            +
                //    if (tpitg.x < i) {
         | 
| 559 | 
            +
                //        sum[tpitg.x] += sum[tpitg.x + i];
         | 
| 560 | 
            +
                //    }
         | 
| 561 | 
            +
                //    threadgroup_barrier(mem_flags::mem_threadgroup);
         | 
| 562 | 
            +
                //}
         | 
| 563 | 
            +
                //
         | 
| 564 | 
            +
                //if (tpitg.x == 0) {
         | 
| 565 | 
            +
                //    dst[im*ne1*ne0 + r1*ne0 + r0] = sum[0];
         | 
| 566 | 
            +
                //}
         | 
| 567 | 
            +
            }
         | 
| 478 568 |  | 
| 479 569 | 
             
            kernel void kernel_alibi_f32(
         | 
| 480 570 | 
             
                    device const float * src0,
         | 
| @@ -571,7 +661,25 @@ kernel void kernel_rope( | |
| 571 661 | 
             
                        dst_data[1] = x0*sin_theta + x1*cos_theta;
         | 
| 572 662 | 
             
                    }
         | 
| 573 663 | 
             
                } else {
         | 
| 574 | 
            -
                     | 
| 664 | 
            +
                    for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
         | 
| 665 | 
            +
                        for (int64_t ic = 0; ic < n_dims; ic += 2) {
         | 
| 666 | 
            +
                            const float cos_theta = cos(theta);
         | 
| 667 | 
            +
                            const float sin_theta = sin(theta);
         | 
| 668 | 
            +
             | 
| 669 | 
            +
                            theta *= theta_scale;
         | 
| 670 | 
            +
             | 
| 671 | 
            +
                            const int64_t i0 = ib*n_dims + ic/2;
         | 
| 672 | 
            +
             | 
| 673 | 
            +
                            device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
         | 
| 674 | 
            +
                            device       float * dst_data  = (device float *)((device char *)  dst + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
         | 
| 675 | 
            +
             | 
| 676 | 
            +
                            const float x0 = src[0];
         | 
| 677 | 
            +
                            const float x1 = src[n_dims/2];
         | 
| 678 | 
            +
             | 
| 679 | 
            +
                            dst_data[0]        = x0*cos_theta - x1*sin_theta;
         | 
| 680 | 
            +
                            dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
         | 
| 681 | 
            +
                        }
         | 
| 682 | 
            +
                    }
         | 
| 575 683 | 
             
                }
         | 
| 576 684 | 
             
            }
         | 
| 577 685 |  | 
| @@ -1598,12 +1706,12 @@ template <typename type4x4> | |
| 1598 1706 | 
             
            void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
         | 
| 1599 1707 | 
             
                device const uint16_t * qs = ((device const uint16_t *)xb + 1);
         | 
| 1600 1708 | 
             
                const half d = il ? (xb->d / 16.h) : xb->d;
         | 
| 1601 | 
            -
                const half m = il ? (-8.h * 16.h) : -8.h;
         | 
| 1709 | 
            +
                const half m = il ? ( -8.h * 16.h) : -8.h;
         | 
| 1602 1710 | 
             
                const ushort mask0 = il ? 0x00F0 : 0x000F;
         | 
| 1603 1711 | 
             
                const ushort mask1 = il ? 0xF000 : 0x0F00;
         | 
| 1604 1712 |  | 
| 1605 1713 | 
             
                for (int i=0;i<8;i++) {
         | 
| 1606 | 
            -
                    reg[i/2][2*(i%2)] | 
| 1714 | 
            +
                    reg[i/2][2*(i%2)]   = (((qs[i] & mask0)     ) + m) * d;
         | 
| 1607 1715 | 
             
                    reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) + m) * d;
         | 
| 1608 1716 | 
             
                }
         | 
| 1609 1717 | 
             
            }
         | 
| @@ -1617,11 +1725,21 @@ void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg | |
| 1617 1725 | 
             
                const ushort mask1 = il ? 0xF000 : 0x0F00;
         | 
| 1618 1726 |  | 
| 1619 1727 | 
             
                for (int i=0;i<8;i++) {
         | 
| 1620 | 
            -
                    reg[i/2][2*(i%2)] | 
| 1728 | 
            +
                    reg[i/2][2*(i%2)]   = (((qs[i] & mask0)     ) * d) + m;
         | 
| 1621 1729 | 
             
                    reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) * d) + m;
         | 
| 1622 1730 | 
             
                }
         | 
| 1623 1731 | 
             
            }
         | 
| 1624 1732 |  | 
| 1733 | 
            +
            template <typename type4x4>
         | 
| 1734 | 
            +
            void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) {
         | 
| 1735 | 
            +
                device const int8_t * qs = ((device const int8_t *)xb->qs);
         | 
| 1736 | 
            +
                const half d = xb->d;
         | 
| 1737 | 
            +
             | 
| 1738 | 
            +
                for (int i=0;i<16;i++) {
         | 
| 1739 | 
            +
                    reg[i/4][i%4] = (qs[i + 16*il] * d);
         | 
| 1740 | 
            +
                }
         | 
| 1741 | 
            +
            }
         | 
| 1742 | 
            +
             | 
| 1625 1743 | 
             
            template <typename type4x4>
         | 
| 1626 1744 | 
             
            void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
         | 
| 1627 1745 | 
             
                const half d = xb->d;
         | 
| @@ -1850,6 +1968,7 @@ kernel void kernel_mul_mm(device const  uchar * src0, | |
| 1850 1968 | 
             
                    //load data and store to threadgroup memory
         | 
| 1851 1969 | 
             
                    half4x4 temp_a;
         | 
| 1852 1970 | 
             
                    dequantize_func(x, il, temp_a);
         | 
| 1971 | 
            +
                    threadgroup_barrier(mem_flags::mem_threadgroup);
         | 
| 1853 1972 | 
             
                    #pragma unroll(16)
         | 
| 1854 1973 | 
             
                    for (int i = 0; i < 16; i++) {
         | 
| 1855 1974 | 
             
                        *(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
         | 
| @@ -1895,6 +2014,7 @@ kernel void kernel_mul_mm(device const  uchar * src0, | |
| 1895 2014 | 
             
                    }
         | 
| 1896 2015 | 
             
                } else {
         | 
| 1897 2016 | 
             
                    // block is smaller than 64x32, we should avoid writing data outside of the matrix
         | 
| 2017 | 
            +
                    threadgroup_barrier(mem_flags::mem_threadgroup);
         | 
| 1898 2018 | 
             
                    threadgroup float *temp_str = ((threadgroup float *)shared_memory) \
         | 
| 1899 2019 | 
             
                                                  + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
         | 
| 1900 2020 | 
             
                    for (int i = 0; i < 8; i++) {
         | 
| @@ -1922,9 +2042,10 @@ kernel void kernel_mul_mm(device const  uchar * src0, | |
| 1922 2042 | 
             
            typedef void (get_rows_t)(device const void *, device const int *, device float *, constant int64_t &, \
         | 
| 1923 2043 | 
             
                                      constant uint64_t &, constant uint64_t &, uint, uint, uint);
         | 
| 1924 2044 |  | 
| 1925 | 
            -
            template [[host_name("kernel_get_rows_f16")]] | 
| 2045 | 
            +
            template [[host_name("kernel_get_rows_f16")]]  kernel get_rows_t kernel_get_rows<half4x4,    1, dequantize_f16>;
         | 
| 1926 2046 | 
             
            template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
         | 
| 1927 2047 | 
             
            template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
         | 
| 2048 | 
            +
            template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_t kernel_get_rows<block_q8_0, 2, dequantize_q8_0>;
         | 
| 1928 2049 | 
             
            template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_t kernel_get_rows<block_q2_K, QK_NL, dequantize_q2_K>;
         | 
| 1929 2050 | 
             
            template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows<block_q3_K, QK_NL, dequantize_q3_K>;
         | 
| 1930 2051 | 
             
            template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows<block_q4_K, QK_NL, dequantize_q4_K>;
         | 
| @@ -1935,9 +2056,10 @@ typedef void (mat_mm_t)(device const uchar *, device const float *, device float | |
| 1935 2056 | 
             
                                         constant int64_t &, constant int64_t &, constant int64_t &, constant int64_t &, \
         | 
| 1936 2057 | 
             
                                         constant int64_t &, constant int64_t &, constant uint &, threadgroup uchar *, uint3, uint, uint);
         | 
| 1937 2058 |  | 
| 1938 | 
            -
            template [[host_name("kernel_mul_mm_f16_f32")]] | 
| 2059 | 
            +
            template [[host_name("kernel_mul_mm_f16_f32")]]  kernel mat_mm_t kernel_mul_mm<half4x4,    1, dequantize_f16>;
         | 
| 1939 2060 | 
             
            template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
         | 
| 1940 2061 | 
             
            template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;
         | 
| 2062 | 
            +
            template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2, dequantize_q8_0>;
         | 
| 1941 2063 | 
             
            template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>;
         | 
| 1942 2064 | 
             
            template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>;
         | 
| 1943 2065 | 
             
            template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;
         |