llama_cpp 0.9.4 → 0.10.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -62,6 +62,8 @@ struct ggml_metal_context {
62
62
  GGML_METAL_DECL_KERNEL(add_row); // TODO: avoid this extra kernel, instead extend the "add" kernel to support broadcast
63
63
  GGML_METAL_DECL_KERNEL(mul);
64
64
  GGML_METAL_DECL_KERNEL(mul_row); // TODO: avoid this extra kernel, instead extend the "mul" kernel to support broadcast
65
+ GGML_METAL_DECL_KERNEL(div);
66
+ GGML_METAL_DECL_KERNEL(div_row);
65
67
  GGML_METAL_DECL_KERNEL(scale);
66
68
  GGML_METAL_DECL_KERNEL(scale_4);
67
69
  GGML_METAL_DECL_KERNEL(silu);
@@ -112,15 +114,35 @@ struct ggml_metal_context {
112
114
  GGML_METAL_DECL_KERNEL(mul_mm_q4_K_f32);
113
115
  GGML_METAL_DECL_KERNEL(mul_mm_q5_K_f32);
114
116
  GGML_METAL_DECL_KERNEL(mul_mm_q6_K_f32);
117
+ GGML_METAL_DECL_KERNEL(mul_mm_id_f32_f32);
118
+ GGML_METAL_DECL_KERNEL(mul_mm_id_f16_f32);
119
+ GGML_METAL_DECL_KERNEL(mul_mm_id_q4_0_f32);
120
+ GGML_METAL_DECL_KERNEL(mul_mm_id_q4_1_f32);
121
+ GGML_METAL_DECL_KERNEL(mul_mm_id_q5_0_f32);
122
+ GGML_METAL_DECL_KERNEL(mul_mm_id_q5_1_f32);
123
+ GGML_METAL_DECL_KERNEL(mul_mm_id_q8_0_f32);
124
+ GGML_METAL_DECL_KERNEL(mul_mm_id_q2_K_f32);
125
+ GGML_METAL_DECL_KERNEL(mul_mm_id_q3_K_f32);
126
+ GGML_METAL_DECL_KERNEL(mul_mm_id_q4_K_f32);
127
+ GGML_METAL_DECL_KERNEL(mul_mm_id_q5_K_f32);
128
+ GGML_METAL_DECL_KERNEL(mul_mm_id_q6_K_f32);
115
129
  GGML_METAL_DECL_KERNEL(rope_f32);
116
130
  GGML_METAL_DECL_KERNEL(rope_f16);
117
131
  GGML_METAL_DECL_KERNEL(alibi_f32);
118
132
  GGML_METAL_DECL_KERNEL(im2col_f16);
133
+ GGML_METAL_DECL_KERNEL(argsort_f32_i32_asc);
134
+ GGML_METAL_DECL_KERNEL(argsort_f32_i32_desc);
119
135
  GGML_METAL_DECL_KERNEL(cpy_f32_f16);
120
136
  GGML_METAL_DECL_KERNEL(cpy_f32_f32);
137
+ GGML_METAL_DECL_KERNEL(cpy_f32_q8_0);
138
+ GGML_METAL_DECL_KERNEL(cpy_f32_q4_0);
139
+ GGML_METAL_DECL_KERNEL(cpy_f32_q4_1);
140
+ //GGML_METAL_DECL_KERNEL(cpy_f32_q5_0);
141
+ //GGML_METAL_DECL_KERNEL(cpy_f32_q5_1);
121
142
  GGML_METAL_DECL_KERNEL(cpy_f16_f16);
122
143
  GGML_METAL_DECL_KERNEL(concat);
123
144
  GGML_METAL_DECL_KERNEL(sqr);
145
+ GGML_METAL_DECL_KERNEL(sum_rows);
124
146
 
125
147
  #undef GGML_METAL_DECL_KERNEL
126
148
  };
@@ -164,12 +186,10 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
164
186
  }
165
187
  }
166
188
 
167
-
168
-
169
189
  struct ggml_metal_context * ggml_metal_init(int n_cb) {
170
190
  GGML_METAL_LOG_INFO("%s: allocating\n", __func__);
171
191
 
172
- id <MTLDevice> device;
192
+ id<MTLDevice> device;
173
193
  NSString * s;
174
194
 
175
195
  #if TARGET_OS_OSX
@@ -215,6 +235,9 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
215
235
 
216
236
  NSString * sourcePath;
217
237
  NSString * ggmlMetalPathResources = [[NSProcessInfo processInfo].environment objectForKey:@"GGML_METAL_PATH_RESOURCES"];
238
+
239
+ GGML_METAL_LOG_INFO("%s: GGML_METAL_PATH_RESOURCES = %s\n", __func__, ggmlMetalPathResources ? [ggmlMetalPathResources UTF8String] : "nil");
240
+
218
241
  if (ggmlMetalPathResources) {
219
242
  sourcePath = [ggmlMetalPathResources stringByAppendingPathComponent:@"ggml-metal.metal"];
220
243
  } else {
@@ -245,6 +268,29 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
245
268
  }
246
269
  }
247
270
 
271
+ #if TARGET_OS_OSX
272
+ // print MTL GPU family:
273
+ GGML_METAL_LOG_INFO("%s: GPU name: %s\n", __func__, [[ctx->device name] UTF8String]);
274
+
275
+ // determine max supported GPU family
276
+ // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
277
+ // https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
278
+ for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) {
279
+ if ([ctx->device supportsFamily:i]) {
280
+ GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - (int) MTLGPUFamilyApple1 + 1, i);
281
+ break;
282
+ }
283
+ }
284
+
285
+ GGML_METAL_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
286
+ GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1e6);
287
+ if (ctx->device.maxTransferRate != 0) {
288
+ GGML_METAL_LOG_INFO("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1e6);
289
+ } else {
290
+ GGML_METAL_LOG_INFO("%s: maxTransferRate = built-in GPU\n", __func__);
291
+ }
292
+ #endif
293
+
248
294
  // load kernels
249
295
  {
250
296
  NSError * error = nil;
@@ -266,6 +312,8 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
266
312
  GGML_METAL_ADD_KERNEL(add_row);
267
313
  GGML_METAL_ADD_KERNEL(mul);
268
314
  GGML_METAL_ADD_KERNEL(mul_row);
315
+ GGML_METAL_ADD_KERNEL(div);
316
+ GGML_METAL_ADD_KERNEL(div_row);
269
317
  GGML_METAL_ADD_KERNEL(scale);
270
318
  GGML_METAL_ADD_KERNEL(scale_4);
271
319
  GGML_METAL_ADD_KERNEL(silu);
@@ -317,43 +365,40 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
317
365
  GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32);
318
366
  GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32);
