llama_cpp 0.3.7 → 0.4.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
 - data/CHANGELOG.md +17 -0
 - data/README.md +1 -1
 - data/examples/chat.rb +2 -4
 - data/ext/llama_cpp/extconf.rb +3 -3
 - data/ext/llama_cpp/llama_cpp.cpp +118 -117
 - data/ext/llama_cpp/src/ggml-alloc.c +97 -53
 - data/ext/llama_cpp/src/ggml-alloc.h +4 -0
 - data/ext/llama_cpp/src/ggml-cuda.cu +1010 -497
 - data/ext/llama_cpp/src/ggml-cuda.h +32 -23
 - data/ext/llama_cpp/src/ggml-metal.h +9 -3
 - data/ext/llama_cpp/src/ggml-metal.m +142 -161
 - data/ext/llama_cpp/src/ggml-metal.metal +577 -500
 - data/ext/llama_cpp/src/ggml.c +2064 -233
 - data/ext/llama_cpp/src/ggml.h +238 -13
 - data/ext/llama_cpp/src/k_quants.c +110 -54
 - data/ext/llama_cpp/src/llama-util.h +10 -8
 - data/ext/llama_cpp/src/llama.cpp +4544 -2890
 - data/ext/llama_cpp/src/llama.h +133 -123
 - data/lib/llama_cpp/version.rb +2 -2
 - data/lib/llama_cpp.rb +1 -1
 - data/sig/llama_cpp.rbs +8 -8
 - metadata +2 -2
 
| 
         @@ -5,7 +5,6 @@ 
     | 
|
| 
       5 
5 
     | 
    
         
             
            #import <Foundation/Foundation.h>
         
     | 
| 
       6 
6 
     | 
    
         | 
| 
       7 
7 
     | 
    
         
             
            #import <Metal/Metal.h>
         
     | 
| 
       8 
     | 
    
         
            -
            #import <MetalPerformanceShaders/MetalPerformanceShaders.h>
         
     | 
| 
       9 
8 
     | 
    
         | 
| 
       10 
9 
     | 
    
         
             
            #undef MIN
         
     | 
| 
       11 
10 
     | 
    
         
             
            #undef MAX
         
     | 
| 
         @@ -64,6 +63,7 @@ struct ggml_metal_context { 
     | 
|
| 
       64 
63 
     | 
    
         
             
                GGML_METAL_DECL_KERNEL(get_rows_f16);
         
     | 
| 
       65 
64 
     | 
    
         
             
                GGML_METAL_DECL_KERNEL(get_rows_q4_0);
         
     | 
| 
       66 
65 
     | 
    
         
             
                GGML_METAL_DECL_KERNEL(get_rows_q4_1);
         
     | 
| 
      
 66 
     | 
    
         
            +
                GGML_METAL_DECL_KERNEL(get_rows_q8_0);
         
     | 
| 
       67 
67 
     | 
    
         
             
                GGML_METAL_DECL_KERNEL(get_rows_q2_K);
         
     | 
| 
       68 
68 
     | 
    
         
             
                GGML_METAL_DECL_KERNEL(get_rows_q3_K);
         
     | 
| 
       69 
69 
     | 
    
         
             
                GGML_METAL_DECL_KERNEL(get_rows_q4_K);
         
     | 
| 
         @@ -74,11 +74,21 @@ struct ggml_metal_context { 
     | 
|
| 
       74 
74 
     | 
    
         
             
                GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
         
     | 
| 
       75 
75 
     | 
    
         
             
                GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
         
     | 
| 
       76 
76 
     | 
    
         
             
                GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
         
     | 
| 
      
 77 
     | 
    
         
            +
                GGML_METAL_DECL_KERNEL(mul_mat_q8_0_f32);
         
     | 
| 
       77 
78 
     | 
    
         
             
                GGML_METAL_DECL_KERNEL(mul_mat_q2_K_f32);
         
     | 
| 
       78 
79 
     | 
    
         
             
                GGML_METAL_DECL_KERNEL(mul_mat_q3_K_f32);
         
     | 
| 
       79 
80 
     | 
    
         
             
                GGML_METAL_DECL_KERNEL(mul_mat_q4_K_f32);
         
     | 
| 
       80 
81 
     | 
    
         
             
                GGML_METAL_DECL_KERNEL(mul_mat_q5_K_f32);
         
     | 
| 
       81 
82 
     | 
    
         
             
                GGML_METAL_DECL_KERNEL(mul_mat_q6_K_f32);
         
     | 
| 
      
 83 
     | 
    
         
            +
                GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
         
     | 
| 
      
 84 
     | 
    
         
            +
                GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
         
     | 
| 
      
 85 
     | 
    
         
            +
                GGML_METAL_DECL_KERNEL(mul_mm_q4_1_f32);
         
     | 
| 
      
 86 
     | 
    
         
            +
                GGML_METAL_DECL_KERNEL(mul_mm_q8_0_f32);
         
     | 
| 
      
 87 
     | 
    
         
            +
                GGML_METAL_DECL_KERNEL(mul_mm_q2_K_f32);
         
     | 
| 
      
 88 
     | 
    
         
            +
                GGML_METAL_DECL_KERNEL(mul_mm_q3_K_f32);
         
     | 
| 
      
 89 
     | 
    
         
            +
                GGML_METAL_DECL_KERNEL(mul_mm_q4_K_f32);
         
     | 
| 
      
 90 
     | 
    
         
            +
                GGML_METAL_DECL_KERNEL(mul_mm_q5_K_f32);
         
     | 
| 
      
 91 
     | 
    
         
            +
                GGML_METAL_DECL_KERNEL(mul_mm_q6_K_f32);
         
     | 
| 
       82 
92 
     | 
    
         
             
                GGML_METAL_DECL_KERNEL(rope);
         
     | 
| 
       83 
93 
     | 
    
         
             
                GGML_METAL_DECL_KERNEL(alibi_f32);
         
     | 
| 
       84 
94 
     | 
    
         
             
                GGML_METAL_DECL_KERNEL(cpy_f32_f16);
         
     | 
| 
         @@ -110,13 +120,6 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) { 
     | 
|
| 
       110 
120 
     | 
    
         
             
                ctx->n_buffers = 0;
         
     | 
| 
       111 
121 
     | 
    
         
             
                ctx->concur_list_len = 0;
         
     | 
| 
       112 
122 
     | 
    
         | 
| 
       113 
     | 
    
         
            -
                // determine if we can use MPS
         
     | 
| 
       114 
     | 
    
         
            -
                if (MPSSupportsMTLDevice(ctx->device)) {
         
     | 
| 
       115 
     | 
    
         
            -
                    fprintf(stderr, "%s: using MPS\n", __func__);
         
     | 
| 
       116 
     | 
    
         
            -
                } else {
         
     | 
| 
       117 
     | 
    
         
            -
                    fprintf(stderr, "%s: not using MPS\n", __func__);
         
     | 
| 
       118 
     | 
    
         
            -
                    GGML_ASSERT(false && "MPS not supported");
         
     | 
| 
       119 
     | 
    
         
            -
                }
         
     | 
| 
       120 
123 
     | 
    
         | 
| 
       121 
124 
     | 
    
         
             
            #if 0
         
     | 
| 
       122 
125 
     | 
    
         
             
                // compile from source string and show compile log
         
     | 
| 
         @@ -126,7 +129,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) { 
     | 
|
| 
       126 
129 
     | 
    
         
             
                    ctx->library = [ctx->device newLibraryWithSource:msl_library_source options:nil error:&error];
         
     | 
| 
       127 
130 
     | 
    
         
             
                    if (error) {
         
     | 
| 
       128 
131 
     | 
    
         
             
                        fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
         
     | 
| 
       129 
     | 
    
         
            -
                         
     | 
| 
      
 132 
     | 
    
         
            +
                        return NULL;
         
     | 
| 
       130 
133 
     | 
    
         
             
                    }
         
     | 
| 
       131 
134 
     | 
    
         
             
                }
         
     | 
| 
       132 
135 
     | 
    
         
             
            #else
         
     | 
| 
         @@ -144,7 +147,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) { 
     | 
|
| 
       144 
147 
     | 
    
         
             
                    NSString * src  = [NSString stringWithContentsOfFile:path encoding:NSUTF8StringEncoding error:&error];
         
     | 
| 
       145 
148 
     | 
    
         
             
                    if (error) {
         
     | 
| 
       146 
149 
     | 
    
         
             
                        fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
         
     | 
| 
       147 
     | 
    
         
            -
                         
     | 
| 
      
 150 
     | 
    
         
            +
                        return NULL;
         
     | 
| 
       148 
151 
     | 
    
         
             
                    }
         
     | 
| 
       149 
152 
     | 
    
         | 
| 
       150 
153 
     | 
    
         
             
            #ifdef GGML_QKK_64
         
     | 
| 
         @@ -156,17 +159,24 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) { 
     | 
|
| 
       156 
159 
     | 
    
         
             
            #endif
         
     | 
| 
       157 
160 
     | 
    
         
             
                    if (error) {
         
     | 
| 
       158 
161 
     | 
    
         
             
                        fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
         
     | 
| 
       159 
     | 
    
         
            -
                         
     | 
| 
      
 162 
     | 
    
         
            +
                        return NULL;
         
     | 
| 
       160 
163 
     | 
    
         
             
                    }
         
     | 
| 
       161 
164 
     | 
    
         
             
                }
         
     | 
