llama_cpp 0.9.4 → 0.10.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -62,6 +62,8 @@ struct ggml_metal_context {
62
62
  GGML_METAL_DECL_KERNEL(add_row); // TODO: avoid this extra kernel, instead extend the "add" kernel to support broadcast
63
63
  GGML_METAL_DECL_KERNEL(mul);
64
64
  GGML_METAL_DECL_KERNEL(mul_row); // TODO: avoid this extra kernel, instead extend the "mul" kernel to support broadcast
65
+ GGML_METAL_DECL_KERNEL(div);
66
+ GGML_METAL_DECL_KERNEL(div_row);
65
67
  GGML_METAL_DECL_KERNEL(scale);
66
68
  GGML_METAL_DECL_KERNEL(scale_4);
67
69
  GGML_METAL_DECL_KERNEL(silu);
@@ -112,15 +114,35 @@ struct ggml_metal_context {
112
114
  GGML_METAL_DECL_KERNEL(mul_mm_q4_K_f32);
113
115
  GGML_METAL_DECL_KERNEL(mul_mm_q5_K_f32);
114
116
  GGML_METAL_DECL_KERNEL(mul_mm_q6_K_f32);
117
+ GGML_METAL_DECL_KERNEL(mul_mm_id_f32_f32);
118
+ GGML_METAL_DECL_KERNEL(mul_mm_id_f16_f32);
119
+ GGML_METAL_DECL_KERNEL(mul_mm_id_q4_0_f32);
120
+ GGML_METAL_DECL_KERNEL(mul_mm_id_q4_1_f32);
121
+ GGML_METAL_DECL_KERNEL(mul_mm_id_q5_0_f32);
122
+ GGML_METAL_DECL_KERNEL(mul_mm_id_q5_1_f32);
123
+ GGML_METAL_DECL_KERNEL(mul_mm_id_q8_0_f32);
124
+ GGML_METAL_DECL_KERNEL(mul_mm_id_q2_K_f32);
125
+ GGML_METAL_DECL_KERNEL(mul_mm_id_q3_K_f32);
126
+ GGML_METAL_DECL_KERNEL(mul_mm_id_q4_K_f32);
127
+ GGML_METAL_DECL_KERNEL(mul_mm_id_q5_K_f32);
128
+ GGML_METAL_DECL_KERNEL(mul_mm_id_q6_K_f32);
115
129
  GGML_METAL_DECL_KERNEL(rope_f32);
116
130
  GGML_METAL_DECL_KERNEL(rope_f16);
117
131
  GGML_METAL_DECL_KERNEL(alibi_f32);
118
132
  GGML_METAL_DECL_KERNEL(im2col_f16);
133
+ GGML_METAL_DECL_KERNEL(argsort_f32_i32_asc);
134
+ GGML_METAL_DECL_KERNEL(argsort_f32_i32_desc);
119
135
  GGML_METAL_DECL_KERNEL(cpy_f32_f16);
120
136
  GGML_METAL_DECL_KERNEL(cpy_f32_f32);
137
+ GGML_METAL_DECL_KERNEL(cpy_f32_q8_0);
138
+ GGML_METAL_DECL_KERNEL(cpy_f32_q4_0);
139
+ GGML_METAL_DECL_KERNEL(cpy_f32_q4_1);
140
+ //GGML_METAL_DECL_KERNEL(cpy_f32_q5_0);
141
+ //GGML_METAL_DECL_KERNEL(cpy_f32_q5_1);
121
142
  GGML_METAL_DECL_KERNEL(cpy_f16_f16);
122
143
  GGML_METAL_DECL_KERNEL(concat);
123
144
  GGML_METAL_DECL_KERNEL(sqr);
145
+ GGML_METAL_DECL_KERNEL(sum_rows);
124
146
 
125
147
  #undef GGML_METAL_DECL_KERNEL
126
148
  };
@@ -164,12 +186,10 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
164
186
  }
165
187
  }
166
188
 
167
-
168
-
169
189
  struct ggml_metal_context * ggml_metal_init(int n_cb) {
170
190
  GGML_METAL_LOG_INFO("%s: allocating\n", __func__);
171
191
 
172
- id <MTLDevice> device;
192
+ id<MTLDevice> device;
173
193
  NSString * s;
174
194
 
175
195
  #if TARGET_OS_OSX
@@ -215,6 +235,9 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
215
235
 
216
236
  NSString * sourcePath;
217
237
  NSString * ggmlMetalPathResources = [[NSProcessInfo processInfo].environment objectForKey:@"GGML_METAL_PATH_RESOURCES"];
238
+
239
+ GGML_METAL_LOG_INFO("%s: GGML_METAL_PATH_RESOURCES = %s\n", __func__, ggmlMetalPathResources ? [ggmlMetalPathResources UTF8String] : "nil");
240
+
218
241
  if (ggmlMetalPathResources) {
219
242
  sourcePath = [ggmlMetalPathResources stringByAppendingPathComponent:@"ggml-metal.metal"];
220
243
  } else {
@@ -245,6 +268,29 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
245
268
  }
246
269
  }
247
270
 
271
+ #if TARGET_OS_OSX
272
+ // print MTL GPU family:
273
+ GGML_METAL_LOG_INFO("%s: GPU name: %s\n", __func__, [[ctx->device name] UTF8String]);
274
+
275
+ // determine max supported GPU family
276
+ // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
277
+ // https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
278
+ for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) {
279
+ if ([ctx->device supportsFamily:i]) {
280
+ GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - (int) MTLGPUFamilyApple1 + 1, i);
281
+ break;
282
+ }
283
+ }
284
+
285
+ GGML_METAL_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
286
+ GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1e6);
287
+ if (ctx->device.maxTransferRate != 0) {
288
+ GGML_METAL_LOG_INFO("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1e6);
289
+ } else {
290
+ GGML_METAL_LOG_INFO("%s: maxTransferRate = built-in GPU\n", __func__);
291
+ }
292
+ #endif
293
+
248
294
  // load kernels
249
295
  {
250
296
  NSError * error = nil;
@@ -266,6 +312,8 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
266
312
  GGML_METAL_ADD_KERNEL(add_row);
267
313
  GGML_METAL_ADD_KERNEL(mul);
268
314
  GGML_METAL_ADD_KERNEL(mul_row);
315
+ GGML_METAL_ADD_KERNEL(div);
316
+ GGML_METAL_ADD_KERNEL(div_row);
269
317
  GGML_METAL_ADD_KERNEL(scale);
270
318
  GGML_METAL_ADD_KERNEL(scale_4);
271
319
  GGML_METAL_ADD_KERNEL(silu);
@@ -317,43 +365,40 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
317
365
  GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32);
318
366
  GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32);