319
367
  GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32);
368
+ GGML_METAL_ADD_KERNEL(mul_mm_id_f32_f32);
369
+ GGML_METAL_ADD_KERNEL(mul_mm_id_f16_f32);
370
+ GGML_METAL_ADD_KERNEL(mul_mm_id_q4_0_f32);
371
+ GGML_METAL_ADD_KERNEL(mul_mm_id_q4_1_f32);
372
+ GGML_METAL_ADD_KERNEL(mul_mm_id_q5_0_f32);
373
+ GGML_METAL_ADD_KERNEL(mul_mm_id_q5_1_f32);
374
+ GGML_METAL_ADD_KERNEL(mul_mm_id_q8_0_f32);
375
+ GGML_METAL_ADD_KERNEL(mul_mm_id_q2_K_f32);
376
+ GGML_METAL_ADD_KERNEL(mul_mm_id_q3_K_f32);
377
+ GGML_METAL_ADD_KERNEL(mul_mm_id_q4_K_f32);
378
+ GGML_METAL_ADD_KERNEL(mul_mm_id_q5_K_f32);
379
+ GGML_METAL_ADD_KERNEL(mul_mm_id_q6_K_f32);
320
380
  }
321
381
  GGML_METAL_ADD_KERNEL(rope_f32);
322
382
  GGML_METAL_ADD_KERNEL(rope_f16);
323
383
  GGML_METAL_ADD_KERNEL(alibi_f32);
324
384
  GGML_METAL_ADD_KERNEL(im2col_f16);
385
+ GGML_METAL_ADD_KERNEL(argsort_f32_i32_asc);
386
+ GGML_METAL_ADD_KERNEL(argsort_f32_i32_desc);
325
387
  GGML_METAL_ADD_KERNEL(cpy_f32_f16);
326
388
  GGML_METAL_ADD_KERNEL(cpy_f32_f32);
389
+ GGML_METAL_ADD_KERNEL(cpy_f32_q8_0);
390
+ GGML_METAL_ADD_KERNEL(cpy_f32_q4_0);
391
+ GGML_METAL_ADD_KERNEL(cpy_f32_q4_1);
392
+ //GGML_METAL_ADD_KERNEL(cpy_f32_q5_0);
393
+ //GGML_METAL_ADD_KERNEL(cpy_f32_q5_1);
327
394
  GGML_METAL_ADD_KERNEL(cpy_f16_f16);
328
395
  GGML_METAL_ADD_KERNEL(concat);
329
396
  GGML_METAL_ADD_KERNEL(sqr);
397
+ GGML_METAL_ADD_KERNEL(sum_rows);
330
398
 
331
399
  #undef GGML_METAL_ADD_KERNEL
332
400
  }
333
401
 
334
- #if TARGET_OS_OSX
335
- // print MTL GPU family:
336
- GGML_METAL_LOG_INFO("%s: GPU name: %s\n", __func__, [[ctx->device name] UTF8String]);
337
-
338
- // determine max supported GPU family
339
- // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
340
- // https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
341
- for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) {
342
- if ([ctx->device supportsFamily:i]) {
343
- GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - (int) MTLGPUFamilyApple1 + 1, i);
344
- break;
345
- }
346
- }
347
-
348
- GGML_METAL_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
349
- GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MiB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
350
- if (ctx->device.maxTransferRate != 0) {
351
- GGML_METAL_LOG_INFO("%s: maxTransferRate = %8.2f MiB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
352
- } else {
353
- GGML_METAL_LOG_INFO("%s: maxTransferRate = built-in GPU\n", __func__);
354
- }
355
- #endif
356
-
357
402
  return ctx;
358
403
  }
359
404
 
@@ -367,6 +412,8 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
367
412
  GGML_METAL_DEL_KERNEL(add_row);
368
413
  GGML_METAL_DEL_KERNEL(mul);
369
414
  GGML_METAL_DEL_KERNEL(mul_row);
415
+ GGML_METAL_DEL_KERNEL(div);
416
+ GGML_METAL_DEL_KERNEL(div_row);
370
417
  GGML_METAL_DEL_KERNEL(scale);
371
418
  GGML_METAL_DEL_KERNEL(scale_4);
372
419
  GGML_METAL_DEL_KERNEL(silu);
@@ -418,16 +465,36 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
418
465
  GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32);
419
466
  GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32);
420
467
  GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32);
468
+ GGML_METAL_DEL_KERNEL(mul_mm_id_f32_f32);
469
+ GGML_METAL_DEL_KERNEL(mul_mm_id_f16_f32);
470
+ GGML_METAL_DEL_KERNEL(mul_mm_id_q4_0_f32);
471
+ GGML_METAL_DEL_KERNEL(mul_mm_id_q4_1_f32);
472
+ GGML_METAL_DEL_KERNEL(mul_mm_id_q5_0_f32);
473
+ GGML_METAL_DEL_KERNEL(mul_mm_id_q5_1_f32);
474
+ GGML_METAL_DEL_KERNEL(mul_mm_id_q8_0_f32);
475
+ GGML_METAL_DEL_KERNEL(mul_mm_id_q2_K_f32);
476
+ GGML_METAL_DEL_KERNEL(mul_mm_id_q3_K_f32);
477
+ GGML_METAL_DEL_KERNEL(mul_mm_id_q4_K_f32);
478
+ GGML_METAL_DEL_KERNEL(mul_mm_id_q5_K_f32);
479
+ GGML_METAL_DEL_KERNEL(mul_mm_id_q6_K_f32);
421
480
  }
422
481
  GGML_METAL_DEL_KERNEL(rope_f32);
423
482
  GGML_METAL_DEL_KERNEL(rope_f16);
424
483
  GGML_METAL_DEL_KERNEL(alibi_f32);
425
484
  GGML_METAL_DEL_KERNEL(im2col_f16);
485
+ GGML_METAL_DEL_KERNEL(argsort_f32_i32_asc);
486
+ GGML_METAL_DEL_KERNEL(argsort_f32_i32_desc);
426
487
  GGML_METAL_DEL_KERNEL(cpy_f32_f16);
427
488
  GGML_METAL_DEL_KERNEL(cpy_f32_f32);
