llama_cpp 0.9.5 → 0.10.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -62,6 +62,8 @@ struct ggml_metal_context {
62
62
  GGML_METAL_DECL_KERNEL(add_row); // TODO: avoid this extra kernel, instead extend the "add" kernel to support broadcast
63
63
  GGML_METAL_DECL_KERNEL(mul);
64
64
  GGML_METAL_DECL_KERNEL(mul_row); // TODO: avoid this extra kernel, instead extend the "mul" kernel to support broadcast
65
+ GGML_METAL_DECL_KERNEL(div);
66
+ GGML_METAL_DECL_KERNEL(div_row);
65
67
  GGML_METAL_DECL_KERNEL(scale);
66
68
  GGML_METAL_DECL_KERNEL(scale_4);
67
69
  GGML_METAL_DECL_KERNEL(silu);
@@ -112,15 +114,35 @@ struct ggml_metal_context {
112
114
  GGML_METAL_DECL_KERNEL(mul_mm_q4_K_f32);
113
115
  GGML_METAL_DECL_KERNEL(mul_mm_q5_K_f32);
114
116
  GGML_METAL_DECL_KERNEL(mul_mm_q6_K_f32);
117
+ GGML_METAL_DECL_KERNEL(mul_mm_id_f32_f32);
118
+ GGML_METAL_DECL_KERNEL(mul_mm_id_f16_f32);
119
+ GGML_METAL_DECL_KERNEL(mul_mm_id_q4_0_f32);
120
+ GGML_METAL_DECL_KERNEL(mul_mm_id_q4_1_f32);
121
+ GGML_METAL_DECL_KERNEL(mul_mm_id_q5_0_f32);
122
+ GGML_METAL_DECL_KERNEL(mul_mm_id_q5_1_f32);
123
+ GGML_METAL_DECL_KERNEL(mul_mm_id_q8_0_f32);
124
+ GGML_METAL_DECL_KERNEL(mul_mm_id_q2_K_f32);
125
+ GGML_METAL_DECL_KERNEL(mul_mm_id_q3_K_f32);
126
+ GGML_METAL_DECL_KERNEL(mul_mm_id_q4_K_f32);
127
+ GGML_METAL_DECL_KERNEL(mul_mm_id_q5_K_f32);
128
+ GGML_METAL_DECL_KERNEL(mul_mm_id_q6_K_f32);
115
129
  GGML_METAL_DECL_KERNEL(rope_f32);
116
130
  GGML_METAL_DECL_KERNEL(rope_f16);
117
131
  GGML_METAL_DECL_KERNEL(alibi_f32);
118
132
  GGML_METAL_DECL_KERNEL(im2col_f16);
133
+ GGML_METAL_DECL_KERNEL(argsort_f32_i32_asc);
134
+ GGML_METAL_DECL_KERNEL(argsort_f32_i32_desc);
119
135
  GGML_METAL_DECL_KERNEL(cpy_f32_f16);
120
136
  GGML_METAL_DECL_KERNEL(cpy_f32_f32);
137
+ GGML_METAL_DECL_KERNEL(cpy_f32_q8_0);
138
+ GGML_METAL_DECL_KERNEL(cpy_f32_q4_0);
139
+ GGML_METAL_DECL_KERNEL(cpy_f32_q4_1);
140
+ //GGML_METAL_DECL_KERNEL(cpy_f32_q5_0);
141
+ //GGML_METAL_DECL_KERNEL(cpy_f32_q5_1);
121
142
  GGML_METAL_DECL_KERNEL(cpy_f16_f16);
122
143
  GGML_METAL_DECL_KERNEL(concat);
123
144
  GGML_METAL_DECL_KERNEL(sqr);
145
+ GGML_METAL_DECL_KERNEL(sum_rows);
124
146
 
125
147
  #undef GGML_METAL_DECL_KERNEL
126
148
  };
@@ -164,12 +186,10 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
164
186
  }
165
187
  }
166
188
 
167
-
168
-
169
189
  struct ggml_metal_context * ggml_metal_init(int n_cb) {
170
190
  GGML_METAL_LOG_INFO("%s: allocating\n", __func__);
171
191
 
172
- id <MTLDevice> device;
192
+ id<MTLDevice> device;
173
193
  NSString * s;
174
194
 
175
195
  #if TARGET_OS_OSX
@@ -215,6 +235,9 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
215
235
 
216
236
  NSString * sourcePath;
217
237
  NSString * ggmlMetalPathResources = [[NSProcessInfo processInfo].environment objectForKey:@"GGML_METAL_PATH_RESOURCES"];
238
+
239
+ GGML_METAL_LOG_INFO("%s: GGML_METAL_PATH_RESOURCES = %s\n", __func__, ggmlMetalPathResources ? [ggmlMetalPathResources UTF8String] : "nil");
240
+
218
241
  if (ggmlMetalPathResources) {
219
242
  sourcePath = [ggmlMetalPathResources stringByAppendingPathComponent:@"ggml-metal.metal"];
220
243
  } else {
@@ -245,6 +268,29 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
245
268
  }
246
269
  }
247
270
 
271
+ #if TARGET_OS_OSX
272
+ // print MTL GPU family:
273
+ GGML_METAL_LOG_INFO("%s: GPU name: %s\n", __func__, [[ctx->device name] UTF8String]);
274
+
275
+ // determine max supported GPU family
276
+ // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
277
+ // https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
278
+ for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) {
279
+ if ([ctx->device supportsFamily:i]) {
280
+ GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - (int) MTLGPUFamilyApple1 + 1, i);
281
+ break;
282
+ }
283
+ }
284
+
285
+ GGML_METAL_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
286
+ GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1e6);
287
+ if (ctx->device.maxTransferRate != 0) {
288
+ GGML_METAL_LOG_INFO("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1e6);
289
+ } else {
290
+ GGML_METAL_LOG_INFO("%s: maxTransferRate = built-in GPU\n", __func__);
291
+ }
292
+ #endif
293
+
248
294
  // load kernels
249
295
  {
250
296
  NSError * error = nil;
@@ -266,6 +312,8 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
266
312
  GGML_METAL_ADD_KERNEL(add_row);
267
313
  GGML_METAL_ADD_KERNEL(mul);
268
314
  GGML_METAL_ADD_KERNEL(mul_row);
315
+ GGML_METAL_ADD_KERNEL(div);
316
+ GGML_METAL_ADD_KERNEL(div_row);
269
317
  GGML_METAL_ADD_KERNEL(scale);
270
318
  GGML_METAL_ADD_KERNEL(scale_4);
271
319
  GGML_METAL_ADD_KERNEL(silu);
@@ -317,43 +365,40 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
317
365
  GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32);
318
366
  GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32);