319
367
  GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32);
368
+ GGML_METAL_ADD_KERNEL(mul_mm_id_f32_f32);
369
+ GGML_METAL_ADD_KERNEL(mul_mm_id_f16_f32);
370
+ GGML_METAL_ADD_KERNEL(mul_mm_id_q4_0_f32);
371
+ GGML_METAL_ADD_KERNEL(mul_mm_id_q4_1_f32);
372
+ GGML_METAL_ADD_KERNEL(mul_mm_id_q5_0_f32);
373
+ GGML_METAL_ADD_KERNEL(mul_mm_id_q5_1_f32);
374
+ GGML_METAL_ADD_KERNEL(mul_mm_id_q8_0_f32);
375
+ GGML_METAL_ADD_KERNEL(mul_mm_id_q2_K_f32);
376
+ GGML_METAL_ADD_KERNEL(mul_mm_id_q3_K_f32);
377
+ GGML_METAL_ADD_KERNEL(mul_mm_id_q4_K_f32);
378
+ GGML_METAL_ADD_KERNEL(mul_mm_id_q5_K_f32);
379
+ GGML_METAL_ADD_KERNEL(mul_mm_id_q6_K_f32);
320
380
  }
321
381
  GGML_METAL_ADD_KERNEL(rope_f32);
322
382
  GGML_METAL_ADD_KERNEL(rope_f16);
323
383
  GGML_METAL_ADD_KERNEL(alibi_f32);
324
384
  GGML_METAL_ADD_KERNEL(im2col_f16);
385
+ GGML_METAL_ADD_KERNEL(argsort_f32_i32_asc);
386
+ GGML_METAL_ADD_KERNEL(argsort_f32_i32_desc);
325
387
  GGML_METAL_ADD_KERNEL(cpy_f32_f16);
326
388
  GGML_METAL_ADD_KERNEL(cpy_f32_f32);
389
+ GGML_METAL_ADD_KERNEL(cpy_f32_q8_0);
390
+ GGML_METAL_ADD_KERNEL(cpy_f32_q4_0);
391
+ GGML_METAL_ADD_KERNEL(cpy_f32_q4_1);
392
+ //GGML_METAL_ADD_KERNEL(cpy_f32_q5_0);
393
+ //GGML_METAL_ADD_KERNEL(cpy_f32_q5_1);
327
394
  GGML_METAL_ADD_KERNEL(cpy_f16_f16);
328
395
  GGML_METAL_ADD_KERNEL(concat);
329
396
  GGML_METAL_ADD_KERNEL(sqr);
397
+ GGML_METAL_ADD_KERNEL(sum_rows);
330
398
 
331
399
  #undef GGML_METAL_ADD_KERNEL
332
400
  }
333
401
 
334
- #if TARGET_OS_OSX
335
- // print MTL GPU family:
336
- GGML_METAL_LOG_INFO("%s: GPU name: %s\n", __func__, [[ctx->device name] UTF8String]);
337
-
338
- // determine max supported GPU family
339
- // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
340
- // https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
341
- for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) {
342
- if ([ctx->device supportsFamily:i]) {
343
- GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - (int) MTLGPUFamilyApple1 + 1, i);
344
- break;
345
- }
346
- }
347
-
348
- GGML_METAL_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
349
- GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MiB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
350
- if (ctx->device.maxTransferRate != 0) {
351
- GGML_METAL_LOG_INFO("%s: maxTransferRate = %8.2f MiB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
352
- } else {
353
- GGML_METAL_LOG_INFO("%s: maxTransferRate = built-in GPU\n", __func__);
354
- }
355
- #endif
356
-
357
402
  return ctx;
358
403
  }
359
404
 
@@ -367,6 +412,8 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
367
412
  GGML_METAL_DEL_KERNEL(add_row);
368
413
  GGML_METAL_DEL_KERNEL(mul);
369
414
  GGML_METAL_DEL_KERNEL(mul_row);
415
+ GGML_METAL_DEL_KERNEL(div);
416
+ GGML_METAL_DEL_KERNEL(div_row);
370
417
  GGML_METAL_DEL_KERNEL(scale);
371
418
  GGML_METAL_DEL_KERNEL(scale_4);
372
419
  GGML_METAL_DEL_KERNEL(silu);
@@ -418,16 +465,36 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
418
465
  GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32);
419
466
  GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32);
420
467
  GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32);
468
+ GGML_METAL_DEL_KERNEL(mul_mm_id_f32_f32);
469
+ GGML_METAL_DEL_KERNEL(mul_mm_id_f16_f32);
470
+ GGML_METAL_DEL_KERNEL(mul_mm_id_q4_0_f32);
471
+ GGML_METAL_DEL_KERNEL(mul_mm_id_q4_1_f32);
472
+ GGML_METAL_DEL_KERNEL(mul_mm_id_q5_0_f32);
473
+ GGML_METAL_DEL_KERNEL(mul_mm_id_q5_1_f32);
474
+ GGML_METAL_DEL_KERNEL(mul_mm_id_q8_0_f32);
475
+ GGML_METAL_DEL_KERNEL(mul_mm_id_q2_K_f32);
476
+ GGML_METAL_DEL_KERNEL(mul_mm_id_q3_K_f32);
477
+ GGML_METAL_DEL_KERNEL(mul_mm_id_q4_K_f32);
478
+ GGML_METAL_DEL_KERNEL(mul_mm_id_q5_K_f32);
479
+ GGML_METAL_DEL_KERNEL(mul_mm_id_q6_K_f32);
421
480
  }
422
481
  GGML_METAL_DEL_KERNEL(rope_f32);
423
482
  GGML_METAL_DEL_KERNEL(rope_f16);
424
483
  GGML_METAL_DEL_KERNEL(alibi_f32);
425
484
  GGML_METAL_DEL_KERNEL(im2col_f16);
485
+ GGML_METAL_DEL_KERNEL(argsort_f32_i32_asc);
486
+ GGML_METAL_DEL_KERNEL(argsort_f32_i32_desc);
426
487
  GGML_METAL_DEL_KERNEL(cpy_f32_f16);
427
488
  GGML_METAL_DEL_KERNEL(cpy_f32_f32);