489
+ GGML_METAL_DEL_KERNEL(cpy_f32_q8_0);
490
+ GGML_METAL_DEL_KERNEL(cpy_f32_q4_0);
491
+ GGML_METAL_DEL_KERNEL(cpy_f32_q4_1);
492
+ //GGML_METAL_DEL_KERNEL(cpy_f32_q5_0);
493
+ //GGML_METAL_DEL_KERNEL(cpy_f32_q5_1);
428
494
  GGML_METAL_DEL_KERNEL(cpy_f16_f16);
429
495
  GGML_METAL_DEL_KERNEL(concat);
430
496
  GGML_METAL_DEL_KERNEL(sqr);
497
+ GGML_METAL_DEL_KERNEL(sum_rows);
431
498
 
432
499
  #undef GGML_METAL_DEL_KERNEL
433
500
 
@@ -471,6 +538,13 @@ int * ggml_metal_get_concur_list(struct ggml_metal_context * ctx) {
471
538
  return ctx->concur_list;
472
539
  }
473
540
 
541
+ // temporarily defined here for compatibility between ggml-backend and the old API
542
+ struct ggml_backend_metal_buffer_context {
543
+ void * data;
544
+
545
+ id<MTLBuffer> metal;
546
+ };
547
+
474
548
  // finds the Metal buffer that contains the tensor data on the GPU device
475
549
  // the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the
476
550
  // Metal buffer based on the host memory pointer
@@ -480,8 +554,17 @@ static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, stru
480
554
 
481
555
  const int64_t tsize = ggml_nbytes(t);
482
556
 
483
- if (t->buffer && t->buffer->backend && t->buffer->backend->context) {
484
- ctx = t->buffer->backend->context;
557
+ // compatibility with ggml-backend
558
+ if (t->buffer && t->buffer->buft == ggml_backend_metal_buffer_type()) {
559
+ struct ggml_backend_metal_buffer_context * buf_ctx = (struct ggml_backend_metal_buffer_context *) t->buffer->context;
560
+
561
+ const int64_t ioffs = (int64_t) t->data - (int64_t) buf_ctx->data;
562
+
563
+ GGML_ASSERT(ioffs >= 0 && ioffs + tsize <= (int64_t) t->buffer->size);
564
+
565
+ *offs = (size_t) ioffs;
566
+
567
+ return buf_ctx->metal;
485
568
  }
486
569
 
487
570
  // find the view that contains the tensor fully
@@ -706,6 +789,51 @@ void ggml_metal_graph_find_concurrency(
706
789
  }
707
790
  }
708
791
 
792
+ static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
793
+ switch (op->op) {
794
+ case GGML_OP_UNARY:
795
+ switch (ggml_get_unary_op(op)) {
796
+ case GGML_UNARY_OP_SILU:
797
+ case GGML_UNARY_OP_RELU:
798
+ case GGML_UNARY_OP_GELU:
799
+ return true;
800
+ default:
801
+ return false;
802
+ }
803
+ case GGML_OP_NONE:
804
+ case GGML_OP_RESHAPE:
805
+ case GGML_OP_VIEW:
806
+ case GGML_OP_TRANSPOSE:
807
+ case GGML_OP_PERMUTE:
808
+ case GGML_OP_CONCAT:
809
+ case GGML_OP_ADD:
810
+ case GGML_OP_MUL:
811
+ case GGML_OP_DIV:
812
+ case GGML_OP_SCALE:
813
+ case GGML_OP_SQR:
814
+ case GGML_OP_SUM_ROWS:
815
+ case GGML_OP_SOFT_MAX:
816
+ case GGML_OP_RMS_NORM:
817
+ case GGML_OP_NORM:
818
+ case GGML_OP_ALIBI:
819
+ case GGML_OP_ROPE:
820
+ case GGML_OP_IM2COL:
821
+ case GGML_OP_ARGSORT:
822
+ case GGML_OP_DUP:
823
+ case GGML_OP_CPY:
824
+ case GGML_OP_CONT:
825
+ case GGML_OP_MUL_MAT:
826
+ case GGML_OP_MUL_MAT_ID:
827
+ return true;
828
+ case GGML_OP_DIAG_MASK_INF:
829
+ case GGML_OP_GET_ROWS:
830
+ {
831
+ return op->ne[0] % 4 == 0;
832
+ }
833
+ default:
834
+ return false;
835
+ }
836
+ }
709
837
  void ggml_metal_graph_compute(
710
838
  struct ggml_metal_context * ctx,
711
839
  struct ggml_cgraph * gf) {
@@ -776,6 +904,8 @@ void ggml_metal_graph_compute(
776
904
  } break;
777
905
  }
778
906
 
907
+ GGML_ASSERT(ggml_metal_supports_op(dst));
908
+
779
909
  const int64_t ne00 = src0 ? src0->ne[0] : 0;
780
910
  const int64_t ne01 = src0 ? src0->ne[1] : 0;
781
911
  const int64_t ne02 = src0 ? src0->ne[2] : 0;
