llama_cpp 0.9.5 → 0.10.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -62,6 +62,8 @@ struct ggml_metal_context {
62
62
  GGML_METAL_DECL_KERNEL(add_row); // TODO: avoid this extra kernel, instead extend the "add" kernel to support broadcast
63
63
  GGML_METAL_DECL_KERNEL(mul);
64
64
  GGML_METAL_DECL_KERNEL(mul_row); // TODO: avoid this extra kernel, instead extend the "mul" kernel to support broadcast
65
+ GGML_METAL_DECL_KERNEL(div);
66
+ GGML_METAL_DECL_KERNEL(div_row);
65
67
  GGML_METAL_DECL_KERNEL(scale);
66
68
  GGML_METAL_DECL_KERNEL(scale_4);
67
69
  GGML_METAL_DECL_KERNEL(silu);
@@ -112,15 +114,35 @@ struct ggml_metal_context {
112
114
  GGML_METAL_DECL_KERNEL(mul_mm_q4_K_f32);
113
115
  GGML_METAL_DECL_KERNEL(mul_mm_q5_K_f32);
114
116
  GGML_METAL_DECL_KERNEL(mul_mm_q6_K_f32);
117
+ GGML_METAL_DECL_KERNEL(mul_mm_id_f32_f32);
118
+ GGML_METAL_DECL_KERNEL(mul_mm_id_f16_f32);
119
+ GGML_METAL_DECL_KERNEL(mul_mm_id_q4_0_f32);
120
+ GGML_METAL_DECL_KERNEL(mul_mm_id_q4_1_f32);
121
+ GGML_METAL_DECL_KERNEL(mul_mm_id_q5_0_f32);
122
+ GGML_METAL_DECL_KERNEL(mul_mm_id_q5_1_f32);
123
+ GGML_METAL_DECL_KERNEL(mul_mm_id_q8_0_f32);
124
+ GGML_METAL_DECL_KERNEL(mul_mm_id_q2_K_f32);
125
+ GGML_METAL_DECL_KERNEL(mul_mm_id_q3_K_f32);
126
+ GGML_METAL_DECL_KERNEL(mul_mm_id_q4_K_f32);
127
+ GGML_METAL_DECL_KERNEL(mul_mm_id_q5_K_f32);
128
+ GGML_METAL_DECL_KERNEL(mul_mm_id_q6_K_f32);
115
129
  GGML_METAL_DECL_KERNEL(rope_f32);
116
130
  GGML_METAL_DECL_KERNEL(rope_f16);
117
131
  GGML_METAL_DECL_KERNEL(alibi_f32);
118
132
  GGML_METAL_DECL_KERNEL(im2col_f16);
133
+ GGML_METAL_DECL_KERNEL(argsort_f32_i32_asc);
134
+ GGML_METAL_DECL_KERNEL(argsort_f32_i32_desc);
119
135
  GGML_METAL_DECL_KERNEL(cpy_f32_f16);
120
136
  GGML_METAL_DECL_KERNEL(cpy_f32_f32);
137
+ GGML_METAL_DECL_KERNEL(cpy_f32_q8_0);
138
+ GGML_METAL_DECL_KERNEL(cpy_f32_q4_0);
139
+ GGML_METAL_DECL_KERNEL(cpy_f32_q4_1);
140
+ //GGML_METAL_DECL_KERNEL(cpy_f32_q5_0);
141
+ //GGML_METAL_DECL_KERNEL(cpy_f32_q5_1);
121
142
  GGML_METAL_DECL_KERNEL(cpy_f16_f16);
122
143
  GGML_METAL_DECL_KERNEL(concat);
123
144
  GGML_METAL_DECL_KERNEL(sqr);
145
+ GGML_METAL_DECL_KERNEL(sum_rows);
124
146
 
125
147
  #undef GGML_METAL_DECL_KERNEL
126
148
  };
@@ -164,12 +186,10 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
164
186
  }
165
187
  }
166
188
 
167
-
168
-
169
189
  struct ggml_metal_context * ggml_metal_init(int n_cb) {
170
190
  GGML_METAL_LOG_INFO("%s: allocating\n", __func__);
171
191
 
172
- id <MTLDevice> device;
192
+ id<MTLDevice> device;
173
193
  NSString * s;
174
194
 
175
195
  #if TARGET_OS_OSX
@@ -215,6 +235,9 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
215
235
 
216
236
  NSString * sourcePath;
217
237
  NSString * ggmlMetalPathResources = [[NSProcessInfo processInfo].environment objectForKey:@"GGML_METAL_PATH_RESOURCES"];
238
+
239
+ GGML_METAL_LOG_INFO("%s: GGML_METAL_PATH_RESOURCES = %s\n", __func__, ggmlMetalPathResources ? [ggmlMetalPathResources UTF8String] : "nil");
240
+
218
241
  if (ggmlMetalPathResources) {
219
242
  sourcePath = [ggmlMetalPathResources stringByAppendingPathComponent:@"ggml-metal.metal"];
220
243
  } else {
@@ -245,6 +268,29 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
245
268
  }
246
269
  }
247
270
 
271
+ #if TARGET_OS_OSX
272
+ // print MTL GPU family:
273
+ GGML_METAL_LOG_INFO("%s: GPU name: %s\n", __func__, [[ctx->device name] UTF8String]);
274
+
275
+ // determine max supported GPU family
276
+ // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
277
+ // https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
278
+ for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) {
279
+ if ([ctx->device supportsFamily:i]) {
280
+ GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - (int) MTLGPUFamilyApple1 + 1, i);
281
+ break;
282
+ }
283
+ }
284
+
285
+ GGML_METAL_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
286
+ GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1e6);
287
+ if (ctx->device.maxTransferRate != 0) {
288
+ GGML_METAL_LOG_INFO("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1e6);
289
+ } else {
290
+ GGML_METAL_LOG_INFO("%s: maxTransferRate = built-in GPU\n", __func__);
291
+ }
292
+ #endif
293
+
248
294
  // load kernels
249
295
  {
250
296
  NSError * error = nil;
@@ -266,6 +312,8 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
266
312
  GGML_METAL_ADD_KERNEL(add_row);
267
313
  GGML_METAL_ADD_KERNEL(mul);
268
314
  GGML_METAL_ADD_KERNEL(mul_row);
315
+ GGML_METAL_ADD_KERNEL(div);
316
+ GGML_METAL_ADD_KERNEL(div_row);
269
317
  GGML_METAL_ADD_KERNEL(scale);
270
318
  GGML_METAL_ADD_KERNEL(scale_4);
271
319
  GGML_METAL_ADD_KERNEL(silu);
@@ -317,43 +365,40 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
317
365
  GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32);
318
366
  GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32);