489
+ GGML_METAL_DEL_KERNEL(cpy_f32_q8_0);
490
+ GGML_METAL_DEL_KERNEL(cpy_f32_q4_0);
491
+ GGML_METAL_DEL_KERNEL(cpy_f32_q4_1);
492
+ //GGML_METAL_DEL_KERNEL(cpy_f32_q5_0);
493
+ //GGML_METAL_DEL_KERNEL(cpy_f32_q5_1);
428
494
  GGML_METAL_DEL_KERNEL(cpy_f16_f16);
429
495
  GGML_METAL_DEL_KERNEL(concat);
430
496
  GGML_METAL_DEL_KERNEL(sqr);
497
+ GGML_METAL_DEL_KERNEL(sum_rows);
431
498
 
432
499
  #undef GGML_METAL_DEL_KERNEL
433
500
 
@@ -471,6 +538,13 @@ int * ggml_metal_get_concur_list(struct ggml_metal_context * ctx) {
471
538
  return ctx->concur_list;
472
539
  }
473
540
 
541
+ // temporarily defined here for compatibility between ggml-backend and the old API
542
+ struct ggml_backend_metal_buffer_context {
543
+ void * data;
544
+
545
+ id<MTLBuffer> metal;
546
+ };
547
+
474
548
  // finds the Metal buffer that contains the tensor data on the GPU device
475
549
  // the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the
476
550
  // Metal buffer based on the host memory pointer
@@ -480,8 +554,17 @@ static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, stru
480
554
 
481
555
  const int64_t tsize = ggml_nbytes(t);
482
556
 
483
- if (t->buffer && t->buffer->backend && t->buffer->backend->context) {
484
- ctx = t->buffer->backend->context;
557
+ // compatibility with ggml-backend
558
+ if (t->buffer && t->buffer->buft == ggml_backend_metal_buffer_type()) {
559
+ struct ggml_backend_metal_buffer_context * buf_ctx = (struct ggml_backend_metal_buffer_context *) t->buffer->context;
560
+
561
+ const int64_t ioffs = (int64_t) t->data - (int64_t) buf_ctx->data;
562
+
563
+ GGML_ASSERT(ioffs >= 0 && ioffs + tsize <= (int64_t) t->buffer->size);
564
+
565
+ *offs = (size_t) ioffs;
566
+
567
+ return buf_ctx->metal;
485
568
  }
486
569
 
487
570
  // find the view that contains the tensor fully
@@ -706,6 +789,51 @@ void ggml_metal_graph_find_concurrency(
706
789
  }
707
790
  }
708
791
 
792
+ static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
793
+ switch (op->op) {
794
+ case GGML_OP_UNARY:
795
+ switch (ggml_get_unary_op(op)) {
796
+ case GGML_UNARY_OP_SILU:
797
+ case GGML_UNARY_OP_RELU:
798
+ case GGML_UNARY_OP_GELU:
799
+ return true;
800
+ default:
801
+ return false;
802
+ }
803
+ case GGML_OP_NONE:
804
+ case GGML_OP_RESHAPE:
805
+ case GGML_OP_VIEW:
806
+ case GGML_OP_TRANSPOSE:
807
+ case GGML_OP_PERMUTE:
808
+ case GGML_OP_CONCAT:
809
+ case GGML_OP_ADD:
810
+ case GGML_OP_MUL:
811
+ case GGML_OP_DIV:
812
+ case GGML_OP_SCALE:
813
+ case GGML_OP_SQR:
814
+ case GGML_OP_SUM_ROWS:
815
+ case GGML_OP_SOFT_MAX:
816
+ case GGML_OP_RMS_NORM:
817
+ case GGML_OP_NORM:
818
+ case GGML_OP_ALIBI:
819
+ case GGML_OP_ROPE:
820
+ case GGML_OP_IM2COL:
821
+ case GGML_OP_ARGSORT:
822
+ case GGML_OP_DUP:
823
+ case GGML_OP_CPY:
824
+ case GGML_OP_CONT:
825
+ case GGML_OP_MUL_MAT:
826
+ case GGML_OP_MUL_MAT_ID:
827
+ return true;
828
+ case GGML_OP_DIAG_MASK_INF:
829
+ case GGML_OP_GET_ROWS:
830
+ {
831
+ return op->ne[0] % 4 == 0;
832
+ }
833
+ default:
834
+ return false;
835
+ }
836
+ }
709
837
  void ggml_metal_graph_compute(
710
838
  struct ggml_metal_context * ctx,
711
839
  struct ggml_cgraph * gf) {
@@ -776,6 +904,8 @@ void ggml_metal_graph_compute(
776
904
  } break;
777
905
  }
778
906
 
907
+ GGML_ASSERT(ggml_metal_supports_op(dst));
908
+
779
909
  const int64_t ne00 = src0 ? src0->ne[0] : 0;
780
910
  const int64_t ne01 = src0 ? src0->ne[1] : 0;
781
911
  const int64_t ne02 = src0 ? src0->ne[2] : 0;
