llama_cpp 0.7.0 → 0.8.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -73,6 +73,8 @@ struct ggml_metal_context {
73
73
  GGML_METAL_DECL_KERNEL(get_rows_f16);
74
74
  GGML_METAL_DECL_KERNEL(get_rows_q4_0);
75
75
  GGML_METAL_DECL_KERNEL(get_rows_q4_1);
76
+ GGML_METAL_DECL_KERNEL(get_rows_q5_0);
77
+ GGML_METAL_DECL_KERNEL(get_rows_q5_1);
76
78
  GGML_METAL_DECL_KERNEL(get_rows_q8_0);
77
79
  GGML_METAL_DECL_KERNEL(get_rows_q2_K);
78
80
  GGML_METAL_DECL_KERNEL(get_rows_q3_K);
@@ -81,22 +83,26 @@ struct ggml_metal_context {
81
83
  GGML_METAL_DECL_KERNEL(get_rows_q6_K);
82
84
  GGML_METAL_DECL_KERNEL(rms_norm);
83
85
  GGML_METAL_DECL_KERNEL(norm);
84
- GGML_METAL_DECL_KERNEL(mul_mat_f32_f32);
85
- GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
86
- GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_1row);
87
- GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_l4);
88
- GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
89
- GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
90
- GGML_METAL_DECL_KERNEL(mul_mat_q8_0_f32);
91
- GGML_METAL_DECL_KERNEL(mul_mat_q2_K_f32);
92
- GGML_METAL_DECL_KERNEL(mul_mat_q3_K_f32);
93
- GGML_METAL_DECL_KERNEL(mul_mat_q4_K_f32);
94
- GGML_METAL_DECL_KERNEL(mul_mat_q5_K_f32);
95
- GGML_METAL_DECL_KERNEL(mul_mat_q6_K_f32);
86
+ GGML_METAL_DECL_KERNEL(mul_mv_f32_f32);
87
+ GGML_METAL_DECL_KERNEL(mul_mv_f16_f32);
88
+ GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_1row);
89
+ GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_l4);
90
+ GGML_METAL_DECL_KERNEL(mul_mv_q4_0_f32);
91
+ GGML_METAL_DECL_KERNEL(mul_mv_q4_1_f32);
92
+ GGML_METAL_DECL_KERNEL(mul_mv_q5_0_f32);
93
+ GGML_METAL_DECL_KERNEL(mul_mv_q5_1_f32);
94
+ GGML_METAL_DECL_KERNEL(mul_mv_q8_0_f32);
95
+ GGML_METAL_DECL_KERNEL(mul_mv_q2_K_f32);
96
+ GGML_METAL_DECL_KERNEL(mul_mv_q3_K_f32);
97
+ GGML_METAL_DECL_KERNEL(mul_mv_q4_K_f32);
98
+ GGML_METAL_DECL_KERNEL(mul_mv_q5_K_f32);
99
+ GGML_METAL_DECL_KERNEL(mul_mv_q6_K_f32);
96
100
  GGML_METAL_DECL_KERNEL(mul_mm_f32_f32);
97
101
  GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
98
102
  GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
99
103
  GGML_METAL_DECL_KERNEL(mul_mm_q4_1_f32);
104
+ GGML_METAL_DECL_KERNEL(mul_mm_q5_0_f32);
105
+ GGML_METAL_DECL_KERNEL(mul_mm_q5_1_f32);
100
106
  GGML_METAL_DECL_KERNEL(mul_mm_q8_0_f32);
101
107
  GGML_METAL_DECL_KERNEL(mul_mm_q2_K_f32);
102
108
  GGML_METAL_DECL_KERNEL(mul_mm_q3_K_f32);
@@ -109,6 +115,8 @@ struct ggml_metal_context {
109
115
  GGML_METAL_DECL_KERNEL(cpy_f32_f16);
110
116
  GGML_METAL_DECL_KERNEL(cpy_f32_f32);
111
117
  GGML_METAL_DECL_KERNEL(cpy_f16_f16);
118
+ GGML_METAL_DECL_KERNEL(concat);
119
+ GGML_METAL_DECL_KERNEL(sqr);
112
120
 
113
121
  #undef GGML_METAL_DECL_KERNEL
114
122
  };
@@ -183,56 +191,44 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
183
191
 
184
192
  ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
185
193
 