@@ -868,6 +998,8 @@ void ggml_metal_graph_compute(
868
998
  [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
869
999
  } break;
870
1000
  case GGML_OP_ADD:
1001
+ case GGML_OP_MUL:
1002
+ case GGML_OP_DIV:
871
1003
  {
872
1004
  GGML_ASSERT(ggml_is_contiguous(src0));
873
1005
  GGML_ASSERT(ggml_is_contiguous(src1));
@@ -881,11 +1013,21 @@ void ggml_metal_graph_compute(
881
1013
  GGML_ASSERT(ne11 == 1);
882
1014
 
883
1015
  nb = ne00 / 4;
884
- [encoder setComputePipelineState:ctx->pipeline_add_row];
1016
+ switch (dst->op) {
1017
+ case GGML_OP_ADD: [encoder setComputePipelineState:ctx->pipeline_add_row]; break;
1018
+ case GGML_OP_MUL: [encoder setComputePipelineState:ctx->pipeline_mul_row]; break;
1019
+ case GGML_OP_DIV: [encoder setComputePipelineState:ctx->pipeline_div_row]; break;
1020
+ default: GGML_ASSERT(false);
1021
+ }
885
1022
 
886
1023
  bcast_row = true;
887
1024
  } else {
888
- [encoder setComputePipelineState:ctx->pipeline_add];
1025
+ switch (dst->op) {
1026
+ case GGML_OP_ADD: [encoder setComputePipelineState:ctx->pipeline_add]; break;
1027
+ case GGML_OP_MUL: [encoder setComputePipelineState:ctx->pipeline_mul]; break;
1028
+ case GGML_OP_DIV: [encoder setComputePipelineState:ctx->pipeline_div]; break;
1029
+ default: GGML_ASSERT(false);
1030
+ }
889
1031
  }
890
1032
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
891
1033
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
@@ -926,31 +1068,6 @@ void ggml_metal_graph_compute(
926
1068
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
927
1069
  }
928
1070
  } break;
929
- case GGML_OP_MUL:
930
- {
931
- GGML_ASSERT(ggml_is_contiguous(src0));
932
- GGML_ASSERT(ggml_is_contiguous(src1));
933
-
934
- // utilize float4
935
- GGML_ASSERT(ne00 % 4 == 0);
936
- const int64_t nb = ne00/4;
937
-
938
- if (ggml_nelements(src1) == ne10) {
939
- // src1 is a row
940
- GGML_ASSERT(ne11 == 1);
941
- [encoder setComputePipelineState:ctx->pipeline_mul_row];
942
- } else {
943
- [encoder setComputePipelineState:ctx->pipeline_mul];
944
- }
945
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
946
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
947
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
948
- [encoder setBytes:&nb length:sizeof(nb) atIndex:3];
949
-
950
- const int64_t n = ggml_nelements(dst)/4;
951
-
952
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
953
- } break;
954
1071
  case GGML_OP_SCALE:
955
1072
  {
956
1073
  GGML_ASSERT(ggml_is_contiguous(src0));
@@ -1023,25 +1140,66 @@ void ggml_metal_graph_compute(
1023
1140
  const int64_t n = ggml_nelements(dst);
1024
1141
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1025
1142
  } break;
1143
+ case GGML_OP_SUM_ROWS:
1144
+ {
1145
+ GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
1146
+
1147
+ [encoder setComputePipelineState:ctx->pipeline_sum_rows];
1148
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1149
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1150
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
1151
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
1152
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
1153
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
1154
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
1155
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
1156
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
1157
+ [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
1158
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
1159
+ [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
1160
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
1161
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
1162
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
1163
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
1164
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
1165
+ [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17];
1166
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:18];
1167
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:19];
1168
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:20];
1169
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:21];
1170
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:22];
1171
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:23];
1172
+ [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:24];
1173
+ [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:25];
1174
+
1175
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1176
+ } break;
1026
1177
  case GGML_OP_SOFT_MAX:
1027
1178
  {
1028
1179
  int nth = 32; // SIMD width
1029
1180
 
1030
1181
  if (ne00%4 == 0) {
1182
+ while (nth < ne00/4 && nth < 256) {
1183
+ nth *= 2;
1184
+ }
1031
1185
  [encoder setComputePipelineState:ctx->pipeline_soft_max_4];
1032
1186
  } else {
1033
- do {
1187
+ while (nth < ne00 && nth < 1024) {
1034
1188
  nth *= 2;
1035
- } while (nth <= ne00 && nth <= 1024);
1036
- nth /= 2;
1189
+ }
1037
1190
  [encoder setComputePipelineState:ctx->pipeline_soft_max];
1038
1191
  }
1039
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1040
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1041
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
1042
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
1043
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
1044
- [encoder setThreadgroupMemoryLength:GGML_PAD(nth/32*sizeof(float), 16) atIndex:0];
1192
+
1193
+ const float scale = ((float *) dst->op_params)[0];
1194
+
1195
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1196
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1197
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1198
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1199
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
1200
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
1201
+ [encoder setBytes:&scale length:sizeof(scale) atIndex:6];
1202
+ [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
1045
1203
 
1046
1204
  [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1047
1205
  } break;
@@ -1070,9 +1228,13 @@ void ggml_metal_graph_compute(
1070
1228
  case GGML_OP_MUL_MAT:
1071
1229
  {
1072
1230
  GGML_ASSERT(ne00 == ne10);
1073
- GGML_ASSERT(ne03 == ne13);
1074
1231
 
1075
- const uint gqa = ne12/ne02;
1232
+ // TODO: assert that dim2 and dim3 are contiguous
1233
+ GGML_ASSERT(ne12 % ne02 == 0);
1234
+ GGML_ASSERT(ne13 % ne03 == 0);
1235
+
1236
+ const uint r2 = ne12/ne02;
1237
+ const uint r3 = ne13/ne03;
1076
1238
 
1077
1239
  // find the break-even point where the matrix-matrix kernel becomes more efficient compared
1078
1240
  // to the matrix-vector kernel
@@ -1107,7 +1269,7 @@ void ggml_metal_graph_compute(
1107
1269
  !ggml_is_transposed(src1) &&
1108
1270
  src1t == GGML_TYPE_F32 &&
1109
1271
  ne00 % 32 == 0 && ne00 >= 64 &&
1110
- ne11 > ne11_mm_min) {
1272
+ (ne11 > ne11_mm_min || (ggml_is_quantized(src0t) && ne12 > 1))) {
1111
1273
  //printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
1112
1274
  switch (src0->type) {
1113
1275
  case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f32_f32]; break;
@@ -1137,9 +1299,10 @@ void ggml_metal_graph_compute(
1137
1299
  [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10];
1138
1300
  [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11];
1139
1301
  [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12];
1140
- [encoder setBytes:&gqa length:sizeof(gqa) atIndex:13];
1302
+ [encoder setBytes:&r2 length:sizeof(r2) atIndex:13];
1303
+ [encoder setBytes:&r3 length:sizeof(r3) atIndex:14];
1141
1304
  [encoder setThreadgroupMemoryLength:8192 atIndex:0];
1142
- [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
1305
+ [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
1143
1306
  } else {
1144
1307
  int nth0 = 32;
1145
1308
  int nth1 = 1;
@@ -1175,90 +1338,60 @@ void ggml_metal_graph_compute(
1175
1338
  } break;
1176
1339
  case GGML_TYPE_Q4_0:
1177
1340
  {
1178
- GGML_ASSERT(ne02 == 1);
1179
- GGML_ASSERT(ne12 == 1);
1180
-
1181
1341
  nth0 = 8;
1182
1342
  nth1 = 8;
1183
1343
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_0_f32];
1184
1344
  } break;
1185
1345
  case GGML_TYPE_Q4_1:
1186
1346
  {
1187
- GGML_ASSERT(ne02 == 1);
1188
- GGML_ASSERT(ne12 == 1);
1189
-
1190
1347
  nth0 = 8;
1191
1348
  nth1 = 8;
1192
1349
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_1_f32];
1193
1350
  } break;
1194
1351
  case GGML_TYPE_Q5_0:
1195
1352
  {
1196
- GGML_ASSERT(ne02 == 1);
1197
- GGML_ASSERT(ne12 == 1);
1198
-
1199
1353
  nth0 = 8;
1200
1354
  nth1 = 8;
1201
1355
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_0_f32];
1202
1356
  } break;
1203
1357
  case GGML_TYPE_Q5_1:
1204
1358
  {
1205
- GGML_ASSERT(ne02 == 1);
1206
- GGML_ASSERT(ne12 == 1);
1207
-
1208
1359
  nth0 = 8;
1209
1360
  nth1 = 8;
1210
1361
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_1_f32];
1211
1362
  } break;
1212
1363
  case GGML_TYPE_Q8_0:
1213
1364
  {
1214
- GGML_ASSERT(ne02 == 1);
1215
- GGML_ASSERT(ne12 == 1);
1216
-
1217
1365
  nth0 = 8;
1218
1366
  nth1 = 8;
1219
1367
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_q8_0_f32];
1220
1368
  } break;
1221
1369
  case GGML_TYPE_Q2_K:
1222
1370
  {
1223
- GGML_ASSERT(ne02 == 1);
1224
- GGML_ASSERT(ne12 == 1);
1225
-
1226
1371
  nth0 = 2;
1227
1372
  nth1 = 32;
1228
1373
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_q2_K_f32];
1229
1374
  } break;
1230
1375
  case GGML_TYPE_Q3_K:
1231
1376
  {
1232
- GGML_ASSERT(ne02 == 1);
1233
- GGML_ASSERT(ne12 == 1);
1234
-
1235
1377
  nth0 = 2;
1236
1378
  nth1 = 32;
1237
1379
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_q3_K_f32];
1238
1380
  } break;
1239
1381
  case GGML_TYPE_Q4_K:
1240
1382
  {
1241
- GGML_ASSERT(ne02 == 1);
1242
- GGML_ASSERT(ne12 == 1);
1243
-
1244
1383
  nth0 = 4; //1;
1245
1384
  nth1 = 8; //32;
1246
1385
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_K_f32];
1247
1386
  } break;