@@ -868,6 +998,8 @@ void ggml_metal_graph_compute(
868
998
  [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
869
999
  } break;
870
1000
  case GGML_OP_ADD:
1001
+ case GGML_OP_MUL:
1002
+ case GGML_OP_DIV:
871
1003
  {
872
1004
  GGML_ASSERT(ggml_is_contiguous(src0));
873
1005
  GGML_ASSERT(ggml_is_contiguous(src1));
@@ -881,11 +1013,21 @@ void ggml_metal_graph_compute(
881
1013
  GGML_ASSERT(ne11 == 1);
882
1014
 
883
1015
  nb = ne00 / 4;
884
- [encoder setComputePipelineState:ctx->pipeline_add_row];
1016
+ switch (dst->op) {
1017
+ case GGML_OP_ADD: [encoder setComputePipelineState:ctx->pipeline_add_row]; break;
1018
+ case GGML_OP_MUL: [encoder setComputePipelineState:ctx->pipeline_mul_row]; break;
1019
+ case GGML_OP_DIV: [encoder setComputePipelineState:ctx->pipeline_div_row]; break;
1020
+ default: GGML_ASSERT(false);
1021
+ }
885
1022
 
886
1023
  bcast_row = true;
887
1024
  } else {
888
- [encoder setComputePipelineState:ctx->pipeline_add];
1025
+ switch (dst->op) {
1026
+ case GGML_OP_ADD: [encoder setComputePipelineState:ctx->pipeline_add]; break;
1027
+ case GGML_OP_MUL: [encoder setComputePipelineState:ctx->pipeline_mul]; break;
1028
+ case GGML_OP_DIV: [encoder setComputePipelineState:ctx->pipeline_div]; break;
1029
+ default: GGML_ASSERT(false);
1030
+ }
889
1031
  }
890
1032
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
891
1033
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
@@ -926,31 +1068,6 @@ void ggml_metal_graph_compute(
926
1068
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
927
1069
  }
928
1070
  } break;
929
- case GGML_OP_MUL:
930
- {
931
- GGML_ASSERT(ggml_is_contiguous(src0));
932
- GGML_ASSERT(ggml_is_contiguous(src1));
933
-
934
- // utilize float4
935
- GGML_ASSERT(ne00 % 4 == 0);
936
- const int64_t nb = ne00/4;
937
-
938
- if (ggml_nelements(src1) == ne10) {
939
- // src1 is a row
940
- GGML_ASSERT(ne11 == 1);
941
- [encoder setComputePipelineState:ctx->pipeline_mul_row];
942
- } else {
943
- [encoder setComputePipelineState:ctx->pipeline_mul];
944
- }
945
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
946
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
947
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
948
- [encoder setBytes:&nb length:sizeof(nb) atIndex:3];
949
-
950
- const int64_t n = ggml_nelements(dst)/4;
951
-
952
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
953
- } break;
954
1071
  case GGML_OP_SCALE:
955
1072
  {
956
1073
  GGML_ASSERT(ggml_is_contiguous(src0));
@@ -1023,25 +1140,66 @@ void ggml_metal_graph_compute(
1023
1140
  const int64_t n = ggml_nelements(dst);
1024
1141
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1025
1142
  } break;
1143
+ case GGML_OP_SUM_ROWS:
1144
+ {
1145
+ GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
1146
+
1147
+ [encoder setComputePipelineState:ctx->pipeline_sum_rows];
1148
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1149
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1150
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
1151
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
1152
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
1153
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
1154
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
1155
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
1156
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
1157
+ [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
1158
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
1159
+ [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
1160
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
1161
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
1162
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
1163
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
1164
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
1165
+ [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17];
1166
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:18];
1167
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:19];
1168
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:20];
1169
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:21];
1170
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:22];
1171
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:23];
1172
+ [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:24];
1173
+ [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:25];
1174
+
1175
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1176
+ } break;
1026
1177
  case GGML_OP_SOFT_MAX:
1027
1178
  {
1028
1179
  int nth = 32; // SIMD width
1029
1180
 
1030
1181
  if (ne00%4 == 0) {
1182
+ while (nth < ne00/4 && nth < 256) {
1183
+ nth *= 2;
1184
+ }
1031
1185
  [encoder setComputePipelineState:ctx->pipeline_soft_max_4];
1032
1186
  } else {
1033
- do {
1187
+ while (nth < ne00 && nth < 1024) {
1034
1188
  nth *= 2;
1035
- } while (nth <= ne00 && nth <= 1024);
1036
- nth /= 2;
1189
+ }
1037
1190
  [encoder setComputePipelineState:ctx->pipeline_soft_max];
1038
1191
  }
1039
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1040
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1041
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
1042
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
1043
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
1044
- [encoder setThreadgroupMemoryLength:GGML_PAD(nth/32*sizeof(float), 16) atIndex:0];
1192
+
1193
+ const float scale = ((float *) dst->op_params)[0];
1194
+
1195
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1196
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1197
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1198
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1199
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
1200
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
1201
+ [encoder setBytes:&scale length:sizeof(scale) atIndex:6];
1202
+ [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
1045
1203
 
1046
1204
  [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1047
1205
  } break;
@@ -1070,9 +1228,13 @@ void ggml_metal_graph_compute(
1070
1228
  case GGML_OP_MUL_MAT:
1071
1229
  {
1072
1230
  GGML_ASSERT(ne00 == ne10);
1073
- GGML_ASSERT(ne03 == ne13);
1074
1231
 
1075
- const uint gqa = ne12/ne02;
1232
+ // TODO: assert that dim2 and dim3 are contiguous
1233
+ GGML_ASSERT(ne12 % ne02 == 0);
1234
+ GGML_ASSERT(ne13 % ne03 == 0);
1235
+
1236
+ const uint r2 = ne12/ne02;
1237
+ const uint r3 = ne13/ne03;
1076
1238
 
1077
1239
  // find the break-even point where the matrix-matrix kernel becomes more efficient compared
1078
1240
  // to the matrix-vector kernel
@@ -1107,7 +1269,7 @@ void ggml_metal_graph_compute(
1107
1269
  !ggml_is_transposed(src1) &&
1108
1270
  src1t == GGML_TYPE_F32 &&
1109
1271
  ne00 % 32 == 0 && ne00 >= 64 &&
1110
- ne11 > ne11_mm_min) {
1272
+ (ne11 > ne11_mm_min || (ggml_is_quantized(src0t) && ne12 > 1))) {
1111
1273
  //printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
1112
1274
  switch (src0->type) {
1113
1275
  case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f32_f32]; break;
@@ -1137,9 +1299,10 @@ void ggml_metal_graph_compute(
1137
1299
  [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10];
1138
1300
  [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11];
1139
1301
  [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12];
1140
- [encoder setBytes:&gqa length:sizeof(gqa) atIndex:13];
1302
+ [encoder setBytes:&r2 length:sizeof(r2) atIndex:13];
1303
+ [encoder setBytes:&r3 length:sizeof(r3) atIndex:14];
1141
1304
  [encoder setThreadgroupMemoryLength:8192 atIndex:0];
1142
- [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
1305
+ [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
1143
1306
  } else {
1144
1307
  int nth0 = 32;
1145
1308
  int nth1 = 1;
@@ -1175,90 +1338,60 @@ void ggml_metal_graph_compute(
1175
1338
  } break;
1176
1339
  case GGML_TYPE_Q4_0:
1177
1340
  {
1178
- GGML_ASSERT(ne02 == 1);
1179
- GGML_ASSERT(ne12 == 1);
1180
-
1181
1341
  nth0 = 8;
1182
1342
  nth1 = 8;
1183
1343
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_0_f32];
1184
1344
  } break;
1185
1345
  case GGML_TYPE_Q4_1:
1186
1346
  {
1187
- GGML_ASSERT(ne02 == 1);
1188
- GGML_ASSERT(ne12 == 1);
1189
-
1190
1347
  nth0 = 8;
1191
1348
  nth1 = 8;
1192
1349
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_1_f32];
1193
1350
  } break;
1194
1351
  case GGML_TYPE_Q5_0:
1195
1352
  {
1196
- GGML_ASSERT(ne02 == 1);
1197
- GGML_ASSERT(ne12 == 1);
1198
-
1199
1353
  nth0 = 8;
1200
1354
  nth1 = 8;
1201
1355
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_0_f32];
1202
1356
  } break;
1203
1357
  case GGML_TYPE_Q5_1:
1204
1358
  {
1205
- GGML_ASSERT(ne02 == 1);
1206
- GGML_ASSERT(ne12 == 1);
1207
-
1208
1359
  nth0 = 8;
1209
1360
  nth1 = 8;
1210
1361
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_1_f32];
1211
1362
  } break;
1212
1363
  case GGML_TYPE_Q8_0:
1213
1364
  {
1214
- GGML_ASSERT(ne02 == 1);
1215
- GGML_ASSERT(ne12 == 1);
1216
-
1217
1365
  nth0 = 8;
1218
1366
  nth1 = 8;
1219
1367
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_q8_0_f32];
1220
1368
  } break;
1221
1369
  case GGML_TYPE_Q2_K:
1222
1370
  {
1223
- GGML_ASSERT(ne02 == 1);
1224
- GGML_ASSERT(ne12 == 1);
1225
-
1226
1371
  nth0 = 2;
1227
1372
  nth1 = 32;
1228
1373
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_q2_K_f32];
1229
1374
  } break;
1230
1375
  case GGML_TYPE_Q3_K:
1231
1376
  {
1232
- GGML_ASSERT(ne02 == 1);
1233
- GGML_ASSERT(ne12 == 1);
1234
-
1235
1377
  nth0 = 2;
1236
1378
  nth1 = 32;
1237
1379
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_q3_K_f32];
1238
1380
  } break;
1239
1381
  case GGML_TYPE_Q4_K:
1240
1382
  {
1241
- GGML_ASSERT(ne02 == 1);
1242
- GGML_ASSERT(ne12 == 1);
1243
-
1244
1383
  nth0 = 4; //1;
1245
1384
  nth1 = 8; //32;
1246
1385
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_K_f32];
1247
1386
  } break;