319
367
  GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32);
368
+ GGML_METAL_ADD_KERNEL(mul_mm_id_f32_f32);
369
+ GGML_METAL_ADD_KERNEL(mul_mm_id_f16_f32);
370
+ GGML_METAL_ADD_KERNEL(mul_mm_id_q4_0_f32);
371
+ GGML_METAL_ADD_KERNEL(mul_mm_id_q4_1_f32);
372
+ GGML_METAL_ADD_KERNEL(mul_mm_id_q5_0_f32);
373
+ GGML_METAL_ADD_KERNEL(mul_mm_id_q5_1_f32);
374
+ GGML_METAL_ADD_KERNEL(mul_mm_id_q8_0_f32);
375
+ GGML_METAL_ADD_KERNEL(mul_mm_id_q2_K_f32);
376
+ GGML_METAL_ADD_KERNEL(mul_mm_id_q3_K_f32);
377
+ GGML_METAL_ADD_KERNEL(mul_mm_id_q4_K_f32);
378
+ GGML_METAL_ADD_KERNEL(mul_mm_id_q5_K_f32);
379
+ GGML_METAL_ADD_KERNEL(mul_mm_id_q6_K_f32);
320
380
  }
321
381
  GGML_METAL_ADD_KERNEL(rope_f32);
322
382
  GGML_METAL_ADD_KERNEL(rope_f16);
323
383
  GGML_METAL_ADD_KERNEL(alibi_f32);
324
384
  GGML_METAL_ADD_KERNEL(im2col_f16);
385
+ GGML_METAL_ADD_KERNEL(argsort_f32_i32_asc);
386
+ GGML_METAL_ADD_KERNEL(argsort_f32_i32_desc);
325
387
  GGML_METAL_ADD_KERNEL(cpy_f32_f16);
326
388
  GGML_METAL_ADD_KERNEL(cpy_f32_f32);
389
+ GGML_METAL_ADD_KERNEL(cpy_f32_q8_0);
390
+ GGML_METAL_ADD_KERNEL(cpy_f32_q4_0);
391
+ GGML_METAL_ADD_KERNEL(cpy_f32_q4_1);
392
+ //GGML_METAL_ADD_KERNEL(cpy_f32_q5_0);
393
+ //GGML_METAL_ADD_KERNEL(cpy_f32_q5_1);
327
394
  GGML_METAL_ADD_KERNEL(cpy_f16_f16);
328
395
  GGML_METAL_ADD_KERNEL(concat);
329
396
  GGML_METAL_ADD_KERNEL(sqr);
397
+ GGML_METAL_ADD_KERNEL(sum_rows);
330
398
 
331
399
  #undef GGML_METAL_ADD_KERNEL
332
400
  }
333
401
 
334
- #if TARGET_OS_OSX
335
- // print MTL GPU family:
336
- GGML_METAL_LOG_INFO("%s: GPU name: %s\n", __func__, [[ctx->device name] UTF8String]);
337
-
338
- // determine max supported GPU family
339
- // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
340
- // https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
341
- for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) {
342
- if ([ctx->device supportsFamily:i]) {
343
- GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - (int) MTLGPUFamilyApple1 + 1, i);
344
- break;
345
- }
346
- }
347
-
348
- GGML_METAL_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
349
- GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MiB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
350
- if (ctx->device.maxTransferRate != 0) {
351
- GGML_METAL_LOG_INFO("%s: maxTransferRate = %8.2f MiB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
352
- } else {
353
- GGML_METAL_LOG_INFO("%s: maxTransferRate = built-in GPU\n", __func__);
354
- }
355
- #endif
356
-
357
402
  return ctx;
358
403
  }
359
404
 
@@ -367,6 +412,8 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
367
412
  GGML_METAL_DEL_KERNEL(add_row);
368
413
  GGML_METAL_DEL_KERNEL(mul);
369
414
  GGML_METAL_DEL_KERNEL(mul_row);
415
+ GGML_METAL_DEL_KERNEL(div);
416
+ GGML_METAL_DEL_KERNEL(div_row);
370
417
  GGML_METAL_DEL_KERNEL(scale);
371
418
  GGML_METAL_DEL_KERNEL(scale_4);
372
419
  GGML_METAL_DEL_KERNEL(silu);
@@ -418,16 +465,36 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
418
465
  GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32);
419
466
  GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32);
420
467
  GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32);
468
+ GGML_METAL_DEL_KERNEL(mul_mm_id_f32_f32);
469
+ GGML_METAL_DEL_KERNEL(mul_mm_id_f16_f32);
470
+ GGML_METAL_DEL_KERNEL(mul_mm_id_q4_0_f32);
471
+ GGML_METAL_DEL_KERNEL(mul_mm_id_q4_1_f32);
472
+ GGML_METAL_DEL_KERNEL(mul_mm_id_q5_0_f32);
473
+ GGML_METAL_DEL_KERNEL(mul_mm_id_q5_1_f32);
474
+ GGML_METAL_DEL_KERNEL(mul_mm_id_q8_0_f32);
475
+ GGML_METAL_DEL_KERNEL(mul_mm_id_q2_K_f32);
476
+ GGML_METAL_DEL_KERNEL(mul_mm_id_q3_K_f32);
477
+ GGML_METAL_DEL_KERNEL(mul_mm_id_q4_K_f32);
478
+ GGML_METAL_DEL_KERNEL(mul_mm_id_q5_K_f32);
479
+ GGML_METAL_DEL_KERNEL(mul_mm_id_q6_K_f32);
421
480
  }
