llama_cpp 0.12.2 → 0.12.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +15 -0
- data/README.md +2 -2
- data/ext/llama_cpp/extconf.rb +1 -0
- data/ext/llama_cpp/llama_cpp.cpp +68 -6
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +6 -2
- data/vendor/tmp/llama.cpp/Makefile +25 -3
- data/vendor/tmp/llama.cpp/ggml-alloc.c +87 -27
- data/vendor/tmp/llama.cpp/ggml-backend-impl.h +6 -0
- data/vendor/tmp/llama.cpp/ggml-backend.c +176 -18
- data/vendor/tmp/llama.cpp/ggml-backend.h +14 -0
- data/vendor/tmp/llama.cpp/ggml-kompute.cpp +1990 -0
- data/vendor/tmp/llama.cpp/ggml-kompute.h +46 -0
- data/vendor/tmp/llama.cpp/ggml-metal.h +3 -0
- data/vendor/tmp/llama.cpp/ggml-metal.m +144 -113
- data/vendor/tmp/llama.cpp/ggml-metal.metal +303 -4
- data/vendor/tmp/llama.cpp/ggml-opencl.cpp +95 -3
- data/vendor/tmp/llama.cpp/ggml-opencl.h +1 -0
- data/vendor/tmp/llama.cpp/ggml-quants.c +736 -59
- data/vendor/tmp/llama.cpp/ggml-quants.h +20 -1
- data/vendor/tmp/llama.cpp/ggml-sycl.cpp +15255 -0
- data/vendor/tmp/llama.cpp/ggml-sycl.h +29 -0
- data/vendor/tmp/llama.cpp/ggml-vulkan-shaders.hpp +60854 -0
- data/vendor/tmp/llama.cpp/ggml-vulkan.cpp +5270 -0
- data/vendor/tmp/llama.cpp/ggml-vulkan.h +34 -0
- data/vendor/tmp/llama.cpp/ggml.c +664 -117
- data/vendor/tmp/llama.cpp/ggml.h +46 -11
- data/vendor/tmp/llama.cpp/llama.cpp +1426 -341
- data/vendor/tmp/llama.cpp/llama.h +24 -15
- data/vendor/tmp/llama.cpp/unicode.h +2 -1
- metadata +10 -3
@@ -24,19 +24,7 @@
|
|
24
24
|
|
25
25
|
#define UNUSED(x) (void)(x)
|
26
26
|
|
27
|
-
#define GGML_METAL_MAX_KERNELS 256
|
28
|
-
|
29
|
-
struct ggml_metal_buffer {
|
30
|
-
const char * name;
|
31
|
-
|
32
|
-
void * data;
|
33
|
-
size_t size;
|
34
|
-
|
35
|
-
id<MTLBuffer> metal;
|
36
|
-
};
|
37
|
-
|
38
27
|
struct ggml_metal_kernel {
|
39
|
-
id<MTLFunction> function;
|
40
28
|
id<MTLComputePipelineState> pipeline;
|
41
29
|
};
|
42
30
|
|
@@ -72,6 +60,7 @@ enum ggml_metal_kernel_type {
|
|
72
60
|
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K,
|
73
61
|
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS,
|
74
62
|
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS,
|
63
|
+
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS,
|
75
64
|
GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
|
76
65
|
GGML_METAL_KERNEL_TYPE_RMS_NORM,
|
77
66
|
GGML_METAL_KERNEL_TYPE_GROUP_NORM,
|
@@ -93,6 +82,7 @@ enum ggml_metal_kernel_type {
|
|
93
82
|
GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32,
|
94
83
|
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32,
|
95
84
|
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32,
|
85
|
+
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32,
|
96
86
|
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,
|
97
87
|
//GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16,
|
98
88
|
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32,
|
@@ -110,6 +100,7 @@ enum ggml_metal_kernel_type {
|
|
110
100
|
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32,
|
111
101
|
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32,
|
112
102
|
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32,
|
103
|
+
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32,
|
113
104
|
GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,
|
114
105
|
GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32,
|
115
106
|
GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32,
|
@@ -124,6 +115,7 @@ enum ggml_metal_kernel_type {
|
|
124
115
|
GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32,
|
125
116
|
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32,
|
126
117
|
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32,
|
118
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32,
|
127
119
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,
|
128
120
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32,
|
129
121
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32,
|
@@ -138,10 +130,12 @@ enum ggml_metal_kernel_type {
|
|
138
130
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32,
|
139
131
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32,
|
140
132
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32,
|
133
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32,
|
141
134
|
GGML_METAL_KERNEL_TYPE_ROPE_F32,
|
142
135
|
GGML_METAL_KERNEL_TYPE_ROPE_F16,
|
143
136
|
GGML_METAL_KERNEL_TYPE_ALIBI_F32,
|
144
137
|
GGML_METAL_KERNEL_TYPE_IM2COL_F16,
|
138
|
+
GGML_METAL_KERNEL_TYPE_IM2COL_F32,
|
145
139
|
GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
|
146
140
|
GGML_METAL_KERNEL_TYPE_PAD_F32,
|
147
141
|
GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,
|
@@ -168,17 +162,15 @@ struct ggml_metal_context {
|
|
168
162
|
|
169
163
|
id<MTLDevice> device;
|
170
164
|
id<MTLCommandQueue> queue;
|
171
|
-
id<MTLLibrary> library;
|
172
165
|
|
173
166
|
dispatch_queue_t d_queue;
|
174
167
|
|
175
|
-
|
176
|
-
struct ggml_metal_buffer buffers[GGML_METAL_MAX_BUFFERS];
|
177
|
-
|
178
|
-
struct ggml_metal_kernel kernels[GGML_METAL_MAX_KERNELS];
|
168
|
+
struct ggml_metal_kernel kernels[GGML_METAL_KERNEL_TYPE_COUNT];
|
179
169
|
|
180
170
|
bool support_simdgroup_reduction;
|
181
171
|
bool support_simdgroup_mm;
|
172
|
+
|
173
|
+
bool should_capture_next_compute;
|
182
174
|
};
|
183
175
|
|
184
176
|
// MSL code
|
@@ -238,32 +230,28 @@ static void * ggml_metal_host_malloc(size_t n) {
|
|
238
230
|
static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
239
231
|
GGML_METAL_LOG_INFO("%s: allocating\n", __func__);
|
240
232
|
|
241
|
-
|
242
|
-
NSString * s;
|
243
|
-
|
244
|
-
#if TARGET_OS_OSX
|
233
|
+
#if TARGET_OS_OSX && !GGML_METAL_NDEBUG
|
245
234
|
// Show all the Metal device instances in the system
|
246
235
|
NSArray * devices = MTLCopyAllDevices();
|
247
|
-
for (device in devices) {
|
248
|
-
s
|
249
|
-
GGML_METAL_LOG_INFO("%s: found device: %s\n", __func__, [s UTF8String]);
|
236
|
+
for (id<MTLDevice> device in devices) {
|
237
|
+
GGML_METAL_LOG_INFO("%s: found device: %s\n", __func__, [[device name] UTF8String]);
|
250
238
|
}
|
239
|
+
[devices release]; // since it was created by a *Copy* C method
|
251
240
|
#endif
|
252
241
|
|
253
242
|
// Pick and show default Metal device
|
254
|
-
device = MTLCreateSystemDefaultDevice();
|
255
|
-
s
|
256
|
-
GGML_METAL_LOG_INFO("%s: picking default device: %s\n", __func__, [s UTF8String]);
|
243
|
+
id<MTLDevice> device = MTLCreateSystemDefaultDevice();
|
244
|
+
GGML_METAL_LOG_INFO("%s: picking default device: %s\n", __func__, [[device name] UTF8String]);
|
257
245
|
|
258
246
|
// Configure context
|
259
247
|
struct ggml_metal_context * ctx = malloc(sizeof(struct ggml_metal_context));
|
260
248
|
ctx->device = device;
|
261
249
|
ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_BUFFERS);
|
262
250
|
ctx->queue = [ctx->device newCommandQueue];
|
263
|
-
ctx->n_buffers = 0;
|
264
|
-
|
265
251
|
ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
|
266
252
|
|
253
|
+
id<MTLLibrary> metal_library;
|
254
|
+
|
267
255
|
// load library
|
268
256
|
{
|
269
257
|
NSBundle * bundle = nil;
|
@@ -278,7 +266,11 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
278
266
|
// pre-compiled library found
|
279
267
|
NSURL * libURL = [NSURL fileURLWithPath:libPath];
|
280
268
|
GGML_METAL_LOG_INFO("%s: loading '%s'\n", __func__, [libPath UTF8String]);
|
281
|
-
|
269
|
+
metal_library = [ctx->device newLibraryWithURL:libURL error:&error];
|
270
|
+
if (error) {
|
271
|
+
GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
|
272
|
+
return NULL;
|
273
|
+
}
|
282
274
|
} else {
|
283
275
|
GGML_METAL_LOG_INFO("%s: default.metallib not found, loading from source\n", __func__);
|
284
276
|
|
@@ -303,27 +295,25 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
303
295
|
return NULL;
|
304
296
|
}
|
305
297
|
|
306
|
-
|
307
|
-
|
298
|
+
@autoreleasepool {
|
299
|
+
// dictionary of preprocessor macros
|
300
|
+
NSMutableDictionary * prep = [NSMutableDictionary dictionary];
|
308
301
|
|
309
302
|
#ifdef GGML_QKK_64
|
310
|
-
|
303
|
+
prep[@"QK_K"] = @(64);
|
311
304
|
#endif
|
312
305
|
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
//[options setFastMathEnabled:false];
|
306
|
+
MTLCompileOptions* options = [MTLCompileOptions new];
|
307
|
+
options.preprocessorMacros = prep;
|
317
308
|
|
318
|
-
|
319
|
-
|
320
|
-
[options release];
|
321
|
-
[prep release];
|
322
|
-
}
|
309
|
+
//[options setFastMathEnabled:false];
|
323
310
|
|
324
|
-
|
325
|
-
|
326
|
-
|
311
|
+
metal_library = [ctx->device newLibraryWithSource:src options:options error:&error];
|
312
|
+
if (error) {
|
313
|
+
GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
|
314
|
+
return NULL;
|
315
|
+
}
|
316
|
+
}
|
327
317
|
}
|
328
318
|
}
|
329
319
|
|
@@ -367,6 +357,8 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
367
357
|
GGML_METAL_LOG_INFO("%s: simdgroup matrix mul. support = %s\n", __func__, ctx->support_simdgroup_mm ? "true" : "false");
|
368
358
|
GGML_METAL_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
|
369
359
|
|
360
|
+
ctx->should_capture_next_compute = false;
|
361
|
+
|
370
362
|
#if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)
|
371
363
|
if (@available(macOS 10.12, iOS 16.0, *)) {
|
372
364
|
GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1e6);
|
@@ -383,8 +375,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
383
375
|
{
|
384
376
|
NSError * error = nil;
|
385
377
|
|
386
|
-
for (int i = 0; i <
|
387
|
-
ctx->kernels[i].function = nil;
|
378
|
+
for (int i = 0; i < GGML_METAL_KERNEL_TYPE_COUNT; ++i) {
|
388
379
|
ctx->kernels[i].pipeline = nil;
|
389
380
|
}
|
390
381
|
|
@@ -396,10 +387,12 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
396
387
|
#define GGML_METAL_ADD_KERNEL(e, name, supported) \
|
397
388
|
if (supported) { \
|
398
389
|
struct ggml_metal_kernel * kernel = &ctx->kernels[e]; \
|
399
|
-
|
400
|
-
kernel->pipeline = [ctx->device newComputePipelineStateWithFunction:
|
390
|
+
id<MTLFunction> metal_function = [metal_library newFunctionWithName:@"kernel_"#name]; \
|
391
|
+
kernel->pipeline = [ctx->device newComputePipelineStateWithFunction:metal_function error:&error]; \
|
392
|
+
[metal_function release]; \
|
401
393
|
if (error) { \
|
402
394
|
GGML_METAL_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
|
395
|
+
[metal_library release]; \
|
403
396
|
return NULL; \
|
404
397
|
} \
|
405
398
|
} else { \
|
@@ -439,6 +432,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
439
432
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K, get_rows_q6_K, true);
|
440
433
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, get_rows_iq2_xxs, true);
|
441
434
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true);
|
435
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS, get_rows_iq3_xxs, true);
|
442
436
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
|
443
437
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction);
|
444
438
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction);
|
@@ -460,6 +454,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
460
454
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32, mul_mv_q6_K_f32, ctx->support_simdgroup_reduction);
|
461
455
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, ctx->support_simdgroup_reduction);
|
462
456
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, ctx->support_simdgroup_reduction);
|
457
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, mul_mv_iq3_xxs_f32, ctx->support_simdgroup_reduction);
|
463
458
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction);
|
464
459
|
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, ctx->support_simdgroup_reduction);
|
465
460
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, ctx->support_simdgroup_reduction);
|
@@ -477,6 +472,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
477
472
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32, mul_mv_id_q6_K_f32, ctx->support_simdgroup_reduction);
|
478
473
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, ctx->support_simdgroup_reduction);
|
479
474
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, ctx->support_simdgroup_reduction);
|
475
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, mul_mv_id_iq3_xxs_f32, ctx->support_simdgroup_reduction);
|
480
476
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm);
|
481
477
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, ctx->support_simdgroup_mm);
|
482
478
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, ctx->support_simdgroup_mm);
|
@@ -491,6 +487,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
491
487
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32, mul_mm_q6_K_f32, ctx->support_simdgroup_mm);
|
492
488
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, ctx->support_simdgroup_mm);
|
493
489
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, ctx->support_simdgroup_mm);
|
490
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, mul_mm_iq3_xxs_f32, ctx->support_simdgroup_mm);
|
494
491
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm);
|
495
492
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, ctx->support_simdgroup_mm);
|
496
493
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, ctx->support_simdgroup_mm);
|
@@ -505,10 +502,12 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
505
502
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, mul_mm_id_q6_K_f32, ctx->support_simdgroup_mm);
|
506
503
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, ctx->support_simdgroup_mm);
|
507
504
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, ctx->support_simdgroup_mm);
|
505
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, mul_mm_id_iq3_xxs_f32, ctx->support_simdgroup_mm);
|
508
506
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true);
|
509
507
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true);
|
510
508
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true);
|
511
509
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
|
510
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
|
512
511
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
|
513
512
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true);
|
514
513
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
|
@@ -528,27 +527,17 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
528
527
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
|
529
528
|
}
|
530
529
|
|
530
|
+
[metal_library release];
|
531
531
|
return ctx;
|
532
532
|
}
|
533
533
|
|
534
534
|
static void ggml_metal_free(struct ggml_metal_context * ctx) {
|
535
535
|
GGML_METAL_LOG_INFO("%s: deallocating\n", __func__);
|
536
536
|
|
537
|
-
for (int i = 0; i <
|
538
|
-
[ctx->
|
537
|
+
for (int i = 0; i < GGML_METAL_KERNEL_TYPE_COUNT; ++i) {
|
538
|
+
[ctx->kernels[i].pipeline release];
|
539
539
|
}
|
540
540
|
|
541
|
-
for (int i = 0; i < GGML_METAL_MAX_KERNELS; ++i) {
|
542
|
-
if (ctx->kernels[i].pipeline) {
|
543
|
-
[ctx->kernels[i].pipeline release];
|
544
|
-
}
|
545
|
-
|
546
|
-
if (ctx->kernels[i].function) {
|
547
|
-
[ctx->kernels[i].function release];
|
548
|
-
}
|
549
|
-
}
|
550
|
-
|
551
|
-
[ctx->library release];
|
552
541
|
[ctx->queue release];
|
553
542
|
[ctx->device release];
|
554
543
|
|
@@ -580,51 +569,30 @@ struct ggml_backend_metal_buffer_context {
|
|
580
569
|
// the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the
|
581
570
|
// Metal buffer based on the host memory pointer
|
582
571
|
//
|
583
|
-
static id<MTLBuffer> ggml_metal_get_buffer(struct
|
572
|
+
static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_tensor * t, size_t * offs) {
|
584
573
|
//GGML_METAL_LOG_INFO("%s: data tensor '%16s', offs_data = %8ld, offs_eval = %8ld, offs_cach = %8ld\n", __func__, t->name, offs_data, offs_eval, offs_cach);
|
585
574
|
|
586
575
|
const int64_t tsize = ggml_nbytes(t);
|
587
576
|
|
588
577
|
ggml_backend_buffer_t buffer = t->view_src ? t->view_src->buffer : t->buffer;
|
589
578
|
|
590
|
-
|
591
|
-
if (buffer && buffer->buft == ggml_backend_metal_buffer_type()) {
|
592
|
-
struct ggml_backend_metal_buffer_context * buf_ctx = (struct ggml_backend_metal_buffer_context *) buffer->context;
|
593
|
-
|
594
|
-
// find the view that contains the tensor fully
|
595
|
-
for (int i = 0; i < buf_ctx->n_buffers; ++i) {
|
596
|
-
const int64_t ioffs = (int64_t) t->data - (int64_t) buf_ctx->buffers[i].data;
|
597
|
-
|
598
|
-
//GGML_METAL_LOG_INFO("ioffs = %10ld, tsize = %10ld, sum = %10ld, buf_ctx->buffers[%d].size = %10ld\n", ioffs, tsize, ioffs + tsize, i, buf_ctx->buffers[i].size);
|
599
|
-
if (ioffs >= 0 && ioffs + tsize <= (int64_t) buf_ctx->buffers[i].size) {
|
600
|
-
*offs = (size_t) ioffs;
|
601
|
-
|
602
|
-
//GGML_METAL_LOG_INFO("%s: tensor '%16s', offs = %8ld\n", __func__, t->name, *offs);
|
603
|
-
|
604
|
-
return buf_ctx->buffers[i].metal;
|
605
|
-
}
|
606
|
-
}
|
607
|
-
|
608
|
-
GGML_METAL_LOG_ERROR("%s: error: tensor '%s' buffer is nil\n", __func__, t->name);
|
609
|
-
|
610
|
-
return nil;
|
611
|
-
}
|
579
|
+
struct ggml_backend_metal_buffer_context * buf_ctx = (struct ggml_backend_metal_buffer_context *) buffer->context;
|
612
580
|
|
613
581
|
// find the view that contains the tensor fully
|
614
|
-
for (int i = 0; i <
|
615
|
-
const int64_t ioffs = (int64_t) t->data - (int64_t)
|
582
|
+
for (int i = 0; i < buf_ctx->n_buffers; ++i) {
|
583
|
+
const int64_t ioffs = (int64_t) t->data - (int64_t) buf_ctx->buffers[i].data;
|
616
584
|
|
617
|
-
//GGML_METAL_LOG_INFO("ioffs = %10ld, tsize = %10ld, sum = %10ld,
|
618
|
-
if (ioffs >= 0 && ioffs + tsize <= (int64_t)
|
585
|
+
//GGML_METAL_LOG_INFO("ioffs = %10ld, tsize = %10ld, sum = %10ld, buf_ctx->buffers[%d].size = %10ld\n", ioffs, tsize, ioffs + tsize, i, buf_ctx->buffers[i].size);
|
586
|
+
if (ioffs >= 0 && ioffs + tsize <= (int64_t) buf_ctx->buffers[i].size) {
|
619
587
|
*offs = (size_t) ioffs;
|
620
588
|
|
621
|
-
//GGML_METAL_LOG_INFO("%s:
|
589
|
+
//GGML_METAL_LOG_INFO("%s: tensor '%16s', offs = %8ld\n", __func__, t->name, *offs);
|
622
590
|
|
623
|
-
return
|
591
|
+
return buf_ctx->buffers[i].metal;
|
624
592
|
}
|
625
593
|
}
|
626
594
|
|
627
|
-
GGML_METAL_LOG_ERROR("%s: error: buffer is nil\n", __func__);
|
595
|
+
GGML_METAL_LOG_ERROR("%s: error: tensor '%s' buffer is nil\n", __func__, t->name);
|
628
596
|
|
629
597
|
return nil;
|
630
598
|
}
|
@@ -664,6 +632,10 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
|
|
664
632
|
case GGML_OP_ALIBI:
|
665
633
|
case GGML_OP_ROPE:
|
666
634
|
case GGML_OP_IM2COL:
|
635
|
+
return true;
|
636
|
+
case GGML_OP_POOL_1D:
|
637
|
+
case GGML_OP_POOL_2D:
|
638
|
+
return false;
|
667
639
|
case GGML_OP_UPSCALE:
|
668
640
|
case GGML_OP_PAD:
|
669
641
|
case GGML_OP_ARGSORT:
|
@@ -671,7 +643,8 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
|
|
671
643
|
return true;
|
672
644
|
case GGML_OP_MUL_MAT:
|
673
645
|
case GGML_OP_MUL_MAT_ID:
|
674
|
-
return ctx->support_simdgroup_reduction
|
646
|
+
return ctx->support_simdgroup_reduction &&
|
647
|
+
(op->src[0]->type != GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F32);
|
675
648
|
case GGML_OP_CPY:
|
676
649
|
case GGML_OP_DUP:
|
677
650
|
case GGML_OP_CONT:
|
@@ -713,7 +686,6 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
|
|
713
686
|
static bool ggml_metal_graph_compute(
|
714
687
|
struct ggml_metal_context * ctx,
|
715
688
|
struct ggml_cgraph * gf) {
|
716
|
-
@autoreleasepool {
|
717
689
|
|
718
690
|
MTLComputePassDescriptor * edesc = MTLComputePassDescriptor.computePassDescriptor;
|
719
691
|
edesc.dispatchType = MTLDispatchTypeSerial;
|
@@ -725,6 +697,20 @@ static bool ggml_metal_graph_compute(
|
|
725
697
|
const int n_cb = ctx->n_cb;
|
726
698
|
const int n_nodes_per_cb = (n_nodes + n_cb - 1) / n_cb;
|
727
699
|
|
700
|
+
const bool should_capture = ctx->should_capture_next_compute;
|
701
|
+
if (should_capture) {
|
702
|
+
ctx->should_capture_next_compute = false;
|
703
|
+
|
704
|
+
MTLCaptureDescriptor * descriptor = [MTLCaptureDescriptor new];
|
705
|
+
descriptor.captureObject = ctx->queue;
|
706
|
+
|
707
|
+
NSError * error = nil;
|
708
|
+
if (![[MTLCaptureManager sharedCaptureManager] startCaptureWithDescriptor:descriptor error:&error]) {
|
709
|
+
GGML_METAL_LOG_ERROR("%s: error: unable to start capture '%s'\n", __func__, [[error localizedDescription] UTF8String]);
|
710
|
+
GGML_ASSERT(!"capture failed");
|
711
|
+
}
|
712
|
+
}
|
713
|
+
|
728
714
|
id<MTLCommandBuffer> command_buffer_builder[n_cb];
|
729
715
|
for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
|
730
716
|
id<MTLCommandBuffer> command_buffer = [ctx->queue commandBufferWithUnretainedReferences];
|
@@ -733,6 +719,7 @@ static bool ggml_metal_graph_compute(
|
|
733
719
|
// enqueue the command buffers in order to specify their execution order
|
734
720
|
[command_buffer enqueue];
|
735
721
|
}
|
722
|
+
|
736
723
|
const id<MTLCommandBuffer> *command_buffers = command_buffer_builder;
|
737
724
|
|
738
725
|
dispatch_apply(n_cb, ctx->d_queue, ^(size_t iter) {
|
@@ -779,9 +766,9 @@ static bool ggml_metal_graph_compute(
|
|
779
766
|
GGML_ASSERT(!"unsupported op");
|
780
767
|
}
|
781
768
|
|
782
|
-
|
783
|
-
|
784
|
-
|
769
|
+
if (should_capture) {
|
770
|
+
[encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(dst) encoding:NSUTF8StringEncoding]];
|
771
|
+
}
|
785
772
|
|
786
773
|
const int64_t ne00 = src0 ? src0->ne[0] : 0;
|
787
774
|
const int64_t ne01 = src0 ? src0->ne[1] : 0;
|
@@ -817,9 +804,9 @@ static bool ggml_metal_graph_compute(
|
|
817
804
|
const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT;
|
818
805
|
const enum ggml_type dstt = dst ? dst->type : GGML_TYPE_COUNT;
|
819
806
|
|
820
|
-
id<MTLBuffer> id_src0 = src0 ? ggml_metal_get_buffer(
|
821
|
-
id<MTLBuffer> id_src1 = src1 ? ggml_metal_get_buffer(
|
822
|
-
id<MTLBuffer> id_dst = dst ? ggml_metal_get_buffer(
|
807
|
+
id<MTLBuffer> id_src0 = src0 ? ggml_metal_get_buffer(src0, &offs_src0) : nil;
|
808
|
+
id<MTLBuffer> id_src1 = src1 ? ggml_metal_get_buffer(src1, &offs_src1) : nil;
|
809
|
+
id<MTLBuffer> id_dst = dst ? ggml_metal_get_buffer(dst, &offs_dst) : nil;
|
823
810
|
|
824
811
|
//GGML_METAL_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
|
825
812
|
//if (src0) {
|
@@ -1308,6 +1295,7 @@ static bool ggml_metal_graph_compute(
|
|
1308
1295
|
case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32 ].pipeline; break;
|
1309
1296
|
case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32].pipeline; break;
|
1310
1297
|
case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32 ].pipeline; break;
|
1298
|
+
case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32].pipeline; break;
|
1311
1299
|
default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
|
1312
1300
|
}
|
1313
1301
|
|
@@ -1436,6 +1424,12 @@ static bool ggml_metal_graph_compute(
|
|
1436
1424
|
nth1 = 16;
|
1437
1425
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32].pipeline;
|
1438
1426
|
} break;
|
1427
|
+
case GGML_TYPE_IQ3_XXS:
|
1428
|
+
{
|
1429
|
+
nth0 = 4;
|
1430
|
+
nth1 = 16;
|
1431
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32].pipeline;
|
1432
|
+
} break;
|
1439
1433
|
default:
|
1440
1434
|
{
|
1441
1435
|
GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
|
@@ -1478,6 +1472,11 @@ static bool ggml_metal_graph_compute(
|
|
1478
1472
|
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
1479
1473
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1480
1474
|
}
|
1475
|
+
else if (src0t == GGML_TYPE_IQ3_XXS) {
|
1476
|
+
const int mem_size = 256*4+128;
|
1477
|
+
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
1478
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1479
|
+
}
|
1481
1480
|
else if (src0t == GGML_TYPE_Q4_K) {
|
1482
1481
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1483
1482
|
}
|
@@ -1572,6 +1571,7 @@ static bool ggml_metal_graph_compute(
|
|
1572
1571
|
case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32 ].pipeline; break;
|
1573
1572
|
case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32].pipeline; break;
|
1574
1573
|
case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32 ].pipeline; break;
|
1574
|
+
case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32].pipeline; break;
|
1575
1575
|
default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
|
1576
1576
|
}
|
1577
1577
|
|
@@ -1601,7 +1601,7 @@ static bool ggml_metal_graph_compute(
|
|
1601
1601
|
struct ggml_tensor * src_cur = dst->src[2 + (j % n_as)];
|
1602
1602
|
|
1603
1603
|
size_t offs_src_cur = 0;
|
1604
|
-
id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(
|
1604
|
+
id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(src_cur, &offs_src_cur);
|
1605
1605
|
|
1606
1606
|
[encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:19 + j];
|
1607
1607
|
}
|
@@ -1703,6 +1703,12 @@ static bool ggml_metal_graph_compute(
|
|
1703
1703
|
nth1 = 16;
|
1704
1704
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32].pipeline;
|
1705
1705
|
} break;
|
1706
|
+
case GGML_TYPE_IQ3_XXS:
|
1707
|
+
{
|
1708
|
+
nth0 = 4;
|
1709
|
+
nth1 = 16;
|
1710
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32].pipeline;
|
1711
|
+
} break;
|
1706
1712
|
default:
|
1707
1713
|
{
|
1708
1714
|
GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t);
|
@@ -1746,7 +1752,7 @@ static bool ggml_metal_graph_compute(
|
|
1746
1752
|
struct ggml_tensor * src_cur = dst->src[2 + (j % n_as)];
|
1747
1753
|
|
1748
1754
|
size_t offs_src_cur = 0;
|
1749
|
-
id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(
|
1755
|
+
id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(src_cur, &offs_src_cur);
|
1750
1756
|
|
1751
1757
|
[encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:23 + j];
|
1752
1758
|
}
|
@@ -1761,6 +1767,11 @@ static bool ggml_metal_graph_compute(
|
|
1761
1767
|
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
1762
1768
|
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1763
1769
|
}
|
1770
|
+
else if (src2t == GGML_TYPE_IQ3_XXS) {
|
1771
|
+
const int mem_size = 256*4+128;
|
1772
|
+
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
1773
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1774
|
+
}
|
1764
1775
|
else if (src2t == GGML_TYPE_Q4_K) {
|
1765
1776
|
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1766
1777
|
}
|
@@ -1801,6 +1812,7 @@ static bool ggml_metal_graph_compute(
|
|
1801
1812
|
case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K ].pipeline; break;
|
1802
1813
|
case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS].pipeline; break;
|
1803
1814
|
case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS ].pipeline; break;
|
1815
|
+
case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS].pipeline; break;
|
1804
1816
|
case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32 ].pipeline; break;
|
1805
1817
|
default: GGML_ASSERT(false && "not implemented");
|
1806
1818
|
}
|
@@ -2009,7 +2021,7 @@ static bool ggml_metal_graph_compute(
|
|
2009
2021
|
{
|
2010
2022
|
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
2011
2023
|
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
2012
|
-
GGML_ASSERT( dst->type == GGML_TYPE_F16);
|
2024
|
+
GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
|
2013
2025
|
|
2014
2026
|
const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
|
2015
2027
|
const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
|
@@ -2017,6 +2029,7 @@ static bool ggml_metal_graph_compute(
|
|
2017
2029
|
const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
|
2018
2030
|
const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
|
2019
2031
|
const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
|
2032
|
+
|
2020
2033
|
const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
|
2021
2034
|
|
2022
2035
|
const int32_t N = src1->ne[is_2D ? 3 : 2];
|
@@ -2037,8 +2050,8 @@ static bool ggml_metal_graph_compute(
|
|
2037
2050
|
|
2038
2051
|
id<MTLComputePipelineState> pipeline = nil;
|
2039
2052
|
|
2040
|
-
switch (
|
2041
|
-
case GGML_TYPE_F32:
|
2053
|
+
switch (dst->type) {
|
2054
|
+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline; break;
|
2042
2055
|
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline; break;
|
2043
2056
|
default: GGML_ASSERT(false);
|
2044
2057
|
};
|
@@ -2231,15 +2244,12 @@ static bool ggml_metal_graph_compute(
|
|
2231
2244
|
}
|
2232
2245
|
}
|
2233
2246
|
|
2234
|
-
|
2235
|
-
|
2236
|
-
|
2247
|
+
if (should_capture) {
|
2248
|
+
[encoder popDebugGroup];
|
2249
|
+
}
|
2237
2250
|
}
|
2238
2251
|
|
2239
|
-
|
2240
|
-
[encoder endEncoding];
|
2241
|
-
encoder = nil;
|
2242
|
-
}
|
2252
|
+
[encoder endEncoding];
|
2243
2253
|
|
2244
2254
|
[command_buffer commit];
|
2245
2255
|
});
|
@@ -2258,8 +2268,11 @@ static bool ggml_metal_graph_compute(
|
|
2258
2268
|
}
|
2259
2269
|
}
|
2260
2270
|
|
2261
|
-
|
2271
|
+
if (should_capture) {
|
2272
|
+
[[MTLCaptureManager sharedCaptureManager] stopCapture];
|
2262
2273
|
}
|
2274
|
+
|
2275
|
+
return true;
|
2263
2276
|
}
|
2264
2277
|
|
2265
2278
|
////////////////////////////////////////////////////////////////////////////////
|
@@ -2427,6 +2440,16 @@ GGML_CALL static size_t ggml_backend_metal_buffer_type_get_alignment(ggml_backen
|
|
2427
2440
|
UNUSED(buft);
|
2428
2441
|
}
|
2429
2442
|
|
2443
|
+
GGML_CALL static size_t ggml_backend_metal_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
|
2444
|
+
id<MTLDevice> device = ggml_backend_metal_get_device();
|
2445
|
+
size_t max_size = device.maxBufferLength;
|
2446
|
+
ggml_backend_metal_free_device();
|
2447
|
+
|
2448
|
+
return max_size;
|
2449
|
+
|
2450
|
+
UNUSED(buft);
|
2451
|
+
}
|
2452
|
+
|
2430
2453
|
GGML_CALL static bool ggml_backend_metal_buffer_type_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend) {
|
2431
2454
|
return ggml_backend_is_metal(backend) || ggml_backend_is_cpu(backend);
|
2432
2455
|
|
@@ -2445,6 +2468,7 @@ GGML_CALL ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void) {
|
|
2445
2468
|
/* .get_name = */ ggml_backend_metal_buffer_type_get_name,
|
2446
2469
|
/* .alloc_buffer = */ ggml_backend_metal_buffer_type_alloc_buffer,
|
2447
2470
|
/* .get_alignment = */ ggml_backend_metal_buffer_type_get_alignment,
|
2471
|
+
/* .get_max_size = */ ggml_backend_metal_buffer_type_get_max_size,
|
2448
2472
|
/* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
|
2449
2473
|
/* .supports_backend = */ ggml_backend_metal_buffer_type_supports_backend,
|
2450
2474
|
/* .is_host = */ ggml_backend_metal_buffer_type_is_host,
|
@@ -2619,6 +2643,13 @@ bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family) {
|
|
2619
2643
|
return [ctx->device supportsFamily:(MTLGPUFamilyApple1 + family - 1)];
|
2620
2644
|
}
|
2621
2645
|
|
2646
|
+
void ggml_backend_metal_capture_next_compute(ggml_backend_t backend) {
|
2647
|
+
GGML_ASSERT(ggml_backend_is_metal(backend));
|
2648
|
+
|
2649
|
+
struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
|
2650
|
+
ctx->should_capture_next_compute = true;
|
2651
|
+
}
|
2652
|
+
|
2622
2653
|
GGML_CALL ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data); // silence warning
|
2623
2654
|
|
2624
2655
|
GGML_CALL ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data) {
|