1248
1387
  case GGML_TYPE_Q5_K:
1249
1388
  {
1250
- GGML_ASSERT(ne02 == 1);
1251
- GGML_ASSERT(ne12 == 1);
1252
-
1253
1389
  nth0 = 2;
1254
1390
  nth1 = 32;
1255
1391
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_K_f32];
1256
1392
  } break;
1257
1393
  case GGML_TYPE_Q6_K:
1258
1394
  {
1259
- GGML_ASSERT(ne02 == 1);
1260
- GGML_ASSERT(ne12 == 1);
1261
-
1262
1395
  nth0 = 2;
1263
1396
  nth1 = 32;
1264
1397
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_q6_K_f32];
@@ -1287,32 +1420,125 @@ void ggml_metal_graph_compute(
1287
1420
  [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14];
1288
1421
  [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15];
1289
1422
  [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
1290
- [encoder setBytes:&gqa length:sizeof(gqa) atIndex:17];
1423
+ [encoder setBytes:&r2 length:sizeof(r2) atIndex:17];
1424
+ [encoder setBytes:&r3 length:sizeof(r3) atIndex:18];
1291
1425
 
1292
1426
  if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
1293
1427
  src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 ||
1294
1428
  src0t == GGML_TYPE_Q2_K) { // || src0t == GGML_TYPE_Q4_K) {
1295
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1429
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1296
1430
  }
1297
1431
  else if (src0t == GGML_TYPE_Q4_K) {
1298
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1432
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1299
1433
  }
1300
1434
  else if (src0t == GGML_TYPE_Q3_K) {
1301
1435
  #ifdef GGML_QKK_64
1302
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1436
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1303
1437
  #else
1304
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1438
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1305
1439
  #endif
1306
1440
  }
1307
1441
  else if (src0t == GGML_TYPE_Q5_K) {
1308
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1442
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1309
1443
  }
1310
1444
  else if (src0t == GGML_TYPE_Q6_K) {
1311
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1445
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1312
1446
  } else {
1313
1447
  int64_t ny = (ne11 + nrows - 1)/nrows;
1314
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1448
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1449
+ }
1450
+ }
1451
+ } break;
1452
+ case GGML_OP_MUL_MAT_ID:
1453
+ {
1454
+ //GGML_ASSERT(ne00 == ne10);
1455
+ //GGML_ASSERT(ne03 == ne13);
1456
+
1457
+ GGML_ASSERT(src0t == GGML_TYPE_I32);
1458
+
1459
+ const int n_as = ne00;
1460
+
1461
+ // TODO: make this more general
1462
+ GGML_ASSERT(n_as <= 8);
1463
+
1464
+ struct ggml_tensor * src2 = gf->nodes[i]->src[2];
1465
+
1466
+ const int64_t ne20 = src2 ? src2->ne[0] : 0;
1467
+ const int64_t ne21 = src2 ? src2->ne[1] : 0;
1468
+ const int64_t ne22 = src2 ? src2->ne[2] : 0;
1469
+ const int64_t ne23 = src2 ? src2->ne[3] : 0; GGML_UNUSED(ne23);
1470
+
1471
+ const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20);
1472
+ const uint64_t nb21 = src2 ? src2->nb[1] : 0;
1473
+ const uint64_t nb22 = src2 ? src2->nb[2] : 0;
1474
+ const uint64_t nb23 = src2 ? src2->nb[3] : 0; GGML_UNUSED(nb23);
1475
+
1476
+ const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t);
1477
+
1478
+ GGML_ASSERT(!ggml_is_transposed(src2));
1479
+ GGML_ASSERT(!ggml_is_transposed(src1));
1480
+
1481
+ GGML_ASSERT(ne20 % 32 == 0);
1482
+ // !!!!!!!!! TODO: this assert is probably required but not sure!
1483
+ //GGML_ASSERT(ne20 >= 64);
1484
+ GGML_ASSERT(src1t == GGML_TYPE_F32);
1485
+
1486
+ const uint r2 = ne12/ne22;
1487
+ const uint r3 = ne13/ne23;
1488
+
1489
+ // find the break-even point where the matrix-matrix kernel becomes more efficient compared
1490
+ // to the matrix-vector kernel
1491
+ int ne11_mm_min = 0;
1492
+
1493
+ const int idx = ((int32_t *) dst->op_params)[0];
1494
+
1495
+ // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
1496
+ // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
1497
+ if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
1498
+ ne11 > ne11_mm_min) {
1499
+ switch (src2->type) {
1500
+ case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f32_f32]; break;
1501
+ case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f16_f32]; break;
1502
+ case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q4_0_f32]; break;
1503
+ case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q4_1_f32]; break;
1504
+ case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_0_f32]; break;
1505
+ case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_1_f32]; break;
1506
+ case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q8_0_f32]; break;
1507
+ case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q2_K_f32]; break;
1508
+ case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q3_K_f32]; break;
1509
+ case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q4_K_f32]; break;
1510
+ case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_K_f32]; break;
1511
+ case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q6_K_f32]; break;
1512
+ default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
1315
1513
  }
