llama_cpp 0.7.1 → 0.9.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +18 -0
- data/examples/chat.rb +8 -6
- data/ext/llama_cpp/extconf.rb +2 -2
- data/ext/llama_cpp/llama_cpp.cpp +122 -183
- data/ext/llama_cpp/src/ggml-cuda.cu +188 -20
- data/ext/llama_cpp/src/ggml-metal.m +57 -8
- data/ext/llama_cpp/src/ggml-metal.metal +171 -2
- data/ext/llama_cpp/src/ggml-opencl.cpp +188 -222
- data/ext/llama_cpp/src/ggml.c +375 -93
- data/ext/llama_cpp/src/ggml.h +11 -9
- data/ext/llama_cpp/src/k_quants.c +12 -20
- data/ext/llama_cpp/src/llama.cpp +459 -153
- data/ext/llama_cpp/src/llama.h +34 -33
- data/lib/llama_cpp/version.rb +2 -2
- data/lib/llama_cpp.rb +4 -4
- data/sig/llama_cpp.rbs +15 -16
- metadata +3 -3
| @@ -29,6 +29,8 @@ | |
| 29 29 | 
             
            #define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
         | 
| 30 30 | 
             
            #define cublasCreate hipblasCreate
         | 
| 31 31 | 
             
            #define cublasGemmEx hipblasGemmEx
         | 
| 32 | 
            +
            #define cublasGemmBatchedEx hipblasGemmBatchedEx
         | 
| 33 | 
            +
            #define cublasGemmStridedBatchedEx hipblasGemmStridedBatchedEx
         | 
| 32 34 | 
             
            #define cublasHandle_t hipblasHandle_t
         | 
| 33 35 | 
             
            #define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS
         | 
| 34 36 | 
             
            #define cublasSetStream hipblasSetStream
         | 
| @@ -4326,13 +4328,13 @@ static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous | |
| 4326 4328 |  | 
| 4327 4329 | 
             
                const half * x = (const half *) vx;
         | 
| 4328 4330 |  | 
| 4329 | 
            -
                const int row_x | 
| 4330 | 
            -
                const int channel | 
| 4331 | 
            +
                const int row_x     = blockDim.y*blockIdx.y + threadIdx.y;
         | 
| 4332 | 
            +
                const int channel   = blockDim.z*blockIdx.z + threadIdx.z;
         | 
| 4331 4333 | 
             
                const int channel_x = channel / channel_x_divisor;
         | 
| 4332 4334 |  | 
| 4333 | 
            -
                const int nrows_y | 
| 4335 | 
            +
                const int nrows_y   = ncols_x;
         | 
| 4334 4336 | 
             
                const int nrows_dst = nrows_x;
         | 
| 4335 | 
            -
                const int row_dst | 
| 4337 | 
            +
                const int row_dst   = row_x;
         | 
| 4336 4338 |  | 
| 4337 4339 | 
             
                const int idst = channel*nrows_dst + row_dst;
         | 
| 4338 4340 |  | 
| @@ -4345,13 +4347,13 @@ static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous | |
| 4345 4347 | 
             
                        break;
         | 
| 4346 4348 | 
             
                    }
         | 
| 4347 4349 |  | 
| 4348 | 
            -
                    const int ix = channel_x*channel_stride_x + row_x*row_stride_x + col_x;
         | 
| 4349 | 
            -
                    const float xi = __half2float(x[ix]);
         | 
| 4350 | 
            -
             | 
| 4351 4350 | 
             
                    const int row_y = col_x;
         | 
| 4352 4351 |  | 
| 4352 | 
            +
                    const int ix = channel_x*channel_stride_x + row_x*row_stride_x + col_x;
         | 
| 4353 4353 | 
             
                    const int iy = channel*nrows_y + row_y;
         | 
| 4354 4354 |  | 
| 4355 | 
            +
                    const float xi = __half2float(x[ix]);
         | 
| 4356 | 
            +
             | 
| 4355 4357 | 
             
                    tmp += xi * y[iy];
         | 
| 4356 4358 | 
             
                }
         | 