186
- #ifdef GGML_SWIFT
187
- // load the default.metallib file
194
+ // load library
188
195
  {
189
- NSError * error = nil;
190
-
191
- NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
192
- NSString * llamaBundlePath = [bundle pathForResource:@"llama_llama" ofType:@"bundle"];
193
- NSBundle * llamaBundle = [NSBundle bundleWithPath:llamaBundlePath];
194
- NSString * libPath = [llamaBundle pathForResource:@"default" ofType:@"metallib"];
195
- NSURL * libURL = [NSURL fileURLWithPath:libPath];
196
-
197
- // Load the metallib file into a Metal library
198
- ctx->library = [ctx->device newLibraryWithURL:libURL error:&error];
199
-
200
- if (error) {
201
- GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
202
- return NULL;
203
- }
204
- }
196
+ NSBundle * bundle = nil;
197
+ #ifdef SWIFT_PACKAGE
198
+ bundle = SWIFTPM_MODULE_BUNDLE;
205
199
  #else
206
- UNUSED(msl_library_source);
207
-
208
- // read the source from "ggml-metal.metal" into a string and use newLibraryWithSource
209
- {
200
+ bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
201
+ #endif
210
202
  NSError * error = nil;
203
+ NSString * libPath = [bundle pathForResource:@"default" ofType:@"metallib"];
204
+ if (libPath != nil) {
205
+ NSURL * libURL = [NSURL fileURLWithPath:libPath];
206
+ GGML_METAL_LOG_INFO("%s: loading '%s'\n", __func__, [libPath UTF8String]);
207
+ ctx->library = [ctx->device newLibraryWithURL:libURL error:&error];
208
+ } else {
209
+ GGML_METAL_LOG_INFO("%s: default.metallib not found, loading from source\n", __func__);
210
+
211
+ NSString * sourcePath = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
212
+ GGML_METAL_LOG_INFO("%s: loading '%s'\n", __func__, [sourcePath UTF8String]);
213
+ NSString * src = [NSString stringWithContentsOfFile:sourcePath encoding:NSUTF8StringEncoding error:&error];
214
+ if (error) {
215
+ GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
216
+ return NULL;
217
+ }
211
218
 
212
- //NSString * path = [[NSBundle mainBundle] pathForResource:@"../../examples/metal/metal" ofType:@"metal"];
213
- NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
214
- NSString * path = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
215
- GGML_METAL_LOG_INFO("%s: loading '%s'\n", __func__, [path UTF8String]);
216
-
217
- NSString * src = [NSString stringWithContentsOfFile:path encoding:NSUTF8StringEncoding error:&error];
218
- if (error) {
219
- GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
220
- return NULL;
221
- }
222
-
219
+ MTLCompileOptions* options = nil;
223
220
  #ifdef GGML_QKK_64
224
- MTLCompileOptions* options = [MTLCompileOptions new];
225
- options.preprocessorMacros = @{ @"QK_K" : @(64) };
226
- ctx->library = [ctx->device newLibraryWithSource:src options:options error:&error];
227
- #else
228
- ctx->library = [ctx->device newLibraryWithSource:src options:nil error:&error];
221
+ options = [MTLCompileOptions new];
222
+ options.preprocessorMacros = @{ @"QK_K" : @(64) };
229
223
  #endif
224
+ ctx->library = [ctx->device newLibraryWithSource:src options:options error:&error];
225
+ }
226
+
230
227
  if (error) {
231
228
  GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
232
229
  return NULL;
233
230
  }
234
231
  }
235
- #endif
236
232
 
237
233
  // load kernels