1514
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1515
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1516
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1517
+ [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:3];
1518
+ [encoder setBytes:&ne22 length:sizeof(ne22) atIndex:4];
1519
+ [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:5];
1520
+ [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:6];
1521
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
1522
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:8];
1523
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:9];
1524
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10];
1525
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11];
1526
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12];
1527
+ [encoder setBytes:&r2 length:sizeof(r2) atIndex:13];
1528
+ [encoder setBytes:&r3 length:sizeof(r3) atIndex:14];
1529
+ [encoder setBytes:&idx length:sizeof(idx) atIndex:15];
1530
+ // TODO: how to make this an array? read Metal docs
1531
+ for (int j = 0; j < n_as; ++j) {
1532
+ struct ggml_tensor * src_cur = dst->src[2 + j];
1533
+
1534
+ size_t offs_src_cur = 0;
1535
+ id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
1536
+
1537
+ [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:16 + j];
1538
+ }
1539
+
1540
+ [encoder setThreadgroupMemoryLength:8192 atIndex:0];
1541
+ [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne21 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
1316
1542
  }
1317
1543
  } break;
1318
1544
  case GGML_OP_GET_ROWS:
@@ -1351,15 +1577,19 @@ void ggml_metal_graph_compute(
1351
1577
  float eps;
1352
1578
  memcpy(&eps, dst->op_params, sizeof(float));
1353
1579
 
1354
- const int nth = MIN(512, ne00);
1580
+ int nth = 32; // SIMD width
1581
+
1582
+ while (nth < ne00/4 && nth < 1024) {
1583
+ nth *= 2;
1584
+ }
1355
1585
 
1356
1586
  [encoder setComputePipelineState:ctx->pipeline_rms_norm];
1357
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1358
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1359
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1360
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
1361
- [encoder setBytes:&eps length:sizeof( float) atIndex:4];
1362
- [encoder setThreadgroupMemoryLength:GGML_PAD(nth/32*sizeof(float), 16) atIndex:0];
1587
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1588
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1589
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1590
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
1591
+ [encoder setBytes:&eps length:sizeof( float) atIndex:4];
1592
+ [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
1363
1593
 
1364
1594
  const int64_t nrows = ggml_nrows(src0);
1365
1595
 
@@ -1433,7 +1663,8 @@ void ggml_metal_graph_compute(
1433
1663
  const int n_past = ((int32_t *) dst->op_params)[0];
1434
1664
  const int n_dims = ((int32_t *) dst->op_params)[1];
1435
1665
  const int mode = ((int32_t *) dst->op_params)[2];
1436
- const int n_orig_ctx = ((int32_t *) dst->op_params)[3];
1666
+ // skip 3, n_ctx, used in GLM RoPE, unimplemented in metal
1667
+ const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
1437
1668
 
1438
1669
  float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
1439
1670
  memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
@@ -1533,18 +1764,48 @@ void ggml_metal_graph_compute(
1533
1764
 
1534
1765
  [encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
1535
1766
  } break;
1767
+ case GGML_OP_ARGSORT:
1768
+ {
1769
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
1770
+ GGML_ASSERT( dst->type == GGML_TYPE_I32);
1771
+
1772
+ const int nrows = ggml_nrows(src0);
1773
+
1774
+ enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
1775
+
1776
+ switch (order) {
1777
+ case GGML_SORT_ASC: [encoder setComputePipelineState:ctx->pipeline_argsort_f32_i32_asc]; break;
1778
+ case GGML_SORT_DESC: [encoder setComputePipelineState:ctx->pipeline_argsort_f32_i32_desc]; break;
1779
+ default: GGML_ASSERT(false);
1780
+ };
1781
+
1782
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1783
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1784
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1785
+
1786
+ [encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00, 1, 1)];
1787
+ } break;
1536
1788
  case GGML_OP_DUP:
1537
1789
  case GGML_OP_CPY:
1538
1790
  case GGML_OP_CONT:
1539
1791
  {
1540
- const int nth = MIN(1024, ne00);
1792
+ GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
1793
+
1794
+ int nth = MIN(1024, ne00/ggml_blck_size(src0->type));
1541
1795
 
1542
1796
  switch (src0t) {
1543
1797
  case GGML_TYPE_F32:
1544
1798
  {
1799
+ GGML_ASSERT(ne0 % ggml_blck_size(dst->type) == 0);
1800
+
1545
1801
  switch (dstt) {
1546
- case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f16]; break;
1547
- case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32]; break;
1802
+ case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f16]; break;
1803
+ case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32]; break;
1804
+ case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q8_0]; break;
1805
+ case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q4_0]; break;
1806
+ case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q4_1]; break;
1807
+ //case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q5_0]; break;
1808
+ //case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q5_1]; break;
1548
1809
  default: GGML_ASSERT(false && "not implemented");
1549
1810
  };
1550
1811
  } break;
