llama_cpp 0.7.0 → 0.8.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -73,6 +73,8 @@ struct ggml_metal_context {
73
73
  GGML_METAL_DECL_KERNEL(get_rows_f16);
74
74
  GGML_METAL_DECL_KERNEL(get_rows_q4_0);
75
75
  GGML_METAL_DECL_KERNEL(get_rows_q4_1);
76
+ GGML_METAL_DECL_KERNEL(get_rows_q5_0);
77
+ GGML_METAL_DECL_KERNEL(get_rows_q5_1);
76
78
  GGML_METAL_DECL_KERNEL(get_rows_q8_0);
77
79
  GGML_METAL_DECL_KERNEL(get_rows_q2_K);
78
80
  GGML_METAL_DECL_KERNEL(get_rows_q3_K);
@@ -81,22 +83,26 @@ struct ggml_metal_context {
81
83
  GGML_METAL_DECL_KERNEL(get_rows_q6_K);
82
84
  GGML_METAL_DECL_KERNEL(rms_norm);
83
85
  GGML_METAL_DECL_KERNEL(norm);
84
- GGML_METAL_DECL_KERNEL(mul_mat_f32_f32);
85
- GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
86
- GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_1row);
87
- GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_l4);
88
- GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
89
- GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
90
- GGML_METAL_DECL_KERNEL(mul_mat_q8_0_f32);
91
- GGML_METAL_DECL_KERNEL(mul_mat_q2_K_f32);
92
- GGML_METAL_DECL_KERNEL(mul_mat_q3_K_f32);
93
- GGML_METAL_DECL_KERNEL(mul_mat_q4_K_f32);
94
- GGML_METAL_DECL_KERNEL(mul_mat_q5_K_f32);
95
- GGML_METAL_DECL_KERNEL(mul_mat_q6_K_f32);
86
+ GGML_METAL_DECL_KERNEL(mul_mv_f32_f32);
87
+ GGML_METAL_DECL_KERNEL(mul_mv_f16_f32);
88
+ GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_1row);
89
+ GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_l4);
90
+ GGML_METAL_DECL_KERNEL(mul_mv_q4_0_f32);
91
+ GGML_METAL_DECL_KERNEL(mul_mv_q4_1_f32);
92
+ GGML_METAL_DECL_KERNEL(mul_mv_q5_0_f32);
93
+ GGML_METAL_DECL_KERNEL(mul_mv_q5_1_f32);
94
+ GGML_METAL_DECL_KERNEL(mul_mv_q8_0_f32);
95
+ GGML_METAL_DECL_KERNEL(mul_mv_q2_K_f32);
96
+ GGML_METAL_DECL_KERNEL(mul_mv_q3_K_f32);
97
+ GGML_METAL_DECL_KERNEL(mul_mv_q4_K_f32);
98
+ GGML_METAL_DECL_KERNEL(mul_mv_q5_K_f32);
99
+ GGML_METAL_DECL_KERNEL(mul_mv_q6_K_f32);
96
100
  GGML_METAL_DECL_KERNEL(mul_mm_f32_f32);
97
101
  GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
98
102
  GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
99
103
  GGML_METAL_DECL_KERNEL(mul_mm_q4_1_f32);
104
+ GGML_METAL_DECL_KERNEL(mul_mm_q5_0_f32);
105
+ GGML_METAL_DECL_KERNEL(mul_mm_q5_1_f32);
100
106
  GGML_METAL_DECL_KERNEL(mul_mm_q8_0_f32);
101
107
  GGML_METAL_DECL_KERNEL(mul_mm_q2_K_f32);
102
108
  GGML_METAL_DECL_KERNEL(mul_mm_q3_K_f32);
@@ -109,6 +115,8 @@ struct ggml_metal_context {
109
115
  GGML_METAL_DECL_KERNEL(cpy_f32_f16);
110
116
  GGML_METAL_DECL_KERNEL(cpy_f32_f32);
111
117
  GGML_METAL_DECL_KERNEL(cpy_f16_f16);
118
+ GGML_METAL_DECL_KERNEL(concat);
119
+ GGML_METAL_DECL_KERNEL(sqr);
112
120
 
113
121
  #undef GGML_METAL_DECL_KERNEL
114
122
  };
@@ -183,56 +191,44 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
183
191
 
184
192
  ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
185
193
 