319
367
  GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32);
368
+ GGML_METAL_ADD_KERNEL(mul_mm_id_f32_f32);
369
+ GGML_METAL_ADD_KERNEL(mul_mm_id_f16_f32);
370
+ GGML_METAL_ADD_KERNEL(mul_mm_id_q4_0_f32);
371
+ GGML_METAL_ADD_KERNEL(mul_mm_id_q4_1_f32);
372
+ GGML_METAL_ADD_KERNEL(mul_mm_id_q5_0_f32);
373
+ GGML_METAL_ADD_KERNEL(mul_mm_id_q5_1_f32);
374
+ GGML_METAL_ADD_KERNEL(mul_mm_id_q8_0_f32);
375
+ GGML_METAL_ADD_KERNEL(mul_mm_id_q2_K_f32);
376
+ GGML_METAL_ADD_KERNEL(mul_mm_id_q3_K_f32);
377
+ GGML_METAL_ADD_KERNEL(mul_mm_id_q4_K_f32);
378
+ GGML_METAL_ADD_KERNEL(mul_mm_id_q5_K_f32);
379
+ GGML_METAL_ADD_KERNEL(mul_mm_id_q6_K_f32);
320
380
  }
321
381
  GGML_METAL_ADD_KERNEL(rope_f32);
322
382
  GGML_METAL_ADD_KERNEL(rope_f16);
323
383
  GGML_METAL_ADD_KERNEL(alibi_f32);
324
384
  GGML_METAL_ADD_KERNEL(im2col_f16);
385
+ GGML_METAL_ADD_KERNEL(argsort_f32_i32_asc);
386
+ GGML_METAL_ADD_KERNEL(argsort_f32_i32_desc);
325
387
  GGML_METAL_ADD_KERNEL(cpy_f32_f16);
326
388
  GGML_METAL_ADD_KERNEL(cpy_f32_f32);
389
+ GGML_METAL_ADD_KERNEL(cpy_f32_q8_0);
390
+ GGML_METAL_ADD_KERNEL(cpy_f32_q4_0);
391
+ GGML_METAL_ADD_KERNEL(cpy_f32_q4_1);
392
+ //GGML_METAL_ADD_KERNEL(cpy_f32_q5_0);
393
+ //GGML_METAL_ADD_KERNEL(cpy_f32_q5_1);
327
394
  GGML_METAL_ADD_KERNEL(cpy_f16_f16);
328
395
  GGML_METAL_ADD_KERNEL(concat);
329
396
  GGML_METAL_ADD_KERNEL(sqr);
397
+ GGML_METAL_ADD_KERNEL(sum_rows);
330
398
 
331
399
  #undef GGML_METAL_ADD_KERNEL
332
400
  }
333
401
 
334
- #if TARGET_OS_OSX
335
- // print MTL GPU family:
336
- GGML_METAL_LOG_INFO("%s: GPU name: %s\n", __func__, [[ctx->device name] UTF8String]);
337
-
338
- // determine max supported GPU family
339
- // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
340
- // https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
341
- for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) {
342
- if ([ctx->device supportsFamily:i]) {
343
- GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - (int) MTLGPUFamilyApple1 + 1, i);
344
- break;
345
- }
346
- }
347
-
348
- GGML_METAL_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
349
- GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MiB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
350
- if (ctx->device.maxTransferRate != 0) {
351
- GGML_METAL_LOG_INFO("%s: maxTransferRate = %8.2f MiB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
352
- } else {
353
- GGML_METAL_LOG_INFO("%s: maxTransferRate = built-in GPU\n", __func__);
354
- }
355
- #endif
356
-
357
402
  return ctx;
358
403
  }
359
404
 
@@ -367,6 +412,8 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
367
412
  GGML_METAL_DEL_KERNEL(add_row);
368
413
  GGML_METAL_DEL_KERNEL(mul);
369
414
  GGML_METAL_DEL_KERNEL(mul_row);
415
+ GGML_METAL_DEL_KERNEL(div);
416
+ GGML_METAL_DEL_KERNEL(div_row);
370
417
  GGML_METAL_DEL_KERNEL(scale);
371
418
  GGML_METAL_DEL_KERNEL(scale_4);
372
419
  GGML_METAL_DEL_KERNEL(silu);
@@ -418,16 +465,36 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
418
465
  GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32);
419
466
  GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32);
420
467
  GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32);
468
+ GGML_METAL_DEL_KERNEL(mul_mm_id_f32_f32);
469
+ GGML_METAL_DEL_KERNEL(mul_mm_id_f16_f32);
470
+ GGML_METAL_DEL_KERNEL(mul_mm_id_q4_0_f32);
471
+ GGML_METAL_DEL_KERNEL(mul_mm_id_q4_1_f32);
472
+ GGML_METAL_DEL_KERNEL(mul_mm_id_q5_0_f32);
473
+ GGML_METAL_DEL_KERNEL(mul_mm_id_q5_1_f32);
474
+ GGML_METAL_DEL_KERNEL(mul_mm_id_q8_0_f32);
475
+ GGML_METAL_DEL_KERNEL(mul_mm_id_q2_K_f32);
476
+ GGML_METAL_DEL_KERNEL(mul_mm_id_q3_K_f32);
477
+ GGML_METAL_DEL_KERNEL(mul_mm_id_q4_K_f32);
478
+ GGML_METAL_DEL_KERNEL(mul_mm_id_q5_K_f32);
479
+ GGML_METAL_DEL_KERNEL(mul_mm_id_q6_K_f32);
421
480
  }
