llama_cpp 0.12.3 → 0.12.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +13 -0
- data/ext/llama_cpp/extconf.rb +1 -0
- data/ext/llama_cpp/llama_cpp.cpp +22 -6
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +4 -2
- data/vendor/tmp/llama.cpp/Makefile +160 -56
- data/vendor/tmp/llama.cpp/ggml-alloc.c +85 -25
- data/vendor/tmp/llama.cpp/ggml-backend-impl.h +6 -0
- data/vendor/tmp/llama.cpp/ggml-backend.c +115 -3
- data/vendor/tmp/llama.cpp/ggml-backend.h +3 -0
- data/vendor/tmp/llama.cpp/ggml-cuda.cu +688 -270
- data/vendor/tmp/llama.cpp/ggml-impl.h +2 -0
- data/vendor/tmp/llama.cpp/ggml-kompute.cpp +1990 -0
- data/vendor/tmp/llama.cpp/ggml-kompute.h +46 -0
- data/vendor/tmp/llama.cpp/ggml-metal.h +3 -0
- data/vendor/tmp/llama.cpp/ggml-metal.m +121 -86
- data/vendor/tmp/llama.cpp/ggml-metal.metal +303 -4
- data/vendor/tmp/llama.cpp/ggml-opencl.cpp +95 -3
- data/vendor/tmp/llama.cpp/ggml-opencl.h +1 -0
- data/vendor/tmp/llama.cpp/ggml-quants.c +745 -109
- data/vendor/tmp/llama.cpp/ggml-quants.h +81 -56
- data/vendor/tmp/llama.cpp/ggml-sycl.cpp +15296 -0
- data/vendor/tmp/llama.cpp/ggml-sycl.h +29 -0
- data/vendor/tmp/llama.cpp/ggml-vulkan-shaders.hpp +51714 -0
- data/vendor/tmp/llama.cpp/ggml-vulkan.cpp +5726 -0
- data/vendor/tmp/llama.cpp/ggml-vulkan.h +39 -0
- data/vendor/tmp/llama.cpp/ggml.c +356 -60
- data/vendor/tmp/llama.cpp/ggml.h +7 -1
- data/vendor/tmp/llama.cpp/llama.cpp +876 -118
- data/vendor/tmp/llama.cpp/llama.h +12 -16
- metadata +9 -2
| @@ -0,0 +1,46 @@ | |
| 1 | 
            +
            #pragma once
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            #include "ggml.h"
         | 
| 4 | 
            +
            #include "ggml-backend.h"
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            #include <stdbool.h>
         | 
| 7 | 
            +
            #include <stddef.h>
         | 
| 8 | 
            +
            #include <stdint.h>
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            #ifdef __cplusplus
         | 
| 11 | 
            +
            extern "C" {
         | 
| 12 | 
            +
            #endif
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            struct ggml_vk_device {
         | 
| 15 | 
            +
                int index;
         | 
| 16 | 
            +
                int type; // same as VkPhysicalDeviceType
         | 
| 17 | 
            +
                size_t heapSize;
         | 
| 18 | 
            +
                const char * name;
         | 
| 19 | 
            +
                const char * vendor;
         | 
| 20 | 
            +
                int subgroupSize;
         | 
| 21 | 
            +
                uint64_t bufferAlignment;
         | 
| 22 | 
            +
                uint64_t maxAlloc;
         | 
| 23 | 
            +
            };
         | 
| 24 | 
            +
             | 
| 25 | 
            +
            struct ggml_vk_device * ggml_vk_available_devices(size_t memoryRequired, size_t * count);
         | 
| 26 | 
            +
            bool ggml_vk_get_device(struct ggml_vk_device * device, size_t memoryRequired, const char * name);
         | 
| 27 | 
            +
            bool ggml_vk_has_vulkan(void);
         | 
| 28 | 
            +
            bool ggml_vk_has_device(void);
         | 
| 29 | 
            +
            struct ggml_vk_device ggml_vk_current_device(void);
         | 
| 30 | 
            +
             | 
| 31 | 
            +
            //
         | 
| 32 | 
            +
            // backend API
         | 
| 33 | 
            +
            //
         | 
| 34 | 
            +
             | 
| 35 | 
            +
            // forward declaration
         | 
| 36 | 
            +
            typedef struct ggml_backend * ggml_backend_t;
         | 
| 37 | 
            +
             | 
| 38 | 
            +
            GGML_API ggml_backend_t ggml_backend_kompute_init(int device);
         | 
| 39 | 
            +
             | 
| 40 | 
            +
            GGML_API bool ggml_backend_is_kompute(ggml_backend_t backend);
         | 
| 41 | 
            +
             | 
| 42 | 
            +
            GGML_API ggml_backend_buffer_type_t ggml_backend_kompute_buffer_type(int device);
         | 
| 43 | 
            +
             | 
| 44 | 
            +
            #ifdef __cplusplus
         | 
| 45 | 
            +
            }
         | 
| 46 | 
            +
            #endif
         | 
| @@ -57,6 +57,9 @@ GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(voi | |
| 57 57 | 
             
            // ref: https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
         | 
| 58 58 | 
             
            GGML_API bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family);
         | 
| 59 59 |  | 
| 60 | 
            +
            // capture all command buffers committed the next time `ggml_backend_graph_compute` is called
         | 
| 61 | 
            +
            GGML_API void ggml_backend_metal_capture_next_compute(ggml_backend_t backend);
         | 
| 62 | 
            +
             | 
| 60 63 | 
             
            #ifdef __cplusplus
         | 
| 61 64 | 
             
            }
         | 
| 62 65 | 
             
            #endif
         | 
| @@ -24,19 +24,7 @@ | |
| 24 24 |  | 
| 25 25 | 
             
            #define UNUSED(x) (void)(x)
         | 
| 26 26 |  | 
| 27 | 
            -
            #define GGML_METAL_MAX_KERNELS 256
         | 
| 28 | 
            -
             | 