422
481
  GGML_METAL_DEL_KERNEL(rope_f32);
423
482
  GGML_METAL_DEL_KERNEL(rope_f16);
424
483
  GGML_METAL_DEL_KERNEL(alibi_f32);
425
484
  GGML_METAL_DEL_KERNEL(im2col_f16);
485
+ GGML_METAL_DEL_KERNEL(argsort_f32_i32_asc);
486
+ GGML_METAL_DEL_KERNEL(argsort_f32_i32_desc);
426
487
  GGML_METAL_DEL_KERNEL(cpy_f32_f16);
427
488
  GGML_METAL_DEL_KERNEL(cpy_f32_f32);
489
+ GGML_METAL_DEL_KERNEL(cpy_f32_q8_0);
490
+ GGML_METAL_DEL_KERNEL(cpy_f32_q4_0);
491
+ GGML_METAL_DEL_KERNEL(cpy_f32_q4_1);
492
+ //GGML_METAL_DEL_KERNEL(cpy_f32_q5_0);
493
+ //GGML_METAL_DEL_KERNEL(cpy_f32_q5_1);
428
494
  GGML_METAL_DEL_KERNEL(cpy_f16_f16);
429
495
  GGML_METAL_DEL_KERNEL(concat);
430
496
  GGML_METAL_DEL_KERNEL(sqr);
497
+ GGML_METAL_DEL_KERNEL(sum_rows);
431
498
 
432
499
  #undef GGML_METAL_DEL_KERNEL
433
500
 
@@ -471,6 +538,13 @@ int * ggml_metal_get_concur_list(struct ggml_metal_context * ctx) {
471
538
  return ctx->concur_list;
472
539
  }
473
540
 
541
+ // temporarily defined here for compatibility between ggml-backend and the old API
542
+ struct ggml_backend_metal_buffer_context {
543
+ void * data;
544
+
545
+ id<MTLBuffer> metal;
546
+ };
547
+
474
548
  // finds the Metal buffer that contains the tensor data on the GPU device
475
549
  // the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the
476
550
  // Metal buffer based on the host memory pointer
@@ -480,8 +554,17 @@ static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, stru
480
554
 
481
555
  const int64_t tsize = ggml_nbytes(t);
482
556
 
483
- if (t->buffer && t->buffer->backend && t->buffer->backend->context) {
484
- ctx = t->buffer->backend->context;
557
+ // compatibility with ggml-backend
558
+ if (t->buffer && t->buffer->buft == ggml_backend_metal_buffer_type()) {
559
+ struct ggml_backend_metal_buffer_context * buf_ctx = (struct ggml_backend_metal_buffer_context *) t->buffer->context;
560
+
561
+ const int64_t ioffs = (int64_t) t->data - (int64_t) buf_ctx->data;
562
+
563
+ GGML_ASSERT(ioffs >= 0 && ioffs + tsize <= (int64_t) t->buffer->size);
564
+
565
+ *offs = (size_t) ioffs;
566
+
567
+ return buf_ctx->metal;
485
568
  }
486
569
 
487
570
  // find the view that contains the tensor fully
@@ -706,6 +789,51 @@ void ggml_metal_graph_find_concurrency(
706
789
  }
707
790
  }
708
791
 
792
+ static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
793
+ switch (op->op) {
794
+ case GGML_OP_UNARY:
795
+ switch (ggml_get_unary_op(op)) {
796
+ case GGML_UNARY_OP_SILU:
797
+ case GGML_UNARY_OP_RELU:
798
+ case GGML_UNARY_OP_GELU:
799
+ return true;
800
+ default:
801
+ return false;
802
+ }
803
+ case GGML_OP_NONE:
804
+ case GGML_OP_RESHAPE:
805
+ case GGML_OP_VIEW:
806
+ case GGML_OP_TRANSPOSE:
807
+ case GGML_OP_PERMUTE:
808
+ case GGML_OP_CONCAT:
809
+ case GGML_OP_ADD:
810
+ case GGML_OP_MUL:
811
+ case GGML_OP_DIV:
812
+ case GGML_OP_SCALE:
813
+ case GGML_OP_SQR:
814
+ case GGML_OP_SUM_ROWS:
815
+ case GGML_OP_SOFT_MAX:
816
+ case GGML_OP_RMS_NORM:
817
+ case GGML_OP_NORM:
818
+ case GGML_OP_ALIBI:
819
+ case GGML_OP_ROPE:
820
+ case GGML_OP_IM2COL:
821
+ case GGML_OP_ARGSORT:
822
+ case GGML_OP_DUP:
823
+ case GGML_OP_CPY:
824
+ case GGML_OP_CONT:
825
+ case GGML_OP_MUL_MAT:
826
+ case GGML_OP_MUL_MAT_ID:
827
+ return true;
828
+ case GGML_OP_DIAG_MASK_INF:
829
+ case GGML_OP_GET_ROWS:
830
+ {
831
+ return op->ne[0] % 4 == 0;
832
+ }
833
+ default:
834
+ return false;
835
+ }
836
+ }
709
837
  void ggml_metal_graph_compute(
710
838
  struct ggml_metal_context * ctx,
711
839
  struct ggml_cgraph * gf) {
@@ -776,6 +904,8 @@ void ggml_metal_graph_compute(
776
904
  } break;
777
905
  }
778
906
 
907
+ GGML_ASSERT(ggml_metal_supports_op(dst));
908
+
779
909
  const int64_t ne00 = src0 ? src0->ne[0] : 0;
780
910
  const int64_t ne01 = src0 ? src0->ne[1] : 0;
781
911
  const int64_t ne02 = src0 ? src0->ne[2] : 0;