| 4357 4359 |  | 
| @@ -5662,10 +5664,10 @@ void ggml_init_cublas() { | |
| 5662 5664 | 
             
                    GGML_ASSERT(g_device_count <= GGML_CUDA_MAX_DEVICES);
         | 
| 5663 5665 | 
             
                    int64_t total_vram = 0;
         | 
| 5664 5666 | 
             
                    fprintf(stderr, "%s: found %d " GGML_CUDA_NAME " devices:\n", __func__, g_device_count);
         | 
| 5665 | 
            -
                    for ( | 
| 5667 | 
            +
                    for (int id = 0; id < g_device_count; ++id) {
         | 
| 5666 5668 | 
             
                        cudaDeviceProp prop;
         | 
| 5667 5669 | 
             
                        CUDA_CHECK(cudaGetDeviceProperties(&prop, id));
         | 
| 5668 | 
            -
                        fprintf(stderr, "  Device % | 
| 5670 | 
            +
                        fprintf(stderr, "  Device %d: %s, compute capability %d.%d\n", id, prop.name, prop.major, prop.minor);
         | 
| 5669 5671 |  | 
| 5670 5672 | 
             
                        g_tensor_split[id] = total_vram;
         | 
| 5671 5673 | 
             
                        total_vram += prop.totalGlobalMem;
         | 
| @@ -5675,15 +5677,15 @@ void ggml_init_cublas() { | |
| 5675 5677 | 
             
                        g_compute_capabilities[id] = 100*prop.major + 10*prop.minor;
         | 
| 5676 5678 | 
             
            #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
         | 
| 5677 5679 | 
             
                    }
         | 
| 5678 | 
            -
                    for ( | 
| 5680 | 
            +
                    for (int id = 0; id < g_device_count; ++id) {
         | 
| 5679 5681 | 
             
                        g_tensor_split[id] /= total_vram;
         | 
| 5680 5682 | 
             
                    }
         | 
| 5681 5683 |  | 
| 5682 | 
            -
                    for ( | 
| 5684 | 
            +
                    for (int id = 0; id < g_device_count; ++id) {
         | 
| 5683 5685 | 
             
                        CUDA_CHECK(ggml_cuda_set_device(id));
         | 
| 5684 5686 |  | 
| 5685 5687 | 
             
                        // create cuda streams
         | 
| 5686 | 
            -
                        for ( | 
| 5688 | 
            +
                        for (int is = 0; is < MAX_STREAMS; ++is) {
         | 
| 5687 5689 | 
             
                            CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStreams[id][is], cudaStreamNonBlocking));
         | 
| 5688 5690 | 
             
                        }
         | 
| 5689 5691 |  | 
| @@ -6252,16 +6254,15 @@ inline void ggml_cuda_op_mul_mat_cublas( | |
| 6252 6254 | 
             
                const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
         | 
| 6253 6255 | 
             
                const int64_t src1_padded_row_size, const cudaStream_t & stream) {
         | 
| 6254 6256 |  | 
| 6255 | 
            -
                GGML_ASSERT(src0_dd_i | 
| 6257 | 
            +
                GGML_ASSERT(src0_dd_i  != nullptr);
         | 
| 6256 6258 | 
             
                GGML_ASSERT(src1_ddf_i != nullptr);
         | 
| 6257 | 
            -
                GGML_ASSERT(dst_dd_i | 
| 6258 | 
            -
             | 
| 6259 | 
            +
                GGML_ASSERT(dst_dd_i   != nullptr);
         | 
| 6259 6260 |  | 
| 6260 6261 | 
             
                const int64_t ne00 = src0->ne[0];
         | 
| 6261 | 
            -
             | 
| 6262 6262 | 
             
                const int64_t ne10 = src1->ne[0];
         | 
| 6263 6263 |  | 
| 6264 6264 | 
             
                const int64_t ne0 = dst->ne[0];
         | 
| 6265 | 
            +
             | 
| 6265 6266 | 
             
                const int64_t row_diff = row_high - row_low;
         | 
| 6266 6267 |  | 
| 6267 6268 | 
             
                int id;
         | 
| @@ -7013,7 +7014,8 @@ static void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tens | |
| 7013 7014 | 
             
            }
         | 
| 7014 7015 |  | 
| 7015 7016 | 
             
            static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){
         | 
| 7016 | 
            -
                GGML_ASSERT(! | 
| 7017 | 
            +
                GGML_ASSERT(!ggml_is_transposed(src0));
         | 
| 7018 | 
            +
                GGML_ASSERT(!ggml_is_transposed(src1));
         | 
| 7017 7019 | 
             
                GGML_ASSERT(!ggml_is_permuted(src0));
         | 
| 7018 7020 | 
             
                GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT);
         | 
| 7019 7021 | 
             
                GGML_ASSERT(src0->type == GGML_TYPE_F16);
         | 
| @@ -7023,11 +7025,11 @@ static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor | |
| 7023 7025 | 
             
                const int64_t ne01 = src0->ne[1];
         | 
| 7024 7026 | 
             
                const int64_t ne02 = src0->ne[2];
         | 
| 7025 7027 |  | 
| 7026 | 
            -
                const int64_t ne12 = src1->ne[2];
         | 
| 7027 | 
            -
             | 
| 7028 7028 | 
             
                const int64_t nb01 = src0->nb[1];
         | 
| 7029 7029 | 
             
                const int64_t nb02 = src0->nb[2];
         | 
| 7030 7030 |  | 
| 7031 | 
            +
                const int64_t ne12 = src1->ne[2];
         | 
| 7032 | 
            +
             | 
| 7031 7033 | 
             
                CUDA_CHECK(ggml_cuda_set_device(g_main_device));
         | 
| 7032 7034 | 
             
                cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
         | 
| 7033 7035 |  | 
| @@ -7046,6 +7048,159 @@ static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor | |
| 7046 7048 | 
             
                ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, main_stream);
         | 
| 7047 7049 | 
             
            }
         | 
| 7048 7050 |  | 
| 7051 | 
            +
            static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){
         | 
| 7052 | 
            +
                GGML_ASSERT(!ggml_is_transposed(src0));
         | 
| 7053 | 
            +
                GGML_ASSERT(!ggml_is_transposed(src1));
         | 
| 7054 | 
            +
                GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT);
         | 
| 7055 | 
            +
                GGML_ASSERT(src0->type == GGML_TYPE_F16);
         | 
| 7056 | 
            +
                GGML_ASSERT(src1->type == GGML_TYPE_F32);
         | 
| 7057 | 
            +
             | 
| 7058 | 
            +
                const int64_t ne00 = src0->ne[0]; GGML_UNUSED(ne00);
         | 
| 7059 | 
            +
                const int64_t ne01 = src0->ne[1];
         | 
| 7060 | 
            +
                const int64_t ne02 = src0->ne[2];
         | 
| 7061 | 
            +
                const int64_t ne03 = src0->ne[3];
         | 
| 7062 | 
            +
             | 
| 7063 | 
            +
                const int64_t nb01 = src0->nb[1];
         | 
| 7064 | 
            +
                const int64_t nb02 = src0->nb[2]; GGML_UNUSED(nb02);
         | 
| 7065 | 
            +
                const int64_t nb03 = src0->nb[3]; GGML_UNUSED(nb03);
         | 
| 7066 | 
            +
             | 
| 7067 | 
            +
                const int64_t ne10 = src1->ne[0];
         | 
| 7068 | 
            +
                const int64_t ne11 = src1->ne[1];
         | 
| 7069 | 
            +
                const int64_t ne12 = src1->ne[2];
         | 
| 7070 | 
            +
                const int64_t ne13 = src1->ne[3];
         | 
| 7071 | 
            +
             | 
| 7072 | 
            +
                const int64_t nb11 = src1->nb[1];
         | 
| 7073 | 
            +
                const int64_t nb12 = src1->nb[2]; GGML_UNUSED(nb12);
         | 
| 7074 | 
            +
                const int64_t nb13 = src1->nb[3]; GGML_UNUSED(nb13);
         | 
| 7075 | 
            +
             | 
| 7076 | 
            +
                const int64_t ne1 = ggml_nelements(src1);
         | 
| 7077 | 
            +
                const int64_t ne  = ggml_nelements(dst);
         | 
| 7078 | 
            +
             | 
| 7079 | 
            +
                CUDA_CHECK(ggml_cuda_set_device(g_main_device));
         | 
| 7080 | 
            +
                cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
         | 
| 7081 | 
            +
             | 
| 7082 | 
            +
                int id;
         | 
| 7083 | 
            +
                CUDA_CHECK(cudaGetDevice(&id));
         | 
| 7084 | 
            +
                CUBLAS_CHECK(cublasSetStream(g_cublas_handles[id], main_stream));
         | 
| 7085 | 
            +
             | 
| 7086 | 
            +
                ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
         | 
| 7087 | 
            +
                void * src0_ddq = src0_extra->data_device[g_main_device];
         | 
| 7088 | 
            +
                half * src0_as_f16 = (half *) src0_ddq;
         | 
| 7089 | 
            +
             | 
| 7090 | 
            +
                ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
         | 
| 7091 | 
            +
                float * src1_ddf = (float *) src1_extra->data_device[g_main_device];
         | 
| 7092 | 
            +
             | 
| 7093 | 
            +
                ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
         | 
| 7094 | 
            +
                float * dst_ddf = (float *) dst_extra->data_device[g_main_device];
         | 
| 7095 | 
            +
             | 
| 7096 | 
            +
                // convert src1 to fp16
         | 
| 7097 | 
            +
                const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
         | 
| 7098 | 
            +
                GGML_ASSERT(to_fp16_cuda != nullptr);
         | 
| 7099 | 
            +
             | 
| 7100 | 
            +
                size_t src1_as = 0;
         | 
| 7101 | 
            +
                half * src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne1 * sizeof(half), &src1_as);
         | 
| 7102 | 
            +
                to_fp16_cuda(src1_ddf, src1_as_f16, ne1, main_stream);
         | 
| 7103 | 
            +
             | 
| 7104 | 
            +
                size_t dst_as = 0;
         | 
| 7105 | 
            +
                half * dst_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &dst_as);
         | 
| 7106 | 
            +
             | 
| 7107 | 
            +
                GGML_ASSERT(ne12 % ne02 == 0);
         | 
| 7108 | 
            +
                GGML_ASSERT(ne13 % ne03 == 0);
         | 
| 7109 | 
            +
             | 
| 7110 | 
            +
                // broadcast factors
         | 
| 7111 | 
            +
                const int64_t r2 = ne12/ne02;
         | 
| 7112 | 
            +
                const int64_t r3 = ne13/ne03;
         | 
| 7113 | 
            +
             | 
| 7114 | 
            +
                const half alpha_f16 = 1.0f;
         | 
| 7115 | 
            +
                const half beta_f16  = 0.0f;
         | 
| 7116 | 
            +
             | 
| 7117 | 
            +
            #if 0
         | 
| 7118 | 
            +
                // use cublasGemmEx
         | 
| 7119 | 
            +
                {
         | 
| 7120 | 
            +
                    for (int i13 = 0; i13 < ne13; ++i13) {
         | 
| 7121 | 
            +
                        for (int i12 = 0; i12 < ne12; ++i12) {
         | 
| 7122 | 
            +
                            int i03 = i13 / r3;
         | 
| 7123 | 
            +
                            int i02 = i12 / r2;
         | 
| 7124 | 
            +
             | 
| 7125 | 
            +
                            CUBLAS_CHECK(
         | 
| 7126 | 
            +
                                    cublasGemmEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
         | 
| 7127 | 
            +
                                        ne01, ne11, ne10,
         | 
| 7128 | 
            +
                                        &alpha_f16, (const char *) src0_as_f16 + i02*src0->nb[2]   + i03*src0->nb[3]  , CUDA_R_16F, nb01/sizeof(half),
         | 
| 7129 | 
            +
                                                    (const char *) src1_as_f16 + i12*src1->nb[2]/2 + i13*src1->nb[3]/2, CUDA_R_16F, nb11/sizeof(float),
         | 
| 7130 | 
            +
                                        &beta_f16,  (      char *)     dst_f16 + i12* dst->nb[2]/2 + i13* dst->nb[3]/2, CUDA_R_16F, ne01,
         | 
| 7131 | 
            +
                                        CUBLAS_COMPUTE_16F,
         | 
| 7132 | 
            +
                                        CUBLAS_GEMM_DEFAULT_TENSOR_OP));
         | 
| 7133 | 
            +
                        }
         | 
| 7134 | 
            +
                    }
         | 
| 7135 | 
            +
                }
         | 
| 7136 | 
            +
            #else
         | 
| 7137 | 
            +
                if (r2 == 1 && r3 == 1 && src0->nb[2]*src0->ne[2] == src0->nb[3] && src1->nb[2]*src1->ne[2] == src1->nb[3]) {
         | 
| 7138 | 
            +
                    // there is no broadcast and src0, src1 are contiguous across dims 2, 3
         | 
| 7139 | 
            +
                    // use cublasGemmStridedBatchedEx
         | 
| 7140 | 
            +
                    CUBLAS_CHECK(
         | 
| 7141 | 
            +
                    cublasGemmStridedBatchedEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
         | 
| 7142 | 
            +
                            ne01, ne11, ne10,
         | 
| 7143 | 
            +
                            &alpha_f16, (const char *) src0_as_f16, CUDA_R_16F, nb01/sizeof(half),  src0->nb[2]/sizeof(half),  // strideA
         | 
| 7144 | 
            +
                                        (const char *) src1_as_f16, CUDA_R_16F, nb11/sizeof(float), src1->nb[2]/sizeof(float), // strideB
         | 
| 7145 | 
            +
                            &beta_f16,  (      char *)     dst_f16, CUDA_R_16F, ne01,                dst->nb[2]/sizeof(float), // strideC
         | 
| 7146 | 
            +
                            ne12*ne13,
         | 
| 7147 | 
            +
                            CUBLAS_COMPUTE_16F,
         | 
| 7148 | 
            +
                            CUBLAS_GEMM_DEFAULT_TENSOR_OP));
         | 
| 7149 | 
            +
                } else {
         | 
| 7150 | 
            +
                    // use cublasGemmBatchedEx
         | 
| 7151 | 
            +
                    // TODO: https://github.com/ggerganov/llama.cpp/pull/3749#discussion_r1369997000
         | 
| 7152 | 
            +
                    const int ne23 = ne12*ne13;
         | 
| 7153 | 
            +
             | 
| 7154 | 
            +
                    // TODO: avoid this alloc
         | 
| 7155 | 
            +
                    void ** ptrs = (void **) malloc(3*ne23*sizeof(void *));
         | 
| 7156 | 
            +
             | 
| 7157 | 
            +
                    for (int i13 = 0; i13 < ne13; ++i13) {
         | 
| 7158 | 
            +
                        for (int i12 = 0; i12 < ne12; ++i12) {
         | 
| 7159 | 
            +
                            int i03 = i13 / r3;
         | 
| 7160 | 
            +
                            int i02 = i12 / r2;
         | 
| 7161 | 
            +
             | 
| 7162 | 
            +
                            ptrs[0*ne23 + i12 + i13*ne12] = (char *) src0_as_f16 + i02*src0->nb[2]   + i03*src0->nb[3];
         | 
| 7163 | 
            +
                            ptrs[1*ne23 + i12 + i13*ne12] = (char *) src1_as_f16 + i12*src1->nb[2]/2 + i13*src1->nb[3]/2;
         | 
| 7164 | 
            +
                            ptrs[2*ne23 + i12 + i13*ne12] = (char *)     dst_f16 + i12* dst->nb[2]/2 + i13* dst->nb[3]/2;
         | 
| 7165 | 
            +
                        }
         | 
| 7166 | 
            +
                    }
         | 
| 7167 | 
            +
             | 
| 7168 | 
            +
                    // allocate device memory for pointers
         | 
| 7169 | 
            +
                    void ** ptrs_as = nullptr;
         | 
| 7170 | 
            +
                    CUDA_CHECK(cudaMalloc(&ptrs_as, 3*ne23*sizeof(void *)));
         | 
| 7171 | 
            +
             | 
| 7172 | 
            +
                    // TODO: this does not work for some reason -- not sure why?
         | 
| 7173 | 
            +
                    //size_t ptrs_s = 0;
         | 
| 7174 | 
            +
                    //ptrs_as = (void **) ggml_cuda_pool_malloc(3*ne23*sizeof(void *), &ptrs_s);
         | 
| 7175 | 
            +
             | 
| 7176 | 
            +
                    // copy pointers to device
         | 
| 7177 | 
            +
                    CUDA_CHECK(cudaMemcpy(ptrs_as, ptrs, 3*ne23*sizeof(void *), cudaMemcpyHostToDevice));
         | 
| 7178 | 
            +
             | 
| 7179 | 
            +
                    free(ptrs);
         | 
| 7180 | 
            +
             | 
| 7181 | 
            +
                    CUBLAS_CHECK(
         | 
| 7182 | 
            +
                    cublasGemmBatchedEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
         | 
| 7183 | 
            +
                            ne01, ne11, ne10,
         | 
| 7184 | 
            +
                            &alpha_f16, (const void **) (ptrs_as + 0*ne23), CUDA_R_16F, nb01/sizeof(half),
         | 
| 7185 | 
            +
                                        (const void **) (ptrs_as + 1*ne23), CUDA_R_16F, nb11/sizeof(float),
         | 
| 7186 | 
            +
                            &beta_f16,  (      void **) (ptrs_as + 2*ne23), CUDA_R_16F, ne01,
         | 
| 7187 | 
            +
                            ne23,
         | 
| 7188 | 
            +
                            CUBLAS_COMPUTE_16F,
         | 
| 7189 | 
            +
                            CUBLAS_GEMM_DEFAULT_TENSOR_OP));
         | 
| 7190 | 
            +
             | 
| 7191 | 
            +
                    // free device memory for pointers
         | 
| 7192 | 
            +
                    CUDA_CHECK(cudaFree(ptrs_as));
         | 
| 7193 | 
            +
                    //ggml_cuda_pool_free(ptrs_as, ptrs_s);
         | 
| 7194 | 
            +
                }
         | 