186
- #ifdef GGML_SWIFT
187
- // load the default.metallib file
194
+ // load library
188
195
  {
189
- NSError * error = nil;
190
-
191
- NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
192
- NSString * llamaBundlePath = [bundle pathForResource:@"llama_llama" ofType:@"bundle"];
193
- NSBundle * llamaBundle = [NSBundle bundleWithPath:llamaBundlePath];
194
- NSString * libPath = [llamaBundle pathForResource:@"default" ofType:@"metallib"];
195
- NSURL * libURL = [NSURL fileURLWithPath:libPath];
196
-
197
- // Load the metallib file into a Metal library
198
- ctx->library = [ctx->device newLibraryWithURL:libURL error:&error];
199
-
200
- if (error) {
201
- GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
202
- return NULL;
203
- }
204
- }
196
+ NSBundle * bundle = nil;
197
+ #ifdef SWIFT_PACKAGE
198
+ bundle = SWIFTPM_MODULE_BUNDLE;
205
199
  #else
206
- UNUSED(msl_library_source);
207
-
208
- // read the source from "ggml-metal.metal" into a string and use newLibraryWithSource
209
- {
200
+ bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
201
+ #endif
210
202
  NSError * error = nil;
203
+ NSString * libPath = [bundle pathForResource:@"default" ofType:@"metallib"];
204
+ if (libPath != nil) {
205
+ NSURL * libURL = [NSURL fileURLWithPath:libPath];
206
+ GGML_METAL_LOG_INFO("%s: loading '%s'\n", __func__, [libPath UTF8String]);
207
+ ctx->library = [ctx->device newLibraryWithURL:libURL error:&error];
208
+ } else {
209
+ GGML_METAL_LOG_INFO("%s: default.metallib not found, loading from source\n", __func__);
210
+
211
+ NSString * sourcePath = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
212
+ GGML_METAL_LOG_INFO("%s: loading '%s'\n", __func__, [sourcePath UTF8String]);
213
+ NSString * src = [NSString stringWithContentsOfFile:sourcePath encoding:NSUTF8StringEncoding error:&error];
214
+ if (error) {
215
+ GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
216
+ return NULL;
217
+ }
211
218
 
212
- //NSString * path = [[NSBundle mainBundle] pathForResource:@"../../examples/metal/metal" ofType:@"metal"];
213
- NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
214
- NSString * path = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
215
- GGML_METAL_LOG_INFO("%s: loading '%s'\n", __func__, [path UTF8String]);
216
-
217
- NSString * src = [NSString stringWithContentsOfFile:path encoding:NSUTF8StringEncoding error:&error];
218
- if (error) {
219
- GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
220
- return NULL;
221
- }
222
-
219
+ MTLCompileOptions* options = nil;
223
220
  #ifdef GGML_QKK_64
224
- MTLCompileOptions* options = [MTLCompileOptions new];
225
- options.preprocessorMacros = @{ @"QK_K" : @(64) };
226
- ctx->library = [ctx->device newLibraryWithSource:src options:options error:&error];
227
- #else
228
- ctx->library = [ctx->device newLibraryWithSource:src options:nil error:&error];
221
+ options = [MTLCompileOptions new];
222
+ options.preprocessorMacros = @{ @"QK_K" : @(64) };
229
223
  #endif
224
+ ctx->library = [ctx->device newLibraryWithSource:src options:options error:&error];
225
+ }
226
+
230
227
  if (error) {
231
228
  GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
232
229
  return NULL;
233
230
  }
234
231
  }
235
- #endif
236
232
 
237
233
  // load kernels