@@ -1619,81 +1880,150 @@ void ggml_metal_graph_compute(
1619
1880
 
1620
1881
  // backend interface
1621
1882
 
1622
- static const char * ggml_backend_metal_name(ggml_backend_t backend) {
1623
- return "Metal";
1883
+ static id<MTLDevice> g_backend_device = nil;
1884
+ static int g_backend_device_ref_count = 0;
1624
1885
 
1625
- UNUSED(backend);
1886
+ static id<MTLDevice> ggml_backend_metal_get_device(void) {
1887
+ if (g_backend_device == nil) {
1888
+ g_backend_device = MTLCreateSystemDefaultDevice();
1889
+ }
1890
+
1891
+ g_backend_device_ref_count++;
1892
+
1893
+ return g_backend_device;
1626
1894
  }
1627
1895
 
1628
- static void ggml_backend_metal_free(ggml_backend_t backend) {
1629
- struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
1630
- ggml_metal_free(ctx);
1631
- free(backend);
1896
+ static void ggml_backend_metal_free_device(void) {
1897
+ assert(g_backend_device_ref_count > 0);
1898
+
1899
+ g_backend_device_ref_count--;
1900
+
1901
+ if (g_backend_device_ref_count == 0) {
1902
+ [g_backend_device release];
1903
+ g_backend_device = nil;
1904
+ }
1632
1905
  }
1633
1906
 
1634
1907
  static void * ggml_backend_metal_buffer_get_base(ggml_backend_buffer_t buffer) {
1635
- return (void *)buffer->context;
1908
+ struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context;
1909
+
1910
+ return ctx->data;
1636
1911
  }
1637
1912
 
1638
1913
  static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer) {
1639
- free(buffer->context);
1914
+ struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context;
1915
+
1916
+ [ctx->metal release];
1917
+ ggml_backend_metal_free_device();
1918
+
1919
+ free(ctx->data);
1920
+ free(ctx);
1921
+
1922
+ UNUSED(buffer);
1923
+ }
1924
+
1925
+ static void ggml_backend_metal_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
1926
+ GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
1927
+ GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
1928
+
1929
+ memcpy((char *)tensor->data + offset, data, size);
1930
+
1931
+ UNUSED(buffer);
1932
+ }
1933
+
1934
+ static void ggml_backend_metal_buffer_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
1935
+ GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds");
1936
+ GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
1937
+
1938
+ memcpy(data, (const char *)tensor->data + offset, size);
1939
+
1940
+ UNUSED(buffer);
1941
+ }
1942
+
1943
+ static void ggml_backend_metal_buffer_cpy_tensor_from(ggml_backend_buffer_t buffer, struct ggml_tensor * src, struct ggml_tensor * dst) {
1944
+ ggml_backend_tensor_get(src, dst->data, 0, ggml_nbytes(src));
1945
+
1946
+ UNUSED(buffer);
1947
+ }
1948
+
1949
+ static void ggml_backend_metal_buffer_cpy_tensor_to(ggml_backend_buffer_t buffer, struct ggml_tensor * src, struct ggml_tensor * dst) {
1950
+ ggml_backend_tensor_set(dst, src->data, 0, ggml_nbytes(src));
1951
+
1640
1952
  UNUSED(buffer);
1641
1953
  }
1642
1954
 
1643
1955
  static struct ggml_backend_buffer_i metal_backend_buffer_i = {
1644
- /* .free_buffer = */ ggml_backend_metal_buffer_free_buffer,
1645
- /* .get_base = */ ggml_backend_metal_buffer_get_base,
1646
- /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
1647
- /* .init_tensor = */ NULL, // no initialization required
1648
- /* .free_tensor = */ NULL, // no cleanup required
1956
+ /* .free_buffer = */ ggml_backend_metal_buffer_free_buffer,
1957
+ /* .get_base = */ ggml_backend_metal_buffer_get_base,
1958
+ /* .init_tensor = */ NULL,
1959
+ /* .set_tensor = */ ggml_backend_metal_buffer_set_tensor,
1960
+ /* .get_tensor = */ ggml_backend_metal_buffer_get_tensor,
1961
+ /* .cpy_tensor_from = */ ggml_backend_metal_buffer_cpy_tensor_from,
1962
+ /* .cpy_tensor_to = */ ggml_backend_metal_buffer_cpy_tensor_to,
1649
1963
  };
1650
1964
 
1651
- static ggml_backend_buffer_t ggml_backend_metal_alloc_buffer(ggml_backend_t backend, size_t size) {
1652
- struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
1965
+ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
1966
+ struct ggml_backend_metal_buffer_context * ctx = malloc(sizeof(struct ggml_backend_metal_buffer_context));
1653
1967
 
1654
- void * data = ggml_metal_host_malloc(size);
1968
+ const size_t size_page = sysconf(_SC_PAGESIZE);
1655
1969
 
1656
- // TODO: set proper name of the buffers
1657
- ggml_metal_add_buffer(ctx, "backend", data, size, 0);
1970
+ size_t size_aligned = size;
1971
+ if ((size_aligned % size_page) != 0) {
1972
+ size_aligned += (size_page - (size_aligned % size_page));
1973
+ }
1974
+
1975
+ ctx->data = ggml_metal_host_malloc(size);
1976
+ ctx->metal = [ggml_backend_metal_get_device() newBufferWithBytesNoCopy:ctx->data
1977
+ length:size_aligned
1978
+ options:MTLResourceStorageModeShared
1979
+ deallocator:nil];
1658
1980
 
1659
- return ggml_backend_buffer_init(backend, metal_backend_buffer_i, data, size);
1981
+ return ggml_backend_buffer_init(buft, metal_backend_buffer_i, ctx, size);
1660
1982
  }
1661
1983
 
1662
- static size_t ggml_backend_metal_get_alignment(ggml_backend_t backend) {
1984
+ static size_t ggml_backend_metal_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
1663
1985
  return 32;
1664
- UNUSED(backend);
1986
+ UNUSED(buft);
1665
1987
  }
1666
1988
 
1667
- static void ggml_backend_metal_set_tensor_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
1668
- GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
1669
- GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
1670
-
1671
- memcpy((char *)tensor->data + offset, data, size);
1989
+ static bool ggml_backend_metal_buffer_type_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend) {
1990
+ return ggml_backend_is_metal(backend) || ggml_backend_is_cpu(backend);
1672
1991
 
1673
- UNUSED(backend);
1992
+ GGML_UNUSED(buft);
1674
1993
  }
1675
1994
 
1676
- static void ggml_backend_metal_get_tensor_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
1677
- GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds");
1678
- GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
1679
-
1680
- memcpy(data, (const char *)tensor->data + offset, size);
1995
+ ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void) {
1996
+ static struct ggml_backend_buffer_type ggml_backend_buffer_type_metal = {
1997
+ /* .iface = */ {
1998
+ /* .alloc_buffer = */ ggml_backend_metal_buffer_type_alloc_buffer,
1999
+ /* .get_alignment = */ ggml_backend_metal_buffer_type_get_alignment,
2000
+ /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
2001
+ /* .supports_backend = */ ggml_backend_metal_buffer_type_supports_backend,
2002
+ },
2003
+ /* .context = */ NULL,
2004
+ };
1681
2005
 
1682
- UNUSED(backend);
2006
+ return &ggml_backend_buffer_type_metal;
1683
2007
  }