1248
1387
  case GGML_TYPE_Q5_K:
1249
1388
  {
1250
- GGML_ASSERT(ne02 == 1);
1251
- GGML_ASSERT(ne12 == 1);
1252
-
1253
1389
  nth0 = 2;
1254
1390
  nth1 = 32;
1255
1391
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_K_f32];
1256
1392
  } break;
1257
1393
  case GGML_TYPE_Q6_K:
1258
1394
  {
1259
- GGML_ASSERT(ne02 == 1);
1260
- GGML_ASSERT(ne12 == 1);
1261
-
1262
1395
  nth0 = 2;
1263
1396
  nth1 = 32;
1264
1397
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_q6_K_f32];
@@ -1287,32 +1420,125 @@ void ggml_metal_graph_compute(
1287
1420
  [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14];
1288
1421
  [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15];
1289
1422
  [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
1290
- [encoder setBytes:&gqa length:sizeof(gqa) atIndex:17];
1423
+ [encoder setBytes:&r2 length:sizeof(r2) atIndex:17];
1424
+ [encoder setBytes:&r3 length:sizeof(r3) atIndex:18];
1291
1425
 
1292
1426
  if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
1293
1427
  src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 ||
1294
1428
  src0t == GGML_TYPE_Q2_K) { // || src0t == GGML_TYPE_Q4_K) {
1295
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1429
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1296
1430
  }
1297
1431
  else if (src0t == GGML_TYPE_Q4_K) {
1298
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1432
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1299
1433
  }
1300
1434
  else if (src0t == GGML_TYPE_Q3_K) {
1301
1435
  #ifdef GGML_QKK_64
1302
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1436
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1303
1437
  #else
1304
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1438
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1305
1439
  #endif
1306
1440
  }
1307
1441
  else if (src0t == GGML_TYPE_Q5_K) {
1308
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1442
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1309
1443
  }
1310
1444
  else if (src0t == GGML_TYPE_Q6_K) {
1311
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1445
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1312
1446
  } else {
1313
1447
  int64_t ny = (ne11 + nrows - 1)/nrows;
1314
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1448
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1449
+ }
1450
+ }
1451
+ } break;
1452
+ case GGML_OP_MUL_MAT_ID:
1453
+ {
1454
+ //GGML_ASSERT(ne00 == ne10);
1455
+ //GGML_ASSERT(ne03 == ne13);
1456
+
1457
+ GGML_ASSERT(src0t == GGML_TYPE_I32);
1458
+
1459
+ const int n_as = ne00;
1460
+
1461
+ // TODO: make this more general
1462
+ GGML_ASSERT(n_as <= 8);
1463
+
1464
+ struct ggml_tensor * src2 = gf->nodes[i]->src[2];
1465
+
1466
+ const int64_t ne20 = src2 ? src2->ne[0] : 0;
1467
+ const int64_t ne21 = src2 ? src2->ne[1] : 0;
1468
+ const int64_t ne22 = src2 ? src2->ne[2] : 0;
1469
+ const int64_t ne23 = src2 ? src2->ne[3] : 0; GGML_UNUSED(ne23);
1470
+
1471
+ const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20);
1472
+ const uint64_t nb21 = src2 ? src2->nb[1] : 0;
1473
+ const uint64_t nb22 = src2 ? src2->nb[2] : 0;
1474
+ const uint64_t nb23 = src2 ? src2->nb[3] : 0; GGML_UNUSED(nb23);
1475
+
1476
+ const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t);
1477
+
1478
+ GGML_ASSERT(!ggml_is_transposed(src2));
1479
+ GGML_ASSERT(!ggml_is_transposed(src1));
1480
+
1481
+ GGML_ASSERT(ne20 % 32 == 0);
1482
+ // !!!!!!!!! TODO: this assert is probably required but not sure!
1483
+ //GGML_ASSERT(ne20 >= 64);
1484
+ GGML_ASSERT(src1t == GGML_TYPE_F32);
1485
+
1486
+ const uint r2 = ne12/ne22;
1487
+ const uint r3 = ne13/ne23;
1488
+
1489
+ // find the break-even point where the matrix-matrix kernel becomes more efficient compared
1490
+ // to the matrix-vector kernel
1491
+ int ne11_mm_min = 0;
1492
+
1493
+ const int idx = ((int32_t *) dst->op_params)[0];
1494
+
1495
+ // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
1496
+ // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
1497
+ if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
1498
+ ne11 > ne11_mm_min) {
1499
+ switch (src2->type) {
1500
+ case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f32_f32]; break;
1501
+ case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f16_f32]; break;
1502
+ case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q4_0_f32]; break;
1503
+ case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q4_1_f32]; break;
1504
+ case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_0_f32]; break;
1505
+ case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_1_f32]; break;
1506
+ case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q8_0_f32]; break;
1507
+ case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q2_K_f32]; break;
1508
+ case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q3_K_f32]; break;
1509
+ case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q4_K_f32]; break;
1510
+ case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_K_f32]; break;
1511
+ case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q6_K_f32]; break;
1512
+ default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
1315
1513
  }