@@ -868,6 +998,8 @@ void ggml_metal_graph_compute(
868
998
  [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
869
999
  } break;
870
1000
  case GGML_OP_ADD:
1001
+ case GGML_OP_MUL:
1002
+ case GGML_OP_DIV:
871
1003
  {
872
1004
  GGML_ASSERT(ggml_is_contiguous(src0));
873
1005
  GGML_ASSERT(ggml_is_contiguous(src1));
@@ -881,11 +1013,21 @@ void ggml_metal_graph_compute(
881
1013
  GGML_ASSERT(ne11 == 1);
882
1014
 
883
1015
  nb = ne00 / 4;
884
- [encoder setComputePipelineState:ctx->pipeline_add_row];
1016
+ switch (dst->op) {
1017
+ case GGML_OP_ADD: [encoder setComputePipelineState:ctx->pipeline_add_row]; break;
1018
+ case GGML_OP_MUL: [encoder setComputePipelineState:ctx->pipeline_mul_row]; break;
1019
+ case GGML_OP_DIV: [encoder setComputePipelineState:ctx->pipeline_div_row]; break;
1020
+ default: GGML_ASSERT(false);
1021
+ }
885
1022
 
886
1023
  bcast_row = true;
887
1024
  } else {
888
- [encoder setComputePipelineState:ctx->pipeline_add];
1025
+ switch (dst->op) {
1026
+ case GGML_OP_ADD: [encoder setComputePipelineState:ctx->pipeline_add]; break;
1027
+ case GGML_OP_MUL: [encoder setComputePipelineState:ctx->pipeline_mul]; break;
1028
+ case GGML_OP_DIV: [encoder setComputePipelineState:ctx->pipeline_div]; break;
1029
+ default: GGML_ASSERT(false);
1030
+ }
889
1031
  }
890
1032
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
891
1033
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
@@ -926,31 +1068,6 @@ void ggml_metal_graph_compute(
926
1068
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
927
1069
  }
928
1070
  } break;
929
- case GGML_OP_MUL:
930
- {
931
- GGML_ASSERT(ggml_is_contiguous(src0));
932
- GGML_ASSERT(ggml_is_contiguous(src1));
933
-
934
- // utilize float4
935
- GGML_ASSERT(ne00 % 4 == 0);
936
- const int64_t nb = ne00/4;
937
-
938
- if (ggml_nelements(src1) == ne10) {
939
- // src1 is a row
940
- GGML_ASSERT(ne11 == 1);
941
- [encoder setComputePipelineState:ctx->pipeline_mul_row];
942
- } else {
943
- [encoder setComputePipelineState:ctx->pipeline_mul];
944
- }
945
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
946
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
947
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
948
- [encoder setBytes:&nb length:sizeof(nb) atIndex:3];
949
-
950
- const int64_t n = ggml_nelements(dst)/4;
951
-
952
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
953
- } break;
954
1071
  case GGML_OP_SCALE:
955
1072
  {
956
1073
  GGML_ASSERT(ggml_is_contiguous(src0));
@@ -1023,6 +1140,40 @@ void ggml_metal_graph_compute(
1023
1140
  const int64_t n = ggml_nelements(dst);
1024
1141
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1025
1142
  } break;
1143
+ case GGML_OP_SUM_ROWS:
1144
+ {
1145
+ GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
1146
+
1147
+ [encoder setComputePipelineState:ctx->pipeline_sum_rows];
1148
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1149
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1150
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
1151
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
1152
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
1153
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
1154
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
1155
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
1156
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
1157
+ [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
1158
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
1159
+ [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
1160
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
1161
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
1162
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
1163
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
1164
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
1165
+ [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17];
1166
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:18];
1167
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:19];
1168
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:20];
1169
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:21];
1170
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:22];
1171
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:23];
1172
+ [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:24];
1173
+ [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:25];
1174
+
1175
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1176
+ } break;
1026
1177
  case GGML_OP_SOFT_MAX:
1027
1178
  {
1028
1179
  int nth = 32; // SIMD width
@@ -1077,9 +1228,13 @@ void ggml_metal_graph_compute(
1077
1228
  case GGML_OP_MUL_MAT:
1078
1229
  {
1079
1230
  GGML_ASSERT(ne00 == ne10);
1080
- GGML_ASSERT(ne03 == ne13);
1081
1231
 
1082
- const uint gqa = ne12/ne02;
1232
+ // TODO: assert that dim2 and dim3 are contiguous
1233
+ GGML_ASSERT(ne12 % ne02 == 0);
1234
+ GGML_ASSERT(ne13 % ne03 == 0);
1235
+
1236
+ const uint r2 = ne12/ne02;
1237
+ const uint r3 = ne13/ne03;
1083
1238
 
1084
1239
  // find the break-even point where the matrix-matrix kernel becomes more efficient compared
1085
1240
  // to the matrix-vector kernel
@@ -1114,7 +1269,7 @@ void ggml_metal_graph_compute(
1114
1269
  !ggml_is_transposed(src1) &&
1115
1270
  src1t == GGML_TYPE_F32 &&
1116
1271
  ne00 % 32 == 0 && ne00 >= 64 &&
1117
- ne11 > ne11_mm_min) {
1272
+ (ne11 > ne11_mm_min || (ggml_is_quantized(src0t) && ne12 > 1))) {
1118
1273
  //printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
1119
1274
  switch (src0->type) {
1120
1275
  case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f32_f32]; break;
@@ -1144,9 +1299,10 @@ void ggml_metal_graph_compute(
1144
1299
  [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10];
1145
1300
  [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11];
1146
1301
  [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12];
1147
- [encoder setBytes:&gqa length:sizeof(gqa) atIndex:13];
1302
+ [encoder setBytes:&r2 length:sizeof(r2) atIndex:13];
1303
+ [encoder setBytes:&r3 length:sizeof(r3) atIndex:14];
1148
1304
  [encoder setThreadgroupMemoryLength:8192 atIndex:0];
1149
- [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
1305
+ [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
1150
1306
  } else {
1151
1307
  int nth0 = 32;
1152
1308
  int nth1 = 1;
@@ -1182,90 +1338,60 @@ void ggml_metal_graph_compute(
1182
1338
  } break;
1183
1339
  case GGML_TYPE_Q4_0:
1184
1340
  {
1185
- GGML_ASSERT(ne02 == 1);
1186
- GGML_ASSERT(ne12 == 1);
1187
-
1188
1341
  nth0 = 8;
1189
1342
  nth1 = 8;
1190
1343
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_0_f32];
1191
1344
  } break;
1192
1345
  case GGML_TYPE_Q4_1:
1193
1346
  {
1194
- GGML_ASSERT(ne02 == 1);
1195
- GGML_ASSERT(ne12 == 1);
1196
-
1197
1347
  nth0 = 8;
1198
1348
  nth1 = 8;
1199
1349
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_1_f32];
1200
1350
  } break;
1201
1351
  case GGML_TYPE_Q5_0:
1202
1352
  {
1203
- GGML_ASSERT(ne02 == 1);
1204
- GGML_ASSERT(ne12 == 1);
1205
-
1206
1353
  nth0 = 8;
1207
1354
  nth1 = 8;
1208
1355
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_0_f32];
1209
1356
  } break;
1210
1357
  case GGML_TYPE_Q5_1:
1211
1358
  {
1212
- GGML_ASSERT(ne02 == 1);
1213
- GGML_ASSERT(ne12 == 1);
1214
-
1215
1359
  nth0 = 8;
1216
1360
  nth1 = 8;
1217
1361
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_1_f32];
1218
1362
  } break;
1219
1363
  case GGML_TYPE_Q8_0:
1220
1364
  {
1221
- GGML_ASSERT(ne02 == 1);
1222
- GGML_ASSERT(ne12 == 1);
1223
-
1224
1365
  nth0 = 8;
1225
1366
  nth1 = 8;
1226
1367
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_q8_0_f32];
1227
1368
  } break;
1228
1369
  case GGML_TYPE_Q2_K:
1229
1370
  {
1230
- GGML_ASSERT(ne02 == 1);
1231
- GGML_ASSERT(ne12 == 1);
1232
-
1233
1371
  nth0 = 2;
1234
1372
  nth1 = 32;
1235
1373
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_q2_K_f32];
1236
1374
  } break;
1237
1375
  case GGML_TYPE_Q3_K:
1238
1376
  {
1239
- GGML_ASSERT(ne02 == 1);
1240
- GGML_ASSERT(ne12 == 1);
1241
-
1242
1377
  nth0 = 2;
1243
1378
  nth1 = 32;
1244
1379
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_q3_K_f32];
1245
1380
  } break;