238
234
  {
@@ -264,6 +260,8 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
264
260
  GGML_METAL_ADD_KERNEL(get_rows_f16);
265
261
  GGML_METAL_ADD_KERNEL(get_rows_q4_0);
266
262
  GGML_METAL_ADD_KERNEL(get_rows_q4_1);
263
+ GGML_METAL_ADD_KERNEL(get_rows_q5_0);
264
+ GGML_METAL_ADD_KERNEL(get_rows_q5_1);
267
265
  GGML_METAL_ADD_KERNEL(get_rows_q8_0);
268
266
  GGML_METAL_ADD_KERNEL(get_rows_q2_K);
269
267
  GGML_METAL_ADD_KERNEL(get_rows_q3_K);
@@ -272,40 +270,61 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
272
270
  GGML_METAL_ADD_KERNEL(get_rows_q6_K);
273
271
  GGML_METAL_ADD_KERNEL(rms_norm);
274
272
  GGML_METAL_ADD_KERNEL(norm);
275
- GGML_METAL_ADD_KERNEL(mul_mat_f32_f32);
276
- GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
277
- GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_1row);
278
- GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_l4);
279
- GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
280
- GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
281
- GGML_METAL_ADD_KERNEL(mul_mat_q8_0_f32);
282
- GGML_METAL_ADD_KERNEL(mul_mat_q2_K_f32);
283
- GGML_METAL_ADD_KERNEL(mul_mat_q3_K_f32);
284
- GGML_METAL_ADD_KERNEL(mul_mat_q4_K_f32);
285
- GGML_METAL_ADD_KERNEL(mul_mat_q5_K_f32);
286
- GGML_METAL_ADD_KERNEL(mul_mat_q6_K_f32);
287
- GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
288
- GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
289
- GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32);
290
- GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32);
291
- GGML_METAL_ADD_KERNEL(mul_mm_q4_1_f32);
292
- GGML_METAL_ADD_KERNEL(mul_mm_q2_K_f32);
293
- GGML_METAL_ADD_KERNEL(mul_mm_q3_K_f32);
294
- GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32);
295
- GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32);
296
- GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32);
273
+ GGML_METAL_ADD_KERNEL(mul_mv_f32_f32);
274
+ GGML_METAL_ADD_KERNEL(mul_mv_f16_f32);
275
+ GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_1row);
276
+ GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_l4);
277
+ GGML_METAL_ADD_KERNEL(mul_mv_q4_0_f32);
278
+ GGML_METAL_ADD_KERNEL(mul_mv_q4_1_f32);
279
+ GGML_METAL_ADD_KERNEL(mul_mv_q5_0_f32);
280
+ GGML_METAL_ADD_KERNEL(mul_mv_q5_1_f32);
281
+ GGML_METAL_ADD_KERNEL(mul_mv_q8_0_f32);
282
+ GGML_METAL_ADD_KERNEL(mul_mv_q2_K_f32);
283
+ GGML_METAL_ADD_KERNEL(mul_mv_q3_K_f32);
284
+ GGML_METAL_ADD_KERNEL(mul_mv_q4_K_f32);
285
+ GGML_METAL_ADD_KERNEL(mul_mv_q5_K_f32);
286
+ GGML_METAL_ADD_KERNEL(mul_mv_q6_K_f32);
287
+ if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
288
+ GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
289
+ GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
290
+ GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32);
291
+ GGML_METAL_ADD_KERNEL(mul_mm_q4_1_f32);
292
+ GGML_METAL_ADD_KERNEL(mul_mm_q5_0_f32);
293
+ GGML_METAL_ADD_KERNEL(mul_mm_q5_1_f32);
294
+ GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32);
295
+ GGML_METAL_ADD_KERNEL(mul_mm_q2_K_f32);
296
+ GGML_METAL_ADD_KERNEL(mul_mm_q3_K_f32);
297
+ GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32);
298
+ GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32);
299
+ GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32);
300
+ }
297
301
  GGML_METAL_ADD_KERNEL(rope_f32);
298
302
  GGML_METAL_ADD_KERNEL(rope_f16);
299
303
  GGML_METAL_ADD_KERNEL(alibi_f32);
300
304
  GGML_METAL_ADD_KERNEL(cpy_f32_f16);
301
305
  GGML_METAL_ADD_KERNEL(cpy_f32_f32);
302
306
  GGML_METAL_ADD_KERNEL(cpy_f16_f16);
307
+ GGML_METAL_ADD_KERNEL(concat);
308
+ GGML_METAL_ADD_KERNEL(sqr);
303
309
 
304
310
  #undef GGML_METAL_ADD_KERNEL
305
311
  }
306
312
 
307
- GGML_METAL_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
308
313
  #if TARGET_OS_OSX
314
+ // print MTL GPU family:
315
+ GGML_METAL_LOG_INFO("%s: GPU name: %s\n", __func__, [[ctx->device name] UTF8String]);
316
+
317
+ // determine max supported GPU family
318
+ // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
319
+ // https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
320
+ for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) {
321
+ if ([ctx->device supportsFamily:i]) {
322
+ GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - MTLGPUFamilyApple1 + 1, i);
323
+ break;
324
+ }
325
+ }
326
+
327
+ GGML_METAL_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
309
328
  GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