422
481
  GGML_METAL_DEL_KERNEL(rope_f32);
423
482
  GGML_METAL_DEL_KERNEL(rope_f16);
424
483
  GGML_METAL_DEL_KERNEL(alibi_f32);
425
484
  GGML_METAL_DEL_KERNEL(im2col_f16);
485
+ GGML_METAL_DEL_KERNEL(argsort_f32_i32_asc);
486
+ GGML_METAL_DEL_KERNEL(argsort_f32_i32_desc);
426
487
  GGML_METAL_DEL_KERNEL(cpy_f32_f16);
427
488
  GGML_METAL_DEL_KERNEL(cpy_f32_f32);
489
+ GGML_METAL_DEL_KERNEL(cpy_f32_q8_0);
490
+ GGML_METAL_DEL_KERNEL(cpy_f32_q4_0);
491
+ GGML_METAL_DEL_KERNEL(cpy_f32_q4_1);
492
+ //GGML_METAL_DEL_KERNEL(cpy_f32_q5_0);
493
+ //GGML_METAL_DEL_KERNEL(cpy_f32_q5_1);
428
494
  GGML_METAL_DEL_KERNEL(cpy_f16_f16);
429
495
  GGML_METAL_DEL_KERNEL(concat);
430
496
  GGML_METAL_DEL_KERNEL(sqr);
497
+ GGML_METAL_DEL_KERNEL(sum_rows);
431
498
 
432
499
  #undef GGML_METAL_DEL_KERNEL
433
500
 
@@ -471,6 +538,13 @@ int * ggml_metal_get_concur_list(struct ggml_metal_context * ctx) {
471
538
  return ctx->concur_list;
472
539
  }
473
540
 
541
+ // temporarily defined here for compatibility between ggml-backend and the old API
542
+ struct ggml_backend_metal_buffer_context {
543
+ void * data;
544
+
545
+ id<MTLBuffer> metal;
546
+ };
547
+
474
548
  // finds the Metal buffer that contains the tensor data on the GPU device
475
549
  // the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the
476
550
  // Metal buffer based on the host memory pointer
@@ -480,8 +554,17 @@ static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, stru
480
554
 
481
555
  const int64_t tsize = ggml_nbytes(t);
482
556
 
483
- if (t->buffer && t->buffer->backend && t->buffer->backend->context) {
484
- ctx = t->buffer->backend->context;
557
+ // compatibility with ggml-backend
558
+ if (t->buffer && t->buffer->buft == ggml_backend_metal_buffer_type()) {
559
+ struct ggml_backend_metal_buffer_context * buf_ctx = (struct ggml_backend_metal_buffer_context *) t->buffer->context;
560
+
561
+ const int64_t ioffs = (int64_t) t->data - (int64_t) buf_ctx->data;
562
+
563
+ GGML_ASSERT(ioffs >= 0 && ioffs + tsize <= (int64_t) t->buffer->size);
564
+
565
+ *offs = (size_t) ioffs;
566
+
567
+ return buf_ctx->metal;
485
568
  }
486
569
 
487
570
  // find the view that contains the tensor fully
@@ -706,6 +789,51 @@ void ggml_metal_graph_find_concurrency(
706
789
  }
707
790
  }
708
791
 
792
+ static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
793
+ switch (op->op) {
794
+ case GGML_OP_UNARY:
795
+ switch (ggml_get_unary_op(op)) {
796
+ case GGML_UNARY_OP_SILU:
797
+ case GGML_UNARY_OP_RELU:
798
+ case GGML_UNARY_OP_GELU:
799
+ return true;
800
+ default:
801
+ return false;
802
+ }
803
+ case GGML_OP_NONE:
804
+ case GGML_OP_RESHAPE:
805
+ case GGML_OP_VIEW:
806
+ case GGML_OP_TRANSPOSE:
807
+ case GGML_OP_PERMUTE:
808
+ case GGML_OP_CONCAT:
809
+ case GGML_OP_ADD:
810
+ case GGML_OP_MUL:
811
+ case GGML_OP_DIV:
812
+ case GGML_OP_SCALE:
813
+ case GGML_OP_SQR:
814
+ case GGML_OP_SUM_ROWS:
815
+ case GGML_OP_SOFT_MAX:
816
+ case GGML_OP_RMS_NORM:
817
+ case GGML_OP_NORM:
818
+ case GGML_OP_ALIBI:
819
+ case GGML_OP_ROPE:
820
+ case GGML_OP_IM2COL:
821
+ case GGML_OP_ARGSORT:
822
+ case GGML_OP_DUP:
823
+ case GGML_OP_CPY:
824
+ case GGML_OP_CONT:
825
+ case GGML_OP_MUL_MAT:
826
+ case GGML_OP_MUL_MAT_ID:
827
+ return true;
828
+ case GGML_OP_DIAG_MASK_INF:
829
+ case GGML_OP_GET_ROWS:
830
+ {
831
+ return op->ne[0] % 4 == 0;
832
+ }
833
+ default:
834
+ return false;
835
+ }
836
+ }
709
837
  void ggml_metal_graph_compute(
710
838
  struct ggml_metal_context * ctx,
711
839
  struct ggml_cgraph * gf) {
@@ -776,6 +904,8 @@ void ggml_metal_graph_compute(
776
904
  } break;
777
905
  }
778
906
 
907
+ GGML_ASSERT(ggml_metal_supports_op(dst));
908
+
779
909
  const int64_t ne00 = src0 ? src0->ne[0] : 0;
780
910
  const int64_t ne01 = src0 ? src0->ne[1] : 0;
781
911
  const int64_t ne02 = src0 ? src0->ne[2] : 0;