1514
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1515
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1516
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1517
+ [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:3];
1518
+ [encoder setBytes:&ne22 length:sizeof(ne22) atIndex:4];
1519
+ [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:5];
1520
+ [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:6];
1521
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
1522
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:8];
1523
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:9];
1524
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10];
1525
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11];
1526
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12];
1527
+ [encoder setBytes:&r2 length:sizeof(r2) atIndex:13];
1528
+ [encoder setBytes:&r3 length:sizeof(r3) atIndex:14];
1529
+ [encoder setBytes:&idx length:sizeof(idx) atIndex:15];
1530
+ // TODO: how to make this an array? read Metal docs
1531
+ for (int j = 0; j < n_as; ++j) {
1532
+ struct ggml_tensor * src_cur = dst->src[2 + j];
1533
+
1534
+ size_t offs_src_cur = 0;
1535
+ id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
1536
+
1537
+ [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:16 + j];
1538
+ }
1539
+
1540
+ [encoder setThreadgroupMemoryLength:8192 atIndex:0];
1541
+ [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne21 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
1316
1542
  }
1317
1543
  } break;
1318
1544
  case GGML_OP_GET_ROWS:
@@ -1351,15 +1577,19 @@ void ggml_metal_graph_compute(
1351
1577
  float eps;
1352
1578
  memcpy(&eps, dst->op_params, sizeof(float));
1353
1579
 
1354
- const int nth = MIN(512, ne00);
1580
+ int nth = 32; // SIMD width
1581
+
1582
+ while (nth < ne00/4 && nth < 1024) {
1583
+ nth *= 2;
1584
+ }
1355
1585
 
1356
1586
  [encoder setComputePipelineState:ctx->pipeline_rms_norm];
1357
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1358
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1359
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1360
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
1361
- [encoder setBytes:&eps length:sizeof( float) atIndex:4];
1362
- [encoder setThreadgroupMemoryLength:GGML_PAD(nth/32*sizeof(float), 16) atIndex:0];
1587
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1588
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1589
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1590
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
1591
+ [encoder setBytes:&eps length:sizeof( float) atIndex:4];
1592
+ [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
1363
1593
 
1364
1594
  const int64_t nrows = ggml_nrows(src0);
1365
1595
 
@@ -1433,7 +1663,8 @@ void ggml_metal_graph_compute(
1433
1663
  const int n_past = ((int32_t *) dst->op_params)[0];
1434
1664
  const int n_dims = ((int32_t *) dst->op_params)[1];
1435
1665
  const int mode = ((int32_t *) dst->op_params)[2];
1436
- const int n_orig_ctx = ((int32_t *) dst->op_params)[3];
1666
+ // skip 3, n_ctx, used in GLM RoPE, unimplemented in metal
1667
+ const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
1437
1668
 
1438
1669
  float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
1439
1670
  memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
@@ -1533,18 +1764,48 @@ void ggml_metal_graph_compute(
1533
1764
 
1534
1765
  [encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
1535
1766
  } break;
1767
+ case GGML_OP_ARGSORT:
1768
+ {
1769
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
1770
+ GGML_ASSERT( dst->type == GGML_TYPE_I32);
1771
+
1772
+ const int nrows = ggml_nrows(src0);
1773
+
1774
+ enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
1775
+
1776
+ switch (order) {
1777
+ case GGML_SORT_ASC: [encoder setComputePipelineState:ctx->pipeline_argsort_f32_i32_asc]; break;
1778
+ case GGML_SORT_DESC: [encoder setComputePipelineState:ctx->pipeline_argsort_f32_i32_desc]; break;
1779
+ default: GGML_ASSERT(false);
1780
+ };
1781
+
1782
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1783
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1784
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1785
+
1786
+ [encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00, 1, 1)];
1787
+ } break;
1536
1788
  case GGML_OP_DUP:
1537
1789
  case GGML_OP_CPY:
1538
1790
  case GGML_OP_CONT:
1539
1791
  {
1540
- const int nth = MIN(1024, ne00);
1792
+ GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
1793
+
1794
+ int nth = MIN(1024, ne00/ggml_blck_size(src0->type));
1541
1795
 
1542
1796
  switch (src0t) {
1543
1797
  case GGML_TYPE_F32:
1544
1798
  {
1799
+ GGML_ASSERT(ne0 % ggml_blck_size(dst->type) == 0);
1800
+
1545
1801
  switch (dstt) {
1546
- case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f16]; break;
1547
- case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32]; break;
1802
+ case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f16]; break;
1803
+ case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32]; break;
1804
+ case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q8_0]; break;
1805
+ case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q4_0]; break;
1806
+ case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q4_1]; break;
1807
+ //case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q5_0]; break;
1808
+ //case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q5_1]; break;
1548
1809
  default: GGML_ASSERT(false && "not implemented");
1549
1810
  };
1550
1811
  } break;