1684
2008
 
1685
- static void ggml_backend_metal_synchronize(ggml_backend_t backend) {
2009
+ static const char * ggml_backend_metal_name(ggml_backend_t backend) {
2010
+ return "Metal";
2011
+
1686
2012
  UNUSED(backend);
1687
2013
  }
1688
2014
 
1689
- static void ggml_backend_metal_cpy_tensor_from(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst) {
1690
- ggml_backend_tensor_get(src, dst->data, 0, ggml_nbytes(src));
2015
+ static void ggml_backend_metal_free(ggml_backend_t backend) {
2016
+ struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
2017
+ ggml_metal_free(ctx);
2018
+ free(backend);
2019
+ }
1691
2020
 
2021
+ static void ggml_backend_metal_synchronize(ggml_backend_t backend) {
1692
2022
  UNUSED(backend);
1693
2023
  }
1694
2024
 
1695
- static void ggml_backend_metal_cpy_tensor_to(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst) {
1696
- ggml_backend_tensor_set_async(dst, src->data, 0, ggml_nbytes(src));
2025
+ static ggml_backend_buffer_type_t ggml_backend_metal_get_default_buffer_type(ggml_backend_t backend) {
2026
+ return ggml_backend_metal_buffer_type();
1697
2027
 
1698
2028
  UNUSED(backend);
1699
2029
  }
@@ -1705,32 +2035,43 @@ static void ggml_backend_metal_graph_compute(ggml_backend_t backend, struct ggml
1705
2035
  }
1706
2036
 
1707
2037
  static bool ggml_backend_metal_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
1708
- return true;
2038
+ return ggml_metal_supports_op(op);
2039
+
1709
2040
  UNUSED(backend);
1710
- UNUSED(op);
1711
2041
  }
1712
2042
 
1713
2043
  static struct ggml_backend_i metal_backend_i = {
1714
- /* .get_name = */ ggml_backend_metal_name,
1715
- /* .free = */ ggml_backend_metal_free,
1716
- /* .alloc_buffer = */ ggml_backend_metal_alloc_buffer,
1717
- /* .get_alignment = */ ggml_backend_metal_get_alignment,
1718
- /* .set_tensor_async = */ ggml_backend_metal_set_tensor_async,
1719
- /* .get_tensor_async = */ ggml_backend_metal_get_tensor_async,
1720
- /* .synchronize = */ ggml_backend_metal_synchronize,
1721
- /* .cpy_tensor_from = */ ggml_backend_metal_cpy_tensor_from,
1722
- /* .cpy_tensor_to = */ ggml_backend_metal_cpy_tensor_to,
1723
- /* .graph_plan_create = */ NULL, // the metal implementation does not require creating graph plans atm
1724
- /* .graph_plan_free = */ NULL,
1725
- /* .graph_plan_compute = */ NULL,
1726
- /* .graph_compute = */ ggml_backend_metal_graph_compute,
1727
- /* .supports_op = */ ggml_backend_metal_supports_op,
2044
+ /* .get_name = */ ggml_backend_metal_name,
2045
+ /* .free = */ ggml_backend_metal_free,
2046
+ /* .get_default_buffer_type = */ ggml_backend_metal_get_default_buffer_type,
2047
+ /* .set_tensor_async = */ NULL,
2048
+ /* .get_tensor_async = */ NULL,
2049
+ /* .cpy_tensor_from_async = */ NULL,
2050
+ /* .cpy_tensor_to_async = */ NULL,
2051
+ /* .synchronize = */ ggml_backend_metal_synchronize,
2052
+ /* .graph_plan_create = */ NULL, // the metal implementation does not require creating graph plans atm
2053
+ /* .graph_plan_free = */ NULL,
2054
+ /* .graph_plan_compute = */ NULL,
2055
+ /* .graph_compute = */ ggml_backend_metal_graph_compute,
2056
+ /* .supports_op = */ ggml_backend_metal_supports_op,
1728
2057
  };
1729
2058
 
2059
+ // TODO: make a common log callback for all backends in ggml-backend
2060
+ static void ggml_backend_log_callback(enum ggml_log_level level, const char * msg, void * user_data) {
2061
+ fprintf(stderr, "%s", msg);
2062
+
2063
+ UNUSED(level);
2064
+ UNUSED(user_data);
2065
+ }
2066
+
1730
2067
  ggml_backend_t ggml_backend_metal_init(void) {
1731
- struct ggml_metal_context * ctx = malloc(sizeof(struct ggml_metal_context));
2068
+ ggml_metal_log_set_callback(ggml_backend_log_callback, NULL);
2069
+
2070
+ struct ggml_metal_context * ctx = ggml_metal_init(GGML_DEFAULT_N_THREADS);
1732
2071
 
1733
- ctx = ggml_metal_init(GGML_DEFAULT_N_THREADS);
2072
+ if (ctx == NULL) {
2073
+ return NULL;
2074
+ }
1734
2075
 
1735
2076
  ggml_backend_t metal_backend = malloc(sizeof(struct ggml_backend));
1736
2077
 
@@ -1747,7 +2088,26 @@ bool ggml_backend_is_metal(ggml_backend_t backend) {
1747
2088
  }
1748
2089
 
1749
2090
  void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
2091
+ GGML_ASSERT(ggml_backend_is_metal(backend));
2092
+
1750
2093
  struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
1751
2094
 
1752
2095
  ggml_metal_set_n_cb(ctx, n_cb);
1753
2096
  }
2097
+
2098
+ bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family) {
2099
+ GGML_ASSERT(ggml_backend_is_metal(backend));
2100
+
2101
+ struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
2102
+
2103
+ return [ctx->device supportsFamily:(MTLGPUFamilyApple1 + family - 1)];
2104
+ }
2105
+
2106
+ ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data); // silence warning
2107
+
2108
+ ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data) {
2109
+ return ggml_backend_metal_init();
2110
+
2111
+ GGML_UNUSED(params);
2112
+ GGML_UNUSED(user_data);
2113
+ }