@@ -868,6 +998,8 @@ void ggml_metal_graph_compute(
868
998
  [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
869
999
  } break;
870
1000
  case GGML_OP_ADD:
1001
+ case GGML_OP_MUL:
1002
+ case GGML_OP_DIV:
871
1003
  {
872
1004
  GGML_ASSERT(ggml_is_contiguous(src0));
873
1005
  GGML_ASSERT(ggml_is_contiguous(src1));
@@ -881,11 +1013,21 @@ void ggml_metal_graph_compute(
881
1013
  GGML_ASSERT(ne11 == 1);
882
1014
 
883
1015
  nb = ne00 / 4;
884
- [encoder setComputePipelineState:ctx->pipeline_add_row];
1016
+ switch (dst->op) {
1017
+ case GGML_OP_ADD: [encoder setComputePipelineState:ctx->pipeline_add_row]; break;
1018
+ case GGML_OP_MUL: [encoder setComputePipelineState:ctx->pipeline_mul_row]; break;
1019
+ case GGML_OP_DIV: [encoder setComputePipelineState:ctx->pipeline_div_row]; break;
1020
+ default: GGML_ASSERT(false);
1021
+ }
885
1022
 
886
1023
  bcast_row = true;
887
1024
  } else {
888
- [encoder setComputePipelineState:ctx->pipeline_add];
1025
+ switch (dst->op) {
1026
+ case GGML_OP_ADD: [encoder setComputePipelineState:ctx->pipeline_add]; break;
1027
+ case GGML_OP_MUL: [encoder setComputePipelineState:ctx->pipeline_mul]; break;
1028
+ case GGML_OP_DIV: [encoder setComputePipelineState:ctx->pipeline_div]; break;
1029
+ default: GGML_ASSERT(false);
1030
+ }
889
1031
  }
890
1032
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
891
1033
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
@@ -926,31 +1068,6 @@ void ggml_metal_graph_compute(
926
1068
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
927
1069
  }
928
1070
  } break;
929
- case GGML_OP_MUL:
930
- {
931
- GGML_ASSERT(ggml_is_contiguous(src0));
932
- GGML_ASSERT(ggml_is_contiguous(src1));
933
-
934
- // utilize float4
935
- GGML_ASSERT(ne00 % 4 == 0);
936
- const int64_t nb = ne00/4;
937
-
938
- if (ggml_nelements(src1) == ne10) {
939
- // src1 is a row
940
- GGML_ASSERT(ne11 == 1);
941
- [encoder setComputePipelineState:ctx->pipeline_mul_row];
942
- } else {
943
- [encoder setComputePipelineState:ctx->pipeline_mul];
944
- }
945
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
946
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
947
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
948
- [encoder setBytes:&nb length:sizeof(nb) atIndex:3];
949
-
950
- const int64_t n = ggml_nelements(dst)/4;
951
-
952
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
953
- } break;
954
1071
  case GGML_OP_SCALE:
955
1072
  {
956
1073
  GGML_ASSERT(ggml_is_contiguous(src0));
@@ -1023,6 +1140,40 @@ void ggml_metal_graph_compute(
1023
1140
  const int64_t n = ggml_nelements(dst);
1024
1141
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1025
1142
  } break;
1143
+ case GGML_OP_SUM_ROWS:
1144
+ {
1145
+ GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
1146
+
1147
+ [encoder setComputePipelineState:ctx->pipeline_sum_rows];
1148
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1149
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1150
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
1151
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
1152
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
1153
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
1154
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
1155
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
1156
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
1157
+ [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
1158
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
1159
+ [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
1160
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
1161
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
1162
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
1163
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
1164
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
1165
+ [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17];
1166
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:18];
1167
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:19];
1168
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:20];
1169
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:21];
1170
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:22];
1171
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:23];
1172
+ [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:24];
1173
+ [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:25];
1174
+
1175
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1176
+ } break;
1026
1177
  case GGML_OP_SOFT_MAX:
1027
1178
  {
1028
1179
  int nth = 32; // SIMD width
@@ -1077,9 +1228,13 @@ void ggml_metal_graph_compute(
1077
1228
  case GGML_OP_MUL_MAT:
1078
1229
  {
1079
1230
  GGML_ASSERT(ne00 == ne10);
1080
- GGML_ASSERT(ne03 == ne13);
1081
1231
 
1082
- const uint gqa = ne12/ne02;
1232
+ // TODO: assert that dim2 and dim3 are contiguous
1233
+ GGML_ASSERT(ne12 % ne02 == 0);
1234
+ GGML_ASSERT(ne13 % ne03 == 0);
1235
+
1236
+ const uint r2 = ne12/ne02;
1237
+ const uint r3 = ne13/ne03;
1083
1238
 
1084
1239
  // find the break-even point where the matrix-matrix kernel becomes more efficient compared
1085
1240
  // to the matrix-vector kernel
@@ -1114,7 +1269,7 @@ void ggml_metal_graph_compute(
1114
1269
  !ggml_is_transposed(src1) &&
1115
1270
  src1t == GGML_TYPE_F32 &&
1116
1271
  ne00 % 32 == 0 && ne00 >= 64 &&
1117
- ne11 > ne11_mm_min) {
1272
+ (ne11 > ne11_mm_min || (ggml_is_quantized(src0t) && ne12 > 1))) {
1118
1273
  //printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
1119
1274
  switch (src0->type) {
1120
1275
  case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f32_f32]; break;
@@ -1144,9 +1299,10 @@ void ggml_metal_graph_compute(
1144
1299
  [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10];
1145
1300
  [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11];
1146
1301
  [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12];
1147
- [encoder setBytes:&gqa length:sizeof(gqa) atIndex:13];
1302
+ [encoder setBytes:&r2 length:sizeof(r2) atIndex:13];
1303
+ [encoder setBytes:&r3 length:sizeof(r3) atIndex:14];
1148
1304
  [encoder setThreadgroupMemoryLength:8192 atIndex:0];
1149
- [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
1305
+ [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
1150
1306
  } else {
1151
1307
  int nth0 = 32;
1152
1308
  int nth1 = 1;
@@ -1182,90 +1338,60 @@ void ggml_metal_graph_compute(
1182
1338
  } break;
1183
1339
  case GGML_TYPE_Q4_0:
1184
1340
  {
1185
- GGML_ASSERT(ne02 == 1);
1186
- GGML_ASSERT(ne12 == 1);
1187
-
1188
1341
  nth0 = 8;
1189
1342
  nth1 = 8;
1190
1343
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_0_f32];
1191
1344
  } break;
1192
1345
  case GGML_TYPE_Q4_1:
1193
1346
  {
1194
- GGML_ASSERT(ne02 == 1);
1195
- GGML_ASSERT(ne12 == 1);
1196
-
1197
1347
  nth0 = 8;
1198
1348
  nth1 = 8;
1199
1349
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_1_f32];
1200
1350
  } break;
1201
1351
  case GGML_TYPE_Q5_0:
1202
1352
  {
1203
- GGML_ASSERT(ne02 == 1);
1204
- GGML_ASSERT(ne12 == 1);
1205
-
1206
1353
  nth0 = 8;
1207
1354
  nth1 = 8;
1208
1355
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_0_f32];
1209
1356
  } break;
1210
1357
  case GGML_TYPE_Q5_1:
1211
1358
  {
1212
- GGML_ASSERT(ne02 == 1);
1213
- GGML_ASSERT(ne12 == 1);
1214
-
1215
1359
  nth0 = 8;
1216
1360
  nth1 = 8;
1217
1361
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_1_f32];
1218
1362
  } break;
1219
1363
  case GGML_TYPE_Q8_0:
1220
1364
  {
1221
- GGML_ASSERT(ne02 == 1);
1222
- GGML_ASSERT(ne12 == 1);
1223
-
1224
1365
  nth0 = 8;
1225
1366
  nth1 = 8;
1226
1367
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_q8_0_f32];
1227
1368
  } break;
1228
1369
  case GGML_TYPE_Q2_K:
1229
1370
  {
1230
- GGML_ASSERT(ne02 == 1);
1231
- GGML_ASSERT(ne12 == 1);
1232
-
1233
1371
  nth0 = 2;
1234
1372
  nth1 = 32;
1235
1373
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_q2_K_f32];
1236
1374
  } break;
1237
1375
  case GGML_TYPE_Q3_K:
1238
1376
  {
1239
- GGML_ASSERT(ne02 == 1);
1240
- GGML_ASSERT(ne12 == 1);
1241
-
1242
1377
  nth0 = 2;
1243
1378
  nth1 = 32;
1244
1379
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_q3_K_f32];
1245
1380
  } break;