| 
       162 
165 
     | 
    
         
             
            #endif
         
     | 
| 
       163 
166 
     | 
    
         | 
| 
       164 
167 
     | 
    
         
             
                // load kernels
         
     | 
| 
       165 
168 
     | 
    
         
             
                {
         
     | 
| 
      
 169 
     | 
    
         
            +
                    NSError * error = nil;
         
     | 
| 
       166 
170 
     | 
    
         
             
            #define GGML_METAL_ADD_KERNEL(name) \
         
     | 
| 
       167 
171 
     | 
    
         
             
                    ctx->function_##name = [ctx->library newFunctionWithName:@"kernel_"#name]; \
         
     | 
| 
       168 
     | 
    
         
            -
                    ctx->pipeline_##name = [ctx->device newComputePipelineStateWithFunction:ctx->function_##name error 
     | 
| 
       169 
     | 
    
         
            -
                    fprintf(stderr, "%s: loaded %-32s %16p\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name 
     | 
| 
      
 172 
     | 
    
         
            +
                    ctx->pipeline_##name = [ctx->device newComputePipelineStateWithFunction:ctx->function_##name error:&error]; \
         
     | 
| 
      
 173 
     | 
    
         
            +
                    fprintf(stderr, "%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name, \
         
     | 
| 
      
 174 
     | 
    
         
            +
                            (int) ctx->pipeline_##name.maxTotalThreadsPerThreadgroup, \
         
     | 
| 
      
 175 
     | 
    
         
            +
                            (int) ctx->pipeline_##name.threadExecutionWidth); \
         
     | 
| 
      
 176 
     | 
    
         
            +
                    if (error) { \
         
     | 
| 
      
 177 
     | 
    
         
            +
                        fprintf(stderr, "%s: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
         
     | 
| 
      
 178 
     | 
    
         
            +
                        return NULL; \
         
     | 
| 
      
 179 
     | 
    
         
            +
                    }
         
     | 
| 
       170 
180 
     | 
    
         | 
| 
       171 
181 
     | 
    
         
             
                    GGML_METAL_ADD_KERNEL(add);
         
     | 
| 
       172 
182 
     | 
    
         
             
                    GGML_METAL_ADD_KERNEL(add_row);
         
     | 
| 
         @@ -181,6 +191,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) { 
     | 
|
| 
       181 
191 
     | 
    
         
             
                    GGML_METAL_ADD_KERNEL(get_rows_f16);
         
     | 
| 
       182 
192 
     | 
    
         
             
                    GGML_METAL_ADD_KERNEL(get_rows_q4_0);
         
     | 
| 
       183 
193 
     | 
    
         
             
                    GGML_METAL_ADD_KERNEL(get_rows_q4_1);
         
     | 
| 
      
 194 
     | 
    
         
            +
                    GGML_METAL_ADD_KERNEL(get_rows_q8_0);
         
     | 
| 
       184 
195 
     | 
    
         
             
                    GGML_METAL_ADD_KERNEL(get_rows_q2_K);
         
     | 
| 
       185 
196 
     | 
    
         
             
                    GGML_METAL_ADD_KERNEL(get_rows_q3_K);
         
     | 
| 
       186 
197 
     | 
    
         
             
                    GGML_METAL_ADD_KERNEL(get_rows_q4_K);
         
     | 
| 
         @@ -191,11 +202,21 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) { 
     | 
|
| 
       191 
202 
     | 
    
         
             
                    GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
         
     | 
| 
       192 
203 
     | 
    
         
             
                    GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
         
     | 
| 
       193 
204 
     | 
    
         
             
                    GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
         
     | 
| 
      
 205 
     | 
    
         
            +
                    GGML_METAL_ADD_KERNEL(mul_mat_q8_0_f32);
         
     | 
| 
       194 
206 
     | 
    
         
             
                    GGML_METAL_ADD_KERNEL(mul_mat_q2_K_f32);
         
     | 
| 
       195 
207 
     | 
    
         
             
                    GGML_METAL_ADD_KERNEL(mul_mat_q3_K_f32);
         
     | 
| 
       196 
208 
     | 
    
         
             
                    GGML_METAL_ADD_KERNEL(mul_mat_q4_K_f32);
         
     | 
| 
       197 
209 
     | 
    
         
             
                    GGML_METAL_ADD_KERNEL(mul_mat_q5_K_f32);
         
     | 
| 
       198 
210 
     | 
    
         
             
                    GGML_METAL_ADD_KERNEL(mul_mat_q6_K_f32);
         
     | 
| 
      
 211 
     | 
    
         
            +
                    GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
         
     | 
| 
      
 212 
     | 
    
         
            +
                    GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32);
         
     | 
| 
      
 213 
     | 
    
         
            +
                    GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32);
         
     | 
| 
      
 214 
     | 
    
         
            +
                    GGML_METAL_ADD_KERNEL(mul_mm_q4_1_f32);
         
     | 
| 
      
 215 
     | 
    
         
            +
                    GGML_METAL_ADD_KERNEL(mul_mm_q2_K_f32);
         
     | 
| 
      
 216 
     | 
    
         
            +
                    GGML_METAL_ADD_KERNEL(mul_mm_q3_K_f32);
         
     | 
| 
      
 217 
     | 
    
         
            +
                    GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32);
         
     | 
| 
      
 218 
     | 
    
         
            +
                    GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32);
         
     | 
| 
      
 219 
     | 
    
         
            +
                    GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32);
         
     | 
| 
       199 
220 
     | 
    
         
             
                    GGML_METAL_ADD_KERNEL(rope);
         
     | 
| 
       200 
221 
     | 
    
         
             
                    GGML_METAL_ADD_KERNEL(alibi_f32);
         
     | 
| 
       201 
222 
     | 
    
         
             
                    GGML_METAL_ADD_KERNEL(cpy_f32_f16);
         
     | 
| 
         @@ -205,12 +226,12 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) { 
     | 
|
| 
       205 
226 
     | 
    
         
             
            #undef GGML_METAL_ADD_KERNEL
         
     | 
| 
       206 
227 
     | 
    
         
             
                }
         
     | 
| 
       207 
228 
     | 
    
         | 