| 29 | 
            -
            struct ggml_metal_buffer {
         | 
| 30 | 
            -
                const char * name;
         | 
| 31 | 
            -
             | 
| 32 | 
            -
                void   * data;
         | 
| 33 | 
            -
                size_t   size;
         | 
| 34 | 
            -
             | 
| 35 | 
            -
                id<MTLBuffer> metal;
         | 
| 36 | 
            -
            };
         | 
| 37 | 
            -
             | 
| 38 27 | 
             
            struct ggml_metal_kernel {
         | 
| 39 | 
            -
                id<MTLFunction>             function;
         | 
| 40 28 | 
             
                id<MTLComputePipelineState> pipeline;
         | 
| 41 29 | 
             
            };
         | 
| 42 30 |  | 
| @@ -72,6 +60,7 @@ enum ggml_metal_kernel_type { | |
| 72 60 | 
             
                GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K,
         | 
| 73 61 | 
             
                GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS,
         | 
| 74 62 | 
             
                GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS,
         | 
| 63 | 
            +
                GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS,
         | 
| 75 64 | 
             
                GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
         | 
| 76 65 | 
             
                GGML_METAL_KERNEL_TYPE_RMS_NORM,
         | 
| 77 66 | 
             
                GGML_METAL_KERNEL_TYPE_GROUP_NORM,
         | 
| @@ -93,6 +82,7 @@ enum ggml_metal_kernel_type { | |
| 93 82 | 
             
                GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32,
         | 
| 94 83 | 
             
                GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32,
         | 
| 95 84 | 
             
                GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32,
         | 
| 85 | 
            +
                GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32,
         | 
| 96 86 | 
             
                GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,
         | 
| 97 87 | 
             
              //GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16,
         | 
| 98 88 | 
             
                GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32,
         | 
| @@ -110,6 +100,7 @@ enum ggml_metal_kernel_type { | |
| 110 100 | 
             
                GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32,
         | 
| 111 101 | 
             
                GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32,
         | 
| 112 102 | 
             
                GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32,
         | 
| 103 | 
            +
                GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32,
         | 
| 113 104 | 
             
                GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,
         | 
| 114 105 | 
             
                GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32,
         | 
| 115 106 | 
             
                GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32,
         | 
| @@ -124,6 +115,7 @@ enum ggml_metal_kernel_type { | |
| 124 115 | 
             
                GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32,
         | 
| 125 116 | 
             
                GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32,
         | 
| 126 117 | 
             
                GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32,
         | 
| 118 | 
            +
                GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32,
         | 
| 127 119 | 
             
                GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,
         | 
| 128 120 | 
             
                GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32,
         | 
| 129 121 | 
             
                GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32,
         | 
| @@ -138,10 +130,12 @@ enum ggml_metal_kernel_type { | |
| 138 130 | 
             
                GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32,
         | 
| 139 131 | 
             
                GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32,
         | 
| 140 132 | 
             
                GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32,
         | 
| 133 | 
            +
                GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32,
         | 
| 141 134 | 
             
                GGML_METAL_KERNEL_TYPE_ROPE_F32,
         | 
| 142 135 | 
             
                GGML_METAL_KERNEL_TYPE_ROPE_F16,
         | 
| 143 136 | 
             
                GGML_METAL_KERNEL_TYPE_ALIBI_F32,
         | 
| 144 137 | 
             
                GGML_METAL_KERNEL_TYPE_IM2COL_F16,
         | 
| 138 | 
            +
                GGML_METAL_KERNEL_TYPE_IM2COL_F32,
         | 
| 145 139 | 
             
                GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
         | 
| 146 140 | 
             
                GGML_METAL_KERNEL_TYPE_PAD_F32,
         | 
| 147 141 | 
             
                GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,
         | 
| @@ -168,17 +162,15 @@ struct ggml_metal_context { | |
| 168 162 |  | 
| 169 163 | 
             
                id<MTLDevice>       device;
         | 
| 170 164 | 
             
                id<MTLCommandQueue> queue;
         | 
| 171 | 
            -
                id<MTLLibrary>      library;
         | 
| 172 165 |  | 
| 173 166 | 
             
                dispatch_queue_t d_queue;
         | 
| 174 167 |  | 
| 175 | 
            -
                 | 
| 176 | 
            -
                struct ggml_metal_buffer buffers[GGML_METAL_MAX_BUFFERS];
         | 
| 177 | 
            -
             | 
| 178 | 
            -
                struct ggml_metal_kernel kernels[GGML_METAL_MAX_KERNELS];
         | 
| 168 | 
            +
                struct ggml_metal_kernel kernels[GGML_METAL_KERNEL_TYPE_COUNT];
         | 
| 179 169 |  | 
| 180 170 | 
             
                bool support_simdgroup_reduction;
         | 
| 181 171 | 
             
                bool support_simdgroup_mm;
         | 
| 172 | 
            +
             | 
| 173 | 
            +
                bool should_capture_next_compute;
         | 
| 182 174 | 
             
            };
         | 
| 183 175 |  | 
| 184 176 | 
             
            // MSL code
         | 
| @@ -242,26 +234,24 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) { | |
| 242 234 | 
             
                // Show all the Metal device instances in the system
         | 
| 243 235 | 
             
                NSArray * devices = MTLCopyAllDevices();
         | 
| 244 236 | 
             
                for (id<MTLDevice> device in devices) {
         | 
| 245 | 
            -
                     | 
| 246 | 
            -
                    GGML_METAL_LOG_INFO("%s: found device: %s\n", __func__, [s UTF8String]);
         | 
| 237 | 
            +
                    GGML_METAL_LOG_INFO("%s: found device: %s\n", __func__, [[device name] UTF8String]);
         | 
| 247 238 | 
             
                }
         | 
| 248 239 | 
             
                [devices release]; // since it was created by a *Copy* C method
         | 
| 249 240 | 
             
            #endif
         | 
| 250 241 |  | 
| 251 242 | 
             
                // Pick and show default Metal device
         | 
| 252 243 | 
             
                id<MTLDevice> device = MTLCreateSystemDefaultDevice();
         | 
| 253 | 
            -
                 | 
| 254 | 
            -
                GGML_METAL_LOG_INFO("%s: picking default device: %s\n", __func__, [s UTF8String]);
         | 
| 244 | 
            +
                GGML_METAL_LOG_INFO("%s: picking default device: %s\n", __func__, [[device name] UTF8String]);
         | 
| 255 245 |  | 
| 256 246 | 
             
                // Configure context
         | 
| 257 247 | 
             
                struct ggml_metal_context * ctx = malloc(sizeof(struct ggml_metal_context));
         | 
| 258 248 | 
             
                ctx->device = device;
         | 
| 259 249 | 
             
                ctx->n_cb   = MIN(n_cb, GGML_METAL_MAX_BUFFERS);
         | 
| 260 250 | 
             
                ctx->queue  = [ctx->device newCommandQueue];
         | 
| 261 | 
            -
                ctx->n_buffers = 0;
         | 
| 262 | 
            -
             | 
| 263 251 | 
             
                ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
         | 
| 264 252 |  | 
| 253 | 
            +
                id<MTLLibrary> metal_library;
         | 
| 254 | 
            +
             | 
| 265 255 | 
             
                // load library
         | 
| 266 256 | 
             
                {
         | 
| 267 257 | 
             
                    NSBundle * bundle = nil;
         | 
| @@ -276,7 +266,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) { | |
| 276 266 | 
             
                        // pre-compiled library found
         | 
| 277 267 | 
             
                        NSURL * libURL = [NSURL fileURLWithPath:libPath];
         | 
| 278 268 | 
             
                        GGML_METAL_LOG_INFO("%s: loading '%s'\n", __func__, [libPath UTF8String]);
         | 
| 279 | 
            -
                         | 
| 269 | 
            +
                        metal_library = [ctx->device newLibraryWithURL:libURL error:&error];
         | 
| 280 270 | 
             
                        if (error) {
         | 
| 281 271 | 
             
                            GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
         | 
| 282 272 | 
             
                            return NULL;
         | 
| @@ -318,7 +308,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) { | |
| 318 308 |  | 
| 319 309 | 
             
                            //[options setFastMathEnabled:false];
         | 
| 320 310 |  | 
| 321 | 
            -
                             | 
| 311 | 
            +
                            metal_library = [ctx->device newLibraryWithSource:src options:options error:&error];
         | 
| 322 312 | 
             
                            if (error) {
         | 
| 323 313 | 
             
                                GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
         | 
| 324 314 | 
             
                                return NULL;
         | 
| @@ -367,6 +357,8 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) { | |
| 367 357 | 
             
                GGML_METAL_LOG_INFO("%s: simdgroup matrix mul. support = %s\n",       __func__, ctx->support_simdgroup_mm ? "true" : "false");
         | 
| 368 358 | 
             
                GGML_METAL_LOG_INFO("%s: hasUnifiedMemory              = %s\n",       __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
         | 
| 369 359 |  | 
| 360 | 
            +
                ctx->should_capture_next_compute = false;
         | 
| 361 | 
            +
             | 
| 370 362 | 
             
            #if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)
         | 
| 371 363 | 
             
                if (@available(macOS 10.12, iOS 16.0, *)) {
         | 
| 372 364 | 
             
                    GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize  = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1e6);
         | 
| @@ -383,8 +375,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) { | |
| 383 375 | 
             
                {
         | 
| 384 376 | 
             
                    NSError * error = nil;
         | 
| 385 377 |  | 
| 386 | 
            -
                    for (int i = 0; i <  | 
| 387 | 
            -
                        ctx->kernels[i].function = nil;
         | 
| 378 | 
            +
                    for (int i = 0; i < GGML_METAL_KERNEL_TYPE_COUNT; ++i) {
         | 
| 388 379 | 
             
                        ctx->kernels[i].pipeline = nil;
         | 
| 389 380 | 
             
                    }
         | 
| 390 381 |  | 
| @@ -396,10 +387,12 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) { | |
| 396 387 | 
             
            #define GGML_METAL_ADD_KERNEL(e, name, supported) \
         | 
| 397 388 | 
             
                    if (supported) { \
         | 
| 398 389 | 
             
                        struct ggml_metal_kernel * kernel = &ctx->kernels[e]; \
         | 
| 399 | 
            -
                         | 
| 400 | 
            -
                        kernel->pipeline = [ctx->device newComputePipelineStateWithFunction: | 
| 390 | 
            +
                        id<MTLFunction> metal_function = [metal_library newFunctionWithName:@"kernel_"#name]; \
         | 
| 391 | 
            +
                        kernel->pipeline = [ctx->device newComputePipelineStateWithFunction:metal_function error:&error]; \
         | 
| 392 | 
            +
                        [metal_function release]; \
         | 
| 401 393 | 
             
                        if (error) { \
         | 
| 402 394 | 
             
                            GGML_METAL_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
         | 
| 395 | 
            +
                            [metal_library release]; \
         | 
| 403 396 | 
             
                            return NULL; \
         | 
| 404 397 | 
             
                        } \
         | 
| 405 398 | 
             
                    } else { \
         | 
| @@ -439,6 +432,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) { | |
| 439 432 | 
             
                    GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K,             get_rows_q6_K,          true);
         | 
| 440 433 | 
             
                    GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS,          get_rows_iq2_xxs,       true);
         | 
| 441 434 | 
             
                    GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS,           get_rows_iq2_xs,        true);
         | 
| 435 | 
            +
                    GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS,          get_rows_iq3_xxs,       true);
         | 
| 442 436 | 
             
                    GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,              get_rows_i32,           true);
         | 
| 443 437 | 
             
                    GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM,                  rms_norm,               ctx->support_simdgroup_reduction);
         | 
| 444 438 | 
             
                    GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM,                group_norm,             ctx->support_simdgroup_reduction);
         | 
| @@ -460,6 +454,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) { | |
| 460 454 | 
             
                    GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32,           mul_mv_q6_K_f32,        ctx->support_simdgroup_reduction);
         | 
| 461 455 | 
             
                    GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32,        mul_mv_iq2_xxs_f32,     ctx->support_simdgroup_reduction);
         | 
| 462 456 | 
             
                    GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32,         mul_mv_iq2_xs_f32,      ctx->support_simdgroup_reduction);
         | 
| 457 | 
            +
                    GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32,        mul_mv_iq3_xxs_f32,     ctx->support_simdgroup_reduction);
         | 
| 463 458 | 
             
                    GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,         mul_mv_id_f32_f32,      ctx->support_simdgroup_reduction);
         | 
| 464 459 | 
             
                  //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16,         mul_mv_id_f16_f16,      ctx->support_simdgroup_reduction);
         | 
| 465 460 | 
             
                    GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32,         mul_mv_id_f16_f32,      ctx->support_simdgroup_reduction);
         | 
| @@ -477,6 +472,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) { | |
| 477 472 | 
             
                    GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32,        mul_mv_id_q6_K_f32,     ctx->support_simdgroup_reduction);
         | 