1246
1381
  case GGML_TYPE_Q4_K:
1247
1382
  {
1248
- GGML_ASSERT(ne02 == 1);
1249
- GGML_ASSERT(ne12 == 1);
1250
-
1251
1383
  nth0 = 4; //1;
1252
1384
  nth1 = 8; //32;
1253
1385
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_K_f32];
1254
1386
  } break;
1255
1387
  case GGML_TYPE_Q5_K:
1256
1388
  {
1257
- GGML_ASSERT(ne02 == 1);
1258
- GGML_ASSERT(ne12 == 1);
1259
-
1260
1389
  nth0 = 2;
1261
1390
  nth1 = 32;
1262
1391
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_K_f32];
1263
1392
  } break;
1264
1393
  case GGML_TYPE_Q6_K:
1265
1394
  {
1266
- GGML_ASSERT(ne02 == 1);
1267
- GGML_ASSERT(ne12 == 1);
1268
-
1269
1395
  nth0 = 2;
1270
1396
  nth1 = 32;
1271
1397
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_q6_K_f32];
@@ -1294,32 +1420,125 @@ void ggml_metal_graph_compute(
1294
1420
  [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14];
1295
1421
  [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15];
1296
1422
  [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
1297
- [encoder setBytes:&gqa length:sizeof(gqa) atIndex:17];
1423
+ [encoder setBytes:&r2 length:sizeof(r2) atIndex:17];
1424
+ [encoder setBytes:&r3 length:sizeof(r3) atIndex:18];
1298
1425
 
1299
1426
  if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
1300
1427
  src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 ||
1301
1428
  src0t == GGML_TYPE_Q2_K) { // || src0t == GGML_TYPE_Q4_K) {
1302
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1429
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1303
1430
  }
1304
1431
  else if (src0t == GGML_TYPE_Q4_K) {
1305
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1432
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1306
1433
  }
1307
1434
  else if (src0t == GGML_TYPE_Q3_K) {
1308
1435
  #ifdef GGML_QKK_64
1309
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1436
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1310
1437
  #else
1311
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1438
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1312
1439
  #endif
1313
1440
  }
1314
1441
  else if (src0t == GGML_TYPE_Q5_K) {
1315
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1442
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1316
1443
  }
1317
1444
  else if (src0t == GGML_TYPE_Q6_K) {
1318
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1445
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1319
1446
  } else {
1320
1447
  int64_t ny = (ne11 + nrows - 1)/nrows;
1321
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1448
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1449
+ }
1450
+ }
1451
+ } break;
1452
+ case GGML_OP_MUL_MAT_ID:
1453
+ {
1454
+ //GGML_ASSERT(ne00 == ne10);
1455
+ //GGML_ASSERT(ne03 == ne13);
1456
+
1457
+ GGML_ASSERT(src0t == GGML_TYPE_I32);
1458
+
1459
+ const int n_as = ne00;
1460
+
1461
+ // TODO: make this more general
1462
+ GGML_ASSERT(n_as <= 8);
1463
+
1464
+ struct ggml_tensor * src2 = gf->nodes[i]->src[2];
1465
+
1466
+ const int64_t ne20 = src2 ? src2->ne[0] : 0;
1467
+ const int64_t ne21 = src2 ? src2->ne[1] : 0;
1468
+ const int64_t ne22 = src2 ? src2->ne[2] : 0;
1469
+ const int64_t ne23 = src2 ? src2->ne[3] : 0; GGML_UNUSED(ne23);
1470
+
1471
+ const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20);
1472
+ const uint64_t nb21 = src2 ? src2->nb[1] : 0;
1473
+ const uint64_t nb22 = src2 ? src2->nb[2] : 0;
1474
+ const uint64_t nb23 = src2 ? src2->nb[3] : 0; GGML_UNUSED(nb23);
1475
+
1476
+ const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t);
1477
+
1478
+ GGML_ASSERT(!ggml_is_transposed(src2));
1479
+ GGML_ASSERT(!ggml_is_transposed(src1));
1480
+
1481
+ GGML_ASSERT(ne20 % 32 == 0);
1482
+ // !!!!!!!!! TODO: this assert is probably required but not sure!
1483
+ //GGML_ASSERT(ne20 >= 64);
1484
+ GGML_ASSERT(src1t == GGML_TYPE_F32);
1485
+
1486
+ const uint r2 = ne12/ne22;
1487
+ const uint r3 = ne13/ne23;
1488
+
1489
+ // find the break-even point where the matrix-matrix kernel becomes more efficient compared
1490
+ // to the matrix-vector kernel
1491
+ int ne11_mm_min = 0;
1492
+
1493
+ const int idx = ((int32_t *) dst->op_params)[0];
1494
+
1495
+ // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
1496
+ // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
1497
+ if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
1498
+ ne11 > ne11_mm_min) {
1499
+ switch (src2->type) {
1500
+ case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f32_f32]; break;
1501
+ case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f16_f32]; break;
1502
+ case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q4_0_f32]; break;
1503
+ case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q4_1_f32]; break;
1504
+ case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_0_f32]; break;
1505
+ case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_1_f32]; break;
1506
+ case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q8_0_f32]; break;
1507
+ case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q2_K_f32]; break;
1508
+ case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q3_K_f32]; break;
1509
+ case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q4_K_f32]; break;
1510
+ case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_K_f32]; break;
1511
+ case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q6_K_f32]; break;
1512
+ default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
1322
1513
  }