238
234
  {
@@ -264,6 +260,8 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
264
260
  GGML_METAL_ADD_KERNEL(get_rows_f16);
265
261
  GGML_METAL_ADD_KERNEL(get_rows_q4_0);
266
262
  GGML_METAL_ADD_KERNEL(get_rows_q4_1);
263
+ GGML_METAL_ADD_KERNEL(get_rows_q5_0);
264
+ GGML_METAL_ADD_KERNEL(get_rows_q5_1);
267
265
  GGML_METAL_ADD_KERNEL(get_rows_q8_0);
268
266
  GGML_METAL_ADD_KERNEL(get_rows_q2_K);
269
267
  GGML_METAL_ADD_KERNEL(get_rows_q3_K);
@@ -272,40 +270,61 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
272
270
  GGML_METAL_ADD_KERNEL(get_rows_q6_K);
273
271
  GGML_METAL_ADD_KERNEL(rms_norm);
274
272
  GGML_METAL_ADD_KERNEL(norm);
275
- GGML_METAL_ADD_KERNEL(mul_mat_f32_f32);
276
- GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
277
- GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_1row);
278
- GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_l4);
279
- GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
280
- GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
281
- GGML_METAL_ADD_KERNEL(mul_mat_q8_0_f32);
282
- GGML_METAL_ADD_KERNEL(mul_mat_q2_K_f32);
283
- GGML_METAL_ADD_KERNEL(mul_mat_q3_K_f32);
284
- GGML_METAL_ADD_KERNEL(mul_mat_q4_K_f32);
285
- GGML_METAL_ADD_KERNEL(mul_mat_q5_K_f32);
286
- GGML_METAL_ADD_KERNEL(mul_mat_q6_K_f32);
287
- GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
288
- GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
289
- GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32);
290
- GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32);
291
- GGML_METAL_ADD_KERNEL(mul_mm_q4_1_f32);
292
- GGML_METAL_ADD_KERNEL(mul_mm_q2_K_f32);
293
- GGML_METAL_ADD_KERNEL(mul_mm_q3_K_f32);
294
- GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32);
295
- GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32);
296
- GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32);
273
+ GGML_METAL_ADD_KERNEL(mul_mv_f32_f32);
274
+ GGML_METAL_ADD_KERNEL(mul_mv_f16_f32);
275
+ GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_1row);
276
+ GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_l4);
277
+ GGML_METAL_ADD_KERNEL(mul_mv_q4_0_f32);
278
+ GGML_METAL_ADD_KERNEL(mul_mv_q4_1_f32);
279
+ GGML_METAL_ADD_KERNEL(mul_mv_q5_0_f32);
280
+ GGML_METAL_ADD_KERNEL(mul_mv_q5_1_f32);
281
+ GGML_METAL_ADD_KERNEL(mul_mv_q8_0_f32);
282
+ GGML_METAL_ADD_KERNEL(mul_mv_q2_K_f32);
283
+ GGML_METAL_ADD_KERNEL(mul_mv_q3_K_f32);
284
+ GGML_METAL_ADD_KERNEL(mul_mv_q4_K_f32);
285
+ GGML_METAL_ADD_KERNEL(mul_mv_q5_K_f32);
286
+ GGML_METAL_ADD_KERNEL(mul_mv_q6_K_f32);
287
+ if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
288
+ GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
289
+ GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
290
+ GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32);
291
+ GGML_METAL_ADD_KERNEL(mul_mm_q4_1_f32);
292
+ GGML_METAL_ADD_KERNEL(mul_mm_q5_0_f32);
293
+ GGML_METAL_ADD_KERNEL(mul_mm_q5_1_f32);
294
+ GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32);
295
+ GGML_METAL_ADD_KERNEL(mul_mm_q2_K_f32);
296
+ GGML_METAL_ADD_KERNEL(mul_mm_q3_K_f32);
297
+ GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32);
298
+ GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32);
299
+ GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32);
300
+ }
297
301
  GGML_METAL_ADD_KERNEL(rope_f32);
298
302
  GGML_METAL_ADD_KERNEL(rope_f16);
299
303
  GGML_METAL_ADD_KERNEL(alibi_f32);
300
304
  GGML_METAL_ADD_KERNEL(cpy_f32_f16);
301
305
  GGML_METAL_ADD_KERNEL(cpy_f32_f32);
302
306
  GGML_METAL_ADD_KERNEL(cpy_f16_f16);
307
+ GGML_METAL_ADD_KERNEL(concat);
308
+ GGML_METAL_ADD_KERNEL(sqr);
303
309
 
304
310
  #undef GGML_METAL_ADD_KERNEL
305
311
  }
306
312
 
307
- GGML_METAL_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
308
313
  #if TARGET_OS_OSX
314
+ // print MTL GPU family:
315
+ GGML_METAL_LOG_INFO("%s: GPU name: %s\n", __func__, [[ctx->device name] UTF8String]);
316
+
317
+ // determine max supported GPU family
318
+ // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
319
+ // https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
320
+ for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) {
321
+ if ([ctx->device supportsFamily:i]) {
322
+ GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - MTLGPUFamilyApple1 + 1, i);
323
+ break;
324
+ }
325
+ }
326
+
327
+ GGML_METAL_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
309
328
  GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