| 478 473 | 
             
                    GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32,     mul_mv_id_iq2_xxs_f32,  ctx->support_simdgroup_reduction);
         | 
| 479 474 | 
             
                    GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32,      mul_mv_id_iq2_xs_f32,   ctx->support_simdgroup_reduction);
         | 
| 475 | 
            +
                    GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32,     mul_mv_id_iq3_xxs_f32,  ctx->support_simdgroup_reduction);
         | 
| 480 476 | 
             
                    GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,            mul_mm_f32_f32,         ctx->support_simdgroup_mm);
         | 
| 481 477 | 
             
                    GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32,            mul_mm_f16_f32,         ctx->support_simdgroup_mm);
         | 
| 482 478 | 
             
                    GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32,           mul_mm_q4_0_f32,        ctx->support_simdgroup_mm);
         | 
| @@ -491,6 +487,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) { | |
| 491 487 | 
             
                    GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32,           mul_mm_q6_K_f32,        ctx->support_simdgroup_mm);
         | 
| 492 488 | 
             
                    GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32,        mul_mm_iq2_xxs_f32,     ctx->support_simdgroup_mm);
         | 
| 493 489 | 
             
                    GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32,         mul_mm_iq2_xs_f32,      ctx->support_simdgroup_mm);
         | 
| 490 | 
            +
                    GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32,        mul_mm_iq3_xxs_f32,     ctx->support_simdgroup_mm);
         | 