| 7195 | 
            +
            #endif
         | 
| 7196 | 
            +
             | 
| 7197 | 
            +
                const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
         | 
| 7198 | 
            +
                to_fp32_cuda(dst_f16, dst_ddf, ne, main_stream);
         | 
| 7199 | 
            +
             | 
| 7200 | 
            +
                ggml_cuda_pool_free(src1_as_f16, src1_as);
         | 
| 7201 | 
            +
                ggml_cuda_pool_free(dst_f16, dst_as);
         | 
| 7202 | 
            +
            }
         | 
| 7203 | 
            +
             | 
| 7049 7204 | 
             
            static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
         | 
| 7050 7205 | 
             
                bool all_on_device = (src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT) &&
         | 
| 7051 7206 | 
             
                    src1->backend == GGML_BACKEND_GPU && dst->backend == GGML_BACKEND_GPU;
         | 
| @@ -7058,10 +7213,23 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1 | |
| 7058 7213 | 
             
                    }
         | 
| 7059 7214 | 
             
                }
         | 
| 7060 7215 |  | 
| 7216 | 
            +
                // debug helpers
         | 
| 7217 | 
            +
                //printf("src0: %8d %8d %8d %8d\n", src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3]);
         | 
| 7218 | 
            +
                //printf("      %8d %8d %8d %8d\n", src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]);
         | 