| 
       208 
     | 
    
         
            -
                fprintf(stderr, "%s: recommendedMaxWorkingSetSize 
     | 
| 
       209 
     | 
    
         
            -
                fprintf(stderr, "%s: hasUnifiedMemory 
     | 
| 
      
 229 
     | 
    
         
            +
                fprintf(stderr, "%s: recommendedMaxWorkingSetSize  = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
         
     | 
| 
      
 230 
     | 
    
         
            +
                fprintf(stderr, "%s: hasUnifiedMemory              = %s\n",       __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
         
     | 
| 
       210 
231 
     | 
    
         
             
                if (ctx->device.maxTransferRate != 0) {
         
     | 
| 
       211 
     | 
    
         
            -
                    fprintf(stderr, "%s: maxTransferRate 
     | 
| 
      
 232 
     | 
    
         
            +
                    fprintf(stderr, "%s: maxTransferRate               = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
         
     | 
| 
       212 
233 
     | 
    
         
             
                } else {
         
     | 
| 
       213 
     | 
    
         
            -
                    fprintf(stderr, "%s: maxTransferRate 
     | 
| 
      
 234 
     | 
    
         
            +
                    fprintf(stderr, "%s: maxTransferRate               = built-in GPU\n", __func__);
         
     | 
| 
       214 
235 
     | 
    
         
             
                }
         
     | 
| 
       215 
236 
     | 
    
         | 
| 
       216 
237 
     | 
    
         
             
                return ctx;
         
     | 
| 
         @@ -224,15 +245,31 @@ void ggml_metal_free(struct ggml_metal_context * ctx) { 
     | 
|
| 
       224 
245 
     | 
    
         
             
                free(ctx);
         
     | 
| 
       225 
246 
     | 
    
         
             
            }
         
     | 
| 
       226 
247 
     | 
    
         | 
| 
      
 248 
     | 
    
         
            +
            void * ggml_metal_host_malloc(size_t n) {
         
     | 
| 
      
 249 
     | 
    
         
            +
                void * data = NULL;
         
     | 
| 
      
 250 
     | 
    
         
            +
                const int result = posix_memalign((void **) &data, getpagesize(), n);
         
     | 
| 
      
 251 
     | 
    
         
            +
                if (result != 0) {
         
     | 
| 
      
 252 
     | 
    
         
            +
                    fprintf(stderr, "%s: error: posix_memalign failed\n", __func__);
         
     | 
| 
      
 253 
     | 
    
         
            +
                    return NULL;
         
     | 
| 
      
 254 
     | 
    
         
            +
                }
         
     | 
| 
      
 255 
     | 
    
         
            +
             
     | 
| 
      
 256 
     | 
    
         
            +
                return data;
         
     | 
| 
      
 257 
     | 
    
         
            +
            }
         
     | 
| 
      
 258 
     | 
    
         
            +
             
     | 
| 
      
 259 
     | 
    
         
            +
            void ggml_metal_host_free(void * data) {
         
     | 
| 
      
 260 
     | 
    
         
            +
                free(data);
         
     | 
| 
      
 261 
     | 
    
         
            +
            }
         
     | 
| 
      
 262 
     | 
    
         
            +
             
     | 
| 
       227 
263 
     | 
    
         
             
            void ggml_metal_set_n_cb(struct ggml_metal_context * ctx, int n_cb) {
         
     | 
| 
       228 
264 
     | 
    
         
             
                ctx->n_cb = n_cb;
         
     | 
| 
       229 
265 
     | 
    
         
             
            }
         
     | 
| 
       230 
266 
     | 
    
         | 
| 
       231 
     | 
    
         
            -
             
     | 
| 
       232 
     | 
    
         
            -
                 
     | 
| 
       233 
     | 
    
         
            -
             
     | 
| 
       234 
     | 
    
         
            -
             
     | 
| 
       235 
     | 
    
         
            -
             
     | 
| 
      
 267 
     | 
    
         
            +
            int ggml_metal_if_optimized(struct ggml_metal_context * ctx) {
         
     | 
| 
      
 268 
     | 
    
         
            +
                return ctx->concur_list_len;
         
     | 
| 
      
 269 
     | 
    
         
            +
            }
         
     | 
| 
      
 270 
     | 
    
         
            +
             
     | 
| 
      
 271 
     | 
    
         
            +
            int * ggml_metal_get_concur_list(struct ggml_metal_context * ctx) {
         
     | 
| 
      
 272 
     | 
    
         
            +
                return ctx->concur_list;
         
     | 
| 
       236 
273 
     | 
    
         
             
            }
         
     | 
| 
       237 
274 
     | 
    
         | 
| 
       238 
275 
     | 
    
         
             
            // finds the Metal buffer that contains the tensor data on the GPU device
         
     | 
| 
         @@ -375,7 +412,7 @@ void ggml_metal_get_tensor( 
     | 
|
| 
       375 
412 
     | 
    
         | 
| 
       376 
413 
     | 
    
         
             
            void ggml_metal_graph_find_concurrency(
         
     | 
| 
       377 
414 
     | 
    
         
             
                    struct ggml_metal_context * ctx,
         
     | 
| 
       378 
     | 
    
         
            -
                    struct ggml_cgraph * gf) {
         
     | 
| 
      
 415 
     | 
    
         
            +
                    struct ggml_cgraph * gf, bool check_mem) {
         
     | 
| 
       379 
416 
     | 
    
         
             
                int search_depth = gf->n_nodes; //we only find concurrency in this range to avoid wasting too much time
         
     | 
| 
       380 
417 
     | 
    
         
             
                int nodes_unused[GGML_MAX_CONCUR];
         
     | 
| 
       381 
418 
     | 
    
         | 
| 
         @@ -422,7 +459,7 @@ void ggml_metal_graph_find_concurrency( 
     | 
|
| 
       422 
459 
     | 
    
         
             
                                    }
         
     | 
| 
       423 
460 
     | 
    
         
             
                                }
         
     | 
| 
       424 
461 
     | 
    
         
             
                            }
         
     | 
| 
       425 
     | 
    
         
            -
                            if (exe_flag) {
         
     | 
| 
      
 462 
     | 
    
         
            +
                            if (exe_flag && check_mem) {
         
     | 
| 
       426 
463 
     | 
    
         
             
                                // check if nodes[i]'s data will be overwritten by a node before nodes[i].
         
     | 
| 
       427 
464 
     | 
    
         
             
                                // if node[5] and node[3] write to the same memory region, then we can't issue node[5] before node[3]
         
     | 
| 
       428 
465 
     | 
    
         
             
                                int64_t data_start = (int64_t) gf->nodes[i]->data;
         
     | 
| 
         @@ -506,19 +543,15 @@ void ggml_metal_graph_compute( 
     | 
|
| 
       506 
543 
     | 
    
         | 
| 
       507 
544 
     | 
    
         
             
                        id<MTLCommandBuffer> command_buffer = command_buffers[cb_idx];
         
     | 
| 
       508 
545 
     | 
    
         | 
| 
       509 
     | 
    
         
            -
                        id<MTLComputeCommandEncoder> encoder =  
     | 
| 
      
 546 
     | 
    
         
            +
                        id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
         
     | 
| 
       510 
547 
     | 
    
         | 
| 
       511 
     | 
    
         
            -
                        const int node_start = 
     | 
| 
       512 
     | 
    
         
            -
                        const int node_end   = (cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb;
         
     | 
| 
      
 548 
     | 
    
         
            +
                        const int node_start =                                      (cb_idx + 0) * n_nodes_per_cb;
         
     | 
| 
      
 549 
     | 
    
         
            +
                        const int node_end   = MIN((cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb, n_nodes);
         
     | 
| 
       513 
550 
     | 
    
         | 
| 
       514 
551 
     | 
    
         
             
                        for (int ind = node_start; ind < node_end; ++ind) {
         
     | 
| 
       515 
552 
     | 
    
         
             
                            const int i = has_concur ? ctx->concur_list[ind] : ind;
         
     | 
| 
       516 
553 
     | 
    
         | 
| 
       517 
554 
     | 
    
         
             
                            if (i == -1) {
         
     | 
| 
       518 
     | 
    
         
            -
                                if (encoder == nil) {
         
     | 
| 
       519 
     | 
    
         
            -
                                    encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
         
     | 
| 
       520 
     | 
    
         
            -
                                    continue;
         
     | 
| 
       521 
     | 
    
         
            -
                                }
         
     | 
| 
       522 
555 
     | 
    
         
             
                                [encoder memoryBarrierWithScope:MTLBarrierScopeBuffers];
         
     | 
| 
       523 
556 
     | 
    
         
             
                                continue;
         
     | 
| 
       524 
557 
     | 
    
         
             
                            }
         
     | 
| 
         @@ -592,10 +625,6 @@ void ggml_metal_graph_compute( 
     | 
|
| 
       592 
625 
     | 
    
         
             
                                    } break;
         
     | 
| 
       593 
626 
     | 
    
         
             
                                case GGML_OP_ADD:
         
     | 
| 
       594 
627 
     | 
    
         
             
                                    {
         
     | 
| 
       595 
     | 
    
         
            -
                                        if (encoder == nil) {
         
     | 
| 
       596 
     | 
    
         
            -
                                            encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
         
     | 
| 
       597 
     | 
    
         
            -
                                        }
         
     | 
| 
       598 
     | 
    
         
            -
             
     | 
| 
       599 
628 
     | 
    
         
             
                                        if (ggml_nelements(src1) == ne10) {
         
     | 
| 
       600 
629 
     | 
    
         
             
                                            // src1 is a row
         
     | 
| 
       601 
630 
     | 
    
         
             
                                            [encoder setComputePipelineState:ctx->pipeline_add_row];
         
     | 
| 
         @@ -613,10 +642,6 @@ void ggml_metal_graph_compute( 
     | 
|
| 
       613 
642 
     | 
    
         
             
                                    } break;
         
     | 
| 
       614 
643 
     | 
    
         
             
                                case GGML_OP_MUL:
         
     | 
| 
       615 
644 
     | 
    
         
             
                                    {
         
     | 
| 
       616 
     | 
    
         
            -
                                        if (encoder == nil) {
         
     | 
| 
       617 
     | 
    
         
            -
                                            encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
         
     | 
| 
       618 
     | 
    
         
            -
                                        }
         
     | 
| 
       619 
     | 
    
         
            -
             
     | 
| 
       620 
645 
     | 
    
         
             
                                        if (ggml_nelements(src1) == ne10) {
         
     | 
| 
       621 
646 
     | 
    
         
             
                                            // src1 is a row
         
     | 
| 
       622 
647 
     | 
    
         
             
                                            [encoder setComputePipelineState:ctx->pipeline_mul_row];
         
     | 
| 
         @@ -634,10 +659,6 @@ void ggml_metal_graph_compute( 
     | 
|
| 
       634 
659 
     | 
    
         
             
                                    } break;
         
     | 
| 
       635 
660 
     | 
    
         
             
                                case GGML_OP_SCALE:
         
     | 
| 
       636 
661 
     | 
    
         
             
                                    {
         
     | 
| 
       637 
     | 
    
         
            -
                                        if (encoder == nil) {
         
     | 
| 
       638 
     | 
    
         
            -
                                            encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
         
     | 
| 
       639 
     | 
    
         
            -
                                        }
         
     | 
| 
       640 
     | 
    
         
            -
             
     | 
| 
       641 
662 
     | 
    
         
             
                                        const float scale = *(const float *) src1->data;
         
     | 
| 
       642 
663 
     | 
    
         | 
| 
       643 
664 
     | 
    
         
             
                                        [encoder setComputePipelineState:ctx->pipeline_scale];
         
     | 
| 
         @@ -653,10 +674,6 @@ void ggml_metal_graph_compute( 
     | 
|
| 
       653 
674 
     | 
    
         
             
                                    switch (ggml_get_unary_op(gf->nodes[i])) {
         
     | 
| 
       654 
675 
     | 
    
         
             
                                        case GGML_UNARY_OP_SILU:
         
     | 
| 
       655 
676 
     | 
    
         
             
                                            {
         
     | 
| 
       656 
     | 
    
         
            -
                                                if (encoder == nil) {
         
     | 
| 
       657 
     | 
    
         
            -
                                                    encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
         
     | 
| 
       658 
     | 
    
         
            -
                                                }
         
     | 
| 
       659 
     | 
    
         
            -
             
     | 
| 
       660 
677 
     | 
    
         
             
                                                [encoder setComputePipelineState:ctx->pipeline_silu];
         
     | 
| 
       661 
678 
     | 
    
         
             
                                                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
         
     | 
| 
       662 
679 
     | 
    
         
             
                                                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
         
     | 
| 
         @@ -667,10 +684,6 @@ void ggml_metal_graph_compute( 
     | 
|
| 
       667 
684 
     | 
    
         
             
                                            } break;
         
     | 
| 
       668 
685 
     | 
    
         
             
                                        case GGML_UNARY_OP_RELU:
         
     | 
| 
       669 
686 
     | 
    
         
             
                                            {
         
     | 
| 
       670 
     | 
    
         
            -
                                                if (encoder == nil) {
         
     | 
| 
       671 
     | 
    
         
            -
                                                    encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
         
     | 
| 
       672 
     | 
    
         
            -
                                                }
         
     | 
| 
       673 
     | 
    
         
            -
             
     | 
| 
       674 
687 
     | 
    
         
             
                                                [encoder setComputePipelineState:ctx->pipeline_relu];
         
     | 
| 
       675 
688 
     | 
    
         
             
                                                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
         
     | 
| 
       676 
689 
     | 
    
         
             
                                                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
         
     | 
| 
         @@ -681,10 +694,6 @@ void ggml_metal_graph_compute( 
     | 
|
| 
       681 
694 
     | 
    
         
             
                                            } break;
         
     | 
| 
       682 
695 
     | 
    
         
             
                                        case GGML_UNARY_OP_GELU:
         
     | 
| 
       683 
696 
     | 
    
         
             
                                            {
         
     | 
| 
       684 
     | 
    
         
            -
                                                if (encoder == nil) {
         
     | 
| 
       685 
     | 
    
         
            -
                                                    encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
         
     | 
| 
       686 
     | 
    
         
            -
                                                }
         
     | 
| 
       687 
     | 
    
         
            -
             
     | 
| 
       688 
697 
     | 
    
         
             
                                                [encoder setComputePipelineState:ctx->pipeline_gelu];
         
     | 
| 
       689 
698 
     | 
    
         
             
                                                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
         
     | 
| 
       690 
699 
     | 
    
         
             
                                                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
         
     | 
| 
         @@ -701,10 +710,6 @@ void ggml_metal_graph_compute( 
     | 
|
| 
       701 
710 
     | 
    
         
             
                                    } break;
         
     | 
| 
       702 
711 
     | 
    
         
             
                                case GGML_OP_SOFT_MAX:
         
     | 
| 
       703 
712 
     | 
    
         
             
                                    {
         
     | 
| 
       704 
     | 
    
         
            -
                                        if (encoder == nil) {
         
     | 
| 
       705 
     | 
    
         
            -
                                            encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
         
     | 
| 
       706 
     | 
    
         
            -
                                        }
         
     | 
| 
       707 
     | 
    
         
            -
             
     | 
| 
       708 
713 
     | 
    
         
             
                                        const int nth = 32;
         
     | 
| 
       709 
714 
     | 
    
         | 
| 
       710 
715 
     | 
    
         
             
                                        [encoder setComputePipelineState:ctx->pipeline_soft_max];
         
     | 
| 
         @@ -719,10 +724,6 @@ void ggml_metal_graph_compute( 
     | 
|
| 
       719 
724 
     | 
    
         
             
                                    } break;
         
     | 
| 
       720 
725 
     | 
    
         
             
                                case GGML_OP_DIAG_MASK_INF:
         
     | 
| 
       721 
726 
     | 
    
         
             
                                    {
         
     | 
| 
       722 
     | 
    
         
            -
                                        if (encoder == nil) {
         
     | 
| 
       723 
     | 
    
         
            -
                                            encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
         
     | 
| 
       724 
     | 
    
         
            -
                                        }
         
     | 
| 
       725 
     | 
    
         
            -
             
     | 
| 
       726 
727 
     | 
    
         
             
                                        const int n_past = ((int32_t *)(dst->op_params))[0];
         
     | 
| 
       727 
728 
     | 
    
         | 
| 
       728 
729 
     | 
    
         
             
                                        [encoder setComputePipelineState:ctx->pipeline_diag_mask_inf];
         
     | 
| 
         @@ -740,53 +741,43 @@ void ggml_metal_graph_compute( 
     | 
|
| 
       740 
741 
     | 
    
         | 
| 
       741 
742 
     | 
    
         
             
                                        GGML_ASSERT(ne00 == ne10);
         
     | 
| 
       742 
743 
     | 
    
         
             
                                        // GGML_ASSERT(ne02 == ne12); // Should be checked on individual data types until broadcast is implemented everywhere
         
     | 
| 
      
 744 
     | 
    
         
            +
                                        uint gqa = ne12/ne02;
         
     | 
| 
       743 
745 
     | 
    
         
             
                                        GGML_ASSERT(ne03 == ne13);
         
     | 
| 
       744 
746 
     | 
    
         | 
| 
      
 747 
     | 
    
         
            +
                                        // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
         
     | 
| 
      
 748 
     | 
    
         
            +
                                        // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
         
     | 
| 
       745 
749 
     | 
    
         
             
                                        if (ggml_is_contiguous(src0) &&
         
     | 
| 
       746 
750 
     | 
    
         
             
                                            ggml_is_contiguous(src1) &&
         
     | 
| 
       747 
     | 
    
         
            -
                                             
     | 
| 
       748 
     | 
    
         
            -
             
     | 
| 
       749 
     | 
    
         
            -
                                             
     | 
| 
       750 
     | 
    
         
            -
             
     | 
| 
       751 
     | 
    
         
            -
             
     | 
| 
       752 
     | 
    
         
            -
             
     | 
| 
       753 
     | 
    
         
            -
             
     | 
| 
       754 
     | 
    
         
            -
             
     | 
| 
       755 
     | 
    
         
            -
             
     | 
| 
       756 
     | 
    
         
            -
             
     | 
| 
       757 
     | 
    
         
            -
             
     | 
| 
       758 
     | 
    
         
            -
             
     | 
| 
       759 
     | 
    
         
            -
                                                 
     | 
| 
       760 
     | 
    
         
            -
             
     | 
| 
       761 
     | 
    
         
            -
             
     | 
| 
       762 
     | 
    
         
            -
                                                matrixDescriptorWithRows:ne11 columns:ne10 rowBytes:src1->nb[1] dataType:src1dt];
         
     | 
| 
       763 
     | 
    
         
            -
             
     | 
| 
       764 
     | 
    
         
            -
                                            MPSMatrixDescriptor * desc  = [MPSMatrixDescriptor
         
     | 
| 
       765 
     | 
    
         
            -
                                                matrixDescriptorWithRows:ne1 columns:ne0 rowBytes:dst->nb[1] dataType:MPSDataTypeFloat32];
         
     | 
| 
       766 
     | 
    
         
            -
             
     | 
| 
       767 
     | 
    
         
            -
                                            MPSMatrixMultiplication * mul = [[MPSMatrixMultiplication alloc]
         
     | 
| 
       768 
     | 
    
         
            -
                                                initWithDevice:ctx->device transposeLeft:false transposeRight:true
         
     | 
| 
       769 
     | 
    
         
            -
                                                    resultRows:ne11 resultColumns:ne01 interiorColumns:ne00 alpha:1.0 beta:0.0];
         
     | 
| 
       770 
     | 
    
         
            -
             
     | 
| 
       771 
     | 
    
         
            -
                                            // we need to do ne12 multiplications
         
     | 
| 
       772 
     | 
    
         
            -
                                            // TODO: is there a way to do this in parallel - currently very slow ..
         
     | 
| 
       773 
     | 
    
         
            -
                                            // TODO: might be possible to offload part of the computation to ANE using Accelerate's CBLAS
         
     | 
| 
       774 
     | 
    
         
            -
                                            for (int64_t i02 = 0; i02 < ne12; ++i02) {
         
     | 
| 
       775 
     | 
    
         
            -
                                                size_t offs_src0_cur = offs_src0 + i02/(ne12/ne02)*nb02; // gqa not used for now
         
     | 
| 
       776 
     | 
    
         
            -
                                                size_t offs_src1_cur = offs_src1 + i02*nb12;
         
     | 
| 
       777 
     | 
    
         
            -
                                                size_t offs_dst_cur  = offs_dst  + i02*nb2;
         
     | 
| 
       778 
     | 
    
         
            -
             
     | 
| 
       779 
     | 
    
         
            -
                                                MPSMatrix * mat_src0 = [[MPSMatrix alloc] initWithBuffer:id_src0 offset:offs_src0_cur descriptor:desc0];
         
     | 
| 
       780 
     | 
    
         
            -
                                                MPSMatrix * mat_src1 = [[MPSMatrix alloc] initWithBuffer:id_src1 offset:offs_src1_cur descriptor:desc1];
         
     | 
| 
       781 
     | 
    
         
            -
                                                MPSMatrix * mat_dst  = [[MPSMatrix alloc] initWithBuffer:id_dst  offset:offs_dst_cur  descriptor:desc ];
         
     | 
| 
       782 
     | 
    
         
            -
             
     | 
| 
       783 
     | 
    
         
            -
                                                [mul encodeToCommandBuffer:command_buffer leftMatrix:mat_src1 rightMatrix:mat_src0 resultMatrix:mat_dst];
         
     | 
| 
      
 751 
     | 
    
         
            +
                                            src1t == GGML_TYPE_F32 &&
         
     | 
| 
      
 752 
     | 
    
         
            +
                                            [ctx->device supportsFamily:MTLGPUFamilyApple7] &&
         
     | 
| 
      
 753 
     | 
    
         
            +
                                            ne00%32 == 0 &&
         
     | 
| 
      
 754 
     | 
    
         
            +
                                            ne11 > 1) {
         
     | 
| 
      
 755 
     | 
    
         
            +
                                            switch (src0->type) {
         
     | 
| 
      
 756 
     | 
    
         
            +
                                                case GGML_TYPE_F16:  [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32];  break;
         
     | 
| 
      
 757 
     | 
    
         
            +
                                                case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break;
         
     | 
| 
      
 758 
     | 
    
         
            +
                                                case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break;
         
     | 
| 
      
 759 
     | 
    
         
            +
                                                case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q8_0_f32]; break;
         
     | 
| 
      
 760 
     | 
    
         
            +
                                                case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q2_K_f32]; break;
         
     | 
| 
      
 761 
     | 
    
         
            +
                                                case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q3_K_f32]; break;
         
     | 
| 
      
 762 
     | 
    
         
            +
                                                case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_K_f32]; break;
         
     | 
| 
      
 763 
     | 
    
         
            +
                                                case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_K_f32]; break;
         
     | 
| 
      
 764 
     | 
    
         
            +
                                                case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q6_K_f32]; break;
         
     | 
| 
      
 765 
     | 
    
         
            +
                                                default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
         
     | 
| 
       784 
766 
     | 
    
         
             
                                            }
         
     | 
| 
      
 767 
     | 
    
         
            +
                                            [encoder setBuffer:id_src0 offset:offs_src0    atIndex:0];
         
     | 
| 
      
 768 
     | 
    
         
            +
                                            [encoder setBuffer:id_src1 offset:offs_src1    atIndex:1];
         
     | 
| 
      
 769 
     | 
    
         
            +
                                            [encoder setBuffer:id_dst  offset:offs_dst     atIndex:2];
         
     | 
| 
      
 770 
     | 
    
         
            +
                                            [encoder setBytes:&ne00    length:sizeof(ne00) atIndex:3];
         
     | 
| 
      
 771 
     | 
    
         
            +
                                            [encoder setBytes:&ne02    length:sizeof(ne02) atIndex:4];
         
     | 
| 
      
 772 
     | 
    
         
            +
                                            [encoder setBytes:&nb01    length:sizeof(nb01) atIndex:5];
         
     | 
| 
      
 773 
     | 
    
         
            +
                                            [encoder setBytes:&nb02    length:sizeof(nb02) atIndex:6];
         
     | 
| 
      
 774 
     | 
    
         
            +
                                            [encoder setBytes:&ne12    length:sizeof(ne12) atIndex:7];
         
     | 
| 
      
 775 
     | 
    
         
            +
                                            [encoder setBytes:&ne0     length:sizeof(ne0)  atIndex:8];
         
     | 
| 
      
 776 
     | 
    
         
            +
                                            [encoder setBytes:&ne1     length:sizeof(ne1)  atIndex:9];
         
     | 
| 
      
 777 
     | 
    
         
            +
                                            [encoder setBytes:&gqa     length:sizeof(gqa)  atIndex:10];
         
     | 
| 
      
 778 
     | 
    
         
            +
                                            [encoder setThreadgroupMemoryLength:8192 atIndex:0];
         
     | 
| 
      
 779 
     | 
    
         
            +
                                            [encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
         
     | 
| 
       785 
780 
     | 
    
         
             
                                        } else {
         
     | 
| 
       786 
     | 
    
         
            -
                                            if (encoder == nil) {
         
     | 
| 
       787 
     | 
    
         
            -
                                                encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
         
     | 
| 
       788 
     | 
    
         
            -
                                            }
         
     | 
| 
       789 
     | 
    
         
            -
             
     | 
| 
       790 
781 
     | 
    
         
             
                                            int nth0 = 32;
         
     | 
| 
       791 
782 
     | 
    
         
             
                                            int nth1 = 1;
         
     | 
| 
       792 
783 
     | 
    
         | 
| 
         @@ -816,6 +807,15 @@ void ggml_metal_graph_compute( 
     | 
|
| 
       816 
807 
     | 
    
         
             
                                                        nth1 = 8;
         
     | 
| 
       817 
808 
     | 
    
         
             
                                                        [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_1_f32];
         
     | 
| 
       818 
809 
     | 
    
         
             
                                                    } break;
         
     | 
| 
      
 810 
     | 
    
         
            +
                                                case GGML_TYPE_Q8_0:
         
     | 
| 
      
 811 
     | 
    
         
            +
                                                    {
         
     | 
| 
      
 812 
     | 
    
         
            +
                                                        GGML_ASSERT(ne02 == 1);
         
     | 
| 
      
 813 
     | 
    
         
            +
                                                        GGML_ASSERT(ne12 == 1);
         
     | 
| 
      
 814 
     | 
    
         
            +
             
     | 
| 
      
 815 
     | 
    
         
            +
                                                        nth0 = 8;
         
     | 
| 
      
 816 
     | 
    
         
            +
                                                        nth1 = 8;
         
     | 
| 
      
 817 
     | 
    
         
            +
                                                        [encoder setComputePipelineState:ctx->pipeline_mul_mat_q8_0_f32];
         
     | 
| 
      
 818 
     | 
    
         
            +
                                                    } break;
         
     | 
| 
       819 
819 
     | 
    
         
             
                                                case GGML_TYPE_Q2_K:
         
     | 
| 
       820 
820 
     | 
    
         
             
                                                    {
         
     | 
| 
       821 
821 
     | 
    
         
             
                                                        GGML_ASSERT(ne02 == 1);
         
     | 
| 
         @@ -885,23 +885,24 @@ void ggml_metal_graph_compute( 
     | 
|
| 
       885 
885 
     | 
    
         
             
                                            [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14];
         
     | 
| 
       886 
886 
     | 
    
         
             
                                            [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:15];
         
     | 
| 
       887 
887 
     | 
    
         
             
                                            [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:16];
         
     | 
| 
      
 888 
     | 
    
         
            +
                                            [encoder setBytes:&gqa  length:sizeof(gqa)  atIndex:17];
         
     | 
| 
       888 
889 
     | 
    
         | 
| 
       889 
     | 
    
         
            -
                                            if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
         
     | 
| 
      
 890 
     | 
    
         
            +
                                            if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q8_0 ||
         
     | 
| 
       890 
891 
     | 
    
         
             
                                                src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) {
         
     | 
| 
       891 
     | 
    
         
            -
                                                [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7) 
     | 
| 
      
 892 
     | 
    
         
            +
                                                [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
         
     | 
| 
       892 
893 
     | 
    
         
             
                                            }
         
     | 
| 
       893 
894 
     | 
    
         
             
                                            else if (src0t == GGML_TYPE_Q3_K) {
         
     | 
| 
       894 
895 
     | 
    
         
             
            #ifdef GGML_QKK_64
         
     | 
| 
       895 
     | 
    
         
            -
                                                [encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11,  
     | 
| 
      
 896 
     | 
    
         
            +
                                                [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
         
     | 
| 
       896 
897 
     | 
    
         
             
            #else
         
     | 
| 
       897 
     | 
    
         
            -
                                                [encoder dispatchThreadgroups:MTLSizeMake((ne01+3)/4, ne11,  
     | 
| 
      
 898 
     | 
    
         
            +
                                                [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
         
     | 
| 
       898 
899 
     | 
    
         
             
            #endif
         
     | 
| 
       899 
900 
     | 
    
         
             
                                            }
         
     | 
| 
       900 
901 
     | 
    
         
             
                                            else if (src0t == GGML_TYPE_Q5_K) {
         
     | 
| 
       901 
     | 
    
         
            -
                                                [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3) 
     | 
| 
      
 902 
     | 
    
         
            +
                                                [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
         
     | 
| 
       902 
903 
     | 
    
         
             
                                            }
         
     | 
| 
       903 
904 
     | 
    
         
             
                                            else if (src0t == GGML_TYPE_Q6_K) {
         
     | 
| 
       904 
     | 
    
         
            -
                                                [encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11,  
     | 
| 
      
 905 
     | 
    
         
            +
                                                [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
         
     | 
| 
       905 
906 
     | 
    
         
             
                                            } else {
         
     | 
| 
       906 
907 
     | 
    
         
             
                                                [encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0];
         
     | 
| 
       907 
908 
     | 
    
         
             
                                                [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
         
     | 
| 
         @@ -910,14 +911,11 @@ void ggml_metal_graph_compute( 
     | 
|
| 
       910 
911 
     | 
    
         
             
                                    } break;
         
     | 
| 
       911 
912 
     | 
    
         
             
                                case GGML_OP_GET_ROWS:
         
     | 
| 
       912 
913 
     | 
    
         
             
                                    {
         
     | 
| 
       913 
     | 
    
         
            -
                                        if (encoder == nil) {
         
     | 
| 
       914 
     | 
    
         
            -
                                            encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
         
     | 
| 
       915 
     | 
    
         
            -
                                        }
         
     | 
| 
       916 
     | 
    
         
            -
             
     | 
| 
       917 
914 
     | 
    
         
             
                                        switch (src0->type) {
         
     | 
| 
       918 
     | 
    
         
            -
                                            case GGML_TYPE_F16:  [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; 
     | 
| 
      
 915 
     | 
    
         
            +
                                            case GGML_TYPE_F16:  [encoder setComputePipelineState:ctx->pipeline_get_rows_f16];  break;
         
     | 
| 
       919 
916 
     | 
    
         
             
                                            case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
         
     | 
| 
       920 
917 
     | 
    
         
             
                                            case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
         
     | 
| 
      
 918 
     | 
    
         
            +
                                            case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q8_0]; break;
         
     | 
| 
       921 
919 
     | 
    
         
             
                                            case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_K]; break;
         
     | 
| 
       922 
920 
     | 
    
         
             
                                            case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_K]; break;
         
     | 
| 
       923 
921 
     | 
    
         
             
                                            case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_K]; break;
         
     | 
| 
         @@ -939,10 +937,6 @@ void ggml_metal_graph_compute( 
     | 
|
| 
       939 
937 
     | 
    
         
             
                                    } break;
         
     | 
| 
       940 
938 
     | 
    
         
             
                                case GGML_OP_RMS_NORM:
         
     | 
| 
       941 
939 
     | 
    
         
             
                                    {
         
     | 
| 
       942 
     | 
    
         
            -
                                        if (encoder == nil) {
         
     | 
| 
       943 
     | 
    
         
            -
                                            encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
         
     | 
| 
       944 
     | 
    
         
            -
                                        }
         
     | 
| 
       945 
     | 
    
         
            -
             
     | 
| 
       946 
940 
     | 
    
         
             
                                        float eps;
         
     | 
| 
       947 
941 
     | 
    
         
             
                                        memcpy(&eps, dst->op_params, sizeof(float));
         
     | 
| 
       948 
942 
     | 
    
         | 
| 
         @@ -962,20 +956,17 @@ void ggml_metal_graph_compute( 
     | 
|
| 
       962 
956 
     | 
    
         
             
                                    } break;
         
     | 
| 
       963 
957 
     | 
    
         
             
                                case GGML_OP_NORM:
         
     | 
| 
       964 
958 
     | 
    
         
             
                                    {
         
     | 
| 
       965 
     | 
    
         
            -
                                         
     | 
| 
       966 
     | 
    
         
            -
             
     | 
| 
       967 
     | 
    
         
            -
                                        }
         
     | 
| 
       968 
     | 
    
         
            -
             
     | 
| 
       969 
     | 
    
         
            -
                                        const float eps = 1e-5f;
         
     | 
| 
      
 959 
     | 
    
         
            +
                                        float eps;
         
     | 
| 
      
 960 
     | 
    
         
            +
                                        memcpy(&eps, dst->op_params, sizeof(float));
         
     | 
| 
       970 
961 
     | 
    
         | 
| 
       971 
962 
     | 
    
         
             
                                        const int nth = 256;
         
     | 
| 
       972 
963 
     | 
    
         | 
| 
       973 
964 
     | 
    
         
             
                                        [encoder setComputePipelineState:ctx->pipeline_norm];
         
     | 
| 
       974 
     | 
    
         
            -
                                        [encoder setBuffer:id_src0 offset:offs_src0 
     | 
| 
       975 
     | 
    
         
            -
                                        [encoder setBuffer:id_dst  offset:offs_dst 
     | 
| 
       976 
     | 
    
         
            -
                                        [encoder setBytes:&ne00 
     | 
| 
       977 
     | 
    
         
            -
                                        [encoder setBytes:&nb01 
     | 
| 
       978 
     | 
    
         
            -
                                        [encoder setBytes:&eps 
     | 
| 
      
 965 
     | 
    
         
            +
                                        [encoder setBuffer:id_src0 offset:offs_src0        atIndex:0];
         
     | 
| 
      
 966 
     | 
    
         
            +
                                        [encoder setBuffer:id_dst  offset:offs_dst         atIndex:1];
         
     | 
| 
      
 967 
     | 
    
         
            +
                                        [encoder setBytes:&ne00    length:sizeof( int64_t) atIndex:2];
         
     | 
| 
      
 968 
     | 
    
         
            +
                                        [encoder setBytes:&nb01    length:sizeof(uint64_t) atIndex:3];
         
     | 
| 
      
 969 
     | 
    
         
            +
                                        [encoder setBytes:&eps     length:sizeof(   float) atIndex:4];
         
     | 
| 
       979 
970 
     | 
    
         
             
                                        [encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
         
     | 
| 
       980 
971 
     | 
    
         | 
| 
       981 
972 
     | 
    
         
             
                                        const int64_t nrows = ggml_nrows(src0);
         
     | 
| 
         @@ -984,10 +975,6 @@ void ggml_metal_graph_compute( 
     | 
|
| 
       984 
975 
     | 
    
         
             
                                    } break;
         
     | 
| 
       985 
976 
     | 
    
         
             
                                case GGML_OP_ALIBI:
         
     | 
| 
       986 
977 
     | 
    
         
             
                                    {
         
     | 
| 
       987 
     | 
    
         
            -
                                        if (encoder == nil) {
         
     | 
| 
       988 
     | 
    
         
            -
                                            encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
         
     | 
| 
       989 
     | 
    
         
            -
                                        }
         
     | 
| 
       990 
     | 
    
         
            -
             
     | 
| 
       991 
978 
     | 
    
         
             
                                        GGML_ASSERT((src0t == GGML_TYPE_F32));
         
     | 
| 
       992 
979 
     | 
    
         | 
| 
       993 
980 
     | 
    
         
             
                                        const int n_past = ((int32_t *) dst->op_params)[0]; UNUSED(n_past);
         
     | 
| 
         @@ -1022,15 +1009,13 @@ void ggml_metal_graph_compute( 
     | 
|
| 
       1022 
1009 
     | 
    
         
             
                                        [encoder setBytes:&nb2  length:sizeof(uint64_t) atIndex:16];
         
     | 
| 
       1023 
1010 
     | 
    
         
             
                                        [encoder setBytes:&nb3  length:sizeof(uint64_t) atIndex:17];
         
     | 
| 
       1024 
1011 
     | 
    
         
             
                                        [encoder setBytes:&m0  length:sizeof(    float) atIndex:18];
         
     | 
| 
      
 1012 
     | 
    
         
            +
             
     | 
| 
       1025 
1013 
     | 
    
         
             
                                        const int nth = 32;
         
     | 
| 
      
 1014 
     | 
    
         
            +
             
     | 
| 
       1026 
1015 
     | 
    
         
             
                                        [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
         
     | 
| 
       1027 
1016 
     | 
    
         
             
                                    } break;
         
     | 
| 
       1028 
1017 
     | 
    
         
             
                                case GGML_OP_ROPE:
         
     | 
| 
       1029 
1018 
     | 
    
         
             
                                    {
         
     | 
| 
       1030 
     | 
    
         
            -
                                        if (encoder == nil) {
         
     | 
| 
       1031 
     | 
    
         
            -
                                            encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
         
     | 
| 
       1032 
     | 
    
         
            -
                                        }
         
     | 
| 
       1033 
     | 
    
         
            -
             
     | 
| 
       1034 
1019 
     | 
    
         
             
                                        const int n_past = ((int32_t *) dst->op_params)[0];
         
     | 
| 
       1035 
1020 
     | 
    
         
             
                                        const int n_dims = ((int32_t *) dst->op_params)[1];
         
     | 
| 
       1036 
1021 
     | 
    
         
             
                                        const int mode   = ((int32_t *) dst->op_params)[2];
         
     | 
| 
         @@ -1041,8 +1026,8 @@ void ggml_metal_graph_compute( 
     | 
|
| 
       1041 
1026 
     | 
    
         
             
                                        memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
         
     | 
| 
       1042 
1027 
     | 
    
         | 
| 
       1043 
1028 
     | 
    
         
             
                                        [encoder setComputePipelineState:ctx->pipeline_rope];
         
     | 
| 
       1044 
     | 
    
         
            -
                                        [encoder setBuffer:id_src0 offset:offs_src0 
     | 
| 
       1045 
     | 
    
         
            -
                                        [encoder setBuffer:id_dst  offset:offs_dst 
     | 
| 
      
 1029 
     | 
    
         
            +
                                        [encoder setBuffer:id_src0 offset:offs_src0        atIndex:0];
         
     | 
| 
      
 1030 
     | 
    
         
            +
                                        [encoder setBuffer:id_dst  offset:offs_dst         atIndex:1];
         
     | 
| 
       1046 
1031 
     | 
    
         
             
                                        [encoder setBytes:&ne00    length:sizeof( int64_t) atIndex:2];
         
     | 
| 
       1047 
1032 
     | 
    
         
             
                                        [encoder setBytes:&ne01    length:sizeof( int64_t) atIndex:3];
         
     | 
| 
       1048 
1033 
     | 
    
         
             
                                        [encoder setBytes:&ne02    length:sizeof( int64_t) atIndex:4];
         
     | 
| 
         @@ -1071,10 +1056,6 @@ void ggml_metal_graph_compute( 
     | 
|
| 
       1071 
1056 
     | 
    
         
             
                                case GGML_OP_CPY:
         
     | 
| 
       1072 
1057 
     | 
    
         
             
                                case GGML_OP_CONT:
         
     | 
| 
       1073 
1058 
     | 
    
         
             
                                    {
         
     | 
| 
       1074 
     | 
    
         
            -
                                        if (encoder == nil) {
         
     | 
| 
       1075 
     | 
    
         
            -
                                            encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
         
     | 
| 
       1076 
     | 
    
         
            -
                                        }
         
     | 
| 
       1077 
     | 
    
         
            -
             
     | 
| 
       1078 
1059 
     | 
    
         
             
                                        const int nth = 32;
         
     | 
| 
       1079 
1060 
     | 
    
         | 
| 
       1080 
1061 
     | 
    
         
             
                                        switch (src0t) {
         
     | 
| 
         @@ -1097,24 +1078,24 @@ void ggml_metal_graph_compute( 
     | 
|
| 
       1097 
1078 
     | 
    
         
             
                                            default: GGML_ASSERT(false && "not implemented");
         
     | 
| 
       1098 
1079 
     | 
    
         
             
                                        }
         
     | 
| 
       1099 
1080 
     | 
    
         | 
| 
       1100 
     | 
    
         
            -
                                        [encoder setBuffer:id_src0 offset:offs_src0 
     | 
| 
       1101 
     | 
    
         
            -
                                        [encoder setBuffer:id_dst  offset:offs_dst 
     | 
| 
       1102 
     | 
    
         
            -
                                        [encoder setBytes:&ne00 
     | 
| 
       1103 
     | 
    
         
            -
                                        [encoder setBytes:&ne01 
     | 
| 
       1104 
     | 
    
         
            -
                                        [encoder setBytes:&ne02 
     | 
| 
       1105 
     | 
    
         
            -
                                        [encoder setBytes:&ne03 
     | 
| 
       1106 
     | 
    
         
            -
                                        [encoder setBytes:&nb00 
     | 
| 
       1107 
     | 
    
         
            -
                                        [encoder setBytes:&nb01 
     | 
| 
       1108 
     | 
    
         
            -
                                        [encoder setBytes:&nb02 
     | 
| 
       1109 
     | 
    
         
            -
                                        [encoder setBytes:&nb03 
     | 
| 
       1110 
     | 
    
         
            -
                                        [encoder setBytes:&ne0 
     | 
| 
       1111 
     | 
    
         
            -
                                        [encoder setBytes:&ne1 
     | 
| 
       1112 
     | 
    
         
            -
                                        [encoder setBytes:&ne2 
     | 
| 
       1113 
     | 
    
         
            -
                                        [encoder setBytes:&ne3 
     | 
| 
       1114 
     | 
    
         
            -
                                        [encoder setBytes:&nb0 
     | 
| 
       1115 
     | 
    
         
            -
                                        [encoder setBytes:&nb1 
     | 
| 
       1116 
     | 
    
         
            -
                                        [encoder setBytes:&nb2 
     | 
| 
       1117 
     | 
    
         
            -
                                        [encoder setBytes:&nb3 
     | 
| 
      
 1081 
     | 
    
         
            +
                                        [encoder setBuffer:id_src0 offset:offs_src0        atIndex:0];
         
     | 
| 
      
 1082 
     | 
    
         
            +
                                        [encoder setBuffer:id_dst  offset:offs_dst         atIndex:1];
         
     | 
| 
      
 1083 
     | 
    
         
            +
                                        [encoder setBytes:&ne00    length:sizeof( int64_t) atIndex:2];
         
     | 
| 
      
 1084 
     | 
    
         
            +
                                        [encoder setBytes:&ne01    length:sizeof( int64_t) atIndex:3];
         
     | 
| 
      
 1085 
     | 
    
         
            +
                                        [encoder setBytes:&ne02    length:sizeof( int64_t) atIndex:4];
         
     | 
| 
      
 1086 
     | 
    
         
            +
                                        [encoder setBytes:&ne03    length:sizeof( int64_t) atIndex:5];
         
     | 
| 
      
 1087 
     | 
    
         
            +
                                        [encoder setBytes:&nb00    length:sizeof(uint64_t) atIndex:6];
         
     | 
| 
      
 1088 
     | 
    
         
            +
                                        [encoder setBytes:&nb01    length:sizeof(uint64_t) atIndex:7];
         
     | 
| 
      
 1089 
     | 
    
         
            +
                                        [encoder setBytes:&nb02    length:sizeof(uint64_t) atIndex:8];
         
     | 
| 
      
 1090 
     | 
    
         
            +
                                        [encoder setBytes:&nb03    length:sizeof(uint64_t) atIndex:9];
         
     | 
| 
      
 1091 
     | 
    
         
            +
                                        [encoder setBytes:&ne0     length:sizeof( int64_t) atIndex:10];
         
     | 
| 
      
 1092 
     | 
    
         
            +
                                        [encoder setBytes:&ne1     length:sizeof( int64_t) atIndex:11];
         
     | 
| 
      
 1093 
     | 
    
         
            +
                                        [encoder setBytes:&ne2     length:sizeof( int64_t) atIndex:12];
         
     | 
| 
      
 1094 
     | 
    
         
            +
                                        [encoder setBytes:&ne3     length:sizeof( int64_t) atIndex:13];
         
     | 
| 
      
 1095 
     | 
    
         
            +
                                        [encoder setBytes:&nb0     length:sizeof(uint64_t) atIndex:14];
         
     | 
| 
      
 1096 
     | 
    
         
            +
                                        [encoder setBytes:&nb1     length:sizeof(uint64_t) atIndex:15];
         
     | 
| 
      
 1097 
     | 
    
         
            +
                                        [encoder setBytes:&nb2     length:sizeof(uint64_t) atIndex:16];
         
     | 
| 
      
 1098 
     | 
    
         
            +
                                        [encoder setBytes:&nb3     length:sizeof(uint64_t) atIndex:17];
         
     | 
| 
       1118 
1099 
     | 
    
         | 
| 
       1119 
1100 
     | 
    
         
             
                                        [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
         
     | 
| 
       1120 
1101 
     | 
    
         
             
                                    } break;
         
     |