| 494 491 | 
             
                    GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,         mul_mm_id_f32_f32,      ctx->support_simdgroup_mm);
         | 
| 495 492 | 
             
                    GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32,         mul_mm_id_f16_f32,      ctx->support_simdgroup_mm);
         | 
| 496 493 | 
             
                    GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32,        mul_mm_id_q4_0_f32,     ctx->support_simdgroup_mm);
         | 
| @@ -505,10 +502,12 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) { | |
| 505 502 | 
             
                    GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32,        mul_mm_id_q6_K_f32,     ctx->support_simdgroup_mm);
         | 
| 506 503 | 
             
                    GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32,     mul_mm_id_iq2_xxs_f32,  ctx->support_simdgroup_mm);
         | 
| 507 504 | 
             
                    GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32,      mul_mm_id_iq2_xs_f32,   ctx->support_simdgroup_mm);
         | 
| 505 | 
            +
                    GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32,     mul_mm_id_iq3_xxs_f32,  ctx->support_simdgroup_mm);
         | 
| 508 506 | 
             
                    GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32,                  rope_f32,               true);
         | 
| 509 507 | 
             
                    GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16,                  rope_f16,               true);
         | 
| 510 508 | 
             
                    GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32,                 alibi_f32,              true);
         | 
| 511 509 | 
             
                    GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16,                im2col_f16,             true);
         | 
| 510 | 
            +
                    GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32,                im2col_f32,             true);
         | 
| 512 511 | 
             
                    GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32,               upscale_f32,            true);
         | 
| 513 512 | 
             
                    GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32,                   pad_f32,                true);
         | 
| 514 513 | 
             
                    GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,       argsort_f32_i32_asc,    true);
         | 
| @@ -528,27 +527,17 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) { | |
| 528 527 | 
             
                    GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS,                  sum_rows,               true);
         | 
| 529 528 | 
             
                }
         | 
| 530 529 |  | 
| 530 | 
            +
                [metal_library release];
         | 
| 531 531 | 
             
                return ctx;
         | 
| 532 532 | 
             
            }
         | 