| 7219 | 
            +
                //printf("src1: %8d %8d %8d %8d\n", src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3]);
         | 
| 7220 | 
            +
                //printf("      %8d %8d %8d %8d\n", src1->nb[0], src1->nb[1], src1->nb[2], src1->nb[3]);
         | 
| 7221 | 
            +
                //printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
         | 
| 7222 | 
            +
                //printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
         | 
| 7223 | 
            +
             | 
| 7061 7224 | 
             
                if (all_on_device && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
         | 
| 7225 | 
            +
                    // KQ single-batch
         | 
| 7062 7226 | 
             
                    ggml_cuda_mul_mat_vec_p021(src0, src1, dst);
         | 
| 7063 | 
            -
                } else if (all_on_device && !ggml_is_contiguous(src0) &&  | 
| 7227 | 
            +
                } else if (all_on_device && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
         | 
| 7228 | 
            +
                    // KQV single-batch
         | 
| 7064 7229 | 
             
                    ggml_cuda_mul_mat_vec_nc(src0, src1, dst);
         | 
| 7230 | 
            +
                } else if (all_on_device && src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
         | 
| 7231 | 
            +
                    // KQ + KQV multi-batch
         | 
| 7232 | 
            +
                    ggml_cuda_mul_mat_mat_batched_cublas(src0, src1, dst);
         | 
| 7065 7233 | 
             
                } else if (src0->type == GGML_TYPE_F32) {
         | 
| 7066 7234 | 
             
                    ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false);
         | 