310
329
  if (ctx->device.maxTransferRate != 0) {
311
330
  GGML_METAL_LOG_INFO("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
@@ -339,6 +358,8 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
339
358
  GGML_METAL_DEL_KERNEL(get_rows_f16);
340
359
  GGML_METAL_DEL_KERNEL(get_rows_q4_0);
341
360
  GGML_METAL_DEL_KERNEL(get_rows_q4_1);
361
+ GGML_METAL_DEL_KERNEL(get_rows_q5_0);
362
+ GGML_METAL_DEL_KERNEL(get_rows_q5_1);
342
363
  GGML_METAL_DEL_KERNEL(get_rows_q8_0);
343
364
  GGML_METAL_DEL_KERNEL(get_rows_q2_K);
344
365
  GGML_METAL_DEL_KERNEL(get_rows_q3_K);
@@ -347,34 +368,42 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
347
368
  GGML_METAL_DEL_KERNEL(get_rows_q6_K);
348
369
  GGML_METAL_DEL_KERNEL(rms_norm);
349
370
  GGML_METAL_DEL_KERNEL(norm);
350
- GGML_METAL_DEL_KERNEL(mul_mat_f32_f32);
351
- GGML_METAL_DEL_KERNEL(mul_mat_f16_f32);
352
- GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_1row);
353
- GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_l4);
354
- GGML_METAL_DEL_KERNEL(mul_mat_q4_0_f32);
355
- GGML_METAL_DEL_KERNEL(mul_mat_q4_1_f32);
356
- GGML_METAL_DEL_KERNEL(mul_mat_q8_0_f32);
357
- GGML_METAL_DEL_KERNEL(mul_mat_q2_K_f32);
358
- GGML_METAL_DEL_KERNEL(mul_mat_q3_K_f32);
359
- GGML_METAL_DEL_KERNEL(mul_mat_q4_K_f32);
360
- GGML_METAL_DEL_KERNEL(mul_mat_q5_K_f32);
361
- GGML_METAL_DEL_KERNEL(mul_mat_q6_K_f32);
362
- GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
363
- GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
364
- GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32);
365
- GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32);
366
- GGML_METAL_DEL_KERNEL(mul_mm_q4_1_f32);
367
- GGML_METAL_DEL_KERNEL(mul_mm_q2_K_f32);
368
- GGML_METAL_DEL_KERNEL(mul_mm_q3_K_f32);
369
- GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32);
370
- GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32);
371
- GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32);
371
+ GGML_METAL_DEL_KERNEL(mul_mv_f32_f32);
372
+ GGML_METAL_DEL_KERNEL(mul_mv_f16_f32);
373
+ GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_1row);
374
+ GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_l4);
375
+ GGML_METAL_DEL_KERNEL(mul_mv_q4_0_f32);
376
+ GGML_METAL_DEL_KERNEL(mul_mv_q4_1_f32);
377
+ GGML_METAL_DEL_KERNEL(mul_mv_q5_0_f32);
378
+ GGML_METAL_DEL_KERNEL(mul_mv_q5_1_f32);
379
+ GGML_METAL_DEL_KERNEL(mul_mv_q8_0_f32);
380
+ GGML_METAL_DEL_KERNEL(mul_mv_q2_K_f32);
381
+ GGML_METAL_DEL_KERNEL(mul_mv_q3_K_f32);
382
+ GGML_METAL_DEL_KERNEL(mul_mv_q4_K_f32);
383
+ GGML_METAL_DEL_KERNEL(mul_mv_q5_K_f32);
384
+ GGML_METAL_DEL_KERNEL(mul_mv_q6_K_f32);
385
+ if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
386
+ GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
387
+ GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
388
+ GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32);
389
+ GGML_METAL_DEL_KERNEL(mul_mm_q4_1_f32);
390
+ GGML_METAL_DEL_KERNEL(mul_mm_q5_0_f32);
391
+ GGML_METAL_DEL_KERNEL(mul_mm_q5_1_f32);
392
+ GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32);
393
+ GGML_METAL_DEL_KERNEL(mul_mm_q2_K_f32);
394
+ GGML_METAL_DEL_KERNEL(mul_mm_q3_K_f32);
395
+ GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32);
396
+ GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32);
397
+ GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32);
398
+ }
372
399
  GGML_METAL_DEL_KERNEL(rope_f32);
373
400
  GGML_METAL_DEL_KERNEL(rope_f16);
374
401
  GGML_METAL_DEL_KERNEL(alibi_f32);
375
402
  GGML_METAL_DEL_KERNEL(cpy_f32_f16);
376
403
  GGML_METAL_DEL_KERNEL(cpy_f32_f32);
377
404
  GGML_METAL_DEL_KERNEL(cpy_f16_f16);
405
+ GGML_METAL_DEL_KERNEL(concat);
406
+ GGML_METAL_DEL_KERNEL(sqr);
378
407
 
379
408
  #undef GGML_METAL_DEL_KERNEL
380
409
 