310
329
  if (ctx->device.maxTransferRate != 0) {
311
330
  GGML_METAL_LOG_INFO("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
@@ -339,6 +358,8 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
339
358
  GGML_METAL_DEL_KERNEL(get_rows_f16);
340
359
  GGML_METAL_DEL_KERNEL(get_rows_q4_0);
341
360
  GGML_METAL_DEL_KERNEL(get_rows_q4_1);
361
+ GGML_METAL_DEL_KERNEL(get_rows_q5_0);
362
+ GGML_METAL_DEL_KERNEL(get_rows_q5_1);
342
363
  GGML_METAL_DEL_KERNEL(get_rows_q8_0);
343
364
  GGML_METAL_DEL_KERNEL(get_rows_q2_K);
344
365
  GGML_METAL_DEL_KERNEL(get_rows_q3_K);
@@ -347,34 +368,42 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
347
368
  GGML_METAL_DEL_KERNEL(get_rows_q6_K);
348
369
  GGML_METAL_DEL_KERNEL(rms_norm);
349
370
  GGML_METAL_DEL_KERNEL(norm);
350
- GGML_METAL_DEL_KERNEL(mul_mat_f32_f32);
351
- GGML_METAL_DEL_KERNEL(mul_mat_f16_f32);
352
- GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_1row);
353
- GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_l4);
354
- GGML_METAL_DEL_KERNEL(mul_mat_q4_0_f32);
355
- GGML_METAL_DEL_KERNEL(mul_mat_q4_1_f32);
356
- GGML_METAL_DEL_KERNEL(mul_mat_q8_0_f32);
357
- GGML_METAL_DEL_KERNEL(mul_mat_q2_K_f32);
358
- GGML_METAL_DEL_KERNEL(mul_mat_q3_K_f32);
359
- GGML_METAL_DEL_KERNEL(mul_mat_q4_K_f32);
360
- GGML_METAL_DEL_KERNEL(mul_mat_q5_K_f32);
361
- GGML_METAL_DEL_KERNEL(mul_mat_q6_K_f32);
362
- GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
363
- GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
364
- GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32);
365
- GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32);
366
- GGML_METAL_DEL_KERNEL(mul_mm_q4_1_f32);
367
- GGML_METAL_DEL_KERNEL(mul_mm_q2_K_f32);
368
- GGML_METAL_DEL_KERNEL(mul_mm_q3_K_f32);
369
- GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32);
370
- GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32);
371
- GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32);
371
+ GGML_METAL_DEL_KERNEL(mul_mv_f32_f32);
372
+ GGML_METAL_DEL_KERNEL(mul_mv_f16_f32);
373
+ GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_1row);
374
+ GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_l4);
375
+ GGML_METAL_DEL_KERNEL(mul_mv_q4_0_f32);
376
+ GGML_METAL_DEL_KERNEL(mul_mv_q4_1_f32);
377
+ GGML_METAL_DEL_KERNEL(mul_mv_q5_0_f32);
378
+ GGML_METAL_DEL_KERNEL(mul_mv_q5_1_f32);
379
+ GGML_METAL_DEL_KERNEL(mul_mv_q8_0_f32);
380
+ GGML_METAL_DEL_KERNEL(mul_mv_q2_K_f32);
381
+ GGML_METAL_DEL_KERNEL(mul_mv_q3_K_f32);
382
+ GGML_METAL_DEL_KERNEL(mul_mv_q4_K_f32);
383
+ GGML_METAL_DEL_KERNEL(mul_mv_q5_K_f32);
384
+ GGML_METAL_DEL_KERNEL(mul_mv_q6_K_f32);
385
+ if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
386
+ GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
387
+ GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
388
+ GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32);
389
+ GGML_METAL_DEL_KERNEL(mul_mm_q4_1_f32);
390
+ GGML_METAL_DEL_KERNEL(mul_mm_q5_0_f32);
391
+ GGML_METAL_DEL_KERNEL(mul_mm_q5_1_f32);
392
+ GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32);
393
+ GGML_METAL_DEL_KERNEL(mul_mm_q2_K_f32);
394
+ GGML_METAL_DEL_KERNEL(mul_mm_q3_K_f32);
395
+ GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32);
396
+ GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32);
397
+ GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32);
398
+ }
372
399
  GGML_METAL_DEL_KERNEL(rope_f32);
373
400
  GGML_METAL_DEL_KERNEL(rope_f16);
374
401
  GGML_METAL_DEL_KERNEL(alibi_f32);
375
402
  GGML_METAL_DEL_KERNEL(cpy_f32_f16);
376
403
  GGML_METAL_DEL_KERNEL(cpy_f32_f32);
377
404
  GGML_METAL_DEL_KERNEL(cpy_f16_f16);
405
+ GGML_METAL_DEL_KERNEL(concat);
406
+ GGML_METAL_DEL_KERNEL(sqr);
378
407
 
379
408
  #undef GGML_METAL_DEL_KERNEL
380
409
 