| 7067 7235 | 
             
                } else if (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) {
         | 
| @@ -62,6 +62,7 @@ struct ggml_metal_context { | |
| 62 62 | 
             
                GGML_METAL_DECL_KERNEL(mul);
         | 
| 63 63 | 
             
                GGML_METAL_DECL_KERNEL(mul_row); // TODO: avoid this extra kernel, instead extend the "mul" kernel to support broadcast
         | 
| 64 64 | 
             
                GGML_METAL_DECL_KERNEL(scale);
         | 
| 65 | 
            +
                GGML_METAL_DECL_KERNEL(scale_4);
         | 
| 65 66 | 
             
                GGML_METAL_DECL_KERNEL(silu);
         | 
| 66 67 | 
             
                GGML_METAL_DECL_KERNEL(relu);
         | 
| 67 68 | 
             
                GGML_METAL_DECL_KERNEL(gelu);
         | 
| @@ -73,6 +74,8 @@ struct ggml_metal_context { | |
| 73 74 | 
             
                GGML_METAL_DECL_KERNEL(get_rows_f16);
         | 
| 74 75 | 
             
                GGML_METAL_DECL_KERNEL(get_rows_q4_0);
         | 
| 75 76 | 
             
                GGML_METAL_DECL_KERNEL(get_rows_q4_1);
         | 
| 77 | 
            +
                GGML_METAL_DECL_KERNEL(get_rows_q5_0);
         | 
| 78 | 
            +
                GGML_METAL_DECL_KERNEL(get_rows_q5_1);
         | 
| 76 79 | 
             
                GGML_METAL_DECL_KERNEL(get_rows_q8_0);
         | 
| 77 80 | 
             
                GGML_METAL_DECL_KERNEL(get_rows_q2_K);
         | 
| 78 81 | 
             
                GGML_METAL_DECL_KERNEL(get_rows_q3_K);
         | 
| @@ -87,6 +90,8 @@ struct ggml_metal_context { | |
| 87 90 | 
             
                GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_l4);
         | 
| 88 91 | 
             
                GGML_METAL_DECL_KERNEL(mul_mv_q4_0_f32);
         | 
| 89 92 | 
             
                GGML_METAL_DECL_KERNEL(mul_mv_q4_1_f32);
         | 
| 93 | 
            +
                GGML_METAL_DECL_KERNEL(mul_mv_q5_0_f32);
         | 
| 94 | 
            +
                GGML_METAL_DECL_KERNEL(mul_mv_q5_1_f32);
         | 
| 90 95 | 
             
                GGML_METAL_DECL_KERNEL(mul_mv_q8_0_f32);
         | 
| 91 96 | 
             
                GGML_METAL_DECL_KERNEL(mul_mv_q2_K_f32);
         | 
| 92 97 | 
             
                GGML_METAL_DECL_KERNEL(mul_mv_q3_K_f32);
         | 
| @@ -97,6 +102,8 @@ struct ggml_metal_context { | |
| 97 102 | 
             
                GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
         | 
| 98 103 | 
             
                GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
         | 
| 99 104 | 
             
                GGML_METAL_DECL_KERNEL(mul_mm_q4_1_f32);
         | 
| 105 | 
            +
                GGML_METAL_DECL_KERNEL(mul_mm_q5_0_f32);
         | 
| 106 | 
            +
                GGML_METAL_DECL_KERNEL(mul_mm_q5_1_f32);
         | 
| 100 107 | 
             
                GGML_METAL_DECL_KERNEL(mul_mm_q8_0_f32);
         | 
| 101 108 | 
             
                GGML_METAL_DECL_KERNEL(mul_mm_q2_K_f32);
         | 
| 102 109 | 
             
                GGML_METAL_DECL_KERNEL(mul_mm_q3_K_f32);
         | 
| @@ -243,6 +250,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) { | |
| 243 250 | 
             
                    GGML_METAL_ADD_KERNEL(mul);
         | 
| 244 251 | 
             
                    GGML_METAL_ADD_KERNEL(mul_row);
         | 
| 245 252 | 
             
                    GGML_METAL_ADD_KERNEL(scale);
         | 
| 253 | 
            +
                    GGML_METAL_ADD_KERNEL(scale_4);
         | 
| 246 254 | 
             
                    GGML_METAL_ADD_KERNEL(silu);
         | 
| 247 255 | 
             
                    GGML_METAL_ADD_KERNEL(relu);
         | 
| 248 256 | 
             
                    GGML_METAL_ADD_KERNEL(gelu);
         | 
| @@ -254,6 +262,8 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) { | |
| 254 262 | 
             
                    GGML_METAL_ADD_KERNEL(get_rows_f16);
         | 
| 255 263 | 
             
                    GGML_METAL_ADD_KERNEL(get_rows_q4_0);
         | 
| 256 264 | 
             
                    GGML_METAL_ADD_KERNEL(get_rows_q4_1);
         | 
| 265 | 
            +
                    GGML_METAL_ADD_KERNEL(get_rows_q5_0);
         | 
| 266 | 
            +
                    GGML_METAL_ADD_KERNEL(get_rows_q5_1);
         | 
| 257 267 | 
             
                    GGML_METAL_ADD_KERNEL(get_rows_q8_0);
         | 
| 258 268 | 
             
                    GGML_METAL_ADD_KERNEL(get_rows_q2_K);
         | 
| 259 269 | 
             
                    GGML_METAL_ADD_KERNEL(get_rows_q3_K);
         | 
| @@ -268,6 +278,8 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) { | |
| 268 278 | 
             
                    GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_l4);
         | 
| 269 279 | 
             
                    GGML_METAL_ADD_KERNEL(mul_mv_q4_0_f32);
         | 
| 270 280 | 
             
                    GGML_METAL_ADD_KERNEL(mul_mv_q4_1_f32);
         | 
| 281 | 
            +
                    GGML_METAL_ADD_KERNEL(mul_mv_q5_0_f32);
         | 
| 282 | 
            +
                    GGML_METAL_ADD_KERNEL(mul_mv_q5_1_f32);
         | 
| 271 283 | 
             
                    GGML_METAL_ADD_KERNEL(mul_mv_q8_0_f32);
         | 
| 272 284 | 
             
                    GGML_METAL_ADD_KERNEL(mul_mv_q2_K_f32);
         | 
| 273 285 | 
             
                    GGML_METAL_ADD_KERNEL(mul_mv_q3_K_f32);
         | 
| @@ -278,8 +290,10 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) { | |
| 278 290 | 
             
                        GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
         | 
| 279 291 | 
             
                        GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
         | 
| 280 292 | 
             
                        GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32);
         | 
| 281 | 
            -
                        GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32);
         | 
| 282 293 | 
             
                        GGML_METAL_ADD_KERNEL(mul_mm_q4_1_f32);
         | 
| 294 | 
            +
                        GGML_METAL_ADD_KERNEL(mul_mm_q5_0_f32);
         | 
| 295 | 
            +
                        GGML_METAL_ADD_KERNEL(mul_mm_q5_1_f32);
         | 
| 296 | 
            +
                        GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32);
         | 