1246
1381
  case GGML_TYPE_Q4_K:
1247
1382
  {
1248
- GGML_ASSERT(ne02 == 1);
1249
- GGML_ASSERT(ne12 == 1);
1250
-
1251
1383
  nth0 = 4; //1;
1252
1384
  nth1 = 8; //32;
1253
1385
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_K_f32];
1254
1386
  } break;
1255
1387
  case GGML_TYPE_Q5_K:
1256
1388
  {
1257
- GGML_ASSERT(ne02 == 1);
1258
- GGML_ASSERT(ne12 == 1);
1259
-
1260
1389
  nth0 = 2;
1261
1390
  nth1 = 32;
1262
1391
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_K_f32];
1263
1392
  } break;
1264
1393
  case GGML_TYPE_Q6_K:
1265
1394
  {
1266
- GGML_ASSERT(ne02 == 1);
1267
- GGML_ASSERT(ne12 == 1);
1268
-
1269
1395
  nth0 = 2;
1270
1396
  nth1 = 32;
1271
1397
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_q6_K_f32];
@@ -1294,32 +1420,125 @@ void ggml_metal_graph_compute(
1294
1420
  [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14];
1295
1421
  [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15];
1296
1422
  [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
1297
- [encoder setBytes:&gqa length:sizeof(gqa) atIndex:17];
1423
+ [encoder setBytes:&r2 length:sizeof(r2) atIndex:17];
1424
+ [encoder setBytes:&r3 length:sizeof(r3) atIndex:18];
1298
1425
 
1299
1426
  if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
1300
1427
  src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 ||
1301
1428
  src0t == GGML_TYPE_Q2_K) { // || src0t == GGML_TYPE_Q4_K) {
1302
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1429
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1303
1430
  }
1304
1431
  else if (src0t == GGML_TYPE_Q4_K) {
1305
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1432
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1306
1433
  }
1307
1434
  else if (src0t == GGML_TYPE_Q3_K) {
1308
1435
  #ifdef GGML_QKK_64
1309
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1436
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1310
1437
  #else
1311
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1438
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1312
1439
  #endif
1313
1440
  }
1314
1441
  else if (src0t == GGML_TYPE_Q5_K) {
1315
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1442
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1316
1443
  }
1317
1444
  else if (src0t == GGML_TYPE_Q6_K) {
1318
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1445
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1319
1446
  } else {
1320
1447
  int64_t ny = (ne11 + nrows - 1)/nrows;
1321
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1448
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1449
+ }
1450
+ }
1451
+ } break;
1452
+ case GGML_OP_MUL_MAT_ID:
1453
+ {
1454
+ //GGML_ASSERT(ne00 == ne10);
1455
+ //GGML_ASSERT(ne03 == ne13);
1456
+
1457
+ GGML_ASSERT(src0t == GGML_TYPE_I32);
1458
+
1459
+ const int n_as = ne00;
1460
+
1461
+ // TODO: make this more general
1462
+ GGML_ASSERT(n_as <= 8);
1463
+
1464
+ struct ggml_tensor * src2 = gf->nodes[i]->src[2];
1465
+
1466
+ const int64_t ne20 = src2 ? src2->ne[0] : 0;
1467
+ const int64_t ne21 = src2 ? src2->ne[1] : 0;
1468
+ const int64_t ne22 = src2 ? src2->ne[2] : 0;
1469
+ const int64_t ne23 = src2 ? src2->ne[3] : 0; GGML_UNUSED(ne23);
1470
+
1471
+ const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20);
1472
+ const uint64_t nb21 = src2 ? src2->nb[1] : 0;
1473
+ const uint64_t nb22 = src2 ? src2->nb[2] : 0;
1474
+ const uint64_t nb23 = src2 ? src2->nb[3] : 0; GGML_UNUSED(nb23);
1475
+
1476
+ const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t);
1477
+
1478
+ GGML_ASSERT(!ggml_is_transposed(src2));
1479
+ GGML_ASSERT(!ggml_is_transposed(src1));
1480
+
1481
+ GGML_ASSERT(ne20 % 32 == 0);
1482
+ // !!!!!!!!! TODO: this assert is probably required but not sure!
1483
+ //GGML_ASSERT(ne20 >= 64);
1484
+ GGML_ASSERT(src1t == GGML_TYPE_F32);
1485
+
1486
+ const uint r2 = ne12/ne22;
1487
+ const uint r3 = ne13/ne23;
1488
+
1489
+ // find the break-even point where the matrix-matrix kernel becomes more efficient compared
1490
+ // to the matrix-vector kernel
1491
+ int ne11_mm_min = 0;
1492
+
1493
+ const int idx = ((int32_t *) dst->op_params)[0];
1494
+
1495
+ // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
1496
+ // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
1497
+ if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
1498
+ ne11 > ne11_mm_min) {
1499
+ switch (src2->type) {
1500
+ case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f32_f32]; break;
1501
+ case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f16_f32]; break;
1502
+ case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q4_0_f32]; break;
1503
+ case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q4_1_f32]; break;
1504
+ case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_0_f32]; break;
1505
+ case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_1_f32]; break;
1506
+ case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q8_0_f32]; break;
1507
+ case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q2_K_f32]; break;
1508
+ case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q3_K_f32]; break;
1509
+ case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q4_K_f32]; break;
1510
+ case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_K_f32]; break;
1511
+ case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q6_K_f32]; break;
1512
+ default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
1322
1513
  }