@@ -1619,81 +1880,150 @@ void ggml_metal_graph_compute(
1619
1880
 
1620
1881
  // backend interface
1621
1882
 
1622
- static const char * ggml_backend_metal_name(ggml_backend_t backend) {
1623
- return "Metal";
1883
+ static id<MTLDevice> g_backend_device = nil;
1884
+ static int g_backend_device_ref_count = 0;
1624
1885
 
1625
- UNUSED(backend);
1886
+ static id<MTLDevice> ggml_backend_metal_get_device(void) {
1887
+ if (g_backend_device == nil) {
1888
+ g_backend_device = MTLCreateSystemDefaultDevice();
1889
+ }
1890
+
1891
+ g_backend_device_ref_count++;
1892
+
1893
+ return g_backend_device;
1626
1894
  }
1627
1895
 
1628
- static void ggml_backend_metal_free(ggml_backend_t backend) {
1629
- struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
1630
- ggml_metal_free(ctx);
1631
- free(backend);
1896
+ static void ggml_backend_metal_free_device(void) {
1897
+ assert(g_backend_device_ref_count > 0);
1898
+
1899
+ g_backend_device_ref_count--;
1900
+
1901
+ if (g_backend_device_ref_count == 0) {
1902
+ [g_backend_device release];
1903
+ g_backend_device = nil;
1904
+ }
1632
1905
  }
1633
1906
 
1634
1907
  static void * ggml_backend_metal_buffer_get_base(ggml_backend_buffer_t buffer) {
1635
- return (void *)buffer->context;
1908
+ struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context;
1909
+
1910
+ return ctx->data;
1636
1911
  }
1637
1912
 
1638
1913
  static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer) {
1639
- free(buffer->context);
1914
+ struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context;
1915
+
1916
+ [ctx->metal release];
1917
+ ggml_backend_metal_free_device();
1918
+
1919
+ free(ctx->data);
1920
+ free(ctx);
1921
+
1922
+ UNUSED(buffer);
1923
+ }
1924
+
1925
+ static void ggml_backend_metal_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
1926
+ GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
1927
+ GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
1928
+
1929
+ memcpy((char *)tensor->data + offset, data, size);
1930
+
1931
+ UNUSED(buffer);
1932
+ }
1933
+
1934
+ static void ggml_backend_metal_buffer_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
1935
+ GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds");
1936
+ GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
1937
+
1938
+ memcpy(data, (const char *)tensor->data + offset, size);
1939
+
1940
+ UNUSED(buffer);
1941
+ }
1942
+
1943
+ static void ggml_backend_metal_buffer_cpy_tensor_from(ggml_backend_buffer_t buffer, struct ggml_tensor * src, struct ggml_tensor * dst) {
1944
+ ggml_backend_tensor_get(src, dst->data, 0, ggml_nbytes(src));
1945
+
1946
+ UNUSED(buffer);
1947
+ }
1948
+
1949
+ static void ggml_backend_metal_buffer_cpy_tensor_to(ggml_backend_buffer_t buffer, struct ggml_tensor * src, struct ggml_tensor * dst) {
1950
+ ggml_backend_tensor_set(dst, src->data, 0, ggml_nbytes(src));
1951
+
1640
1952
  UNUSED(buffer);
1641
1953
  }
1642
1954
 
1643
1955
  static struct ggml_backend_buffer_i metal_backend_buffer_i = {
1644
- /* .free_buffer = */ ggml_backend_metal_buffer_free_buffer,
1645
- /* .get_base = */ ggml_backend_metal_buffer_get_base,
1646
- /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
1647
- /* .init_tensor = */ NULL, // no initialization required
1648
- /* .free_tensor = */ NULL, // no cleanup required
1956
+ /* .free_buffer = */ ggml_backend_metal_buffer_free_buffer,
1957
+ /* .get_base = */ ggml_backend_metal_buffer_get_base,
1958
+ /* .init_tensor = */ NULL,
1959
+ /* .set_tensor = */ ggml_backend_metal_buffer_set_tensor,
1960
+ /* .get_tensor = */ ggml_backend_metal_buffer_get_tensor,
1961
+ /* .cpy_tensor_from = */ ggml_backend_metal_buffer_cpy_tensor_from,
1962
+ /* .cpy_tensor_to = */ ggml_backend_metal_buffer_cpy_tensor_to,
1649
1963
  };
1650
1964
 
1651
- static ggml_backend_buffer_t ggml_backend_metal_alloc_buffer(ggml_backend_t backend, size_t size) {
1652
- struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
1965
+ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
1966
+ struct ggml_backend_metal_buffer_context * ctx = malloc(sizeof(struct ggml_backend_metal_buffer_context));
1653
1967
 
1654
- void * data = ggml_metal_host_malloc(size);
1968
+ const size_t size_page = sysconf(_SC_PAGESIZE);
1655
1969
 
1656
- // TODO: set proper name of the buffers
1657
- ggml_metal_add_buffer(ctx, "backend", data, size, 0);
1970
+ size_t size_aligned = size;
1971
+ if ((size_aligned % size_page) != 0) {
1972
+ size_aligned += (size_page - (size_aligned % size_page));
1973
+ }
1974
+
1975
+ ctx->data = ggml_metal_host_malloc(size);
1976
+ ctx->metal = [ggml_backend_metal_get_device() newBufferWithBytesNoCopy:ctx->data
1977
+ length:size_aligned
1978
+ options:MTLResourceStorageModeShared
1979
+ deallocator:nil];
1658
1980
 
1659
- return ggml_backend_buffer_init(backend, metal_backend_buffer_i, data, size);
1981
+ return ggml_backend_buffer_init(buft, metal_backend_buffer_i, ctx, size);
1660
1982
  }
1661
1983
 
1662
- static size_t ggml_backend_metal_get_alignment(ggml_backend_t backend) {
1984
+ static size_t ggml_backend_metal_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
1663
1985
  return 32;
1664
- UNUSED(backend);
1986
+ UNUSED(buft);
1665
1987
  }
1666
1988
 
1667
- static void ggml_backend_metal_set_tensor_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
1668
- GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
1669
- GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
1670
-
1671
- memcpy((char *)tensor->data + offset, data, size);
1989
+ static bool ggml_backend_metal_buffer_type_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend) {
1990
+ return ggml_backend_is_metal(backend) || ggml_backend_is_cpu(backend);
1672
1991
 
1673
- UNUSED(backend);
1992
+ GGML_UNUSED(buft);
1674
1993
  }
1675
1994
 
1676
- static void ggml_backend_metal_get_tensor_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
1677
- GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds");
1678
- GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
1679
-
1680
- memcpy(data, (const char *)tensor->data + offset, size);
1995
+ ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void) {
1996
+ static struct ggml_backend_buffer_type ggml_backend_buffer_type_metal = {
1997
+ /* .iface = */ {
1998
+ /* .alloc_buffer = */ ggml_backend_metal_buffer_type_alloc_buffer,
1999
+ /* .get_alignment = */ ggml_backend_metal_buffer_type_get_alignment,
2000
+ /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
2001
+ /* .supports_backend = */ ggml_backend_metal_buffer_type_supports_backend,
2002
+ },
2003
+ /* .context = */ NULL,
2004
+ };
1681
2005
 
1682
- UNUSED(backend);
2006
+ return &ggml_backend_buffer_type_metal;
1683
2007
  }