| 283 297 | 
             
                        GGML_METAL_ADD_KERNEL(mul_mm_q2_K_f32);
         | 
| 284 298 | 
             
                        GGML_METAL_ADD_KERNEL(mul_mm_q3_K_f32);
         | 
| 285 299 | 
             
                        GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32);
         | 
| @@ -335,6 +349,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) { | |
| 335 349 | 
             
                GGML_METAL_DEL_KERNEL(mul);
         | 
| 336 350 | 
             
                GGML_METAL_DEL_KERNEL(mul_row);
         | 
| 337 351 | 
             
                GGML_METAL_DEL_KERNEL(scale);
         | 
| 352 | 
            +
                GGML_METAL_DEL_KERNEL(scale_4);
         | 
| 338 353 | 
             
                GGML_METAL_DEL_KERNEL(silu);
         | 
| 339 354 | 
             
                GGML_METAL_DEL_KERNEL(relu);
         | 
| 340 355 | 
             
                GGML_METAL_DEL_KERNEL(gelu);
         | 
| @@ -346,6 +361,8 @@ void ggml_metal_free(struct ggml_metal_context * ctx) { | |
| 346 361 | 
             
                GGML_METAL_DEL_KERNEL(get_rows_f16);
         | 
| 347 362 | 
             
                GGML_METAL_DEL_KERNEL(get_rows_q4_0);
         | 
| 348 363 | 
             
                GGML_METAL_DEL_KERNEL(get_rows_q4_1);
         | 
| 364 | 
            +
                GGML_METAL_DEL_KERNEL(get_rows_q5_0);
         | 
| 365 | 
            +
                GGML_METAL_DEL_KERNEL(get_rows_q5_1);
         | 
| 349 366 | 
             
                GGML_METAL_DEL_KERNEL(get_rows_q8_0);
         | 
| 350 367 | 
             
                GGML_METAL_DEL_KERNEL(get_rows_q2_K);
         | 
| 351 368 | 
             
                GGML_METAL_DEL_KERNEL(get_rows_q3_K);
         | 
| @@ -360,6 +377,8 @@ void ggml_metal_free(struct ggml_metal_context * ctx) { | |
| 360 377 | 
             
                GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_l4);
         | 
| 361 378 | 
             
                GGML_METAL_DEL_KERNEL(mul_mv_q4_0_f32);
         | 
| 362 379 | 
             
                GGML_METAL_DEL_KERNEL(mul_mv_q4_1_f32);
         | 
| 380 | 
            +
                GGML_METAL_DEL_KERNEL(mul_mv_q5_0_f32);
         | 
| 381 | 
            +
                GGML_METAL_DEL_KERNEL(mul_mv_q5_1_f32);
         | 
| 363 382 | 
             
                GGML_METAL_DEL_KERNEL(mul_mv_q8_0_f32);
         | 
| 364 383 | 
             
                GGML_METAL_DEL_KERNEL(mul_mv_q2_K_f32);
         | 
| 365 384 | 
             
                GGML_METAL_DEL_KERNEL(mul_mv_q3_K_f32);
         | 
| @@ -370,8 +389,10 @@ void ggml_metal_free(struct ggml_metal_context * ctx) { | |
| 370 389 | 
             
                    GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
         | 
| 371 390 | 
             
                    GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
         | 
| 372 391 | 
             
                    GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32);
         | 
| 373 | 
            -
                    GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32);
         | 
| 374 392 | 
             
                    GGML_METAL_DEL_KERNEL(mul_mm_q4_1_f32);
         | 
| 393 | 
            +
                    GGML_METAL_DEL_KERNEL(mul_mm_q5_0_f32);
         | 
| 394 | 
            +
                    GGML_METAL_DEL_KERNEL(mul_mm_q5_1_f32);
         | 
| 395 | 
            +
                    GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32);
         | 