1514
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1515
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1516
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1517
+ [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:3];
1518
+ [encoder setBytes:&ne22 length:sizeof(ne22) atIndex:4];
1519
+ [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:5];
1520
+ [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:6];
1521
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
1522
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:8];
1523
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:9];
1524
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10];
1525
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11];
1526
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12];
1527
+ [encoder setBytes:&r2 length:sizeof(r2) atIndex:13];
1528
+ [encoder setBytes:&r3 length:sizeof(r3) atIndex:14];
1529
+ [encoder setBytes:&idx length:sizeof(idx) atIndex:15];
1530
+ // TODO: how to make this an array? read Metal docs
1531
+ for (int j = 0; j < n_as; ++j) {
1532
+ struct ggml_tensor * src_cur = dst->src[2 + j];
1533
+
1534
+ size_t offs_src_cur = 0;
1535
+ id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
1536
+
1537
+ [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:16 + j];
1538
+ }
1539
+
1540
+ [encoder setThreadgroupMemoryLength:8192 atIndex:0];
1541
+ [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne21 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
1323
1542
  }
1324
1543
  } break;
1325
1544
  case GGML_OP_GET_ROWS:
@@ -1545,18 +1764,48 @@ void ggml_metal_graph_compute(
1545
1764
 
1546
1765
  [encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
1547
1766
  } break;
1767
+ case GGML_OP_ARGSORT:
1768
+ {
1769
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
1770
+ GGML_ASSERT( dst->type == GGML_TYPE_I32);
1771
+
1772
+ const int nrows = ggml_nrows(src0);
1773
+
1774
+ enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
1775
+
1776
+ switch (order) {
1777
+ case GGML_SORT_ASC: [encoder setComputePipelineState:ctx->pipeline_argsort_f32_i32_asc]; break;
1778
+ case GGML_SORT_DESC: [encoder setComputePipelineState:ctx->pipeline_argsort_f32_i32_desc]; break;
1779
+ default: GGML_ASSERT(false);
1780
+ };
1781
+
1782
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1783
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1784
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1785
+
1786
+ [encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00, 1, 1)];
1787
+ } break;
1548
1788
  case GGML_OP_DUP:
1549
1789
  case GGML_OP_CPY:
1550
1790
  case GGML_OP_CONT:
1551
1791
  {
1552
- const int nth = MIN(1024, ne00);
1792
+ GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
1793
+
1794
+ int nth = MIN(1024, ne00/ggml_blck_size(src0->type));
1553
1795
 
1554
1796
  switch (src0t) {
1555
1797
  case GGML_TYPE_F32:
1556
1798
  {
1799
+ GGML_ASSERT(ne0 % ggml_blck_size(dst->type) == 0);
1800
+
1557
1801
  switch (dstt) {
1558
- case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f16]; break;
1559
- case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32]; break;
1802
+ case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f16]; break;
1803
+ case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32]; break;
1804
+ case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q8_0]; break;
1805
+ case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q4_0]; break;
1806
+ case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q4_1]; break;
1807
+ //case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q5_0]; break;
1808
+ //case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q5_1]; break;
1560
1809
  default: GGML_ASSERT(false && "not implemented");
1561
1810
  };
1562
1811
  } break;