1684
2008
 
1685
- static void ggml_backend_metal_synchronize(ggml_backend_t backend) {
2009
+ static const char * ggml_backend_metal_name(ggml_backend_t backend) {
2010
+ return "Metal";
2011
+
1686
2012
  UNUSED(backend);
1687
2013
  }
1688
2014
 
1689
- static void ggml_backend_metal_cpy_tensor_from(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst) {
1690
- ggml_backend_tensor_get(src, dst->data, 0, ggml_nbytes(src));
2015
+ static void ggml_backend_metal_free(ggml_backend_t backend) {
2016
+ struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
2017
+ ggml_metal_free(ctx);
2018
+ free(backend);
2019
+ }
1691
2020
 
2021
+ static void ggml_backend_metal_synchronize(ggml_backend_t backend) {
1692
2022
  UNUSED(backend);
1693
2023
  }
1694
2024
 
1695
- static void ggml_backend_metal_cpy_tensor_to(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst) {
1696
- ggml_backend_tensor_set_async(dst, src->data, 0, ggml_nbytes(src));
2025
+ static ggml_backend_buffer_type_t ggml_backend_metal_get_default_buffer_type(ggml_backend_t backend) {
2026
+ return ggml_backend_metal_buffer_type();
1697
2027
 
1698
2028
  UNUSED(backend);
1699
2029
  }
@@ -1705,32 +2035,43 @@ static void ggml_backend_metal_graph_compute(ggml_backend_t backend, struct ggml
1705
2035
  }
1706
2036
 
1707
2037
  static bool ggml_backend_metal_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
1708
- return true;
2038
+ return ggml_metal_supports_op(op);
2039
+
1709
2040
  UNUSED(backend);
1710
- UNUSED(op);
1711
2041
  }
1712
2042
 
1713
2043
  static struct ggml_backend_i metal_backend_i = {
1714
- /* .get_name = */ ggml_backend_metal_name,
1715
- /* .free = */ ggml_backend_metal_free,
1716
- /* .alloc_buffer = */ ggml_backend_metal_alloc_buffer,
1717
- /* .get_alignment = */ ggml_backend_metal_get_alignment,
1718
- /* .set_tensor_async = */ ggml_backend_metal_set_tensor_async,
1719
- /* .get_tensor_async = */ ggml_backend_metal_get_tensor_async,
1720
- /* .synchronize = */ ggml_backend_metal_synchronize,
1721
- /* .cpy_tensor_from = */ ggml_backend_metal_cpy_tensor_from,
1722
- /* .cpy_tensor_to = */ ggml_backend_metal_cpy_tensor_to,
1723
- /* .graph_plan_create = */ NULL, // the metal implementation does not require creating graph plans atm
1724
- /* .graph_plan_free = */ NULL,
1725
- /* .graph_plan_compute = */ NULL,
1726
- /* .graph_compute = */ ggml_backend_metal_graph_compute,
1727
- /* .supports_op = */ ggml_backend_metal_supports_op,
2044
+ /* .get_name = */ ggml_backend_metal_name,
2045
+ /* .free = */ ggml_backend_metal_free,
2046
+ /* .get_default_buffer_type = */ ggml_backend_metal_get_default_buffer_type,
2047
+ /* .set_tensor_async = */ NULL,
2048
+ /* .get_tensor_async = */ NULL,
2049
+ /* .cpy_tensor_from_async = */ NULL,
2050
+ /* .cpy_tensor_to_async = */ NULL,
2051
+ /* .synchronize = */ ggml_backend_metal_synchronize,
2052
+ /* .graph_plan_create = */ NULL, // the metal implementation does not require creating graph plans atm
2053
+ /* .graph_plan_free = */ NULL,
2054
+ /* .graph_plan_compute = */ NULL,
2055
+ /* .graph_compute = */ ggml_backend_metal_graph_compute,
2056
+ /* .supports_op = */ ggml_backend_metal_supports_op,
1728
2057
  };
1729
2058
 
2059
+ // TODO: make a common log callback for all backends in ggml-backend
2060
+ static void ggml_backend_log_callback(enum ggml_log_level level, const char * msg, void * user_data) {
2061
+ fprintf(stderr, "%s", msg);
2062
+
2063
+ UNUSED(level);
2064
+ UNUSED(user_data);
2065
+ }
2066
+
1730
2067
  ggml_backend_t ggml_backend_metal_init(void) {
1731
- struct ggml_metal_context * ctx = malloc(sizeof(struct ggml_metal_context));
2068
+ ggml_metal_log_set_callback(ggml_backend_log_callback, NULL);
2069
+
2070
+ struct ggml_metal_context * ctx = ggml_metal_init(GGML_DEFAULT_N_THREADS);
1732
2071
 
1733
- ctx = ggml_metal_init(GGML_DEFAULT_N_THREADS);
2072
+ if (ctx == NULL) {
2073
+ return NULL;
2074
+ }
1734
2075
 
1735
2076
  ggml_backend_t metal_backend = malloc(sizeof(struct ggml_backend));
1736
2077
 
@@ -1747,7 +2088,26 @@ bool ggml_backend_is_metal(ggml_backend_t backend) {
1747
2088
  }
1748
2089
 
1749
2090
  void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
2091
+ GGML_ASSERT(ggml_backend_is_metal(backend));
2092
+
1750
2093
  struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
1751
2094
 
1752
2095
  ggml_metal_set_n_cb(ctx, n_cb);
1753
2096
  }
2097
+
2098
+ bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family) {
2099
+ GGML_ASSERT(ggml_backend_is_metal(backend));
2100
+
2101
+ struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
2102
+
2103
+ return [ctx->device supportsFamily:(MTLGPUFamilyApple1 + family - 1)];
2104
+ }
2105
+
2106
+ ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data); // silence warning
2107
+
2108
+ ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data) {
2109
+ return ggml_backend_metal_init();
2110
+
2111
+ GGML_UNUSED(params);
2112
+ GGML_UNUSED(user_data);
2113
+ }