| 375 396 | 
             
                    GGML_METAL_DEL_KERNEL(mul_mm_q2_K_f32);
         | 
| 376 397 | 
             
                    GGML_METAL_DEL_KERNEL(mul_mm_q3_K_f32);
         | 
| 377 398 | 
             
                    GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32);
         | 
| @@ -905,15 +926,20 @@ void ggml_metal_graph_compute( | |
| 905 926 |  | 
| 906 927 | 
             
                                        const float scale = *(const float *) src1->data;
         | 
| 907 928 |  | 
| 908 | 
            -
                                         | 
| 929 | 
            +
                                        int64_t n = ggml_nelements(dst);
         | 
| 930 | 
            +
             | 
| 931 | 
            +
                                        if (n % 4 == 0) {
         | 
| 932 | 
            +
                                            n /= 4;
         | 
| 933 | 
            +
                                            [encoder setComputePipelineState:ctx->pipeline_scale_4];
         | 
| 934 | 
            +
                                        } else {
         | 
| 935 | 
            +
                                            [encoder setComputePipelineState:ctx->pipeline_scale];
         | 
| 936 | 
            +
                                        }
         | 
| 937 | 
            +
             | 
| 909 938 | 
             
                                        [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
         | 
| 910 939 | 
             
                                        [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
         | 
| 911 940 | 
             
                                        [encoder setBytes:&scale length:sizeof(scale) atIndex:2];
         | 
| 912 941 |  | 
| 913 | 
            -
                                         | 
| 914 | 
            -
                                        GGML_ASSERT(n % 4 == 0);
         | 
| 915 | 
            -
             | 
| 916 | 
            -
                                        [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
         | 
| 942 | 
            +
                                        [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
         | 
| 917 943 | 
             
                                    } break;
         | 
| 918 944 | 
             
                                case GGML_OP_UNARY:
         | 
| 919 945 | 
             
                                    switch (ggml_get_unary_op(gf->nodes[i])) {
         | 
| @@ -1052,6 +1078,8 @@ void ggml_metal_graph_compute( | |
| 1052 1078 | 
             
                                                case GGML_TYPE_F16:  [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32];  break;
         | 
| 1053 1079 | 
             
                                                case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break;
         | 
| 1054 1080 | 
             
                                                case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break;
         | 
| 1081 | 
            +
                                                case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_0_f32]; break;
         | 
| 1082 | 
            +
                                                case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_1_f32]; break;
         | 
| 1055 1083 | 
             
                                                case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q8_0_f32]; break;
         | 
| 1056 1084 | 
             
                                                case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q2_K_f32]; break;
         | 
| 1057 1085 | 
             
                                                case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q3_K_f32]; break;
         | 
| @@ -1121,6 +1149,24 @@ void ggml_metal_graph_compute( | |
| 1121 1149 | 
             
                                                        nth1 = 8;
         | 
| 1122 1150 | 
             
                                                        [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_1_f32];
         | 
| 1123 1151 | 
             
                                                    } break;
         | 
| 1152 | 
            +
                                                case GGML_TYPE_Q5_0:
         | 
| 1153 | 
            +
                                                    {
         | 
| 1154 | 
            +
                                                        GGML_ASSERT(ne02 == 1);
         | 
| 1155 | 
            +
                                                        GGML_ASSERT(ne12 == 1);
         | 
| 1156 | 
            +
             | 
| 1157 | 
            +
                                                        nth0 = 8;
         | 
| 1158 | 
            +
                                                        nth1 = 8;
         | 
| 1159 | 
            +
                                                        [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_0_f32];
         | 
| 1160 | 
            +
                                                    } break;
         | 
| 1161 | 
            +
                                                case GGML_TYPE_Q5_1:
         | 
| 1162 | 
            +
                                                    {
         | 
| 1163 | 
            +
                                                        GGML_ASSERT(ne02 == 1);
         | 
| 1164 | 
            +
                                                        GGML_ASSERT(ne12 == 1);
         | 
| 1165 | 
            +
             | 
| 1166 | 
            +
                                                        nth0 = 8;
         | 
| 1167 | 
            +
                                                        nth1 = 8;
         | 
| 1168 | 
            +
                                                        [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_1_f32];
         | 
| 1169 | 
            +
                                                    } break;
         | 
| 1124 1170 | 
             
                                                case GGML_TYPE_Q8_0:
         | 
| 1125 1171 | 
             
                                                    {
         | 
| 1126 1172 | 
             
                                                        GGML_ASSERT(ne02 == 1);
         | 
| @@ -1201,7 +1247,8 @@ void ggml_metal_graph_compute( | |
| 1201 1247 | 
             
                                            [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:16];
         | 
| 1202 1248 | 
             
                                            [encoder setBytes:&gqa  length:sizeof(gqa)  atIndex:17];
         | 
| 1203 1249 |  | 
| 1204 | 
            -
                                            if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || | 
| 1250 | 
            +
                                            if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
         | 
| 1251 | 
            +
                                                src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 ||
         | 
| 1205 1252 | 
             
                                                src0t == GGML_TYPE_Q2_K) { // || src0t == GGML_TYPE_Q4_K) {
         | 
| 1206 1253 | 
             
                                                [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
         | 
| 1207 1254 | 
             
                                            }
         | 
| @@ -1233,6 +1280,8 @@ void ggml_metal_graph_compute( | |
| 1233 1280 | 
             
                                            case GGML_TYPE_F16:  [encoder setComputePipelineState:ctx->pipeline_get_rows_f16];  break;
         | 
| 1234 1281 | 
             
                                            case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
         | 
| 1235 1282 | 
             
                                            case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
         | 
| 1283 | 
            +
                                            case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_0]; break;
         | 
| 1284 | 
            +
                                            case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_1]; break;
         | 
| 1236 1285 | 
             
                                            case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q8_0]; break;
         | 
| 1237 1286 | 
             
                                            case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_K]; break;
         | 
| 1238 1287 | 
             
                                            case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_K]; break;
         |