@@ -431,7 +460,7 @@ static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, stru
431
460
  for (int i = 0; i < ctx->n_buffers; ++i) {
432
461
  const int64_t ioffs = (int64_t) t->data - (int64_t) ctx->buffers[i].data;
433
462
 
434
- //metal_printf("ioffs = %10ld, tsize = %10ld, sum = %10ld, ctx->buffers[%d].size = %10ld, name = %s\n", ioffs, tsize, ioffs + tsize, i, ctx->buffers[i].size, ctx->buffers[i].name);
463
+ //GGML_METAL_LOG_INFO("ioffs = %10ld, tsize = %10ld, sum = %10ld, ctx->buffers[%d].size = %10ld, name = %s\n", ioffs, tsize, ioffs + tsize, i, ctx->buffers[i].size, ctx->buffers[i].name);
435
464
  if (ioffs >= 0 && ioffs + tsize <= (int64_t) ctx->buffers[i].size) {
436
465
  *offs = (size_t) ioffs;
437
466
 
@@ -766,6 +795,44 @@ void ggml_metal_graph_compute(
766
795
  {
767
796
  // noop
768
797
  } break;
798
+ case GGML_OP_CONCAT:
799
+ {
800
+ const int64_t nb = ne00;
801
+
802
+ [encoder setComputePipelineState:ctx->pipeline_concat];
803
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
804
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
805
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
806
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
807
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
808
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
809
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
810
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
811
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
812
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
813
+ [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
814
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
815
+ [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
816
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
817
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
818
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
819
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
820
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
821
+ [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
822
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
823
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
824
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
825
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
826
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
827
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
828
+ [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
829
+ [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
830
+ [encoder setBytes:&nb length:sizeof(nb) atIndex:27];
831
+
832
+ const int nth = MIN(1024, ne0);
833
+
834
+ [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
835
+ } break;
769
836
  case GGML_OP_ADD:
770
837
  {
771
838
  GGML_ASSERT(ggml_is_contiguous(src0));
@@ -861,9 +928,10 @@ void ggml_metal_graph_compute(
861
928
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
862
929
  [encoder setBytes:&scale length:sizeof(scale) atIndex:2];
863
930
 
864
- const int64_t n = ggml_nelements(dst)/4;
931
+ const int64_t n = ggml_nelements(dst);
932
+ GGML_ASSERT(n % 4 == 0);
865
933
 
866
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
934
+ [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
867
935
  } break;
868
936
  case GGML_OP_UNARY:
869
937
  switch (ggml_get_unary_op(gf->nodes[i])) {
@@ -873,9 +941,10 @@ void ggml_metal_graph_compute(
873
941
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
874
942
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
875
943
 
876
- const int64_t n = ggml_nelements(dst)/4;
944
+ const int64_t n = ggml_nelements(dst);
945
+ GGML_ASSERT(n % 4 == 0);
877
946
 
878
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
947
+ [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
879
948
  } break;
880
949
  case GGML_UNARY_OP_RELU:
881
950
  {
@@ -893,9 +962,10 @@ void ggml_metal_graph_compute(
893
962
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
894
963
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
895
964
 
896
- const int64_t n = ggml_nelements(dst)/4;
965
+ const int64_t n = ggml_nelements(dst);
966
+ GGML_ASSERT(n % 4 == 0);
897
967
 
898
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
968
+ [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
899
969
  } break;
900
970
  default:
901
971
  {
@@ -903,6 +973,17 @@ void ggml_metal_graph_compute(
903
973
  GGML_ASSERT(false);
904
974
  }
905
975
  } break;
976
+ case GGML_OP_SQR:
977
+ {
978
+ GGML_ASSERT(ggml_is_contiguous(src0));
979
+
980
+ [encoder setComputePipelineState:ctx->pipeline_sqr];
981
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
982
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
983
+
984
+ const int64_t n = ggml_nelements(dst);
985
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
986
+ } break;
906
987
  case GGML_OP_SOFT_MAX:
907
988
  {
908
989
  const int nth = MIN(32, ne00);
@@ -944,26 +1025,53 @@ void ggml_metal_graph_compute(
944
1025
  } break;
945
1026
  case GGML_OP_MUL_MAT:
946
1027
  {
947
- // TODO: needs to be updated after PR: https://github.com/ggerganov/ggml/pull/224
948
-
949
1028
  GGML_ASSERT(ne00 == ne10);
950
- // GGML_ASSERT(ne02 == ne12); // Should be checked on individual data types until broadcast is implemented everywhere
951
- uint gqa = ne12/ne02;
952
1029
  GGML_ASSERT(ne03 == ne13);
953
1030
 
1031
+ const uint gqa = ne12/ne02;
1032
+
1033
+ // find the break-even point where the matrix-matrix kernel becomes more efficient compared
1034
+ // to the matrix-vector kernel
1035
+ int ne11_mm_min = 1;
1036
+
1037
+ #if 0
1038
+ // the numbers below are measured on M2 Ultra for 7B and 13B models
1039
+ // these numbers do not translate to other devices or model sizes
1040
+ // TODO: need to find a better approach
1041
+ if ([ctx->device.name isEqualToString:@"Apple M2 Ultra"]) {
1042
+ switch (src0t) {
1043
+ case GGML_TYPE_F16: ne11_mm_min = 2; break;
1044
+ case GGML_TYPE_Q8_0: ne11_mm_min = 7; break;
1045
+ case GGML_TYPE_Q2_K: ne11_mm_min = 15; break;
1046
+ case GGML_TYPE_Q3_K: ne11_mm_min = 7; break;
1047
+ case GGML_TYPE_Q4_0:
1048
+ case GGML_TYPE_Q4_1: ne11_mm_min = 15; break;
1049
+ case GGML_TYPE_Q4_K: ne11_mm_min = 11; break;
1050
+ case GGML_TYPE_Q5_0: // not tested yet
1051
+ case GGML_TYPE_Q5_1: ne11_mm_min = 13; break; // not tested yet
1052
+ case GGML_TYPE_Q5_K: ne11_mm_min = 7; break;
1053
+ case GGML_TYPE_Q6_K: ne11_mm_min = 7; break;
1054
+ default: ne11_mm_min = 1; break;
1055
+ }
1056
+ }
1057
+ #endif
1058
+
954
1059
  // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
955
1060
  // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
956
- if (!ggml_is_transposed(src0) &&
1061
+ if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
1062
+ !ggml_is_transposed(src0) &&
957
1063
  !ggml_is_transposed(src1) &&
958
1064
  src1t == GGML_TYPE_F32 &&
959
- [ctx->device supportsFamily:MTLGPUFamilyApple7] &&
960
- ne00%32 == 0 &&
961
- ne11 > 2) {
1065
+ ne00 % 32 == 0 && ne00 >= 64 &&
1066
+ ne11 > ne11_mm_min) {
1067
+ //printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
962
1068
  switch (src0->type) {
963
1069
  case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f32_f32]; break;
964
1070
  case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break;
965
1071
  case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break;
966
1072
  case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break;
1073
+ case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_0_f32]; break;
1074
+ case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_1_f32]; break;
967
1075
  case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q8_0_f32]; break;
968
1076
  case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q2_K_f32]; break;
969
1077
  case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q3_K_f32]; break;
@@ -987,17 +1095,18 @@ void ggml_metal_graph_compute(
987
1095
  [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12];
988
1096
  [encoder setBytes:&gqa length:sizeof(gqa) atIndex:13];
989
1097
  [encoder setThreadgroupMemoryLength:8192 atIndex:0];
990
- [encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
1098
+ [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
991
1099
  } else {
992
1100
  int nth0 = 32;
993
1101
  int nth1 = 1;
994
1102
  int nrows = 1;
1103
+ //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
995
1104
 
996
1105
  // use custom matrix x vector kernel
997
1106
  switch (src0t) {
998
1107
  case GGML_TYPE_F32:
999
1108
  {
1000
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_f32_f32];
1109
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_f32_f32];
1001
1110
  nrows = 4;
1002
1111
  } break;
1003
1112
  case GGML_TYPE_F16:
@@ -1005,12 +1114,12 @@ void ggml_metal_graph_compute(
1005
1114
  nth0 = 32;
1006
1115
  nth1 = 1;
1007
1116
  if (ne11 * ne12 < 4) {
1008
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_1row];
1117
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_1row];
1009
1118
  } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
1010
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_l4];
1119
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_l4];
1011
1120
  nrows = ne11;
1012
1121
  } else {
1013
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
1122
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32];
1014
1123
  nrows = 4;
1015
1124
  }
1016
1125
  } break;
@@ -1021,7 +1130,7 @@ void ggml_metal_graph_compute(
1021
1130
 
1022
1131
  nth0 = 8;
1023
1132
  nth1 = 8;
1024
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0_f32];
1133
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_0_f32];
1025
1134
  } break;
1026
1135
  case GGML_TYPE_Q4_1:
1027
1136
  {
@@ -1030,7 +1139,25 @@ void ggml_metal_graph_compute(
1030
1139
 
1031
1140
  nth0 = 8;
1032
1141
  nth1 = 8;
1033
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_1_f32];
1142
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_1_f32];
1143
+ } break;
1144
+ case GGML_TYPE_Q5_0:
1145
+ {
1146
+ GGML_ASSERT(ne02 == 1);
1147
+ GGML_ASSERT(ne12 == 1);
1148
+
1149
+ nth0 = 8;
1150
+ nth1 = 8;
1151
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_0_f32];
1152
+ } break;
1153
+ case GGML_TYPE_Q5_1:
1154
+ {
1155
+ GGML_ASSERT(ne02 == 1);
1156
+ GGML_ASSERT(ne12 == 1);
1157
+
1158
+ nth0 = 8;
1159
+ nth1 = 8;
1160
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_1_f32];
1034
1161
  } break;
1035
1162
  case GGML_TYPE_Q8_0:
1036
1163
  {
@@ -1039,7 +1166,7 @@ void ggml_metal_graph_compute(
1039
1166
 
1040
1167
  nth0 = 8;
1041
1168
  nth1 = 8;
1042
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_q8_0_f32];
1169
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_q8_0_f32];
1043
1170
  } break;
1044
1171
  case GGML_TYPE_Q2_K:
1045
1172
  {
@@ -1048,7 +1175,7 @@ void ggml_metal_graph_compute(
1048
1175
 
1049
1176
  nth0 = 2;
1050
1177
  nth1 = 32;
1051
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_q2_K_f32];
1178
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_q2_K_f32];
1052
1179
  } break;
1053
1180
  case GGML_TYPE_Q3_K:
1054
1181
  {
@@ -1057,7 +1184,7 @@ void ggml_metal_graph_compute(
1057
1184
 
1058
1185
  nth0 = 2;
1059
1186
  nth1 = 32;
1060
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_q3_K_f32];
1187
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_q3_K_f32];
1061
1188
  } break;
1062
1189
  case GGML_TYPE_Q4_K:
1063
1190
  {
@@ -1066,7 +1193,7 @@ void ggml_metal_graph_compute(
1066
1193
 
1067
1194
  nth0 = 4; //1;
1068
1195
  nth1 = 8; //32;
1069
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_K_f32];
1196
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_K_f32];
1070
1197
  } break;