1514
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1515
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1516
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1517
+ [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:3];
1518
+ [encoder setBytes:&ne22 length:sizeof(ne22) atIndex:4];
1519
+ [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:5];
1520
+ [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:6];
1521
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
1522
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:8];
1523
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:9];
1524
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10];
1525
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11];
1526
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12];
1527
+ [encoder setBytes:&r2 length:sizeof(r2) atIndex:13];
1528
+ [encoder setBytes:&r3 length:sizeof(r3) atIndex:14];
1529
+ [encoder setBytes:&idx length:sizeof(idx) atIndex:15];
1530
+ // TODO: how to make this an array? read Metal docs
1531
+ for (int j = 0; j < n_as; ++j) {
1532
+ struct ggml_tensor * src_cur = dst->src[2 + j];
1533
+
1534
+ size_t offs_src_cur = 0;
1535
+ id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
1536
+
1537
+ [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:16 + j];
1538
+ }
1539
+
1540
+ [encoder setThreadgroupMemoryLength:8192 atIndex:0];
1541
+ [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne21 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
1323
1542
  }
1324
1543
  } break;
1325
1544
  case GGML_OP_GET_ROWS:
@@ -1545,18 +1764,48 @@ void ggml_metal_graph_compute(
1545
1764
 
1546
1765
  [encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
1547
1766
  } break;
1767
+ case GGML_OP_ARGSORT:
1768
+ {
1769
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
1770
+ GGML_ASSERT( dst->type == GGML_TYPE_I32);
1771
+
1772
+ const int nrows = ggml_nrows(src0);
1773
+
1774
+ enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
1775
+
1776
+ switch (order) {
1777
+ case GGML_SORT_ASC: [encoder setComputePipelineState:ctx->pipeline_argsort_f32_i32_asc]; break;
1778
+ case GGML_SORT_DESC: [encoder setComputePipelineState:ctx->pipeline_argsort_f32_i32_desc]; break;
1779
+ default: GGML_ASSERT(false);
1780
+ };
1781
+
1782
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1783
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1784
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1785
+
1786
+ [encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00, 1, 1)];
1787
+ } break;
1548
1788
  case GGML_OP_DUP:
1549
1789
  case GGML_OP_CPY:
1550
1790
  case GGML_OP_CONT:
1551
1791
  {
1552
- const int nth = MIN(1024, ne00);
1792
+ GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
1793
+
1794
+ int nth = MIN(1024, ne00/ggml_blck_size(src0->type));
1553
1795
 
1554
1796
  switch (src0t) {
1555
1797
  case GGML_TYPE_F32:
1556
1798
  {
1799
+ GGML_ASSERT(ne0 % ggml_blck_size(dst->type) == 0);
1800
+
1557
1801
  switch (dstt) {
1558
- case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f16]; break;
1559
- case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32]; break;
1802
+ case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f16]; break;
1803
+ case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32]; break;
1804
+ case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q8_0]; break;
1805
+ case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q4_0]; break;
1806
+ case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q4_1]; break;
1807
+ //case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q5_0]; break;
1808
+ //case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q5_1]; break;
1560
1809
  default: GGML_ASSERT(false && "not implemented");
1561
1810
  };
1562
1811
  } break;