@@ -1631,81 +1880,150 @@ void ggml_metal_graph_compute(
1631
1880
 
1632
1881
  // backend interface
1633
1882
 
1634
- static const char * ggml_backend_metal_name(ggml_backend_t backend) {
1635
- return "Metal";
1883
+ static id<MTLDevice> g_backend_device = nil;
1884
+ static int g_backend_device_ref_count = 0;
1636
1885
 
1637
- UNUSED(backend);
1886
+ static id<MTLDevice> ggml_backend_metal_get_device(void) {
1887
+ if (g_backend_device == nil) {
1888
+ g_backend_device = MTLCreateSystemDefaultDevice();
1889
+ }
1890
+
1891
+ g_backend_device_ref_count++;
1892
+
1893
+ return g_backend_device;
1638
1894
  }
1639
1895
 
1640
- static void ggml_backend_metal_free(ggml_backend_t backend) {
1641
- struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
1642
- ggml_metal_free(ctx);
1643
- free(backend);
1896
+ static void ggml_backend_metal_free_device(void) {
1897
+ assert(g_backend_device_ref_count > 0);
1898
+
1899
+ g_backend_device_ref_count--;
1900
+
1901
+ if (g_backend_device_ref_count == 0) {
1902
+ [g_backend_device release];
1903
+ g_backend_device = nil;
1904
+ }
1644
1905
  }
1645
1906
 
1646
1907
  static void * ggml_backend_metal_buffer_get_base(ggml_backend_buffer_t buffer) {
1647
- return (void *)buffer->context;
1908
+ struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context;
1909
+
1910
+ return ctx->data;
1648
1911
  }
1649
1912
 
1650
1913
  static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer) {
1651
- free(buffer->context);
1914
+ struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context;
1915
+
1916
+ [ctx->metal release];
1917
+ ggml_backend_metal_free_device();
1918
+
1919
+ free(ctx->data);
1920
+ free(ctx);
1921
+
1922
+ UNUSED(buffer);
1923
+ }
1924
+
1925
+ static void ggml_backend_metal_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
1926
+ GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
1927
+ GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
1928
+
1929
+ memcpy((char *)tensor->data + offset, data, size);
1930
+
1931
+ UNUSED(buffer);
1932
+ }
1933
+
1934
+ static void ggml_backend_metal_buffer_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
1935
+ GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds");
1936
+ GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
1937
+
1938
+ memcpy(data, (const char *)tensor->data + offset, size);
1939
+
1940
+ UNUSED(buffer);
1941
+ }
1942
+
1943
+ static void ggml_backend_metal_buffer_cpy_tensor_from(ggml_backend_buffer_t buffer, struct ggml_tensor * src, struct ggml_tensor * dst) {
1944
+ ggml_backend_tensor_get(src, dst->data, 0, ggml_nbytes(src));
1945
+
1946
+ UNUSED(buffer);
1947
+ }
1948
+
1949
+ static void ggml_backend_metal_buffer_cpy_tensor_to(ggml_backend_buffer_t buffer, struct ggml_tensor * src, struct ggml_tensor * dst) {
1950
+ ggml_backend_tensor_set(dst, src->data, 0, ggml_nbytes(src));
1951
+
1652
1952
  UNUSED(buffer);
1653
1953
  }
1654
1954
 
1655
1955
  static struct ggml_backend_buffer_i metal_backend_buffer_i = {
1656
- /* .free_buffer = */ ggml_backend_metal_buffer_free_buffer,
1657
- /* .get_base = */ ggml_backend_metal_buffer_get_base,
1658
- /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
1659
- /* .init_tensor = */ NULL, // no initialization required
1660
- /* .free_tensor = */ NULL, // no cleanup required
1956
+ /* .free_buffer = */ ggml_backend_metal_buffer_free_buffer,
1957
+ /* .get_base = */ ggml_backend_metal_buffer_get_base,
1958
+ /* .init_tensor = */ NULL,
1959
+ /* .set_tensor = */ ggml_backend_metal_buffer_set_tensor,
1960
+ /* .get_tensor = */ ggml_backend_metal_buffer_get_tensor,
1961
+ /* .cpy_tensor_from = */ ggml_backend_metal_buffer_cpy_tensor_from,
1962
+ /* .cpy_tensor_to = */ ggml_backend_metal_buffer_cpy_tensor_to,
1661
1963
  };
1662
1964
 
1663
- static ggml_backend_buffer_t ggml_backend_metal_alloc_buffer(ggml_backend_t backend, size_t size) {
1664
- struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
1965
+ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
1966
+ struct ggml_backend_metal_buffer_context * ctx = malloc(sizeof(struct ggml_backend_metal_buffer_context));
1665
1967
 
1666
- void * data = ggml_metal_host_malloc(size);
1968
+ const size_t size_page = sysconf(_SC_PAGESIZE);
1667
1969
 
1668
- // TODO: set proper name of the buffers
1669
- ggml_metal_add_buffer(ctx, "backend", data, size, 0);
1970
+ size_t size_aligned = size;
1971
+ if ((size_aligned % size_page) != 0) {
1972
+ size_aligned += (size_page - (size_aligned % size_page));
1973
+ }
1974
+
1975
+ ctx->data = ggml_metal_host_malloc(size);
1976
+ ctx->metal = [ggml_backend_metal_get_device() newBufferWithBytesNoCopy:ctx->data
1977
+ length:size_aligned
1978
+ options:MTLResourceStorageModeShared
1979
+ deallocator:nil];
1670
1980
 
1671
- return ggml_backend_buffer_init(backend, metal_backend_buffer_i, data, size);
1981
+ return ggml_backend_buffer_init(buft, metal_backend_buffer_i, ctx, size);
1672
1982
  }
1673
1983
 
1674
- static size_t ggml_backend_metal_get_alignment(ggml_backend_t backend) {
1984
+ static size_t ggml_backend_metal_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
1675
1985
  return 32;
1676
- UNUSED(backend);
1986
+ UNUSED(buft);
1677
1987
  }
1678
1988
 
1679
- static void ggml_backend_metal_set_tensor_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
1680
- GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
1681
- GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
1682
-
1683
- memcpy((char *)tensor->data + offset, data, size);
1989
+ static bool ggml_backend_metal_buffer_type_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend) {
1990
+ return ggml_backend_is_metal(backend) || ggml_backend_is_cpu(backend);
1684
1991
 
1685
- UNUSED(backend);
1992
+ GGML_UNUSED(buft);
1686
1993
  }
1687
1994
 
1688
- static void ggml_backend_metal_get_tensor_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
1689
- GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds");
1690
- GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
1691
-
1692
- memcpy(data, (const char *)tensor->data + offset, size);
1995
+ ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void) {
1996
+ static struct ggml_backend_buffer_type ggml_backend_buffer_type_metal = {
1997
+ /* .iface = */ {
1998
+ /* .alloc_buffer = */ ggml_backend_metal_buffer_type_alloc_buffer,
1999
+ /* .get_alignment = */ ggml_backend_metal_buffer_type_get_alignment,
2000
+ /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
2001
+ /* .supports_backend = */ ggml_backend_metal_buffer_type_supports_backend,
2002
+ },
2003
+ /* .context = */ NULL,
2004
+ };
1693
2005
 
1694
- UNUSED(backend);
2006
+ return &ggml_backend_buffer_type_metal;
1695
2007
  }