@@ -431,7 +460,7 @@ static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, stru
431
460
  for (int i = 0; i < ctx->n_buffers; ++i) {
432
461
  const int64_t ioffs = (int64_t) t->data - (int64_t) ctx->buffers[i].data;
433
462
 
434
- //metal_printf("ioffs = %10ld, tsize = %10ld, sum = %10ld, ctx->buffers[%d].size = %10ld, name = %s\n", ioffs, tsize, ioffs + tsize, i, ctx->buffers[i].size, ctx->buffers[i].name);
463
+ //GGML_METAL_LOG_INFO("ioffs = %10ld, tsize = %10ld, sum = %10ld, ctx->buffers[%d].size = %10ld, name = %s\n", ioffs, tsize, ioffs + tsize, i, ctx->buffers[i].size, ctx->buffers[i].name);
435
464
  if (ioffs >= 0 && ioffs + tsize <= (int64_t) ctx->buffers[i].size) {
436
465
  *offs = (size_t) ioffs;
437
466
 
@@ -766,6 +795,44 @@ void ggml_metal_graph_compute(
766
795
  {
767
796
  // noop
768
797
  } break;
798
+ case GGML_OP_CONCAT:
799
+ {
800
+ const int64_t nb = ne00;
801
+
802
+ [encoder setComputePipelineState:ctx->pipeline_concat];
803
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
804
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
805
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
806
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
807
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
808
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
809
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
810
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
811
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
812
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
813
+ [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
814
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
815
+ [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
816
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
817
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
818
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
819
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
820
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
821
+ [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
822
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
823
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
824
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
825
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
826
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
827
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
828
+ [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
829
+ [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
830
+ [encoder setBytes:&nb length:sizeof(nb) atIndex:27];
831
+
832
+ const int nth = MIN(1024, ne0);
833
+
834
+ [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
835
+ } break;
769
836
  case GGML_OP_ADD:
770
837
  {
771
838
  GGML_ASSERT(ggml_is_contiguous(src0));
@@ -861,9 +928,10 @@ void ggml_metal_graph_compute(
861
928
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
862
929
  [encoder setBytes:&scale length:sizeof(scale) atIndex:2];
863
930
 
864
- const int64_t n = ggml_nelements(dst)/4;
931
+ const int64_t n = ggml_nelements(dst);
932
+ GGML_ASSERT(n % 4 == 0);
865
933
 
866
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
934
+ [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
867
935
  } break;
868
936
  case GGML_OP_UNARY:
869
937
  switch (ggml_get_unary_op(gf->nodes[i])) {
@@ -873,9 +941,10 @@ void ggml_metal_graph_compute(
873
941
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
874
942
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
875
943
 
876
- const int64_t n = ggml_nelements(dst)/4;
944
+ const int64_t n = ggml_nelements(dst);
945
+ GGML_ASSERT(n % 4 == 0);
877
946
 
878
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
947
+ [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
879
948
  } break;
880
949
  case GGML_UNARY_OP_RELU:
881
950
  {
@@ -893,9 +962,10 @@ void ggml_metal_graph_compute(
893
962
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
894
963
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
895
964
 
896
- const int64_t n = ggml_nelements(dst)/4;
965
+ const int64_t n = ggml_nelements(dst);
966
+ GGML_ASSERT(n % 4 == 0);
897
967
 
898
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
968
+ [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
899
969
  } break;
900
970
  default:
901
971
  {
@@ -903,6 +973,17 @@ void ggml_metal_graph_compute(
903
973
  GGML_ASSERT(false);
904
974
  }
905
975
  } break;
976
+ case GGML_OP_SQR:
977
+ {
978
+ GGML_ASSERT(ggml_is_contiguous(src0));
979
+
980
+ [encoder setComputePipelineState:ctx->pipeline_sqr];
981
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
982
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
983
+
984
+ const int64_t n = ggml_nelements(dst);
985
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
986
+ } break;
906
987
  case GGML_OP_SOFT_MAX:
907
988
  {
908
989
  const int nth = MIN(32, ne00);
@@ -944,26 +1025,53 @@ void ggml_metal_graph_compute(
944
1025
  } break;
945
1026
  case GGML_OP_MUL_MAT:
946
1027
  {
947
- // TODO: needs to be updated after PR: https://github.com/ggerganov/ggml/pull/224
948
-
949
1028
  GGML_ASSERT(ne00 == ne10);
950
- // GGML_ASSERT(ne02 == ne12); // Should be checked on individual data types until broadcast is implemented everywhere
951
- uint gqa = ne12/ne02;
952
1029
  GGML_ASSERT(ne03 == ne13);
953
1030
 
1031
+ const uint gqa = ne12/ne02;
1032
+
1033
+ // find the break-even point where the matrix-matrix kernel becomes more efficient compared
1034
+ // to the matrix-vector kernel
1035
+ int ne11_mm_min = 1;
1036
+
1037
+ #if 0
1038
+ // the numbers below are measured on M2 Ultra for 7B and 13B models
1039
+ // these numbers do not translate to other devices or model sizes
1040
+ // TODO: need to find a better approach
1041
+ if ([ctx->device.name isEqualToString:@"Apple M2 Ultra"]) {
1042
+ switch (src0t) {
1043
+ case GGML_TYPE_F16: ne11_mm_min = 2; break;
1044
+ case GGML_TYPE_Q8_0: ne11_mm_min = 7; break;
1045
+ case GGML_TYPE_Q2_K: ne11_mm_min = 15; break;
1046
+ case GGML_TYPE_Q3_K: ne11_mm_min = 7; break;
1047
+ case GGML_TYPE_Q4_0:
1048
+ case GGML_TYPE_Q4_1: ne11_mm_min = 15; break;
1049
+ case GGML_TYPE_Q4_K: ne11_mm_min = 11; break;
1050
+ case GGML_TYPE_Q5_0: // not tested yet
1051
+ case GGML_TYPE_Q5_1: ne11_mm_min = 13; break; // not tested yet
1052
+ case GGML_TYPE_Q5_K: ne11_mm_min = 7; break;
1053
+ case GGML_TYPE_Q6_K: ne11_mm_min = 7; break;
1054
+ default: ne11_mm_min = 1; break;
1055
+ }
1056
+ }
1057
+ #endif
1058
+
954
1059
  // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
955
1060
  // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
956
- if (!ggml_is_transposed(src0) &&
1061
+ if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
1062
+ !ggml_is_transposed(src0) &&
957
1063
  !ggml_is_transposed(src1) &&
958
1064
  src1t == GGML_TYPE_F32 &&
959
- [ctx->device supportsFamily:MTLGPUFamilyApple7] &&
960
- ne00%32 == 0 &&
961
- ne11 > 2) {
1065
+ ne00 % 32 == 0 && ne00 >= 64 &&
1066
+ ne11 > ne11_mm_min) {
1067
+ //printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
962
1068
  switch (src0->type) {
963
1069
  case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f32_f32]; break;
964
1070
  case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break;
965
1071
  case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break;
966
1072
  case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break;
1073
+ case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_0_f32]; break;
1074
+ case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_1_f32]; break;
967
1075
  case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q8_0_f32]; break;
968
1076
  case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q2_K_f32]; break;
969
1077
  case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q3_K_f32]; break;
@@ -987,17 +1095,18 @@ void ggml_metal_graph_compute(
987
1095
  [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12];
988
1096
  [encoder setBytes:&gqa length:sizeof(gqa) atIndex:13];
989
1097
  [encoder setThreadgroupMemoryLength:8192 atIndex:0];
990
- [encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
1098
+ [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
991
1099
  } else {
992
1100
  int nth0 = 32;
993
1101
  int nth1 = 1;
994
1102
  int nrows = 1;
1103
+ //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
995
1104
 
996
1105
  // use custom matrix x vector kernel
997
1106
  switch (src0t) {
998
1107
  case GGML_TYPE_F32:
999
1108
  {
1000
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_f32_f32];
1109
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_f32_f32];
1001
1110
  nrows = 4;
1002
1111
  } break;
1003
1112
  case GGML_TYPE_F16:
@@ -1005,12 +1114,12 @@ void ggml_metal_graph_compute(
1005
1114
  nth0 = 32;
1006
1115
  nth1 = 1;
1007
1116
  if (ne11 * ne12 < 4) {
1008
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_1row];
1117
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_1row];
1009
1118
  } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
1010
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_l4];
1119
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_l4];
1011
1120
  nrows = ne11;
1012
1121
  } else {
1013
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
1122
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32];
1014
1123
  nrows = 4;
1015
1124
  }
1016
1125
  } break;
@@ -1021,7 +1130,7 @@ void ggml_metal_graph_compute(
1021
1130
 
1022
1131
  nth0 = 8;
1023
1132
  nth1 = 8;
1024
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0_f32];
1133
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_0_f32];
1025
1134
  } break;
1026
1135
  case GGML_TYPE_Q4_1:
1027
1136
  {
@@ -1030,7 +1139,25 @@ void ggml_metal_graph_compute(
1030
1139
 
1031
1140
  nth0 = 8;
1032
1141
  nth1 = 8;
1033
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_1_f32];
1142
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_1_f32];
1143
+ } break;
1144
+ case GGML_TYPE_Q5_0:
1145
+ {
1146
+ GGML_ASSERT(ne02 == 1);
1147
+ GGML_ASSERT(ne12 == 1);
1148
+
1149
+ nth0 = 8;
1150
+ nth1 = 8;
1151
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_0_f32];
1152
+ } break;
1153
+ case GGML_TYPE_Q5_1:
1154
+ {
1155
+ GGML_ASSERT(ne02 == 1);
1156
+ GGML_ASSERT(ne12 == 1);
1157
+
1158
+ nth0 = 8;
1159
+ nth1 = 8;
1160
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_1_f32];
1034
1161
  } break;
1035
1162
  case GGML_TYPE_Q8_0:
1036
1163
  {
@@ -1039,7 +1166,7 @@ void ggml_metal_graph_compute(
1039
1166
 
1040
1167
  nth0 = 8;
1041
1168
  nth1 = 8;
1042
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_q8_0_f32];
1169
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_q8_0_f32];
1043
1170
  } break;
1044
1171
  case GGML_TYPE_Q2_K:
1045
1172
  {
@@ -1048,7 +1175,7 @@ void ggml_metal_graph_compute(
1048
1175
 
1049
1176
  nth0 = 2;
1050
1177
  nth1 = 32;
1051
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_q2_K_f32];
1178
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_q2_K_f32];
1052
1179
  } break;
1053
1180
  case GGML_TYPE_Q3_K:
1054
1181
  {
@@ -1057,7 +1184,7 @@ void ggml_metal_graph_compute(
1057
1184
 
1058
1185
  nth0 = 2;
1059
1186
  nth1 = 32;
1060
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_q3_K_f32];
1187
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_q3_K_f32];
1061
1188
  } break;
1062
1189
  case GGML_TYPE_Q4_K:
1063
1190
  {
@@ -1066,7 +1193,7 @@ void ggml_metal_graph_compute(
1066
1193
 
1067
1194
  nth0 = 4; //1;
1068
1195
  nth1 = 8; //32;
1069
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_K_f32];
1196
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_K_f32];
1070
1197
  } break;