@@ -1631,81 +1880,150 @@ void ggml_metal_graph_compute(
1631
1880
 
1632
1881
  // backend interface
1633
1882
 
1634
- static const char * ggml_backend_metal_name(ggml_backend_t backend) {
1635
- return "Metal";
1883
+ static id<MTLDevice> g_backend_device = nil;
1884
+ static int g_backend_device_ref_count = 0;
1636
1885
 
1637
- UNUSED(backend);
1886
+ static id<MTLDevice> ggml_backend_metal_get_device(void) {
1887
+ if (g_backend_device == nil) {
1888
+ g_backend_device = MTLCreateSystemDefaultDevice();
1889
+ }
1890
+
1891
+ g_backend_device_ref_count++;
1892
+
1893
+ return g_backend_device;
1638
1894
  }
1639
1895
 
1640
- static void ggml_backend_metal_free(ggml_backend_t backend) {
1641
- struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
1642
- ggml_metal_free(ctx);
1643
- free(backend);
1896
+ static void ggml_backend_metal_free_device(void) {
1897
+ assert(g_backend_device_ref_count > 0);
1898
+
1899
+ g_backend_device_ref_count--;
1900
+
1901
+ if (g_backend_device_ref_count == 0) {
1902
+ [g_backend_device release];
1903
+ g_backend_device = nil;
1904
+ }
1644
1905
  }
1645
1906
 
1646
1907
  static void * ggml_backend_metal_buffer_get_base(ggml_backend_buffer_t buffer) {
1647
- return (void *)buffer->context;
1908
+ struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context;
1909
+
1910
+ return ctx->data;
1648
1911
  }
1649
1912
 
1650
1913
  static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer) {
1651
- free(buffer->context);
1914
+ struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context;
1915
+
1916
+ [ctx->metal release];
1917
+ ggml_backend_metal_free_device();
1918
+
1919
+ free(ctx->data);
1920
+ free(ctx);
1921
+
1922
+ UNUSED(buffer);
1923
+ }
1924
+
1925
+ static void ggml_backend_metal_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
1926
+ GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
1927
+ GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
1928
+
1929
+ memcpy((char *)tensor->data + offset, data, size);
1930
+
1931
+ UNUSED(buffer);
1932
+ }
1933
+
1934
+ static void ggml_backend_metal_buffer_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
1935
+ GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds");
1936
+ GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
1937
+
1938
+ memcpy(data, (const char *)tensor->data + offset, size);
1939
+
1940
+ UNUSED(buffer);
1941
+ }
1942
+
1943
+ static void ggml_backend_metal_buffer_cpy_tensor_from(ggml_backend_buffer_t buffer, struct ggml_tensor * src, struct ggml_tensor * dst) {
1944
+ ggml_backend_tensor_get(src, dst->data, 0, ggml_nbytes(src));
1945
+
1946
+ UNUSED(buffer);
1947
+ }
1948
+
1949
+ static void ggml_backend_metal_buffer_cpy_tensor_to(ggml_backend_buffer_t buffer, struct ggml_tensor * src, struct ggml_tensor * dst) {
1950
+ ggml_backend_tensor_set(dst, src->data, 0, ggml_nbytes(src));
1951
+
1652
1952
  UNUSED(buffer);
1653
1953
  }
1654
1954
 
1655
1955
  static struct ggml_backend_buffer_i metal_backend_buffer_i = {
1656
- /* .free_buffer = */ ggml_backend_metal_buffer_free_buffer,
1657
- /* .get_base = */ ggml_backend_metal_buffer_get_base,
1658
- /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
1659
- /* .init_tensor = */ NULL, // no initialization required
1660
- /* .free_tensor = */ NULL, // no cleanup required
1956
+ /* .free_buffer = */ ggml_backend_metal_buffer_free_buffer,
1957
+ /* .get_base = */ ggml_backend_metal_buffer_get_base,
1958
+ /* .init_tensor = */ NULL,
1959
+ /* .set_tensor = */ ggml_backend_metal_buffer_set_tensor,
1960
+ /* .get_tensor = */ ggml_backend_metal_buffer_get_tensor,
1961
+ /* .cpy_tensor_from = */ ggml_backend_metal_buffer_cpy_tensor_from,
1962
+ /* .cpy_tensor_to = */ ggml_backend_metal_buffer_cpy_tensor_to,
1661
1963
  };
1662
1964
 
1663
- static ggml_backend_buffer_t ggml_backend_metal_alloc_buffer(ggml_backend_t backend, size_t size) {
1664
- struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
1965
+ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
1966
+ struct ggml_backend_metal_buffer_context * ctx = malloc(sizeof(struct ggml_backend_metal_buffer_context));
1665
1967
 
1666
- void * data = ggml_metal_host_malloc(size);
1968
+ const size_t size_page = sysconf(_SC_PAGESIZE);
1667
1969
 
1668
- // TODO: set proper name of the buffers
1669
- ggml_metal_add_buffer(ctx, "backend", data, size, 0);
1970
+ size_t size_aligned = size;
1971
+ if ((size_aligned % size_page) != 0) {
1972
+ size_aligned += (size_page - (size_aligned % size_page));
1973
+ }
1974
+
1975
+ ctx->data = ggml_metal_host_malloc(size);
1976
+ ctx->metal = [ggml_backend_metal_get_device() newBufferWithBytesNoCopy:ctx->data
1977
+ length:size_aligned
1978
+ options:MTLResourceStorageModeShared
1979
+ deallocator:nil];
1670
1980
 
1671
- return ggml_backend_buffer_init(backend, metal_backend_buffer_i, data, size);
1981
+ return ggml_backend_buffer_init(buft, metal_backend_buffer_i, ctx, size);
1672
1982
  }
1673
1983
 
1674
- static size_t ggml_backend_metal_get_alignment(ggml_backend_t backend) {
1984
+ static size_t ggml_backend_metal_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
1675
1985
  return 32;
1676
- UNUSED(backend);
1986
+ UNUSED(buft);
1677
1987
  }
1678
1988
 
1679
- static void ggml_backend_metal_set_tensor_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
1680
- GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
1681
- GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
1682
-
1683
- memcpy((char *)tensor->data + offset, data, size);
1989
+ static bool ggml_backend_metal_buffer_type_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend) {
1990
+ return ggml_backend_is_metal(backend) || ggml_backend_is_cpu(backend);
1684
1991
 
1685
- UNUSED(backend);
1992
+ GGML_UNUSED(buft);
1686
1993
  }
1687
1994
 
1688
- static void ggml_backend_metal_get_tensor_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
1689
- GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds");
1690
- GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
1691
-
1692
- memcpy(data, (const char *)tensor->data + offset, size);
1995
+ ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void) {
1996
+ static struct ggml_backend_buffer_type ggml_backend_buffer_type_metal = {
1997
+ /* .iface = */ {
1998
+ /* .alloc_buffer = */ ggml_backend_metal_buffer_type_alloc_buffer,
1999
+ /* .get_alignment = */ ggml_backend_metal_buffer_type_get_alignment,
2000
+ /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
2001
+ /* .supports_backend = */ ggml_backend_metal_buffer_type_supports_backend,
2002
+ },
2003
+ /* .context = */ NULL,
2004
+ };
1693
2005
 
1694
- UNUSED(backend);
2006
+ return &ggml_backend_buffer_type_metal;
1695
2007
  }