| 533 533 |  | 
| 534 534 | 
             
            static void ggml_metal_free(struct ggml_metal_context * ctx) {
         | 
| 535 535 | 
             
                GGML_METAL_LOG_INFO("%s: deallocating\n", __func__);
         | 
| 536 536 |  | 
| 537 | 
            -
                for (int i = 0; i <  | 
| 538 | 
            -
                    [ctx-> | 
| 537 | 
            +
                for (int i = 0; i < GGML_METAL_KERNEL_TYPE_COUNT; ++i) {
         | 
| 538 | 
            +
                    [ctx->kernels[i].pipeline release];
         | 
| 539 539 | 
             
                }
         | 
| 540 540 |  | 
| 541 | 
            -
                for (int i = 0; i < GGML_METAL_MAX_KERNELS; ++i) {
         | 
| 542 | 
            -
                    if (ctx->kernels[i].pipeline) {
         | 
| 543 | 
            -
                        [ctx->kernels[i].pipeline release];
         | 
| 544 | 
            -
                    }
         | 
| 545 | 
            -
             | 
| 546 | 
            -
                    if (ctx->kernels[i].function) {
         | 
| 547 | 
            -
                        [ctx->kernels[i].function release];
         | 
| 548 | 
            -
                    }
         | 
| 549 | 
            -
                }
         | 
| 550 | 
            -
             | 
| 551 | 
            -
                [ctx->library release];
         | 
| 552 541 | 
             
                [ctx->queue release];
         | 
| 553 542 | 
             
                [ctx->device release];
         | 
| 554 543 |  | 
| @@ -580,51 +569,30 @@ struct ggml_backend_metal_buffer_context { | |
| 580 569 | 
             
            // the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the
         | 
| 581 570 | 
             
            // Metal buffer based on the host memory pointer
         | 
| 582 571 | 
             
            //
         | 
| 583 | 
            -
            static id<MTLBuffer> ggml_metal_get_buffer(struct  | 
| 572 | 
            +
            static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_tensor * t, size_t * offs) {
         | 
| 584 573 | 
             
                //GGML_METAL_LOG_INFO("%s: data tensor '%16s', offs_data = %8ld, offs_eval = %8ld, offs_cach = %8ld\n", __func__, t->name, offs_data, offs_eval, offs_cach);
         | 
| 585 574 |  | 
| 586 575 | 
             
                const int64_t tsize = ggml_nbytes(t);
         | 
| 587 576 |  | 
| 588 577 | 
             
                ggml_backend_buffer_t buffer = t->view_src ? t->view_src->buffer : t->buffer;
         | 
| 589 578 |  | 
| 590 | 
            -
                 | 
| 591 | 
            -
                if (buffer && buffer->buft == ggml_backend_metal_buffer_type()) {
         | 
| 592 | 
            -
                    struct ggml_backend_metal_buffer_context * buf_ctx = (struct ggml_backend_metal_buffer_context *) buffer->context;
         | 
| 593 | 
            -
             | 
| 594 | 
            -
                    // find the view that contains the tensor fully
         | 
| 595 | 
            -
                    for (int i = 0; i < buf_ctx->n_buffers; ++i) {
         | 
| 596 | 
            -
                        const int64_t ioffs = (int64_t) t->data - (int64_t) buf_ctx->buffers[i].data;
         | 
| 597 | 
            -
             | 
| 598 | 
            -
                        //GGML_METAL_LOG_INFO("ioffs = %10ld, tsize = %10ld, sum = %10ld, buf_ctx->buffers[%d].size = %10ld\n", ioffs, tsize, ioffs + tsize, i, buf_ctx->buffers[i].size);
         | 
| 599 | 
            -
                        if (ioffs >= 0 && ioffs + tsize <= (int64_t) buf_ctx->buffers[i].size) {
         | 
| 600 | 
            -
                            *offs = (size_t) ioffs;
         | 
| 601 | 
            -
             | 
| 602 | 
            -
                            //GGML_METAL_LOG_INFO("%s: tensor '%16s', offs = %8ld\n", __func__, t->name, *offs);
         | 
| 603 | 
            -
             | 
| 604 | 
            -
                            return buf_ctx->buffers[i].metal;
         | 
| 605 | 
            -
                        }
         | 
| 606 | 
            -
                    }
         | 
| 607 | 
            -
             | 
| 608 | 
            -
                    GGML_METAL_LOG_ERROR("%s: error: tensor '%s' buffer is nil\n", __func__, t->name);
         | 
| 609 | 
            -
             | 
| 610 | 
            -
                    return nil;
         | 
| 611 | 
            -
                }
         | 
| 579 | 
            +
                struct ggml_backend_metal_buffer_context * buf_ctx = (struct ggml_backend_metal_buffer_context *) buffer->context;
         | 
| 612 580 |  | 
| 613 581 | 
             
                // find the view that contains the tensor fully
         | 
| 614 | 
            -
                for (int i = 0; i <  | 
| 615 | 
            -
                    const int64_t ioffs = (int64_t) t->data - (int64_t)  | 
| 582 | 
            +
                for (int i = 0; i < buf_ctx->n_buffers; ++i) {
         | 
| 583 | 
            +
                    const int64_t ioffs = (int64_t) t->data - (int64_t) buf_ctx->buffers[i].data;
         | 
| 616 584 |  | 
| 617 | 
            -
                    //GGML_METAL_LOG_INFO("ioffs = %10ld, tsize = %10ld, sum = %10ld,  | 
| 618 | 
            -
                    if (ioffs >= 0 && ioffs + tsize <= (int64_t)  | 
| 585 | 
            +
                    //GGML_METAL_LOG_INFO("ioffs = %10ld, tsize = %10ld, sum = %10ld, buf_ctx->buffers[%d].size = %10ld\n", ioffs, tsize, ioffs + tsize, i, buf_ctx->buffers[i].size);
         | 
| 586 | 
            +
                    if (ioffs >= 0 && ioffs + tsize <= (int64_t) buf_ctx->buffers[i].size) {
         | 
| 619 587 | 
             
                        *offs = (size_t) ioffs;
         | 
| 620 588 |  | 
| 621 | 
            -
                        //GGML_METAL_LOG_INFO("%s:  | 
| 589 | 
            +
                        //GGML_METAL_LOG_INFO("%s: tensor '%16s', offs = %8ld\n", __func__, t->name, *offs);
         | 
| 622 590 |  | 
| 623 | 
            -
                        return  | 
| 591 | 
            +
                        return buf_ctx->buffers[i].metal;
         | 
| 624 592 | 
             
                    }
         | 
| 625 593 | 
             
                }
         | 
| 626 594 |  | 
| 627 | 
            -
                GGML_METAL_LOG_ERROR("%s: error: buffer is nil\n", __func__);
         | 
| 595 | 
            +
                GGML_METAL_LOG_ERROR("%s: error: tensor '%s' buffer is nil\n", __func__, t->name);
         | 
| 628 596 |  | 
| 629 597 | 
             
                return nil;
         | 
| 630 598 | 
             
            }
         | 
| @@ -664,6 +632,10 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const | |
| 664 632 | 
             
                    case GGML_OP_ALIBI:
         | 
| 665 633 | 
             
                    case GGML_OP_ROPE:
         | 
| 666 634 | 
             
                    case GGML_OP_IM2COL:
         | 
| 635 | 
            +
                        return true;
         | 
| 636 | 
            +
                    case GGML_OP_POOL_1D:
         | 
| 637 | 
            +
                    case GGML_OP_POOL_2D:
         | 
| 638 | 
            +
                        return false;
         | 
| 667 639 | 
             
                    case GGML_OP_UPSCALE:
         | 
| 668 640 | 
             
                    case GGML_OP_PAD:
         | 
| 669 641 | 
             
                    case GGML_OP_ARGSORT:
         | 
| @@ -725,6 +697,20 @@ static bool ggml_metal_graph_compute( | |
| 725 697 | 
             
                const int n_cb = ctx->n_cb;
         | 
| 726 698 | 
             
                const int n_nodes_per_cb = (n_nodes + n_cb - 1) / n_cb;
         | 
| 727 699 |  | 
| 700 | 
            +
                const bool should_capture = ctx->should_capture_next_compute;
         | 
| 701 | 
            +
                if (should_capture) {
         | 
| 702 | 
            +
                    ctx->should_capture_next_compute = false;
         | 
| 703 | 
            +
             | 
| 704 | 
            +
                    MTLCaptureDescriptor * descriptor = [MTLCaptureDescriptor new];
         | 
| 705 | 
            +
                    descriptor.captureObject = ctx->queue;
         | 
| 706 | 
            +
             | 
| 707 | 
            +
                    NSError * error = nil;
         | 
| 708 | 
            +
                    if (![[MTLCaptureManager sharedCaptureManager] startCaptureWithDescriptor:descriptor error:&error]) {
         | 
| 709 | 
            +
                        GGML_METAL_LOG_ERROR("%s: error: unable to start capture '%s'\n", __func__, [[error localizedDescription] UTF8String]);
         | 
| 710 | 
            +
                        GGML_ASSERT(!"capture failed");
         | 
| 711 | 
            +
                    }
         | 
| 712 | 
            +
                }
         | 
| 713 | 
            +
             | 
| 728 714 | 
             
                id<MTLCommandBuffer> command_buffer_builder[n_cb];
         | 
| 729 715 | 
             
                for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
         | 
| 730 716 | 
             
                    id<MTLCommandBuffer> command_buffer  = [ctx->queue commandBufferWithUnretainedReferences];
         | 
| @@ -733,6 +719,7 @@ static bool ggml_metal_graph_compute( | |
| 733 719 | 
             
                    // enqueue the command buffers in order to specify their execution order
         | 
| 734 720 | 
             
                    [command_buffer enqueue];
         | 
| 735 721 | 
             
                }
         | 
| 722 | 
            +
             | 
| 736 723 | 
             
                const id<MTLCommandBuffer> *command_buffers = command_buffer_builder;
         | 
| 737 724 |  | 
| 738 725 | 
             
                dispatch_apply(n_cb, ctx->d_queue, ^(size_t iter) {
         | 
| @@ -779,9 +766,9 @@ static bool ggml_metal_graph_compute( | |
| 779 766 | 
             
                            GGML_ASSERT(!"unsupported op");
         | 
| 780 767 | 
             
                        }
         | 
| 781 768 |  | 
| 782 | 
            -
             | 
| 783 | 
            -
             | 
| 784 | 
            -
             | 
| 769 | 
            +
                        if (should_capture) {
         | 
| 770 | 
            +
                            [encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(dst) encoding:NSUTF8StringEncoding]];
         | 
| 771 | 
            +
                        }
         | 
| 785 772 |  | 
| 786 773 | 
             
                        const int64_t  ne00 = src0 ? src0->ne[0] : 0;
         | 
| 787 774 | 
             
                        const int64_t  ne01 = src0 ? src0->ne[1] : 0;
         | 
| @@ -817,9 +804,9 @@ static bool ggml_metal_graph_compute( | |
| 817 804 | 
             
                        const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT;
         | 
| 818 805 | 
             
                        const enum ggml_type dstt  = dst  ? dst->type  : GGML_TYPE_COUNT;
         | 
| 819 806 |  | 
| 820 | 
            -
                        id<MTLBuffer> id_src0 = src0 ? ggml_metal_get_buffer( | 
| 821 | 
            -
                        id<MTLBuffer> id_src1 = src1 ? ggml_metal_get_buffer( | 
| 822 | 
            -
                        id<MTLBuffer> id_dst  = dst  ? ggml_metal_get_buffer( | 
| 807 | 
            +
                        id<MTLBuffer> id_src0 = src0 ? ggml_metal_get_buffer(src0, &offs_src0) : nil;
         | 
| 808 | 
            +
                        id<MTLBuffer> id_src1 = src1 ? ggml_metal_get_buffer(src1, &offs_src1) : nil;
         | 
| 809 | 
            +
                        id<MTLBuffer> id_dst  = dst  ? ggml_metal_get_buffer(dst,  &offs_dst)  : nil;
         | 
| 823 810 |  | 
| 824 811 | 
             
                        //GGML_METAL_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
         | 
| 825 812 | 
             
                        //if (src0) {
         | 
| @@ -1308,6 +1295,7 @@ static bool ggml_metal_graph_compute( | |
| 1308 1295 | 
             
                                            case GGML_TYPE_Q6_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32   ].pipeline; break;
         | 
| 1309 1296 | 
             
                                            case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32].pipeline; break;
         | 
| 1310 1297 | 
             
                                            case GGML_TYPE_IQ2_XS:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32 ].pipeline; break;
         | 
| 1298 | 
            +
                                            case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32].pipeline; break;
         | 
| 1311 1299 | 
             
                                            default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
         | 
| 1312 1300 | 
             
                                        }
         | 
| 1313 1301 |  | 
| @@ -1436,6 +1424,12 @@ static bool ggml_metal_graph_compute( | |
| 1436 1424 | 
             
                                                    nth1 = 16;
         | 
| 1437 1425 | 
             
                                                    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32].pipeline;
         | 
| 1438 1426 | 
             
                                                } break;
         | 
| 1427 | 
            +
                                            case GGML_TYPE_IQ3_XXS:
         | 
| 1428 | 
            +
                                                {
         | 
| 1429 | 
            +
                                                    nth0 = 4;
         | 
| 1430 | 
            +
                                                    nth1 = 16;
         | 
| 1431 | 
            +
                                                    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32].pipeline;
         | 
| 1432 | 
            +
                                                } break;
         | 
| 1439 1433 | 
             
                                            default:
         | 
| 1440 1434 | 
             
                                                {
         | 
| 1441 1435 | 
             
                                                    GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
         | 
| @@ -1478,6 +1472,11 @@ static bool ggml_metal_graph_compute( | |
| 1478 1472 | 
             
                                            [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
         | 
| 1479 1473 | 
             
                                            [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
         | 
| 1480 1474 | 
             
                                        }
         | 
| 1475 | 
            +
                                        else if (src0t == GGML_TYPE_IQ3_XXS) {
         | 
| 1476 | 
            +
                                            const int mem_size = 256*4+128;
         | 
| 1477 | 
            +
                                            [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
         | 
| 1478 | 
            +
                                            [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
         | 
| 1479 | 
            +
                                        }
         | 
| 1481 1480 | 
             
                                        else if (src0t == GGML_TYPE_Q4_K) {
         | 
| 1482 1481 | 
             
                                            [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
         | 
| 1483 1482 | 
             
                                        }
         | 
| @@ -1572,6 +1571,7 @@ static bool ggml_metal_graph_compute( | |
| 1572 1571 | 
             
                                            case GGML_TYPE_Q6_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32   ].pipeline; break;
         | 
| 1573 1572 | 
             
                                            case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32].pipeline; break;
         | 
| 1574 1573 | 
             
                                            case GGML_TYPE_IQ2_XS:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32 ].pipeline; break;
         | 
| 1574 | 
            +
                                            case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32].pipeline; break;
         | 
| 1575 1575 | 
             
                                            default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
         | 
| 1576 1576 | 
             
                                        }
         | 
| 1577 1577 |  | 
| @@ -1601,7 +1601,7 @@ static bool ggml_metal_graph_compute( | |
| 1601 1601 | 
             
                                            struct ggml_tensor * src_cur = dst->src[2 + (j % n_as)];
         | 
| 1602 1602 |  | 
| 1603 1603 | 
             
                                            size_t offs_src_cur = 0;
         | 
| 1604 | 
            -
                                            id<MTLBuffer> id_src_cur = ggml_metal_get_buffer( | 
| 1604 | 
            +
                                            id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(src_cur, &offs_src_cur);
         | 
| 1605 1605 |  | 
| 1606 1606 | 
             
                                            [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:19 + j];
         | 
| 1607 1607 | 
             
                                        }
         | 
| @@ -1703,6 +1703,12 @@ static bool ggml_metal_graph_compute( | |
| 1703 1703 | 
             
                                                    nth1 = 16;
         | 
| 1704 1704 | 
             
                                                    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32].pipeline;
         | 
| 1705 1705 | 
             
                                                } break;
         | 
| 1706 | 
            +
                                            case GGML_TYPE_IQ3_XXS:
         | 
| 1707 | 
            +
                                                {
         | 
| 1708 | 
            +
                                                    nth0 = 4;
         | 
| 1709 | 
            +
                                                    nth1 = 16;
         | 
| 1710 | 
            +
                                                    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32].pipeline;
         | 
| 1711 | 
            +
                                                } break;
         | 
| 1706 1712 | 
             
                                            default:
         | 
| 1707 1713 | 
             
                                                {
         | 
| 1708 1714 | 
             
                                                    GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t);
         | 
| @@ -1746,7 +1752,7 @@ static bool ggml_metal_graph_compute( | |
| 1746 1752 | 
             
                                            struct ggml_tensor * src_cur = dst->src[2 + (j % n_as)];
         | 
| 1747 1753 |  | 
| 1748 1754 | 
             
                                            size_t offs_src_cur = 0;
         | 
| 1749 | 
            -
                                            id<MTLBuffer> id_src_cur = ggml_metal_get_buffer( | 
| 1755 | 
            +
                                            id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(src_cur, &offs_src_cur);
         | 
| 1750 1756 |  | 
| 1751 1757 | 
             
                                            [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:23 + j];
         | 
| 1752 1758 | 
             
                                        }
         | 
| @@ -1761,6 +1767,11 @@ static bool ggml_metal_graph_compute( | |
| 1761 1767 | 
             
                                            [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
         | 
| 1762 1768 | 
             
                                            [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
         | 
| 1763 1769 | 
             
                                        }
         | 
| 1770 | 
            +
                                        else if (src2t == GGML_TYPE_IQ3_XXS) {
         | 
| 1771 | 
            +
                                            const int mem_size = 256*4+128;
         | 
| 1772 | 
            +
                                            [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
         | 
| 1773 | 
            +
                                            [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
         | 
| 1774 | 
            +
                                        }
         | 
| 1764 1775 | 
             
                                        else if (src2t == GGML_TYPE_Q4_K) {
         | 
| 1765 1776 | 
             
                                            [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
         | 
| 1766 1777 | 
             
                                        }
         | 
| @@ -1801,6 +1812,7 @@ static bool ggml_metal_graph_compute( | |
| 1801 1812 | 
             
                                        case GGML_TYPE_Q6_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K   ].pipeline; break;
         | 
| 1802 1813 | 
             
                                        case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS].pipeline; break;
         | 
| 1803 1814 | 
             
                                        case GGML_TYPE_IQ2_XS:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS ].pipeline; break;
         | 
| 1815 | 
            +
                                        case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS].pipeline; break;
         | 
| 1804 1816 | 
             
                                        case GGML_TYPE_I32:     pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32    ].pipeline; break;
         | 
| 1805 1817 | 
             
                                        default: GGML_ASSERT(false && "not implemented");
         | 
| 1806 1818 | 
             
                                    }
         | 
| @@ -2009,7 +2021,7 @@ static bool ggml_metal_graph_compute( | |
| 2009 2021 | 
             
                                {
         | 
| 2010 2022 | 
             
                                    GGML_ASSERT(src0->type == GGML_TYPE_F16);
         | 
| 2011 2023 | 
             
                                    GGML_ASSERT(src1->type == GGML_TYPE_F32);
         | 
| 2012 | 
            -
                                    GGML_ASSERT( dst->type == GGML_TYPE_F16);
         | 
| 2024 | 
            +
                                    GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
         | 
| 2013 2025 |  | 
| 2014 2026 | 
             
                                    const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
         | 
| 2015 2027 | 
             
                                    const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
         | 
| @@ -2017,6 +2029,7 @@ static bool ggml_metal_graph_compute( | |
| 2017 2029 | 
             
                                    const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
         | 
| 2018 2030 | 
             
                                    const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
         | 
| 2019 2031 | 
             
                                    const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
         | 
| 2032 | 
            +
             | 
| 2020 2033 | 
             
                                    const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
         | 
| 2021 2034 |  | 
| 2022 2035 | 
             
                                    const int32_t N  = src1->ne[is_2D ? 3 : 2];
         | 
| @@ -2037,8 +2050,8 @@ static bool ggml_metal_graph_compute( | |
| 2037 2050 |  | 
| 2038 2051 | 
             
                                    id<MTLComputePipelineState> pipeline = nil;
         | 
| 2039 2052 |  | 
| 2040 | 
            -
                                    switch ( | 
| 2041 | 
            -
                                        case GGML_TYPE_F32:  | 
| 2053 | 
            +
                                    switch (dst->type) {
         | 
| 2054 | 
            +
                                        case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline; break;
         | 
| 2042 2055 | 
             
                                        case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline; break;
         | 
| 2043 2056 | 
             
                                        default: GGML_ASSERT(false);
         | 
| 2044 2057 | 
             
                                    };
         | 
| @@ -2231,9 +2244,9 @@ static bool ggml_metal_graph_compute( | |
| 2231 2244 | 
             
                                }
         | 
| 2232 2245 | 
             
                        }
         | 
| 2233 2246 |  | 
| 2234 | 
            -
             | 
| 2235 | 
            -
             | 
| 2236 | 
            -
             | 
| 2247 | 
            +
                        if (should_capture) {
         | 
| 2248 | 
            +
                            [encoder popDebugGroup];
         | 
| 2249 | 
            +
                        }
         | 
| 2237 2250 | 
             
                    }
         | 
| 2238 2251 |  | 
| 2239 2252 | 
             
                    [encoder endEncoding];
         | 
| @@ -2255,6 +2268,10 @@ static bool ggml_metal_graph_compute( | |
| 2255 2268 | 
             
                    }
         | 
| 2256 2269 | 
             
                }
         | 
| 2257 2270 |  | 
| 2271 | 
            +
                if (should_capture) {
         | 
| 2272 | 
            +
                    [[MTLCaptureManager sharedCaptureManager] stopCapture];
         | 
| 2273 | 
            +
                }
         | 
| 2274 | 
            +
             | 
| 2258 2275 | 
             
                return true;
         | 
| 2259 2276 | 
             
            }
         | 
| 2260 2277 |  | 
| @@ -2423,6 +2440,16 @@ GGML_CALL static size_t ggml_backend_metal_buffer_type_get_alignment(ggml_backen | |
| 2423 2440 | 
             
                UNUSED(buft);
         | 
| 2424 2441 | 
             
            }
         | 
| 2425 2442 |  | 
| 2443 | 
            +
            GGML_CALL static size_t ggml_backend_metal_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
         | 
| 2444 | 
            +
                id<MTLDevice> device = ggml_backend_metal_get_device();
         | 
| 2445 | 
            +
                size_t max_size = device.maxBufferLength;
         | 
| 2446 | 
            +
                ggml_backend_metal_free_device();
         | 
| 2447 | 
            +
             | 
| 2448 | 
            +
                return max_size;
         | 
| 2449 | 
            +
             | 
| 2450 | 
            +
                UNUSED(buft);
         | 
| 2451 | 
            +
            }
         | 
| 2452 | 
            +
             | 
| 2426 2453 | 
             
            GGML_CALL static bool ggml_backend_metal_buffer_type_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend) {
         | 
| 2427 2454 | 
             
                return ggml_backend_is_metal(backend) || ggml_backend_is_cpu(backend);
         | 
| 2428 2455 |  | 
| @@ -2441,6 +2468,7 @@ GGML_CALL ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void) { | |
| 2441 2468 | 
             
                        /* .get_name         = */ ggml_backend_metal_buffer_type_get_name,
         | 
| 2442 2469 | 
             
                        /* .alloc_buffer     = */ ggml_backend_metal_buffer_type_alloc_buffer,
         | 
| 2443 2470 | 
             
                        /* .get_alignment    = */ ggml_backend_metal_buffer_type_get_alignment,
         | 
| 2471 | 
            +
                        /* .get_max_size     = */ ggml_backend_metal_buffer_type_get_max_size,
         | 
| 2444 2472 | 
             
                        /* .get_alloc_size   = */ NULL, // defaults to ggml_nbytes
         | 
| 2445 2473 | 
             
                        /* .supports_backend = */ ggml_backend_metal_buffer_type_supports_backend,
         | 
| 2446 2474 | 
             
                        /* .is_host          = */ ggml_backend_metal_buffer_type_is_host,
         | 
| @@ -2615,6 +2643,13 @@ bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family) { | |
| 2615 2643 | 
             
                return [ctx->device supportsFamily:(MTLGPUFamilyApple1 + family - 1)];
         | 
| 2616 2644 | 
             
            }
         | 
| 2617 2645 |  | 
| 2646 | 
            +
            void ggml_backend_metal_capture_next_compute(ggml_backend_t backend) {
         | 
| 2647 | 
            +
                GGML_ASSERT(ggml_backend_is_metal(backend));
         | 
| 2648 | 
            +
             | 
| 2649 | 
            +
                struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
         | 
| 2650 | 
            +
                ctx->should_capture_next_compute = true;
         | 
| 2651 | 
            +
            }
         | 
| 2652 | 
            +
             | 
| 2618 2653 | 
             
            GGML_CALL ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data); // silence warning
         | 
| 2619 2654 |  | 
| 2620 2655 | 
             
            GGML_CALL ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data) {
         |