1071
1198
  case GGML_TYPE_Q5_K:
1072
1199
  {
@@ -1075,7 +1202,7 @@ void ggml_metal_graph_compute(
1075
1202
 
1076
1203
  nth0 = 2;
1077
1204
  nth1 = 32;
1078
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_q5_K_f32];
1205
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_K_f32];
1079
1206
  } break;
1080
1207
  case GGML_TYPE_Q6_K:
1081
1208
  {
@@ -1084,7 +1211,7 @@ void ggml_metal_graph_compute(
1084
1211
 
1085
1212
  nth0 = 2;
1086
1213
  nth1 = 32;
1087
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_q6_K_f32];
1214
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_q6_K_f32];
1088
1215
  } break;
1089
1216
  default:
1090
1217
  {
@@ -1112,8 +1239,9 @@ void ggml_metal_graph_compute(
1112
1239
  [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
1113
1240
  [encoder setBytes:&gqa length:sizeof(gqa) atIndex:17];
1114
1241
 
1115
- if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q8_0 ||
1116
- src0t == GGML_TYPE_Q2_K) {// || src0t == GGML_TYPE_Q4_K) {
1242
+ if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
1243
+ src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 ||
1244
+ src0t == GGML_TYPE_Q2_K) { // || src0t == GGML_TYPE_Q4_K) {
1117
1245
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1118
1246
  }
1119
1247
  else if (src0t == GGML_TYPE_Q4_K) {
@@ -1144,6 +1272,8 @@ void ggml_metal_graph_compute(
1144
1272
  case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
1145
1273
  case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
1146
1274
  case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
1275
+ case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_0]; break;
1276
+ case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_1]; break;
1147
1277
  case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q8_0]; break;
1148
1278
  case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_K]; break;