1696
2008
 
1697
- static void ggml_backend_metal_synchronize(ggml_backend_t backend) {
2009
+ static const char * ggml_backend_metal_name(ggml_backend_t backend) {
2010
+ return "Metal";
2011
+
1698
2012
  UNUSED(backend);
1699
2013
  }
1700
2014
 
1701
- static void ggml_backend_metal_cpy_tensor_from(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst) {
1702
- ggml_backend_tensor_get(src, dst->data, 0, ggml_nbytes(src));
2015
+ static void ggml_backend_metal_free(ggml_backend_t backend) {
2016
+ struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
2017
+ ggml_metal_free(ctx);
2018
+ free(backend);
2019
+ }
1703
2020
 
2021
+ static void ggml_backend_metal_synchronize(ggml_backend_t backend) {
1704
2022
  UNUSED(backend);
1705
2023
  }
1706
2024
 
1707
- static void ggml_backend_metal_cpy_tensor_to(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst) {
1708
- ggml_backend_tensor_set_async(dst, src->data, 0, ggml_nbytes(src));
2025
+ static ggml_backend_buffer_type_t ggml_backend_metal_get_default_buffer_type(ggml_backend_t backend) {
2026
+ return ggml_backend_metal_buffer_type();
1709
2027
 
1710
2028
  UNUSED(backend);
1711
2029
  }
@@ -1717,32 +2035,43 @@ static void ggml_backend_metal_graph_compute(ggml_backend_t backend, struct ggml
1717
2035
  }
1718
2036
 
1719
2037
  static bool ggml_backend_metal_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
1720
- return true;
2038
+ return ggml_metal_supports_op(op);
2039
+
1721
2040
  UNUSED(backend);
1722
- UNUSED(op);
1723
2041
  }
1724
2042
 
1725
2043
  static struct ggml_backend_i metal_backend_i = {
1726
- /* .get_name = */ ggml_backend_metal_name,
1727
- /* .free = */ ggml_backend_metal_free,
1728
- /* .alloc_buffer = */ ggml_backend_metal_alloc_buffer,
1729
- /* .get_alignment = */ ggml_backend_metal_get_alignment,
1730
- /* .set_tensor_async = */ ggml_backend_metal_set_tensor_async,
1731
- /* .get_tensor_async = */ ggml_backend_metal_get_tensor_async,
1732
- /* .synchronize = */ ggml_backend_metal_synchronize,
1733
- /* .cpy_tensor_from = */ ggml_backend_metal_cpy_tensor_from,
1734
- /* .cpy_tensor_to = */ ggml_backend_metal_cpy_tensor_to,
1735
- /* .graph_plan_create = */ NULL, // the metal implementation does not require creating graph plans atm
1736
- /* .graph_plan_free = */ NULL,
1737
- /* .graph_plan_compute = */ NULL,
1738
- /* .graph_compute = */ ggml_backend_metal_graph_compute,
1739
- /* .supports_op = */ ggml_backend_metal_supports_op,
2044
+ /* .get_name = */ ggml_backend_metal_name,
2045
+ /* .free = */ ggml_backend_metal_free,
2046
+ /* .get_default_buffer_type = */ ggml_backend_metal_get_default_buffer_type,
2047
+ /* .set_tensor_async = */ NULL,
2048
+ /* .get_tensor_async = */ NULL,
2049
+ /* .cpy_tensor_from_async = */ NULL,
2050
+ /* .cpy_tensor_to_async = */ NULL,
2051
+ /* .synchronize = */ ggml_backend_metal_synchronize,
2052
+ /* .graph_plan_create = */ NULL, // the metal implementation does not require creating graph plans atm
2053
+ /* .graph_plan_free = */ NULL,
2054
+ /* .graph_plan_compute = */ NULL,
2055
+ /* .graph_compute = */ ggml_backend_metal_graph_compute,
2056
+ /* .supports_op = */ ggml_backend_metal_supports_op,
1740
2057
  };
1741
2058
 
2059
+ // TODO: make a common log callback for all backends in ggml-backend
2060
+ static void ggml_backend_log_callback(enum ggml_log_level level, const char * msg, void * user_data) {
2061
+ fprintf(stderr, "%s", msg);
2062
+
2063
+ UNUSED(level);
2064
+ UNUSED(user_data);
2065
+ }
2066
+
1742
2067
  ggml_backend_t ggml_backend_metal_init(void) {
1743
- struct ggml_metal_context * ctx = malloc(sizeof(struct ggml_metal_context));
2068
+ ggml_metal_log_set_callback(ggml_backend_log_callback, NULL);
2069
+
2070
+ struct ggml_metal_context * ctx = ggml_metal_init(GGML_DEFAULT_N_THREADS);
1744
2071
 
1745
- ctx = ggml_metal_init(GGML_DEFAULT_N_THREADS);
2072
+ if (ctx == NULL) {
2073
+ return NULL;
2074
+ }
1746
2075
 
1747
2076
  ggml_backend_t metal_backend = malloc(sizeof(struct ggml_backend));
1748
2077
 
@@ -1759,7 +2088,26 @@ bool ggml_backend_is_metal(ggml_backend_t backend) {
1759
2088
  }
1760
2089
 
1761
2090
  void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
2091
+ GGML_ASSERT(ggml_backend_is_metal(backend));
2092
+
1762
2093
  struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
1763
2094
 
1764
2095
  ggml_metal_set_n_cb(ctx, n_cb);
1765
2096
  }
2097
+
2098
+ bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family) {
2099
+ GGML_ASSERT(ggml_backend_is_metal(backend));
2100
+
2101
+ struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
2102
+
2103
+ return [ctx->device supportsFamily:(MTLGPUFamilyApple1 + family - 1)];
2104
+ }
2105
+
2106
+ ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data); // silence warning
2107
+
2108
+ ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data) {
2109
+ return ggml_backend_metal_init();
2110
+
2111
+ GGML_UNUSED(params);
2112
+ GGML_UNUSED(user_data);
2113
+ }