llama_cpp 0.12.3 → 0.12.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +13 -0
- data/ext/llama_cpp/extconf.rb +1 -0
- data/ext/llama_cpp/llama_cpp.cpp +22 -6
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +4 -2
- data/vendor/tmp/llama.cpp/Makefile +160 -56
- data/vendor/tmp/llama.cpp/ggml-alloc.c +85 -25
- data/vendor/tmp/llama.cpp/ggml-backend-impl.h +6 -0
- data/vendor/tmp/llama.cpp/ggml-backend.c +115 -3
- data/vendor/tmp/llama.cpp/ggml-backend.h +3 -0
- data/vendor/tmp/llama.cpp/ggml-cuda.cu +688 -270
- data/vendor/tmp/llama.cpp/ggml-impl.h +2 -0
- data/vendor/tmp/llama.cpp/ggml-kompute.cpp +1990 -0
- data/vendor/tmp/llama.cpp/ggml-kompute.h +46 -0
- data/vendor/tmp/llama.cpp/ggml-metal.h +3 -0
- data/vendor/tmp/llama.cpp/ggml-metal.m +121 -86
- data/vendor/tmp/llama.cpp/ggml-metal.metal +303 -4
- data/vendor/tmp/llama.cpp/ggml-opencl.cpp +95 -3
- data/vendor/tmp/llama.cpp/ggml-opencl.h +1 -0
- data/vendor/tmp/llama.cpp/ggml-quants.c +745 -109
- data/vendor/tmp/llama.cpp/ggml-quants.h +81 -56
- data/vendor/tmp/llama.cpp/ggml-sycl.cpp +15296 -0
- data/vendor/tmp/llama.cpp/ggml-sycl.h +29 -0
- data/vendor/tmp/llama.cpp/ggml-vulkan-shaders.hpp +51714 -0
- data/vendor/tmp/llama.cpp/ggml-vulkan.cpp +5726 -0
- data/vendor/tmp/llama.cpp/ggml-vulkan.h +39 -0
- data/vendor/tmp/llama.cpp/ggml.c +356 -60
- data/vendor/tmp/llama.cpp/ggml.h +7 -1
- data/vendor/tmp/llama.cpp/llama.cpp +876 -118
- data/vendor/tmp/llama.cpp/llama.h +12 -16
- metadata +9 -2
@@ -0,0 +1,46 @@
|
|
1
|
+
#pragma once
|
2
|
+
|
3
|
+
#include "ggml.h"
|
4
|
+
#include "ggml-backend.h"
|
5
|
+
|
6
|
+
#include <stdbool.h>
|
7
|
+
#include <stddef.h>
|
8
|
+
#include <stdint.h>
|
9
|
+
|
10
|
+
#ifdef __cplusplus
|
11
|
+
extern "C" {
|
12
|
+
#endif
|
13
|
+
|
14
|
+
struct ggml_vk_device {
|
15
|
+
int index;
|
16
|
+
int type; // same as VkPhysicalDeviceType
|
17
|
+
size_t heapSize;
|
18
|
+
const char * name;
|
19
|
+
const char * vendor;
|
20
|
+
int subgroupSize;
|
21
|
+
uint64_t bufferAlignment;
|
22
|
+
uint64_t maxAlloc;
|
23
|
+
};
|
24
|
+
|
25
|
+
struct ggml_vk_device * ggml_vk_available_devices(size_t memoryRequired, size_t * count);
|
26
|
+
bool ggml_vk_get_device(struct ggml_vk_device * device, size_t memoryRequired, const char * name);
|
27
|
+
bool ggml_vk_has_vulkan(void);
|
28
|
+
bool ggml_vk_has_device(void);
|
29
|
+
struct ggml_vk_device ggml_vk_current_device(void);
|
30
|
+
|
31
|
+
//
|
32
|
+
// backend API
|
33
|
+
//
|
34
|
+
|
35
|
+
// forward declaration
|
36
|
+
typedef struct ggml_backend * ggml_backend_t;
|
37
|
+
|
38
|
+
GGML_API ggml_backend_t ggml_backend_kompute_init(int device);
|
39
|
+
|
40
|
+
GGML_API bool ggml_backend_is_kompute(ggml_backend_t backend);
|
41
|
+
|
42
|
+
GGML_API ggml_backend_buffer_type_t ggml_backend_kompute_buffer_type(int device);
|
43
|
+
|
44
|
+
#ifdef __cplusplus
|
45
|
+
}
|
46
|
+
#endif
|
@@ -57,6 +57,9 @@ GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(voi
|
|
57
57
|
// ref: https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
|
58
58
|
GGML_API bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family);
|
59
59
|
|
60
|
+
// capture all command buffers committed the next time `ggml_backend_graph_compute` is called
|
61
|
+
GGML_API void ggml_backend_metal_capture_next_compute(ggml_backend_t backend);
|
62
|
+
|
60
63
|
#ifdef __cplusplus
|
61
64
|
}
|
62
65
|
#endif
|
@@ -24,19 +24,7 @@
|
|
24
24
|
|
25
25
|
#define UNUSED(x) (void)(x)
|
26
26
|
|
27
|
-
#define GGML_METAL_MAX_KERNELS 256
|
28
|
-
|
29
|
-
struct ggml_metal_buffer {
|
30
|
-
const char * name;
|
31
|
-
|
32
|
-
void * data;
|
33
|
-
size_t size;
|
34
|
-
|
35
|
-
id<MTLBuffer> metal;
|
36
|
-
};
|
37
|
-
|
38
27
|
struct ggml_metal_kernel {
|
39
|
-
id<MTLFunction> function;
|
40
28
|
id<MTLComputePipelineState> pipeline;
|
41
29
|
};
|
42
30
|
|
@@ -72,6 +60,7 @@ enum ggml_metal_kernel_type {
|
|
72
60
|
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K,
|
73
61
|
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS,
|
74
62
|
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS,
|
63
|
+
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS,
|
75
64
|
GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
|
76
65
|
GGML_METAL_KERNEL_TYPE_RMS_NORM,
|
77
66
|
GGML_METAL_KERNEL_TYPE_GROUP_NORM,
|
@@ -93,6 +82,7 @@ enum ggml_metal_kernel_type {
|
|
93
82
|
GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32,
|
94
83
|
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32,
|
95
84
|
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32,
|
85
|
+
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32,
|
96
86
|
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,
|
97
87
|
//GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16,
|
98
88
|
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32,
|
@@ -110,6 +100,7 @@ enum ggml_metal_kernel_type {
|
|
110
100
|
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32,
|
111
101
|
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32,
|
112
102
|
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32,
|
103
|
+
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32,
|
113
104
|
GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,
|
114
105
|
GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32,
|
115
106
|
GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32,
|
@@ -124,6 +115,7 @@ enum ggml_metal_kernel_type {
|
|
124
115
|
GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32,
|
125
116
|
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32,
|
126
117
|
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32,
|
118
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32,
|
127
119
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,
|
128
120
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32,
|
129
121
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32,
|
@@ -138,10 +130,12 @@ enum ggml_metal_kernel_type {
|
|
138
130
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32,
|
139
131
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32,
|
140
132
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32,
|
133
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32,
|
141
134
|
GGML_METAL_KERNEL_TYPE_ROPE_F32,
|
142
135
|
GGML_METAL_KERNEL_TYPE_ROPE_F16,
|
143
136
|
GGML_METAL_KERNEL_TYPE_ALIBI_F32,
|
144
137
|
GGML_METAL_KERNEL_TYPE_IM2COL_F16,
|
138
|
+
GGML_METAL_KERNEL_TYPE_IM2COL_F32,
|
145
139
|
GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
|
146
140
|
GGML_METAL_KERNEL_TYPE_PAD_F32,
|
147
141
|
GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,
|
@@ -168,17 +162,15 @@ struct ggml_metal_context {
|
|
168
162
|
|
169
163
|
id<MTLDevice> device;
|
170
164
|
id<MTLCommandQueue> queue;
|
171
|
-
id<MTLLibrary> library;
|
172
165
|
|
173
166
|
dispatch_queue_t d_queue;
|
174
167
|
|
175
|
-
|
176
|
-
struct ggml_metal_buffer buffers[GGML_METAL_MAX_BUFFERS];
|
177
|
-
|
178
|
-
struct ggml_metal_kernel kernels[GGML_METAL_MAX_KERNELS];
|
168
|
+
struct ggml_metal_kernel kernels[GGML_METAL_KERNEL_TYPE_COUNT];
|
179
169
|
|
180
170
|
bool support_simdgroup_reduction;
|
181
171
|
bool support_simdgroup_mm;
|
172
|
+
|
173
|
+
bool should_capture_next_compute;
|
182
174
|
};
|
183
175
|
|
184
176
|
// MSL code
|
@@ -242,26 +234,24 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
242
234
|
// Show all the Metal device instances in the system
|
243
235
|
NSArray * devices = MTLCopyAllDevices();
|
244
236
|
for (id<MTLDevice> device in devices) {
|
245
|
-
|
246
|
-
GGML_METAL_LOG_INFO("%s: found device: %s\n", __func__, [s UTF8String]);
|
237
|
+
GGML_METAL_LOG_INFO("%s: found device: %s\n", __func__, [[device name] UTF8String]);
|
247
238
|
}
|
248
239
|
[devices release]; // since it was created by a *Copy* C method
|
249
240
|
#endif
|
250
241
|
|
251
242
|
// Pick and show default Metal device
|
252
243
|
id<MTLDevice> device = MTLCreateSystemDefaultDevice();
|
253
|
-
|
254
|
-
GGML_METAL_LOG_INFO("%s: picking default device: %s\n", __func__, [s UTF8String]);
|
244
|
+
GGML_METAL_LOG_INFO("%s: picking default device: %s\n", __func__, [[device name] UTF8String]);
|
255
245
|
|
256
246
|
// Configure context
|
257
247
|
struct ggml_metal_context * ctx = malloc(sizeof(struct ggml_metal_context));
|
258
248
|
ctx->device = device;
|
259
249
|
ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_BUFFERS);
|
260
250
|
ctx->queue = [ctx->device newCommandQueue];
|
261
|
-
ctx->n_buffers = 0;
|
262
|
-
|
263
251
|
ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
|
264
252
|
|
253
|
+
id<MTLLibrary> metal_library;
|
254
|
+
|
265
255
|
// load library
|
266
256
|
{
|
267
257
|
NSBundle * bundle = nil;
|
@@ -276,7 +266,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
276
266
|
// pre-compiled library found
|
277
267
|
NSURL * libURL = [NSURL fileURLWithPath:libPath];
|
278
268
|
GGML_METAL_LOG_INFO("%s: loading '%s'\n", __func__, [libPath UTF8String]);
|
279
|
-
|
269
|
+
metal_library = [ctx->device newLibraryWithURL:libURL error:&error];
|
280
270
|
if (error) {
|
281
271
|
GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
|
282
272
|
return NULL;
|
@@ -318,7 +308,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
318
308
|
|
319
309
|
//[options setFastMathEnabled:false];
|
320
310
|
|
321
|
-
|
311
|
+
metal_library = [ctx->device newLibraryWithSource:src options:options error:&error];
|
322
312
|
if (error) {
|
323
313
|
GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
|
324
314
|
return NULL;
|
@@ -367,6 +357,8 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
367
357
|
GGML_METAL_LOG_INFO("%s: simdgroup matrix mul. support = %s\n", __func__, ctx->support_simdgroup_mm ? "true" : "false");
|
368
358
|
GGML_METAL_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
|
369
359
|
|
360
|
+
ctx->should_capture_next_compute = false;
|
361
|
+
|
370
362
|
#if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)
|
371
363
|
if (@available(macOS 10.12, iOS 16.0, *)) {
|
372
364
|
GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1e6);
|
@@ -383,8 +375,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
383
375
|
{
|
384
376
|
NSError * error = nil;
|
385
377
|
|
386
|
-
for (int i = 0; i <
|
387
|
-
ctx->kernels[i].function = nil;
|
378
|
+
for (int i = 0; i < GGML_METAL_KERNEL_TYPE_COUNT; ++i) {
|
388
379
|
ctx->kernels[i].pipeline = nil;
|
389
380
|
}
|
390
381
|
|
@@ -396,10 +387,12 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
396
387
|
#define GGML_METAL_ADD_KERNEL(e, name, supported) \
|
397
388
|
if (supported) { \
|
398
389
|
struct ggml_metal_kernel * kernel = &ctx->kernels[e]; \
|
399
|
-
|
400
|
-
kernel->pipeline = [ctx->device newComputePipelineStateWithFunction:
|
390
|
+
id<MTLFunction> metal_function = [metal_library newFunctionWithName:@"kernel_"#name]; \
|
391
|
+
kernel->pipeline = [ctx->device newComputePipelineStateWithFunction:metal_function error:&error]; \
|
392
|
+
[metal_function release]; \
|
401
393
|
if (error) { \
|
402
394
|
GGML_METAL_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
|
395
|
+
[metal_library release]; \
|
403
396
|
return NULL; \
|
404
397
|
} \
|
405
398
|
} else { \
|
@@ -439,6 +432,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
439
432
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K, get_rows_q6_K, true);
|
440
433
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, get_rows_iq2_xxs, true);
|
441
434
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true);
|
435
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS, get_rows_iq3_xxs, true);
|
442
436
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
|
443
437
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction);
|
444
438
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction);
|
@@ -460,6 +454,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
460
454
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32, mul_mv_q6_K_f32, ctx->support_simdgroup_reduction);
|
461
455
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, ctx->support_simdgroup_reduction);
|
462
456
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, ctx->support_simdgroup_reduction);
|
457
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, mul_mv_iq3_xxs_f32, ctx->support_simdgroup_reduction);
|
463
458
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction);
|
464
459
|
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, ctx->support_simdgroup_reduction);
|
465
460
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, ctx->support_simdgroup_reduction);
|
@@ -477,6 +472,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
477
472
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32, mul_mv_id_q6_K_f32, ctx->support_simdgroup_reduction);
|
478
473
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, ctx->support_simdgroup_reduction);
|
479
474
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, ctx->support_simdgroup_reduction);
|
475
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, mul_mv_id_iq3_xxs_f32, ctx->support_simdgroup_reduction);
|
480
476
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm);
|
481
477
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, ctx->support_simdgroup_mm);
|
482
478
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, ctx->support_simdgroup_mm);
|
@@ -491,6 +487,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
491
487
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32, mul_mm_q6_K_f32, ctx->support_simdgroup_mm);
|
492
488
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, ctx->support_simdgroup_mm);
|
493
489
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, ctx->support_simdgroup_mm);
|
490
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, mul_mm_iq3_xxs_f32, ctx->support_simdgroup_mm);
|
494
491
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm);
|
495
492
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, ctx->support_simdgroup_mm);
|
496
493
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, ctx->support_simdgroup_mm);
|
@@ -505,10 +502,12 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
505
502
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, mul_mm_id_q6_K_f32, ctx->support_simdgroup_mm);
|
506
503
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, ctx->support_simdgroup_mm);
|
507
504
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, ctx->support_simdgroup_mm);
|
505
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, mul_mm_id_iq3_xxs_f32, ctx->support_simdgroup_mm);
|
508
506
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true);
|
509
507
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true);
|
510
508
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true);
|
511
509
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
|
510
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
|
512
511
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
|
513
512
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true);
|
514
513
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
|
@@ -528,27 +527,17 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
528
527
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
|
529
528
|
}
|
530
529
|
|
530
|
+
[metal_library release];
|
531
531
|
return ctx;
|
532
532
|
}
|
533
533
|
|
534
534
|
static void ggml_metal_free(struct ggml_metal_context * ctx) {
|
535
535
|
GGML_METAL_LOG_INFO("%s: deallocating\n", __func__);
|
536
536
|
|
537
|
-
for (int i = 0; i <
|
538
|
-
[ctx->
|
537
|
+
for (int i = 0; i < GGML_METAL_KERNEL_TYPE_COUNT; ++i) {
|
538
|
+
[ctx->kernels[i].pipeline release];
|
539
539
|
}
|
540
540
|
|
541
|
-
for (int i = 0; i < GGML_METAL_MAX_KERNELS; ++i) {
|
542
|
-
if (ctx->kernels[i].pipeline) {
|
543
|
-
[ctx->kernels[i].pipeline release];
|
544
|
-
}
|
545
|
-
|
546
|
-
if (ctx->kernels[i].function) {
|
547
|
-
[ctx->kernels[i].function release];
|
548
|
-
}
|
549
|
-
}
|
550
|
-
|
551
|
-
[ctx->library release];
|
552
541
|
[ctx->queue release];
|
553
542
|
[ctx->device release];
|
554
543
|
|
@@ -580,51 +569,30 @@ struct ggml_backend_metal_buffer_context {
|
|
580
569
|
// the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the
|
581
570
|
// Metal buffer based on the host memory pointer
|
582
571
|
//
|
583
|
-
static id<MTLBuffer> ggml_metal_get_buffer(struct
|
572
|
+
static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_tensor * t, size_t * offs) {
|
584
573
|
//GGML_METAL_LOG_INFO("%s: data tensor '%16s', offs_data = %8ld, offs_eval = %8ld, offs_cach = %8ld\n", __func__, t->name, offs_data, offs_eval, offs_cach);
|
585
574
|
|
586
575
|
const int64_t tsize = ggml_nbytes(t);
|
587
576
|
|
588
577
|
ggml_backend_buffer_t buffer = t->view_src ? t->view_src->buffer : t->buffer;
|
589
578
|
|
590
|
-
|
591
|
-
if (buffer && buffer->buft == ggml_backend_metal_buffer_type()) {
|
592
|
-
struct ggml_backend_metal_buffer_context * buf_ctx = (struct ggml_backend_metal_buffer_context *) buffer->context;
|
593
|
-
|
594
|
-
// find the view that contains the tensor fully
|
595
|
-
for (int i = 0; i < buf_ctx->n_buffers; ++i) {
|
596
|
-
const int64_t ioffs = (int64_t) t->data - (int64_t) buf_ctx->buffers[i].data;
|
597
|
-
|
598
|
-
//GGML_METAL_LOG_INFO("ioffs = %10ld, tsize = %10ld, sum = %10ld, buf_ctx->buffers[%d].size = %10ld\n", ioffs, tsize, ioffs + tsize, i, buf_ctx->buffers[i].size);
|
599
|
-
if (ioffs >= 0 && ioffs + tsize <= (int64_t) buf_ctx->buffers[i].size) {
|
600
|
-
*offs = (size_t) ioffs;
|
601
|
-
|
602
|
-
//GGML_METAL_LOG_INFO("%s: tensor '%16s', offs = %8ld\n", __func__, t->name, *offs);
|
603
|
-
|
604
|
-
return buf_ctx->buffers[i].metal;
|
605
|
-
}
|
606
|
-
}
|
607
|
-
|
608
|
-
GGML_METAL_LOG_ERROR("%s: error: tensor '%s' buffer is nil\n", __func__, t->name);
|
609
|
-
|
610
|
-
return nil;
|
611
|
-
}
|
579
|
+
struct ggml_backend_metal_buffer_context * buf_ctx = (struct ggml_backend_metal_buffer_context *) buffer->context;
|
612
580
|
|
613
581
|
// find the view that contains the tensor fully
|
614
|
-
for (int i = 0; i <
|
615
|
-
const int64_t ioffs = (int64_t) t->data - (int64_t)
|
582
|
+
for (int i = 0; i < buf_ctx->n_buffers; ++i) {
|
583
|
+
const int64_t ioffs = (int64_t) t->data - (int64_t) buf_ctx->buffers[i].data;
|
616
584
|
|
617
|
-
//GGML_METAL_LOG_INFO("ioffs = %10ld, tsize = %10ld, sum = %10ld,
|
618
|
-
if (ioffs >= 0 && ioffs + tsize <= (int64_t)
|
585
|
+
//GGML_METAL_LOG_INFO("ioffs = %10ld, tsize = %10ld, sum = %10ld, buf_ctx->buffers[%d].size = %10ld\n", ioffs, tsize, ioffs + tsize, i, buf_ctx->buffers[i].size);
|
586
|
+
if (ioffs >= 0 && ioffs + tsize <= (int64_t) buf_ctx->buffers[i].size) {
|
619
587
|
*offs = (size_t) ioffs;
|
620
588
|
|
621
|
-
//GGML_METAL_LOG_INFO("%s:
|
589
|
+
//GGML_METAL_LOG_INFO("%s: tensor '%16s', offs = %8ld\n", __func__, t->name, *offs);
|
622
590
|
|
623
|
-
return
|
591
|
+
return buf_ctx->buffers[i].metal;
|
624
592
|
}
|
625
593
|
}
|
626
594
|
|
627
|
-
GGML_METAL_LOG_ERROR("%s: error: buffer is nil\n", __func__);
|
595
|
+
GGML_METAL_LOG_ERROR("%s: error: tensor '%s' buffer is nil\n", __func__, t->name);
|
628
596
|
|
629
597
|
return nil;
|
630
598
|
}
|
@@ -664,6 +632,10 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
|
|
664
632
|
case GGML_OP_ALIBI:
|
665
633
|
case GGML_OP_ROPE:
|
666
634
|
case GGML_OP_IM2COL:
|
635
|
+
return true;
|
636
|
+
case GGML_OP_POOL_1D:
|
637
|
+
case GGML_OP_POOL_2D:
|
638
|
+
return false;
|
667
639
|
case GGML_OP_UPSCALE:
|
668
640
|
case GGML_OP_PAD:
|
669
641
|
case GGML_OP_ARGSORT:
|
@@ -725,6 +697,20 @@ static bool ggml_metal_graph_compute(
|
|
725
697
|
const int n_cb = ctx->n_cb;
|
726
698
|
const int n_nodes_per_cb = (n_nodes + n_cb - 1) / n_cb;
|
727
699
|
|
700
|
+
const bool should_capture = ctx->should_capture_next_compute;
|
701
|
+
if (should_capture) {
|
702
|
+
ctx->should_capture_next_compute = false;
|
703
|
+
|
704
|
+
MTLCaptureDescriptor * descriptor = [MTLCaptureDescriptor new];
|
705
|
+
descriptor.captureObject = ctx->queue;
|
706
|
+
|
707
|
+
NSError * error = nil;
|
708
|
+
if (![[MTLCaptureManager sharedCaptureManager] startCaptureWithDescriptor:descriptor error:&error]) {
|
709
|
+
GGML_METAL_LOG_ERROR("%s: error: unable to start capture '%s'\n", __func__, [[error localizedDescription] UTF8String]);
|
710
|
+
GGML_ASSERT(!"capture failed");
|
711
|
+
}
|
712
|
+
}
|
713
|
+
|
728
714
|
id<MTLCommandBuffer> command_buffer_builder[n_cb];
|
729
715
|
for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
|
730
716
|
id<MTLCommandBuffer> command_buffer = [ctx->queue commandBufferWithUnretainedReferences];
|
@@ -733,6 +719,7 @@ static bool ggml_metal_graph_compute(
|
|
733
719
|
// enqueue the command buffers in order to specify their execution order
|
734
720
|
[command_buffer enqueue];
|
735
721
|
}
|
722
|
+
|
736
723
|
const id<MTLCommandBuffer> *command_buffers = command_buffer_builder;
|
737
724
|
|
738
725
|
dispatch_apply(n_cb, ctx->d_queue, ^(size_t iter) {
|
@@ -779,9 +766,9 @@ static bool ggml_metal_graph_compute(
|
|
779
766
|
GGML_ASSERT(!"unsupported op");
|
780
767
|
}
|
781
768
|
|
782
|
-
|
783
|
-
|
784
|
-
|
769
|
+
if (should_capture) {
|
770
|
+
[encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(dst) encoding:NSUTF8StringEncoding]];
|
771
|
+
}
|
785
772
|
|
786
773
|
const int64_t ne00 = src0 ? src0->ne[0] : 0;
|
787
774
|
const int64_t ne01 = src0 ? src0->ne[1] : 0;
|
@@ -817,9 +804,9 @@ static bool ggml_metal_graph_compute(
|
|
817
804
|
const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT;
|
818
805
|
const enum ggml_type dstt = dst ? dst->type : GGML_TYPE_COUNT;
|
819
806
|
|
820
|
-
id<MTLBuffer> id_src0 = src0 ? ggml_metal_get_buffer(
|
821
|
-
id<MTLBuffer> id_src1 = src1 ? ggml_metal_get_buffer(
|
822
|
-
id<MTLBuffer> id_dst = dst ? ggml_metal_get_buffer(
|
807
|
+
id<MTLBuffer> id_src0 = src0 ? ggml_metal_get_buffer(src0, &offs_src0) : nil;
|
808
|
+
id<MTLBuffer> id_src1 = src1 ? ggml_metal_get_buffer(src1, &offs_src1) : nil;
|
809
|
+
id<MTLBuffer> id_dst = dst ? ggml_metal_get_buffer(dst, &offs_dst) : nil;
|
823
810
|
|
824
811
|
//GGML_METAL_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
|
825
812
|
//if (src0) {
|
@@ -1308,6 +1295,7 @@ static bool ggml_metal_graph_compute(
|
|
1308
1295
|
case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32 ].pipeline; break;
|
1309
1296
|
case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32].pipeline; break;
|
1310
1297
|
case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32 ].pipeline; break;
|
1298
|
+
case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32].pipeline; break;
|
1311
1299
|
default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
|
1312
1300
|
}
|
1313
1301
|
|
@@ -1436,6 +1424,12 @@ static bool ggml_metal_graph_compute(
|
|
1436
1424
|
nth1 = 16;
|
1437
1425
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32].pipeline;
|
1438
1426
|
} break;
|
1427
|
+
case GGML_TYPE_IQ3_XXS:
|
1428
|
+
{
|
1429
|
+
nth0 = 4;
|
1430
|
+
nth1 = 16;
|
1431
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32].pipeline;
|
1432
|
+
} break;
|
1439
1433
|
default:
|
1440
1434
|
{
|
1441
1435
|
GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
|
@@ -1478,6 +1472,11 @@ static bool ggml_metal_graph_compute(
|
|
1478
1472
|
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
1479
1473
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1480
1474
|
}
|
1475
|
+
else if (src0t == GGML_TYPE_IQ3_XXS) {
|
1476
|
+
const int mem_size = 256*4+128;
|
1477
|
+
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
1478
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1479
|
+
}
|
1481
1480
|
else if (src0t == GGML_TYPE_Q4_K) {
|
1482
1481
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1483
1482
|
}
|
@@ -1572,6 +1571,7 @@ static bool ggml_metal_graph_compute(
|
|
1572
1571
|
case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32 ].pipeline; break;
|
1573
1572
|
case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32].pipeline; break;
|
1574
1573
|
case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32 ].pipeline; break;
|
1574
|
+
case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32].pipeline; break;
|
1575
1575
|
default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
|
1576
1576
|
}
|
1577
1577
|
|
@@ -1601,7 +1601,7 @@ static bool ggml_metal_graph_compute(
|
|
1601
1601
|
struct ggml_tensor * src_cur = dst->src[2 + (j % n_as)];
|
1602
1602
|
|
1603
1603
|
size_t offs_src_cur = 0;
|
1604
|
-
id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(
|
1604
|
+
id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(src_cur, &offs_src_cur);
|
1605
1605
|
|
1606
1606
|
[encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:19 + j];
|
1607
1607
|
}
|
@@ -1703,6 +1703,12 @@ static bool ggml_metal_graph_compute(
|
|
1703
1703
|
nth1 = 16;
|
1704
1704
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32].pipeline;
|
1705
1705
|
} break;
|
1706
|
+
case GGML_TYPE_IQ3_XXS:
|
1707
|
+
{
|
1708
|
+
nth0 = 4;
|
1709
|
+
nth1 = 16;
|
1710
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32].pipeline;
|
1711
|
+
} break;
|
1706
1712
|
default:
|
1707
1713
|
{
|
1708
1714
|
GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t);
|
@@ -1746,7 +1752,7 @@ static bool ggml_metal_graph_compute(
|
|
1746
1752
|
struct ggml_tensor * src_cur = dst->src[2 + (j % n_as)];
|
1747
1753
|
|
1748
1754
|
size_t offs_src_cur = 0;
|
1749
|
-
id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(
|
1755
|
+
id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(src_cur, &offs_src_cur);
|
1750
1756
|
|
1751
1757
|
[encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:23 + j];
|
1752
1758
|
}
|
@@ -1761,6 +1767,11 @@ static bool ggml_metal_graph_compute(
|
|
1761
1767
|
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
1762
1768
|
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1763
1769
|
}
|
1770
|
+
else if (src2t == GGML_TYPE_IQ3_XXS) {
|
1771
|
+
const int mem_size = 256*4+128;
|
1772
|
+
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
1773
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1774
|
+
}
|
1764
1775
|
else if (src2t == GGML_TYPE_Q4_K) {
|
1765
1776
|
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1766
1777
|
}
|
@@ -1801,6 +1812,7 @@ static bool ggml_metal_graph_compute(
|
|
1801
1812
|
case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K ].pipeline; break;
|
1802
1813
|
case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS].pipeline; break;
|
1803
1814
|
case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS ].pipeline; break;
|
1815
|
+
case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS].pipeline; break;
|
1804
1816
|
case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32 ].pipeline; break;
|
1805
1817
|
default: GGML_ASSERT(false && "not implemented");
|
1806
1818
|
}
|
@@ -2009,7 +2021,7 @@ static bool ggml_metal_graph_compute(
|
|
2009
2021
|
{
|
2010
2022
|
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
2011
2023
|
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
2012
|
-
GGML_ASSERT( dst->type == GGML_TYPE_F16);
|
2024
|
+
GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
|
2013
2025
|
|
2014
2026
|
const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
|
2015
2027
|
const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
|
@@ -2017,6 +2029,7 @@ static bool ggml_metal_graph_compute(
|
|
2017
2029
|
const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
|
2018
2030
|
const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
|
2019
2031
|
const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
|
2032
|
+
|
2020
2033
|
const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
|
2021
2034
|
|
2022
2035
|
const int32_t N = src1->ne[is_2D ? 3 : 2];
|
@@ -2037,8 +2050,8 @@ static bool ggml_metal_graph_compute(
|
|
2037
2050
|
|
2038
2051
|
id<MTLComputePipelineState> pipeline = nil;
|
2039
2052
|
|
2040
|
-
switch (
|
2041
|
-
case GGML_TYPE_F32:
|
2053
|
+
switch (dst->type) {
|
2054
|
+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline; break;
|
2042
2055
|
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline; break;
|
2043
2056
|
default: GGML_ASSERT(false);
|
2044
2057
|
};
|
@@ -2231,9 +2244,9 @@ static bool ggml_metal_graph_compute(
|
|
2231
2244
|
}
|
2232
2245
|
}
|
2233
2246
|
|
2234
|
-
|
2235
|
-
|
2236
|
-
|
2247
|
+
if (should_capture) {
|
2248
|
+
[encoder popDebugGroup];
|
2249
|
+
}
|
2237
2250
|
}
|
2238
2251
|
|
2239
2252
|
[encoder endEncoding];
|
@@ -2255,6 +2268,10 @@ static bool ggml_metal_graph_compute(
|
|
2255
2268
|
}
|
2256
2269
|
}
|
2257
2270
|
|
2271
|
+
if (should_capture) {
|
2272
|
+
[[MTLCaptureManager sharedCaptureManager] stopCapture];
|
2273
|
+
}
|
2274
|
+
|
2258
2275
|
return true;
|
2259
2276
|
}
|
2260
2277
|
|
@@ -2423,6 +2440,16 @@ GGML_CALL static size_t ggml_backend_metal_buffer_type_get_alignment(ggml_backen
|
|
2423
2440
|
UNUSED(buft);
|
2424
2441
|
}
|
2425
2442
|
|
2443
|
+
GGML_CALL static size_t ggml_backend_metal_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
|
2444
|
+
id<MTLDevice> device = ggml_backend_metal_get_device();
|
2445
|
+
size_t max_size = device.maxBufferLength;
|
2446
|
+
ggml_backend_metal_free_device();
|
2447
|
+
|
2448
|
+
return max_size;
|
2449
|
+
|
2450
|
+
UNUSED(buft);
|
2451
|
+
}
|
2452
|
+
|
2426
2453
|
GGML_CALL static bool ggml_backend_metal_buffer_type_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend) {
|
2427
2454
|
return ggml_backend_is_metal(backend) || ggml_backend_is_cpu(backend);
|
2428
2455
|
|
@@ -2441,6 +2468,7 @@ GGML_CALL ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void) {
|
|
2441
2468
|
/* .get_name = */ ggml_backend_metal_buffer_type_get_name,
|
2442
2469
|
/* .alloc_buffer = */ ggml_backend_metal_buffer_type_alloc_buffer,
|
2443
2470
|
/* .get_alignment = */ ggml_backend_metal_buffer_type_get_alignment,
|
2471
|
+
/* .get_max_size = */ ggml_backend_metal_buffer_type_get_max_size,
|
2444
2472
|
/* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
|
2445
2473
|
/* .supports_backend = */ ggml_backend_metal_buffer_type_supports_backend,
|
2446
2474
|
/* .is_host = */ ggml_backend_metal_buffer_type_is_host,
|
@@ -2615,6 +2643,13 @@ bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family) {
|
|
2615
2643
|
return [ctx->device supportsFamily:(MTLGPUFamilyApple1 + family - 1)];
|
2616
2644
|
}
|
2617
2645
|
|
2646
|
+
void ggml_backend_metal_capture_next_compute(ggml_backend_t backend) {
|
2647
|
+
GGML_ASSERT(ggml_backend_is_metal(backend));
|
2648
|
+
|
2649
|
+
struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
|
2650
|
+
ctx->should_capture_next_compute = true;
|
2651
|
+
}
|
2652
|
+
|
2618
2653
|
GGML_CALL ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data); // silence warning
|
2619
2654
|
|
2620
2655
|
GGML_CALL ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data) {
|