1071
1198
  case GGML_TYPE_Q5_K:
1072
1199
  {
@@ -1075,7 +1202,7 @@ void ggml_metal_graph_compute(
1075
1202
 
1076
1203
  nth0 = 2;
1077
1204
  nth1 = 32;
1078
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_q5_K_f32];
1205
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_K_f32];
1079
1206
  } break;
1080
1207
  case GGML_TYPE_Q6_K:
1081
1208
  {
@@ -1084,7 +1211,7 @@ void ggml_metal_graph_compute(
1084
1211
 
1085
1212
  nth0 = 2;
1086
1213
  nth1 = 32;
1087
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_q6_K_f32];
1214
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_q6_K_f32];
1088
1215
  } break;
1089
1216
  default:
1090
1217
  {
@@ -1112,8 +1239,9 @@ void ggml_metal_graph_compute(
1112
1239
  [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
1113
1240
  [encoder setBytes:&gqa length:sizeof(gqa) atIndex:17];
1114
1241
 
1115
- if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q8_0 ||
1116
- src0t == GGML_TYPE_Q2_K) {// || src0t == GGML_TYPE_Q4_K) {
1242
+ if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
1243
+ src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 ||
1244
+ src0t == GGML_TYPE_Q2_K) { // || src0t == GGML_TYPE_Q4_K) {
1117
1245
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1118
1246
  }
1119
1247
  else if (src0t == GGML_TYPE_Q4_K) {
@@ -1144,6 +1272,8 @@ void ggml_metal_graph_compute(
1144
1272
  case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
1145
1273
  case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
1146
1274
  case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
1275
+ case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_0]; break;
1276
+ case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_1]; break;
1147
1277
  case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q8_0]; break;
1148
1278
  case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_K]; break;