1696
2008
 
1697
- static void ggml_backend_metal_synchronize(ggml_backend_t backend) {
2009
+ static const char * ggml_backend_metal_name(ggml_backend_t backend) {
2010
+ return "Metal";
2011
+
1698
2012
  UNUSED(backend);
1699
2013
  }
1700
2014
 
1701
- static void ggml_backend_metal_cpy_tensor_from(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst) {
1702
- ggml_backend_tensor_get(src, dst->data, 0, ggml_nbytes(src));
2015
+ static void ggml_backend_metal_free(ggml_backend_t backend) {
2016
+ struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
2017
+ ggml_metal_free(ctx);
2018
+ free(backend);
2019
+ }
1703
2020
 
2021
+ static void ggml_backend_metal_synchronize(ggml_backend_t backend) {
1704
2022
  UNUSED(backend);
1705
2023
  }
1706
2024
 
1707
- static void ggml_backend_metal_cpy_tensor_to(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst) {
1708
- ggml_backend_tensor_set_async(dst, src->data, 0, ggml_nbytes(src));
2025
+ static ggml_backend_buffer_type_t ggml_backend_metal_get_default_buffer_type(ggml_backend_t backend) {
2026
+ return ggml_backend_metal_buffer_type();
1709
2027
 
1710
2028
  UNUSED(backend);
1711
2029
  }
@@ -1717,32 +2035,43 @@ static void ggml_backend_metal_graph_compute(ggml_backend_t backend, struct ggml
1717
2035
  }
1718
2036
 
1719
2037
  static bool ggml_backend_metal_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
1720
- return true;
2038
+ return ggml_metal_supports_op(op);
2039
+
1721
2040
  UNUSED(backend);
1722
- UNUSED(op);
1723
2041
  }
1724
2042
 
1725
2043
  static struct ggml_backend_i metal_backend_i = {
1726
- /* .get_name = */ ggml_backend_metal_name,
1727
- /* .free = */ ggml_backend_metal_free,
1728
- /* .alloc_buffer = */ ggml_backend_metal_alloc_buffer,
1729
- /* .get_alignment = */ ggml_backend_metal_get_alignment,
1730
- /* .set_tensor_async = */ ggml_backend_metal_set_tensor_async,
1731
- /* .get_tensor_async = */ ggml_backend_metal_get_tensor_async,
1732
- /* .synchronize = */ ggml_backend_metal_synchronize,
1733
- /* .cpy_tensor_from = */ ggml_backend_metal_cpy_tensor_from,
1734
- /* .cpy_tensor_to = */ ggml_backend_metal_cpy_tensor_to,
1735
- /* .graph_plan_create = */ NULL, // the metal implementation does not require creating graph plans atm
1736
- /* .graph_plan_free = */ NULL,
1737
- /* .graph_plan_compute = */ NULL,
1738
- /* .graph_compute = */ ggml_backend_metal_graph_compute,
1739
- /* .supports_op = */ ggml_backend_metal_supports_op,
2044
+ /* .get_name = */ ggml_backend_metal_name,
2045
+ /* .free = */ ggml_backend_metal_free,
2046
+ /* .get_default_buffer_type = */ ggml_backend_metal_get_default_buffer_type,
2047
+ /* .set_tensor_async = */ NULL,
2048
+ /* .get_tensor_async = */ NULL,
2049
+ /* .cpy_tensor_from_async = */ NULL,
2050
+ /* .cpy_tensor_to_async = */ NULL,
2051
+ /* .synchronize = */ ggml_backend_metal_synchronize,
2052
+ /* .graph_plan_create = */ NULL, // the metal implementation does not require creating graph plans atm
2053
+ /* .graph_plan_free = */ NULL,
2054
+ /* .graph_plan_compute = */ NULL,
2055
+ /* .graph_compute = */ ggml_backend_metal_graph_compute,
2056
+ /* .supports_op = */ ggml_backend_metal_supports_op,
1740
2057
  };
1741
2058
 
2059
+ // TODO: make a common log callback for all backends in ggml-backend
2060
+ static void ggml_backend_log_callback(enum ggml_log_level level, const char * msg, void * user_data) {
2061
+ fprintf(stderr, "%s", msg);
2062
+
2063
+ UNUSED(level);
2064
+ UNUSED(user_data);
2065
+ }
2066
+
1742
2067
  ggml_backend_t ggml_backend_metal_init(void) {
1743
- struct ggml_metal_context * ctx = malloc(sizeof(struct ggml_metal_context));
2068
+ ggml_metal_log_set_callback(ggml_backend_log_callback, NULL);
2069
+
2070
+ struct ggml_metal_context * ctx = ggml_metal_init(GGML_DEFAULT_N_THREADS);
1744
2071
 
1745
- ctx = ggml_metal_init(GGML_DEFAULT_N_THREADS);
2072
+ if (ctx == NULL) {
2073
+ return NULL;
2074
+ }
1746
2075
 
1747
2076
  ggml_backend_t metal_backend = malloc(sizeof(struct ggml_backend));
1748
2077
 
@@ -1759,7 +2088,26 @@ bool ggml_backend_is_metal(ggml_backend_t backend) {
1759
2088
  }
1760
2089
 
1761
2090
  void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
2091
+ GGML_ASSERT(ggml_backend_is_metal(backend));
2092
+
1762
2093
  struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
1763
2094
 
1764
2095
  ggml_metal_set_n_cb(ctx, n_cb);
1765
2096
  }
2097
+
2098
+ bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family) {
2099
+ GGML_ASSERT(ggml_backend_is_metal(backend));
2100
+
2101
+ struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
2102
+
2103
+ return [ctx->device supportsFamily:(MTLGPUFamilyApple1 + family - 1)];
2104
+ }
2105
+
2106
+ ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data); // silence warning
2107
+
2108
+ ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data) {
2109
+ return ggml_backend_metal_init();
2110
+
2111
+ GGML_UNUSED(params);
2112
+ GGML_UNUSED(user_data);
2113
+ }