1149
1279
  case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_K]; break;
@@ -1166,6 +1296,8 @@ void ggml_metal_graph_compute(
1166
1296
  } break;
1167
1297
  case GGML_OP_RMS_NORM:
1168
1298
  {
1299
+ GGML_ASSERT(ne00 % 4 == 0);
1300
+
1169
1301
  float eps;
1170
1302
  memcpy(&eps, dst->op_params, sizeof(float));
1171
1303
 
@@ -1208,7 +1340,7 @@ void ggml_metal_graph_compute(
1208
1340
 
1209
1341
  const int nth = MIN(1024, ne00);
1210
1342
 
1211
- const int n_past = ((int32_t *) dst->op_params)[0]; UNUSED(n_past);
1343
+ //const int n_past = ((int32_t *) dst->op_params)[0];
1212
1344
  const int n_head = ((int32_t *) dst->op_params)[1];
1213
1345
  float max_bias;
1214
1346
  memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
@@ -1371,3 +1503,140 @@ void ggml_metal_graph_compute(
1371
1503
 
1372
1504
  }
1373
1505
  }
1506
+
1507
+ ////////////////////////////////////////////////////////////////////////////////
1508
+
1509
+ // backend interface
1510
+
1511
+ static const char * ggml_backend_metal_name(ggml_backend_t backend) {
1512
+ return "Metal";
1513
+
1514
+ UNUSED(backend);
1515
+ }
1516
+
1517
+ static void ggml_backend_metal_free(ggml_backend_t backend) {
1518
+ struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
1519
+ ggml_metal_free(ctx);
1520
+ free(backend);
1521
+ }
1522
+
1523
+ static void * ggml_backend_metal_buffer_get_base(ggml_backend_buffer_t buffer) {
1524
+ return (void *)buffer->context;
1525
+ }
1526
+
1527
+ static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer) {
1528
+ free(buffer->context);
1529
+ UNUSED(buffer);
1530
+ }
1531
+
1532
+ static struct ggml_backend_buffer_i metal_backend_buffer_i = {
1533
+ /* .free_buffer = */ ggml_backend_metal_buffer_free_buffer,
1534
+ /* .get_base = */ ggml_backend_metal_buffer_get_base,
1535
+ /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
1536
+ /* .init_tensor = */ NULL, // no initialization required
1537
+ /* .free_tensor = */ NULL, // no cleanup required
1538
+ };
1539
+
1540
+ static ggml_backend_buffer_t ggml_backend_metal_alloc_buffer(ggml_backend_t backend, size_t size) {
1541
+ struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
1542
+
1543
+ void * data = ggml_metal_host_malloc(size);
1544
+
1545
+ // TODO: set proper name of the buffers
1546
+ ggml_metal_add_buffer(ctx, "backend", data, size, 0);
1547
+
1548
+ return ggml_backend_buffer_init(backend, metal_backend_buffer_i, data, size);
1549
+ }
1550
+
1551
+ static size_t ggml_backend_metal_get_alignment(ggml_backend_t backend) {
1552
+ return 32;
1553
+ UNUSED(backend);
1554
+ }
1555
+
1556
+ static void ggml_backend_metal_set_tensor_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
1557
+ GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
1558
+ GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
1559
+
1560
+ memcpy((char *)tensor->data + offset, data, size);
1561
+
1562
+ UNUSED(backend);
1563
+ }
1564
+
1565
+ static void ggml_backend_metal_get_tensor_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
1566
+ GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds");
1567
+ GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
1568
+
1569
+ memcpy(data, (const char *)tensor->data + offset, size);
1570
+
1571
+ UNUSED(backend);
1572
+ }
1573
+
1574
+ static void ggml_backend_metal_synchronize(ggml_backend_t backend) {
1575
+ UNUSED(backend);
1576
+ }
1577
+
1578
+ static void ggml_backend_metal_cpy_tensor_from(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst) {
1579
+ ggml_backend_tensor_get(src, dst->data, 0, ggml_nbytes(src));
1580
+
1581
+ UNUSED(backend);
1582
+ }
1583
+
1584
+ static void ggml_backend_metal_cpy_tensor_to(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst) {
1585
+ ggml_backend_tensor_set_async(dst, src->data, 0, ggml_nbytes(src));
1586
+
1587
+ UNUSED(backend);
1588
+ }
1589
+
1590
+ static void ggml_backend_metal_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
1591
+ struct ggml_metal_context * metal_ctx = (struct ggml_metal_context *)backend->context;
1592
+
1593
+ ggml_metal_graph_compute(metal_ctx, cgraph);
1594
+ }
1595
+
1596
+ static bool ggml_backend_metal_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
1597
+ return true;
1598
+ UNUSED(backend);
1599
+ UNUSED(op);
1600
+ }
1601
+
1602
+ static struct ggml_backend_i metal_backend_i = {
1603
+ /* .get_name = */ ggml_backend_metal_name,
1604
+ /* .free = */ ggml_backend_metal_free,
1605
+ /* .alloc_buffer = */ ggml_backend_metal_alloc_buffer,
1606
+ /* .get_alignment = */ ggml_backend_metal_get_alignment,
1607
+ /* .set_tensor_async = */ ggml_backend_metal_set_tensor_async,
1608
+ /* .get_tensor_async = */ ggml_backend_metal_get_tensor_async,
1609
+ /* .synchronize = */ ggml_backend_metal_synchronize,
1610
+ /* .cpy_tensor_from = */ ggml_backend_metal_cpy_tensor_from,
1611
+ /* .cpy_tensor_to = */ ggml_backend_metal_cpy_tensor_to,
1612
+ /* .graph_plan_create = */ NULL, // the metal implementation does not require creating graph plans atm
1613
+ /* .graph_plan_free = */ NULL,
1614
+ /* .graph_plan_compute = */ NULL,
1615
+ /* .graph_compute = */ ggml_backend_metal_graph_compute,
1616
+ /* .supports_op = */ ggml_backend_metal_supports_op,
1617
+ };
1618
+
1619
+ ggml_backend_t ggml_backend_metal_init(void) {
1620
+ struct ggml_metal_context * ctx = malloc(sizeof(struct ggml_metal_context));
1621
+
1622
+ ctx = ggml_metal_init(GGML_DEFAULT_N_THREADS);
1623
+
1624
+ ggml_backend_t metal_backend = malloc(sizeof(struct ggml_backend));
1625
+
1626
+ *metal_backend = (struct ggml_backend) {
1627
+ /* .interface = */ metal_backend_i,
1628
+ /* .context = */ ctx,
1629
+ };
1630
+
1631
+ return metal_backend;
1632
+ }
1633
+
1634
+ bool ggml_backend_is_metal(ggml_backend_t backend) {
1635
+ return backend->iface.get_name == ggml_backend_metal_name;
1636
+ }
1637
+
1638
+ void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
1639
+ struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
1640
+
1641
+ ggml_metal_set_n_cb(ctx, n_cb);
1642
+ }