1149
1279
  case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_K]; break;
@@ -1166,6 +1296,8 @@ void ggml_metal_graph_compute(
1166
1296
  } break;
1167
1297
  case GGML_OP_RMS_NORM:
1168
1298
  {
1299
+ GGML_ASSERT(ne00 % 4 == 0);
1300
+
1169
1301
  float eps;
1170
1302
  memcpy(&eps, dst->op_params, sizeof(float));
1171
1303
 
@@ -1208,7 +1340,7 @@ void ggml_metal_graph_compute(
1208
1340
 
1209
1341
  const int nth = MIN(1024, ne00);
1210
1342
 
1211
- const int n_past = ((int32_t *) dst->op_params)[0]; UNUSED(n_past);
1343
+ //const int n_past = ((int32_t *) dst->op_params)[0];
1212
1344
  const int n_head = ((int32_t *) dst->op_params)[1];
1213
1345
  float max_bias;
1214
1346
  memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
@@ -1371,3 +1503,140 @@ void ggml_metal_graph_compute(
1371
1503
 
1372
1504
  }
1373
1505
  }
1506
+
1507
+ ////////////////////////////////////////////////////////////////////////////////
1508
+
1509
+ // backend interface
1510
+
1511
+ static const char * ggml_backend_metal_name(ggml_backend_t backend) {
1512
+ return "Metal";
1513
+
1514
+ UNUSED(backend);
1515
+ }
1516
+
1517
+ static void ggml_backend_metal_free(ggml_backend_t backend) {
1518
+ struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
1519
+ ggml_metal_free(ctx);
1520
+ free(backend);
1521
+ }
1522
+
1523
+ static void * ggml_backend_metal_buffer_get_base(ggml_backend_buffer_t buffer) {
1524
+ return (void *)buffer->context;
1525
+ }
1526
+
1527
+ static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer) {
1528
+ free(buffer->context);
1529
+ UNUSED(buffer);
1530
+ }
1531
+
1532
+ static struct ggml_backend_buffer_i metal_backend_buffer_i = {
1533
+ /* .free_buffer = */ ggml_backend_metal_buffer_free_buffer,
1534
+ /* .get_base = */ ggml_backend_metal_buffer_get_base,
1535
+ /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
1536
+ /* .init_tensor = */ NULL, // no initialization required
1537
+ /* .free_tensor = */ NULL, // no cleanup required
1538
+ };
1539
+
1540
+ static ggml_backend_buffer_t ggml_backend_metal_alloc_buffer(ggml_backend_t backend, size_t size) {
1541
+ struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
1542
+
1543
+ void * data = ggml_metal_host_malloc(size);
1544
+
1545
+ // TODO: set proper name of the buffers
1546
+ ggml_metal_add_buffer(ctx, "backend", data, size, 0);
1547
+
1548
+ return ggml_backend_buffer_init(backend, metal_backend_buffer_i, data, size);
1549
+ }
1550
+
1551
+ static size_t ggml_backend_metal_get_alignment(ggml_backend_t backend) {
1552
+ return 32;
1553
+ UNUSED(backend);
1554
+ }
1555
+
1556
+ static void ggml_backend_metal_set_tensor_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
1557
+ GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
1558
+ GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
1559
+
1560
+ memcpy((char *)tensor->data + offset, data, size);
1561
+
1562
+ UNUSED(backend);
1563
+ }
1564
+
1565
+ static void ggml_backend_metal_get_tensor_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
1566
+ GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds");
1567
+ GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
1568
+
1569
+ memcpy(data, (const char *)tensor->data + offset, size);
1570
+
1571
+ UNUSED(backend);
1572
+ }
1573
+
1574
+ static void ggml_backend_metal_synchronize(ggml_backend_t backend) {
1575
+ UNUSED(backend);
1576
+ }
1577
+
1578
+ static void ggml_backend_metal_cpy_tensor_from(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst) {
1579
+ ggml_backend_tensor_get(src, dst->data, 0, ggml_nbytes(src));
1580
+
1581
+ UNUSED(backend);
1582
+ }
1583
+
1584
+ static void ggml_backend_metal_cpy_tensor_to(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst) {
1585
+ ggml_backend_tensor_set_async(dst, src->data, 0, ggml_nbytes(src));
1586
+
1587
+ UNUSED(backend);
1588
+ }
1589
+
1590
+ static void ggml_backend_metal_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
1591
+ struct ggml_metal_context * metal_ctx = (struct ggml_metal_context *)backend->context;
1592
+
1593
+ ggml_metal_graph_compute(metal_ctx, cgraph);
1594
+ }
1595
+
1596
+ static bool ggml_backend_metal_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
1597
+ return true;
1598
+ UNUSED(backend);
1599
+ UNUSED(op);
1600
+ }
1601
+
1602
+ static struct ggml_backend_i metal_backend_i = {
1603
+ /* .get_name = */ ggml_backend_metal_name,
1604
+ /* .free = */ ggml_backend_metal_free,
1605
+ /* .alloc_buffer = */ ggml_backend_metal_alloc_buffer,
1606
+ /* .get_alignment = */ ggml_backend_metal_get_alignment,
1607
+ /* .set_tensor_async = */ ggml_backend_metal_set_tensor_async,
1608
+ /* .get_tensor_async = */ ggml_backend_metal_get_tensor_async,
1609
+ /* .synchronize = */ ggml_backend_metal_synchronize,
1610
+ /* .cpy_tensor_from = */ ggml_backend_metal_cpy_tensor_from,
1611
+ /* .cpy_tensor_to = */ ggml_backend_metal_cpy_tensor_to,
1612
+ /* .graph_plan_create = */ NULL, // the metal implementation does not require creating graph plans atm
1613
+ /* .graph_plan_free = */ NULL,
1614
+ /* .graph_plan_compute = */ NULL,
1615
+ /* .graph_compute = */ ggml_backend_metal_graph_compute,
1616
+ /* .supports_op = */ ggml_backend_metal_supports_op,
1617
+ };
1618
+
1619
+ ggml_backend_t ggml_backend_metal_init(void) {
1620
+ struct ggml_metal_context * ctx = malloc(sizeof(struct ggml_metal_context));
1621
+
1622
+ ctx = ggml_metal_init(GGML_DEFAULT_N_THREADS);
1623
+
1624
+ ggml_backend_t metal_backend = malloc(sizeof(struct ggml_backend));
1625
+
1626
+ *metal_backend = (struct ggml_backend) {
1627
+ /* .interface = */ metal_backend_i,
1628
+ /* .context = */ ctx,
1629
+ };
1630
+
1631
+ return metal_backend;
1632
+ }
1633
+
1634
+ bool ggml_backend_is_metal(ggml_backend_t backend) {
1635
+ return backend->iface.get_name == ggml_backend_metal_name;
1636
+ }
1637
+
1638
+ void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
1639
+ struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
1640
+
1641
+ ggml_metal_set_n_cb(ctx, n_cb);
1642
+ }