llama_cpp 0.10.0 → 0.10.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -66,9 +66,11 @@ struct ggml_metal_context {
66
66
  GGML_METAL_DECL_KERNEL(div_row);
67
67
  GGML_METAL_DECL_KERNEL(scale);
68
68
  GGML_METAL_DECL_KERNEL(scale_4);
69
- GGML_METAL_DECL_KERNEL(silu);
69
+ GGML_METAL_DECL_KERNEL(tanh);
70
70
  GGML_METAL_DECL_KERNEL(relu);
71
71
  GGML_METAL_DECL_KERNEL(gelu);
72
+ GGML_METAL_DECL_KERNEL(gelu_quick);
73
+ GGML_METAL_DECL_KERNEL(silu);
72
74
  GGML_METAL_DECL_KERNEL(soft_max);
73
75
  GGML_METAL_DECL_KERNEL(soft_max_4);
74
76
  GGML_METAL_DECL_KERNEL(diag_mask_inf);
@@ -86,6 +88,7 @@ struct ggml_metal_context {
86
88
  GGML_METAL_DECL_KERNEL(get_rows_q5_K);
87
89
  GGML_METAL_DECL_KERNEL(get_rows_q6_K);
88
90
  GGML_METAL_DECL_KERNEL(rms_norm);
91
+ GGML_METAL_DECL_KERNEL(group_norm);
89
92
  GGML_METAL_DECL_KERNEL(norm);
90
93
  GGML_METAL_DECL_KERNEL(mul_mv_f32_f32);
91
94
  GGML_METAL_DECL_KERNEL(mul_mv_f16_f16);
@@ -102,6 +105,21 @@ struct ggml_metal_context {
102
105
  GGML_METAL_DECL_KERNEL(mul_mv_q4_K_f32);
103
106
  GGML_METAL_DECL_KERNEL(mul_mv_q5_K_f32);
104
107
  GGML_METAL_DECL_KERNEL(mul_mv_q6_K_f32);
108
+ GGML_METAL_DECL_KERNEL(mul_mv_id_f32_f32);
109
+ //GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f16);
110
+ GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32);
111
+ //GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32_1row);
112
+ //GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32_l4);
113
+ GGML_METAL_DECL_KERNEL(mul_mv_id_q4_0_f32);
114
+ GGML_METAL_DECL_KERNEL(mul_mv_id_q4_1_f32);
115
+ GGML_METAL_DECL_KERNEL(mul_mv_id_q5_0_f32);
116
+ GGML_METAL_DECL_KERNEL(mul_mv_id_q5_1_f32);
117
+ GGML_METAL_DECL_KERNEL(mul_mv_id_q8_0_f32);
118
+ GGML_METAL_DECL_KERNEL(mul_mv_id_q2_K_f32);
119
+ GGML_METAL_DECL_KERNEL(mul_mv_id_q3_K_f32);
120
+ GGML_METAL_DECL_KERNEL(mul_mv_id_q4_K_f32);
121
+ GGML_METAL_DECL_KERNEL(mul_mv_id_q5_K_f32);
122
+ GGML_METAL_DECL_KERNEL(mul_mv_id_q6_K_f32);
105
123
  GGML_METAL_DECL_KERNEL(mul_mm_f32_f32);
106
124
  GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
107
125
  GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
@@ -130,8 +148,11 @@ struct ggml_metal_context {
130
148
  GGML_METAL_DECL_KERNEL(rope_f16);
131
149
  GGML_METAL_DECL_KERNEL(alibi_f32);
132
150
  GGML_METAL_DECL_KERNEL(im2col_f16);
151
+ GGML_METAL_DECL_KERNEL(upscale_f32);
152
+ GGML_METAL_DECL_KERNEL(pad_f32);
133
153
  GGML_METAL_DECL_KERNEL(argsort_f32_i32_asc);
134
154
  GGML_METAL_DECL_KERNEL(argsort_f32_i32_desc);
155
+ GGML_METAL_DECL_KERNEL(leaky_relu_f32);
135
156
  GGML_METAL_DECL_KERNEL(cpy_f32_f16);
136
157
  GGML_METAL_DECL_KERNEL(cpy_f32_f32);
137
158
  GGML_METAL_DECL_KERNEL(cpy_f32_q8_0);
@@ -140,6 +161,7 @@ struct ggml_metal_context {
140
161
  //GGML_METAL_DECL_KERNEL(cpy_f32_q5_0);
141
162
  //GGML_METAL_DECL_KERNEL(cpy_f32_q5_1);
142
163
  GGML_METAL_DECL_KERNEL(cpy_f16_f16);
164
+ GGML_METAL_DECL_KERNEL(cpy_f16_f32);
143
165
  GGML_METAL_DECL_KERNEL(concat);
144
166
  GGML_METAL_DECL_KERNEL(sqr);
145
167
  GGML_METAL_DECL_KERNEL(sum_rows);
@@ -158,7 +180,15 @@ struct ggml_metal_context {
158
180
  @implementation GGMLMetalClass
159
181
  @end
160
182
 
161
- ggml_log_callback ggml_metal_log_callback = NULL;
183
+
184
+ static void ggml_metal_default_log_callback(enum ggml_log_level level, const char * msg, void * user_data) {
185
+ fprintf(stderr, "%s", msg);
186
+
187
+ UNUSED(level);
188
+ UNUSED(user_data);
189
+ }
190
+
191
+ ggml_log_callback ggml_metal_log_callback = ggml_metal_default_log_callback;
162
192
  void * ggml_metal_log_user_data = NULL;
163
193
 
164
194
  void ggml_metal_log_set_callback(ggml_log_callback log_callback, void * user_data) {
@@ -177,6 +207,8 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
177
207
  ggml_metal_log_callback(level, buffer, ggml_metal_log_user_data);
178
208
  } else {
179
209
  char* buffer2 = malloc(len+1);
210
+ va_end(args);
211
+ va_start(args, format);
180
212
  vsnprintf(buffer2, len+1, format, args);
181
213
  buffer2[len] = 0;
182
214
  ggml_metal_log_callback(level, buffer2, ggml_metal_log_user_data);
@@ -316,9 +348,11 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
316
348
  GGML_METAL_ADD_KERNEL(div_row);
317
349
  GGML_METAL_ADD_KERNEL(scale);
318
350
  GGML_METAL_ADD_KERNEL(scale_4);
319
- GGML_METAL_ADD_KERNEL(silu);
351
+ GGML_METAL_ADD_KERNEL(tanh);
320
352
  GGML_METAL_ADD_KERNEL(relu);
321
353
  GGML_METAL_ADD_KERNEL(gelu);
354
+ GGML_METAL_ADD_KERNEL(gelu_quick);
355
+ GGML_METAL_ADD_KERNEL(silu);
322
356
  GGML_METAL_ADD_KERNEL(soft_max);
323
357
  GGML_METAL_ADD_KERNEL(soft_max_4);
324
358
  GGML_METAL_ADD_KERNEL(diag_mask_inf);
@@ -336,6 +370,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
336
370
  GGML_METAL_ADD_KERNEL(get_rows_q5_K);
337
371
  GGML_METAL_ADD_KERNEL(get_rows_q6_K);
338
372
  GGML_METAL_ADD_KERNEL(rms_norm);
373
+ GGML_METAL_ADD_KERNEL(group_norm);
339
374
  GGML_METAL_ADD_KERNEL(norm);
340
375
  GGML_METAL_ADD_KERNEL(mul_mv_f32_f32);
341
376
  GGML_METAL_ADD_KERNEL(mul_mv_f16_f16);
@@ -352,6 +387,21 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
352
387
  GGML_METAL_ADD_KERNEL(mul_mv_q4_K_f32);
353
388
  GGML_METAL_ADD_KERNEL(mul_mv_q5_K_f32);
354
389
  GGML_METAL_ADD_KERNEL(mul_mv_q6_K_f32);
390
+ GGML_METAL_ADD_KERNEL(mul_mv_id_f32_f32);
391
+ //GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f16);
392
+ GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32);
393
+ //GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32_1row);
394
+ //GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32_l4);
395
+ GGML_METAL_ADD_KERNEL(mul_mv_id_q4_0_f32);
396
+ GGML_METAL_ADD_KERNEL(mul_mv_id_q4_1_f32);
397
+ GGML_METAL_ADD_KERNEL(mul_mv_id_q5_0_f32);
398
+ GGML_METAL_ADD_KERNEL(mul_mv_id_q5_1_f32);
399
+ GGML_METAL_ADD_KERNEL(mul_mv_id_q8_0_f32);
400
+ GGML_METAL_ADD_KERNEL(mul_mv_id_q2_K_f32);
401
+ GGML_METAL_ADD_KERNEL(mul_mv_id_q3_K_f32);
402
+ GGML_METAL_ADD_KERNEL(mul_mv_id_q4_K_f32);
403
+ GGML_METAL_ADD_KERNEL(mul_mv_id_q5_K_f32);
404
+ GGML_METAL_ADD_KERNEL(mul_mv_id_q6_K_f32);
355
405
  if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
356
406
  GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
357
407
  GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
@@ -382,8 +432,11 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
382
432
  GGML_METAL_ADD_KERNEL(rope_f16);
383
433
  GGML_METAL_ADD_KERNEL(alibi_f32);
384
434
  GGML_METAL_ADD_KERNEL(im2col_f16);
435
+ GGML_METAL_ADD_KERNEL(upscale_f32);
436
+ GGML_METAL_ADD_KERNEL(pad_f32);
385
437
  GGML_METAL_ADD_KERNEL(argsort_f32_i32_asc);
386
438
  GGML_METAL_ADD_KERNEL(argsort_f32_i32_desc);
439
+ GGML_METAL_ADD_KERNEL(leaky_relu_f32);
387
440
  GGML_METAL_ADD_KERNEL(cpy_f32_f16);
388
441
  GGML_METAL_ADD_KERNEL(cpy_f32_f32);
389
442
  GGML_METAL_ADD_KERNEL(cpy_f32_q8_0);
@@ -392,6 +445,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
392
445
  //GGML_METAL_ADD_KERNEL(cpy_f32_q5_0);
393
446
  //GGML_METAL_ADD_KERNEL(cpy_f32_q5_1);
394
447
  GGML_METAL_ADD_KERNEL(cpy_f16_f16);
448
+ GGML_METAL_ADD_KERNEL(cpy_f16_f32);
395
449
  GGML_METAL_ADD_KERNEL(concat);
396
450
  GGML_METAL_ADD_KERNEL(sqr);
397
451
  GGML_METAL_ADD_KERNEL(sum_rows);
@@ -416,9 +470,11 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
416
470
  GGML_METAL_DEL_KERNEL(div_row);
417
471
  GGML_METAL_DEL_KERNEL(scale);
418
472
  GGML_METAL_DEL_KERNEL(scale_4);
419
- GGML_METAL_DEL_KERNEL(silu);
473
+ GGML_METAL_DEL_KERNEL(tanh);
420
474
  GGML_METAL_DEL_KERNEL(relu);
421
475
  GGML_METAL_DEL_KERNEL(gelu);
476
+ GGML_METAL_DEL_KERNEL(gelu_quick);
477
+ GGML_METAL_DEL_KERNEL(silu);
422
478
  GGML_METAL_DEL_KERNEL(soft_max);
423
479
  GGML_METAL_DEL_KERNEL(soft_max_4);
424
480
  GGML_METAL_DEL_KERNEL(diag_mask_inf);
@@ -436,6 +492,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
436
492
  GGML_METAL_DEL_KERNEL(get_rows_q5_K);
437
493
  GGML_METAL_DEL_KERNEL(get_rows_q6_K);
438
494
  GGML_METAL_DEL_KERNEL(rms_norm);
495
+ GGML_METAL_DEL_KERNEL(group_norm);
439
496
  GGML_METAL_DEL_KERNEL(norm);
440
497
  GGML_METAL_DEL_KERNEL(mul_mv_f32_f32);
441
498
  GGML_METAL_DEL_KERNEL(mul_mv_f16_f16);
@@ -452,6 +509,21 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
452
509
  GGML_METAL_DEL_KERNEL(mul_mv_q4_K_f32);
453
510
  GGML_METAL_DEL_KERNEL(mul_mv_q5_K_f32);
454
511
  GGML_METAL_DEL_KERNEL(mul_mv_q6_K_f32);
512
+ GGML_METAL_DEL_KERNEL(mul_mv_id_f32_f32);
513
+ //GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f16);
514
+ GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32);
515
+ //GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32_1row);
516
+ //GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32_l4);
517
+ GGML_METAL_DEL_KERNEL(mul_mv_id_q4_0_f32);
518
+ GGML_METAL_DEL_KERNEL(mul_mv_id_q4_1_f32);
519
+ GGML_METAL_DEL_KERNEL(mul_mv_id_q5_0_f32);
520
+ GGML_METAL_DEL_KERNEL(mul_mv_id_q5_1_f32);
521
+ GGML_METAL_DEL_KERNEL(mul_mv_id_q8_0_f32);
522
+ GGML_METAL_DEL_KERNEL(mul_mv_id_q2_K_f32);
523
+ GGML_METAL_DEL_KERNEL(mul_mv_id_q3_K_f32);
524
+ GGML_METAL_DEL_KERNEL(mul_mv_id_q4_K_f32);
525
+ GGML_METAL_DEL_KERNEL(mul_mv_id_q5_K_f32);
526
+ GGML_METAL_DEL_KERNEL(mul_mv_id_q6_K_f32);
455
527
  if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
456
528
  GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
457
529
  GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
@@ -482,8 +554,11 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
482
554
  GGML_METAL_DEL_KERNEL(rope_f16);
483
555
  GGML_METAL_DEL_KERNEL(alibi_f32);
484
556
  GGML_METAL_DEL_KERNEL(im2col_f16);
557
+ GGML_METAL_DEL_KERNEL(upscale_f32);
558
+ GGML_METAL_DEL_KERNEL(pad_f32);
485
559
  GGML_METAL_DEL_KERNEL(argsort_f32_i32_asc);
486
560
  GGML_METAL_DEL_KERNEL(argsort_f32_i32_desc);
561
+ GGML_METAL_DEL_KERNEL(leaky_relu_f32);
487
562
  GGML_METAL_DEL_KERNEL(cpy_f32_f16);
488
563
  GGML_METAL_DEL_KERNEL(cpy_f32_f32);
489
564
  GGML_METAL_DEL_KERNEL(cpy_f32_q8_0);
@@ -492,6 +567,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
492
567
  //GGML_METAL_DEL_KERNEL(cpy_f32_q5_0);
493
568
  //GGML_METAL_DEL_KERNEL(cpy_f32_q5_1);
494
569
  GGML_METAL_DEL_KERNEL(cpy_f16_f16);
570
+ GGML_METAL_DEL_KERNEL(cpy_f16_f32);
495
571
  GGML_METAL_DEL_KERNEL(concat);
496
572
  GGML_METAL_DEL_KERNEL(sqr);
497
573
  GGML_METAL_DEL_KERNEL(sum_rows);
@@ -539,12 +615,24 @@ int * ggml_metal_get_concur_list(struct ggml_metal_context * ctx) {
539
615
  }
540
616
 
541
617
  // temporarily defined here for compatibility between ggml-backend and the old API
542
- struct ggml_backend_metal_buffer_context {
543
- void * data;
618
+
619
+ struct ggml_backend_metal_buffer {
620
+ void * data;
621
+ size_t size;
544
622
 
545
623
  id<MTLBuffer> metal;
546
624
  };
547
625
 
626
+ struct ggml_backend_metal_buffer_context {
627
+ void * all_data;
628
+ size_t all_size;
629
+ bool owned;
630
+
631
+ // multiple buffers are used only to avoid the maximum buffer size limitation when using mmap
632
+ int n_buffers;
633
+ struct ggml_backend_metal_buffer buffers[GGML_METAL_MAX_BUFFERS];
634
+ };
635
+
548
636
  // finds the Metal buffer that contains the tensor data on the GPU device
549
637
  // the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the
550
638
  // Metal buffer based on the host memory pointer
@@ -554,17 +642,29 @@ static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, stru
554
642
 
555
643
  const int64_t tsize = ggml_nbytes(t);
556
644
 
645
+ ggml_backend_buffer_t buffer = t->view_src ? t->view_src->buffer : t->buffer;
646
+
557
647
  // compatibility with ggml-backend
558
- if (t->buffer && t->buffer->buft == ggml_backend_metal_buffer_type()) {
559
- struct ggml_backend_metal_buffer_context * buf_ctx = (struct ggml_backend_metal_buffer_context *) t->buffer->context;
648
+ if (buffer && buffer->buft == ggml_backend_metal_buffer_type()) {
649
+ struct ggml_backend_metal_buffer_context * buf_ctx = (struct ggml_backend_metal_buffer_context *) buffer->context;
560
650
 
561
- const int64_t ioffs = (int64_t) t->data - (int64_t) buf_ctx->data;
651
+ // find the view that contains the tensor fully
652
+ for (int i = 0; i < buf_ctx->n_buffers; ++i) {
653
+ const int64_t ioffs = (int64_t) t->data - (int64_t) buf_ctx->buffers[i].data;
562
654
 
563
- GGML_ASSERT(ioffs >= 0 && ioffs + tsize <= (int64_t) t->buffer->size);
655
+ //GGML_METAL_LOG_INFO("ioffs = %10ld, tsize = %10ld, sum = %10ld, buf_ctx->buffers[%d].size = %10ld\n", ioffs, tsize, ioffs + tsize, i, buf_ctx->buffers[i].size);
656
+ if (ioffs >= 0 && ioffs + tsize <= (int64_t) buf_ctx->buffers[i].size) {
657
+ *offs = (size_t) ioffs;
564
658
 
565
- *offs = (size_t) ioffs;
659
+ //GGML_METAL_LOG_INFO("%s: tensor '%16s', offs = %8ld\n", __func__, t->name, *offs);
566
660
 
567
- return buf_ctx->metal;
661
+ return buf_ctx->buffers[i].metal;
662
+ }
663
+ }
664
+
665
+ GGML_METAL_LOG_ERROR("%s: error: tensor '%s' buffer is nil\n", __func__, t->name);
666
+
667
+ return nil;
568
668
  }
569
669
 
570
670
  // find the view that contains the tensor fully
@@ -793,9 +893,11 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
793
893
  switch (op->op) {
794
894
  case GGML_OP_UNARY:
795
895
  switch (ggml_get_unary_op(op)) {
796
- case GGML_UNARY_OP_SILU:
896
+ case GGML_UNARY_OP_TANH:
797
897
  case GGML_UNARY_OP_RELU:
798
898
  case GGML_UNARY_OP_GELU:
899
+ case GGML_UNARY_OP_GELU_QUICK:
900
+ case GGML_UNARY_OP_SILU:
799
901
  return true;
800
902
  default:
801
903
  return false;
@@ -807,6 +909,7 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
807
909
  case GGML_OP_PERMUTE:
808
910
  case GGML_OP_CONCAT:
809
911
  case GGML_OP_ADD:
912
+ case GGML_OP_ACC:
810
913
  case GGML_OP_MUL:
811
914
  case GGML_OP_DIV:
812
915
  case GGML_OP_SCALE:
@@ -814,21 +917,50 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
814
917
  case GGML_OP_SUM_ROWS:
815
918
  case GGML_OP_SOFT_MAX:
816
919
  case GGML_OP_RMS_NORM:
920
+ case GGML_OP_GROUP_NORM:
817
921
  case GGML_OP_NORM:
818
922
  case GGML_OP_ALIBI:
819
923
  case GGML_OP_ROPE:
820
924
  case GGML_OP_IM2COL:
925
+ case GGML_OP_UPSCALE:
926
+ case GGML_OP_PAD:
821
927
  case GGML_OP_ARGSORT:
822
- case GGML_OP_DUP:
823
- case GGML_OP_CPY:
824
- case GGML_OP_CONT:
928
+ case GGML_OP_LEAKY_RELU:
825
929
  case GGML_OP_MUL_MAT:
826
930
  case GGML_OP_MUL_MAT_ID:
827
931
  return true;
932
+ case GGML_OP_CPY:
933
+ case GGML_OP_DUP:
934
+ case GGML_OP_CONT:
935
+ {
936
+ switch (op->src[0]->type) {
937
+ case GGML_TYPE_F32:
938
+ switch (op->type) {
939
+ case GGML_TYPE_F16:
940
+ case GGML_TYPE_F32:
941
+ case GGML_TYPE_Q8_0:
942
+ case GGML_TYPE_Q4_0:
943
+ case GGML_TYPE_Q4_1:
944
+ return true;
945
+ default:
946
+ return false;
947
+ }
948
+ case GGML_TYPE_F16:
949
+ switch (op->type) {
950
+ case GGML_TYPE_F16:
951
+ case GGML_TYPE_F32:
952
+ return true;
953
+ default:
954
+ return false;
955
+ }
956
+ default:
957
+ return false;
958
+ };
959
+ }
828
960
  case GGML_OP_DIAG_MASK_INF:
829
961
  case GGML_OP_GET_ROWS:
830
962
  {
831
- return op->ne[0] % 4 == 0;
963
+ return op->ne[3] == 1;
832
964
  }
833
965
  default:
834
966
  return false;
@@ -904,7 +1036,10 @@ void ggml_metal_graph_compute(
904
1036
  } break;
905
1037
  }
906
1038
 
907
- GGML_ASSERT(ggml_metal_supports_op(dst));
1039
+ if (!ggml_metal_supports_op(dst)) {
1040
+ GGML_METAL_LOG_ERROR("%s: error: unsupported op '%s'\n", __func__, ggml_op_desc(dst));
1041
+ GGML_ASSERT(!"unsupported op");
1042
+ }
908
1043
 
909
1044
  const int64_t ne00 = src0 ? src0->ne[0] : 0;
910
1045
  const int64_t ne01 = src0 ? src0->ne[1] : 0;
@@ -1001,34 +1136,39 @@ void ggml_metal_graph_compute(
1001
1136
  case GGML_OP_MUL:
1002
1137
  case GGML_OP_DIV:
1003
1138
  {
1004
- GGML_ASSERT(ggml_is_contiguous(src0));
1005
- GGML_ASSERT(ggml_is_contiguous(src1));
1139
+ const size_t offs = 0;
1006
1140
 
1007
1141
  bool bcast_row = false;
1008
1142
 
1009
1143
  int64_t nb = ne00;
1010
1144
 
1011
- if (ggml_nelements(src1) == ne10 && ne00 % 4 == 0) {
1145
+ id<MTLComputePipelineState> pipeline = nil;
1146
+
1147
+ if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
1148
+ GGML_ASSERT(ggml_is_contiguous(src0));
1149
+
1012
1150
  // src1 is a row
1013
1151
  GGML_ASSERT(ne11 == 1);
1014
1152
 
1015
1153
  nb = ne00 / 4;
1016
1154
  switch (dst->op) {
1017
- case GGML_OP_ADD: [encoder setComputePipelineState:ctx->pipeline_add_row]; break;
1018
- case GGML_OP_MUL: [encoder setComputePipelineState:ctx->pipeline_mul_row]; break;
1019
- case GGML_OP_DIV: [encoder setComputePipelineState:ctx->pipeline_div_row]; break;
1155
+ case GGML_OP_ADD: pipeline = ctx->pipeline_add_row; break;
1156
+ case GGML_OP_MUL: pipeline = ctx->pipeline_mul_row; break;
1157
+ case GGML_OP_DIV: pipeline = ctx->pipeline_div_row; break;
1020
1158
  default: GGML_ASSERT(false);
1021
1159
  }
1022
1160
 
1023
1161
  bcast_row = true;
1024
1162
  } else {
1025
1163
  switch (dst->op) {
1026
- case GGML_OP_ADD: [encoder setComputePipelineState:ctx->pipeline_add]; break;
1027
- case GGML_OP_MUL: [encoder setComputePipelineState:ctx->pipeline_mul]; break;
1028
- case GGML_OP_DIV: [encoder setComputePipelineState:ctx->pipeline_div]; break;
1164
+ case GGML_OP_ADD: pipeline = ctx->pipeline_add; break;
1165
+ case GGML_OP_MUL: pipeline = ctx->pipeline_mul; break;
1166
+ case GGML_OP_DIV: pipeline = ctx->pipeline_div; break;
1029
1167
  default: GGML_ASSERT(false);
1030
1168
  }
1031
1169
  }
1170
+
1171
+ [encoder setComputePipelineState:pipeline];
1032
1172
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1033
1173
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1034
1174
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
@@ -1056,23 +1196,104 @@ void ggml_metal_graph_compute(
1056
1196
  [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
1057
1197
  [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
1058
1198
  [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
1059
- [encoder setBytes:&nb length:sizeof(nb) atIndex:27];
1199
+ [encoder setBytes:&offs length:sizeof(offs) atIndex:27];
1200
+ [encoder setBytes:&nb length:sizeof(nb) atIndex:28];
1060
1201
 
1061
1202
  if (bcast_row) {
1062
1203
  const int64_t n = ggml_nelements(dst)/4;
1063
1204
 
1064
1205
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1065
1206
  } else {
1066
- const int nth = MIN(1024, ne0);
1207
+ const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
1067
1208
 
1068
1209
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1069
1210
  }
1070
1211
  } break;
1212
+ case GGML_OP_ACC:
1213
+ {
1214
+ GGML_ASSERT(src0t == GGML_TYPE_F32);
1215
+ GGML_ASSERT(src1t == GGML_TYPE_F32);
1216
+ GGML_ASSERT(dstt == GGML_TYPE_F32);
1217
+
1218
+ GGML_ASSERT(ggml_is_contiguous(src0));
1219
+ GGML_ASSERT(ggml_is_contiguous(src1));
1220
+
1221
+ const size_t pnb1 = ((int32_t *) dst->op_params)[0];
1222
+ const size_t pnb2 = ((int32_t *) dst->op_params)[1];
1223
+ const size_t pnb3 = ((int32_t *) dst->op_params)[2];
1224
+ const size_t offs = ((int32_t *) dst->op_params)[3];
1225
+
1226
+ const bool inplace = (bool) ((int32_t *) dst->op_params)[4];
1227
+
1228
+ if (!inplace) {
1229
+ // run a separete kernel to cpy src->dst
1230
+ // not sure how to avoid this
1231
+ // TODO: make a simpler cpy_bytes kernel
1232
+
1233
+ const int nth = MIN(1024, ne00);
1234
+
1235
+ [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32];
1236
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1237
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1238
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1239
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
1240
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
1241
+ [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
1242
+ [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
1243
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
1244
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
1245
+ [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
1246
+ [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
1247
+ [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
1248
+ [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
1249
+ [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
1250
+ [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
1251
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
1252
+ [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
1253
+ [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
1254
+
1255
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1256
+ }
1257
+
1258
+ [encoder setComputePipelineState:ctx->pipeline_add];
1259
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1260
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1261
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1262
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1263
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
1264
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
1265
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
1266
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
1267
+ [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:8];
1268
+ [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:9];
1269
+ [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:10];
1270
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
1271
+ [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
1272
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
1273
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
1274
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
1275
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
1276
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
1277
+ [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
1278
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
1279
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
1280
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
1281
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
1282
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
1283
+ [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:24];
1284
+ [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:25];
1285
+ [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:26];
1286
+ [encoder setBytes:&offs length:sizeof(offs) atIndex:27];
1287
+
1288
+ const int nth = MIN(1024, ne0);
1289
+
1290
+ [encoder dispatchThreadgroups:MTLSizeMake(ne11, ne12, ne13) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1291
+ } break;
1071
1292
  case GGML_OP_SCALE:
1072
1293
  {
1073
1294
  GGML_ASSERT(ggml_is_contiguous(src0));
1074
1295
 
1075
- const float scale = *(const float *) src1->data;
1296
+ const float scale = *(const float *) dst->op_params;
1076
1297
 
1077
1298
  int64_t n = ggml_nelements(dst);
1078
1299
 
@@ -1083,24 +1304,23 @@ void ggml_metal_graph_compute(
1083
1304
  [encoder setComputePipelineState:ctx->pipeline_scale];
1084
1305
  }
1085
1306
 
1086
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1087
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1307
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1308
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1088
1309
  [encoder setBytes:&scale length:sizeof(scale) atIndex:2];
1089
1310
 
1090
1311
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1091
1312
  } break;
1092
1313
  case GGML_OP_UNARY:
1093
1314
  switch (ggml_get_unary_op(gf->nodes[i])) {
1094
- case GGML_UNARY_OP_SILU:
1315
+ case GGML_UNARY_OP_TANH:
1095
1316
  {
1096
- [encoder setComputePipelineState:ctx->pipeline_silu];
1317
+ [encoder setComputePipelineState:ctx->pipeline_tanh];
1097
1318
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1098
1319
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1099
1320
 
1100
1321
  const int64_t n = ggml_nelements(dst);
1101
- GGML_ASSERT(n % 4 == 0);
1102
1322
 
1103
- [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1323
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1104
1324
  } break;
1105
1325
  case GGML_UNARY_OP_RELU:
1106
1326
  {
@@ -1121,6 +1341,28 @@ void ggml_metal_graph_compute(
1121
1341
  const int64_t n = ggml_nelements(dst);
1122
1342
  GGML_ASSERT(n % 4 == 0);
1123
1343
 
1344
+ [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1345
+ } break;
1346
+ case GGML_UNARY_OP_GELU_QUICK:
1347
+ {
1348
+ [encoder setComputePipelineState:ctx->pipeline_gelu_quick];
1349
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1350
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1351
+
1352
+ const int64_t n = ggml_nelements(dst);
1353
+ GGML_ASSERT(n % 4 == 0);
1354
+
1355
+ [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1356
+ } break;
1357
+ case GGML_UNARY_OP_SILU:
1358
+ {
1359
+ [encoder setComputePipelineState:ctx->pipeline_silu];
1360
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1361
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1362
+
1363
+ const int64_t n = ggml_nelements(dst);
1364
+ GGML_ASSERT(n % 4 == 0);
1365
+
1124
1366
  [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1125
1367
  } break;
1126
1368
  default:
@@ -1193,7 +1435,11 @@ void ggml_metal_graph_compute(
1193
1435
  const float scale = ((float *) dst->op_params)[0];
1194
1436
 
1195
1437
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1196
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1438
+ if (id_src1) {
1439
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1440
+ } else {
1441
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
1442
+ }
1197
1443
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1198
1444
  [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1199
1445
  [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
@@ -1444,7 +1690,7 @@ void ggml_metal_graph_compute(
1444
1690
  else if (src0t == GGML_TYPE_Q6_K) {
1445
1691
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1446
1692
  } else {
1447
- int64_t ny = (ne11 + nrows - 1)/nrows;
1693
+ const int64_t ny = (ne11 + nrows - 1)/nrows;
1448
1694
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1449
1695
  }
1450
1696
  }
@@ -1456,7 +1702,7 @@ void ggml_metal_graph_compute(
1456
1702
 
1457
1703
  GGML_ASSERT(src0t == GGML_TYPE_I32);
1458
1704
 
1459
- const int n_as = ne00;
1705
+ const int n_as = ((int32_t *) dst->op_params)[1];
1460
1706
 
1461
1707
  // TODO: make this more general
1462
1708
  GGML_ASSERT(n_as <= 8);
@@ -1488,14 +1734,22 @@ void ggml_metal_graph_compute(
1488
1734
 
1489
1735
  // find the break-even point where the matrix-matrix kernel becomes more efficient compared
1490
1736
  // to the matrix-vector kernel
1491
- int ne11_mm_min = 0;
1737
+ int ne11_mm_min = 1;
1492
1738
 
1493
1739
  const int idx = ((int32_t *) dst->op_params)[0];
1494
1740
 
1741
+ // batch size
1742
+ GGML_ASSERT(ne01 == ne11);
1743
+
1744
+ const int64_t _ne1 = 1; // kernel_mul_mm_impl needs a reference in constant memory
1745
+
1495
1746
  // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
1496
1747
  // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
1497
- if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
1498
- ne11 > ne11_mm_min) {
1748
+ // !!!
1749
+ // TODO: for now, always use mat-vec kernels until we figure out how to improve the
1750
+ // indirect matrix multiplication
1751
+ // !!!
1752
+ if ([ctx->device supportsFamily:MTLGPUFamilyApple7] && _ne1 > ne11_mm_min) {
1499
1753
  switch (src2->type) {
1500
1754
  case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f32_f32]; break;
1501
1755
  case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f16_f32]; break;
@@ -1514,19 +1768,22 @@ void ggml_metal_graph_compute(
1514
1768
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1515
1769
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1516
1770
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1517
- [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:3];
1518
- [encoder setBytes:&ne22 length:sizeof(ne22) atIndex:4];
1519
- [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:5];
1520
- [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:6];
1521
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
1522
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:8];
1523
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:9];
1524
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10];
1525
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11];
1526
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12];
1527
- [encoder setBytes:&r2 length:sizeof(r2) atIndex:13];
1528
- [encoder setBytes:&r3 length:sizeof(r3) atIndex:14];
1529
- [encoder setBytes:&idx length:sizeof(idx) atIndex:15];
1771
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3];
1772
+ [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
1773
+ [encoder setBytes:&ne22 length:sizeof(ne22) atIndex:5];
1774
+ [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6];
1775
+ [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:7];
1776
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:8];
1777
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:9];
1778
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:10];
1779
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:11];
1780
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:12];
1781
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
1782
+ [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:14];
1783
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
1784
+ [encoder setBytes:&r2 length:sizeof(r2) atIndex:16];
1785
+ [encoder setBytes:&r3 length:sizeof(r3) atIndex:17];
1786
+ [encoder setBytes:&idx length:sizeof(idx) atIndex:18];
1530
1787
  // TODO: how to make this an array? read Metal docs
1531
1788
  for (int j = 0; j < n_as; ++j) {
1532
1789
  struct ggml_tensor * src_cur = dst->src[2 + j];
@@ -1534,11 +1791,157 @@ void ggml_metal_graph_compute(
1534
1791
  size_t offs_src_cur = 0;
1535
1792
  id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
1536
1793
 
1537
- [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:16 + j];
1794
+ [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:19 + j];
1538
1795
  }
1539
1796
 
1540
1797
  [encoder setThreadgroupMemoryLength:8192 atIndex:0];
1541
- [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne21 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
1798
+
1799
+ // TODO: processing one row at a time (ne11 -> 1) is not efficient
1800
+ [encoder dispatchThreadgroups:MTLSizeMake( (_ne1 + 31)/32, (ne21 + 63)/64, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
1801
+ } else {
1802
+ int nth0 = 32;
1803
+ int nth1 = 1;
1804
+ int nrows = 1;
1805
+ //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
1806
+
1807
+ // use custom matrix x vector kernel
1808
+ switch (src2t) {
1809
+ case GGML_TYPE_F32:
1810
+ {
1811
+ GGML_ASSERT(src1t == GGML_TYPE_F32);
1812
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_f32_f32];
1813
+ } break;
1814
+ case GGML_TYPE_F16:
1815
+ {
1816
+ GGML_ASSERT(src1t == GGML_TYPE_F32);
1817
+ nth0 = 32;
1818
+ nth1 = 1;
1819
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_f16_f32];
1820
+ } break;
1821
+ case GGML_TYPE_Q4_0:
1822
+ {
1823
+ nth0 = 8;
1824
+ nth1 = 8;
1825
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_0_f32];
1826
+ } break;
1827
+ case GGML_TYPE_Q4_1:
1828
+ {
1829
+ nth0 = 8;
1830
+ nth1 = 8;
1831
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_1_f32];
1832
+ } break;
1833
+ case GGML_TYPE_Q5_0:
1834
+ {
1835
+ nth0 = 8;
1836
+ nth1 = 8;
1837
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_0_f32];
1838
+ } break;
1839
+ case GGML_TYPE_Q5_1:
1840
+ {
1841
+ nth0 = 8;
1842
+ nth1 = 8;
1843
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_1_f32];
1844
+ } break;
1845
+ case GGML_TYPE_Q8_0:
1846
+ {
1847
+ nth0 = 8;
1848
+ nth1 = 8;
1849
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q8_0_f32];
1850
+ } break;
1851
+ case GGML_TYPE_Q2_K:
1852
+ {
1853
+ nth0 = 2;
1854
+ nth1 = 32;
1855
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q2_K_f32];
1856
+ } break;
1857
+ case GGML_TYPE_Q3_K:
1858
+ {
1859
+ nth0 = 2;
1860
+ nth1 = 32;
1861
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q3_K_f32];
1862
+ } break;
1863
+ case GGML_TYPE_Q4_K:
1864
+ {
1865
+ nth0 = 4; //1;
1866
+ nth1 = 8; //32;
1867
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_K_f32];
1868
+ } break;
1869
+ case GGML_TYPE_Q5_K:
1870
+ {
1871
+ nth0 = 2;
1872
+ nth1 = 32;
1873
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_K_f32];
1874
+ } break;
1875
+ case GGML_TYPE_Q6_K:
1876
+ {
1877
+ nth0 = 2;
1878
+ nth1 = 32;
1879
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q6_K_f32];
1880
+ } break;
1881
+ default:
1882
+ {
1883
+ GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
1884
+ GGML_ASSERT(false && "not implemented");
1885
+ }
1886
+ };
1887
+
1888
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1889
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1890
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1891
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3];
1892
+ [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
1893
+ [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
1894
+ [encoder setBytes:&ne22 length:sizeof(ne22) atIndex:6];
1895
+ [encoder setBytes:&nb20 length:sizeof(nb20) atIndex:7];
1896
+ [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:8];
1897
+ [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:9];
1898
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
1899
+ [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:11];
1900
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
1901
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
1902
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
1903
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
1904
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
1905
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17];
1906
+ [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:18];
1907
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:19];
1908
+ [encoder setBytes:&r2 length:sizeof(r2) atIndex:20];
1909
+ [encoder setBytes:&r3 length:sizeof(r3) atIndex:21];
1910
+ [encoder setBytes:&idx length:sizeof(idx) atIndex:22];
1911
+ // TODO: how to make this an array? read Metal docs
1912
+ for (int j = 0; j < n_as; ++j) {
1913
+ struct ggml_tensor * src_cur = dst->src[2 + j];
1914
+
1915
+ size_t offs_src_cur = 0;
1916
+ id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
1917
+
1918
+ [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:23 + j];
1919
+ }
1920
+
1921
+ if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1 ||
1922
+ src2t == GGML_TYPE_Q5_0 || src2t == GGML_TYPE_Q5_1 || src2t == GGML_TYPE_Q8_0 ||
1923
+ src2t == GGML_TYPE_Q2_K) { // || src2t == GGML_TYPE_Q4_K) {
1924
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1925
+ }
1926
+ else if (src2t == GGML_TYPE_Q4_K) {
1927
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1928
+ }
1929
+ else if (src2t == GGML_TYPE_Q3_K) {
1930
+ #ifdef GGML_QKK_64
1931
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1932
+ #else
1933
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1934
+ #endif
1935
+ }
1936
+ else if (src2t == GGML_TYPE_Q5_K) {
1937
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1938
+ }
1939
+ else if (src2t == GGML_TYPE_Q6_K) {
1940
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1941
+ } else {
1942
+ const int64_t ny = (_ne1 + nrows - 1)/nrows;
1943
+ [encoder dispatchThreadgroups:MTLSizeMake(ne21, ny, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1944
+ }
1542
1945
  }
1543
1946
  } break;
1544
1947
  case GGML_OP_GET_ROWS:
@@ -1559,16 +1962,19 @@ void ggml_metal_graph_compute(
1559
1962
  default: GGML_ASSERT(false && "not implemented");
1560
1963
  }
1561
1964
 
1562
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1563
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1564
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1965
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1966
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1967
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1565
1968
  [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
1566
1969
  [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
1567
- [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:5];
1568
-
1569
- const int64_t n = ggml_nelements(src1);
1570
-
1571
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1970
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:5];
1971
+ [encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:6];
1972
+ [encoder setBytes:&nb10 length:sizeof( int64_t) atIndex:7];
1973
+ [encoder setBytes:&nb11 length:sizeof( int64_t) atIndex:8];
1974
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:9];
1975
+ [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:10];
1976
+
1977
+ [encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
1572
1978
  } break;
1573
1979
  case GGML_OP_RMS_NORM:
1574
1980
  {
@@ -1595,6 +2001,38 @@ void ggml_metal_graph_compute(
1595
2001
 
1596
2002
  [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1597
2003
  } break;
2004
+ case GGML_OP_GROUP_NORM:
2005
+ {
2006
+ GGML_ASSERT(ne00 % 4 == 0);
2007
+
2008
+ //float eps;
2009
+ //memcpy(&eps, dst->op_params, sizeof(float));
2010
+
2011
+ const float eps = 1e-6f; // TODO: temporarily hardcoded
2012
+
2013
+ const int32_t n_groups = ((int32_t *) dst->op_params)[0];
2014
+
2015
+ int nth = 32; // SIMD width
2016
+
2017
+ //while (nth < ne00/4 && nth < 1024) {
2018
+ // nth *= 2;
2019
+ //}
2020
+
2021
+ [encoder setComputePipelineState:ctx->pipeline_group_norm];
2022
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2023
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2024
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
2025
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
2026
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
2027
+ [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:5];
2028
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:6];
2029
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:7];
2030
+ [encoder setBytes:&n_groups length:sizeof( int32_t) atIndex:8];
2031
+ [encoder setBytes:&eps length:sizeof( float) atIndex:9];
2032
+ [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
2033
+
2034
+ [encoder dispatchThreadgroups:MTLSizeMake(n_groups, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2035
+ } break;
1598
2036
  case GGML_OP_NORM:
1599
2037
  {
1600
2038
  float eps;
@@ -1764,6 +2202,65 @@ void ggml_metal_graph_compute(
1764
2202
 
1765
2203
  [encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
1766
2204
  } break;
2205
+ case GGML_OP_UPSCALE:
2206
+ {
2207
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
2208
+
2209
+ const int sf = dst->op_params[0];
2210
+
2211
+ [encoder setComputePipelineState:ctx->pipeline_upscale_f32];
2212
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2213
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2214
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
2215
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
2216
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
2217
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
2218
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
2219
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
2220
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
2221
+ [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
2222
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
2223
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
2224
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
2225
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
2226
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
2227
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
2228
+ [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
2229
+ [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
2230
+ [encoder setBytes:&sf length:sizeof(sf) atIndex:18];
2231
+
2232
+ const int nth = MIN(1024, ne0);
2233
+
2234
+ [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2235
+ } break;
2236
+ case GGML_OP_PAD:
2237
+ {
2238
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
2239
+
2240
+ [encoder setComputePipelineState:ctx->pipeline_pad_f32];
2241
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2242
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2243
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
2244
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
2245
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
2246
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
2247
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
2248
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
2249
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
2250
+ [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
2251
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
2252
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
2253
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
2254
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
2255
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
2256
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
2257
+ [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
2258
+ [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
2259
+
2260
+ const int nth = MIN(1024, ne0);
2261
+
2262
+ [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2263
+ } break;
1767
2264
  case GGML_OP_ARGSORT:
1768
2265
  {
1769
2266
  GGML_ASSERT(src0->type == GGML_TYPE_F32);
@@ -1785,6 +2282,22 @@ void ggml_metal_graph_compute(
1785
2282
 
1786
2283
  [encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00, 1, 1)];
1787
2284
  } break;
2285
+ case GGML_OP_LEAKY_RELU:
2286
+ {
2287
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
2288
+
2289
+ float slope;
2290
+ memcpy(&slope, dst->op_params, sizeof(float));
2291
+
2292
+ [encoder setComputePipelineState:ctx->pipeline_leaky_relu_f32];
2293
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2294
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2295
+ [encoder setBytes:&slope length:sizeof(slope) atIndex:2];
2296
+
2297
+ const int64_t n = ggml_nelements(dst);
2298
+
2299
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2300
+ } break;
1788
2301
  case GGML_OP_DUP:
1789
2302
  case GGML_OP_CPY:
1790
2303
  case GGML_OP_CONT:
@@ -1813,7 +2326,7 @@ void ggml_metal_graph_compute(
1813
2326
  {
1814
2327
  switch (dstt) {
1815
2328
  case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f16_f16]; break;
1816
- case GGML_TYPE_F32: GGML_ASSERT(false && "cpy_f16_f32 not implemented"); break;
2329
+ case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f16_f32]; break;
1817
2330
  default: GGML_ASSERT(false && "not implemented");
1818
2331
  };
1819
2332
  } break;
@@ -1880,6 +2393,7 @@ void ggml_metal_graph_compute(
1880
2393
 
1881
2394
  // backend interface
1882
2395
 
2396
+ // default buffer
1883
2397
  static id<MTLDevice> g_backend_device = nil;
1884
2398
  static int g_backend_device_ref_count = 0;
1885
2399
 
@@ -1907,34 +2421,31 @@ static void ggml_backend_metal_free_device(void) {
1907
2421
  static void * ggml_backend_metal_buffer_get_base(ggml_backend_buffer_t buffer) {
1908
2422
  struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context;
1909
2423
 
1910
- return ctx->data;
2424
+ return ctx->all_data;
1911
2425
  }
1912
2426
 
1913
2427
  static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer) {
1914
2428
  struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context;
1915
2429
 
1916
- [ctx->metal release];
2430
+ for (int i = 0; i < ctx->n_buffers; i++) {
2431
+ [ctx->buffers[i].metal release];
2432
+ }
1917
2433
  ggml_backend_metal_free_device();
1918
2434
 
1919
- free(ctx->data);
1920
- free(ctx);
2435
+ if (ctx->owned) {
2436
+ free(ctx->all_data);
2437
+ }
1921
2438
 
1922
- UNUSED(buffer);
2439
+ free(ctx);
1923
2440
  }
1924
2441
 
1925
2442
  static void ggml_backend_metal_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
1926
- GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
1927
- GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
1928
-
1929
2443
  memcpy((char *)tensor->data + offset, data, size);
1930
2444
 
1931
2445
  UNUSED(buffer);
1932
2446
  }
1933
2447
 
1934
2448
  static void ggml_backend_metal_buffer_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
1935
- GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds");
1936
- GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
1937
-
1938
2449
  memcpy(data, (const char *)tensor->data + offset, size);
1939
2450
 
1940
2451
  UNUSED(buffer);
@@ -1952,7 +2463,13 @@ static void ggml_backend_metal_buffer_cpy_tensor_to(ggml_backend_buffer_t buffer
1952
2463
  UNUSED(buffer);
1953
2464
  }
1954
2465
 
1955
- static struct ggml_backend_buffer_i metal_backend_buffer_i = {
2466
+ static void ggml_backend_metal_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
2467
+ struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context;
2468
+
2469
+ memset(ctx->all_data, value, ctx->all_size);
2470
+ }
2471
+
2472
+ static struct ggml_backend_buffer_i ggml_backend_metal_buffer_i = {
1956
2473
  /* .free_buffer = */ ggml_backend_metal_buffer_free_buffer,
1957
2474
  /* .get_base = */ ggml_backend_metal_buffer_get_base,
1958
2475
  /* .init_tensor = */ NULL,
@@ -1960,8 +2477,11 @@ static struct ggml_backend_buffer_i metal_backend_buffer_i = {
1960
2477
  /* .get_tensor = */ ggml_backend_metal_buffer_get_tensor,
1961
2478
  /* .cpy_tensor_from = */ ggml_backend_metal_buffer_cpy_tensor_from,
1962
2479
  /* .cpy_tensor_to = */ ggml_backend_metal_buffer_cpy_tensor_to,
2480
+ /* .clear = */ ggml_backend_metal_buffer_clear,
1963
2481
  };
1964
2482
 
2483
+ // default buffer type
2484
+
1965
2485
  static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
1966
2486
  struct ggml_backend_metal_buffer_context * ctx = malloc(sizeof(struct ggml_backend_metal_buffer_context));
1967
2487
 
@@ -1972,13 +2492,46 @@ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_ba
1972
2492
  size_aligned += (size_page - (size_aligned % size_page));
1973
2493
  }
1974
2494
 
1975
- ctx->data = ggml_metal_host_malloc(size);
1976
- ctx->metal = [ggml_backend_metal_get_device() newBufferWithBytesNoCopy:ctx->data
2495
+ id<MTLDevice> device = ggml_backend_metal_get_device();
2496
+
2497
+ ctx->all_data = ggml_metal_host_malloc(size_aligned);
2498
+ ctx->all_size = size_aligned;
2499
+ ctx->owned = true;
2500
+ ctx->n_buffers = 1;
2501
+
2502
+ ctx->buffers[0].data = ctx->all_data;
2503
+ ctx->buffers[0].size = size;
2504
+ ctx->buffers[0].metal = [device newBufferWithBytesNoCopy:ctx->all_data
1977
2505
  length:size_aligned
1978
2506
  options:MTLResourceStorageModeShared
1979
2507
  deallocator:nil];
1980
2508
 
1981
- return ggml_backend_buffer_init(buft, metal_backend_buffer_i, ctx, size);
2509
+ if (ctx->buffers[0].metal == nil) {
2510
+ GGML_METAL_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0);
2511
+ free(ctx);
2512
+ ggml_backend_metal_free_device();
2513
+ return NULL;
2514
+ }
2515
+
2516
+ GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB", __func__, size_aligned / 1024.0 / 1024.0);
2517
+
2518
+
2519
+ #if TARGET_OS_OSX
2520
+ GGML_METAL_LOG_INFO(", (%8.2f / %8.2f)",
2521
+ device.currentAllocatedSize / 1024.0 / 1024.0,
2522
+ device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
2523
+
2524
+ if (device.currentAllocatedSize > device.recommendedMaxWorkingSetSize) {
2525
+ GGML_METAL_LOG_WARN("%s: warning: current allocated size is greater than the recommended max working set size\n", __func__);
2526
+ } else {
2527
+ GGML_METAL_LOG_INFO("\n");
2528
+ }
2529
+ #else
2530
+ GGML_METAL_LOG_INFO(", (%8.2f)\n", device.currentAllocatedSize / 1024.0 / 1024.0);
2531
+ #endif
2532
+
2533
+
2534
+ return ggml_backend_buffer_init(buft, ggml_backend_metal_buffer_i, ctx, size);
1982
2535
  }
1983
2536
 
1984
2537
  static size_t ggml_backend_metal_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
@@ -1989,7 +2542,13 @@ static size_t ggml_backend_metal_buffer_type_get_alignment(ggml_backend_buffer_t
1989
2542
  static bool ggml_backend_metal_buffer_type_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend) {
1990
2543
  return ggml_backend_is_metal(backend) || ggml_backend_is_cpu(backend);
1991
2544
 
1992
- GGML_UNUSED(buft);
2545
+ UNUSED(buft);
2546
+ }
2547
+
2548
+ static bool ggml_backend_metal_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
2549
+ return true;
2550
+
2551
+ UNUSED(buft);
1993
2552
  }
1994
2553
 
1995
2554
  ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void) {
@@ -1999,6 +2558,7 @@ ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void) {
1999
2558
  /* .get_alignment = */ ggml_backend_metal_buffer_type_get_alignment,
2000
2559
  /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
2001
2560
  /* .supports_backend = */ ggml_backend_metal_buffer_type_supports_backend,
2561
+ /* .is_host = */ ggml_backend_metal_buffer_type_is_host,
2002
2562
  },
2003
2563
  /* .context = */ NULL,
2004
2564
  };
@@ -2006,6 +2566,87 @@ ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void) {
2006
2566
  return &ggml_backend_buffer_type_metal;
2007
2567
  }
2008
2568
 
2569
+ // buffer from ptr
2570
+
2571
+ ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size) {
2572
+ struct ggml_backend_metal_buffer_context * ctx = malloc(sizeof(struct ggml_backend_metal_buffer_context));
2573
+
2574
+ ctx->all_data = data;
2575
+ ctx->all_size = size;
2576
+ ctx->owned = false;
2577
+ ctx->n_buffers = 0;
2578
+
2579
+ const size_t size_page = sysconf(_SC_PAGESIZE);
2580
+ size_t size_aligned = size;
2581
+ if ((size_aligned % size_page) != 0) {
2582
+ size_aligned += (size_page - (size_aligned % size_page));
2583
+ }
2584
+
2585
+ id<MTLDevice> device = ggml_backend_metal_get_device();
2586
+
2587
+ // the buffer fits into the max buffer size allowed by the device
2588
+ if (size_aligned <= device.maxBufferLength) {
2589
+ ctx->buffers[ctx->n_buffers].data = data;
2590
+ ctx->buffers[ctx->n_buffers].size = size;
2591
+
2592
+ ctx->buffers[ctx->n_buffers].metal = [device newBufferWithBytesNoCopy:data length:size_aligned options:MTLResourceStorageModeShared deallocator:nil];
2593
+
2594
+ if (ctx->buffers[ctx->n_buffers].metal == nil) {
2595
+ GGML_METAL_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0);
2596
+ return false;
2597
+ }
2598
+
2599
+ GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB", __func__, size_aligned / 1024.0 / 1024.0);
2600
+
2601
+ ++ctx->n_buffers;
2602
+ } else {
2603
+ // this overlap between the views will guarantee that the tensor with the maximum size will fully fit into
2604
+ // one of the views
2605
+ const size_t size_ovlp = ((max_size + size_page - 1) / size_page + 1) * size_page; // round-up 2 pages just in case
2606
+ const size_t size_step = device.maxBufferLength - size_ovlp;
2607
+ const size_t size_view = device.maxBufferLength;
2608
+
2609
+ for (size_t i = 0; i < size; i += size_step) {
2610
+ const size_t size_step_aligned = (i + size_view <= size) ? size_view : (size_aligned - i);
2611
+
2612
+ ctx->buffers[ctx->n_buffers].data = (void *) ((uint8_t *) data + i);
2613
+ ctx->buffers[ctx->n_buffers].size = size_step_aligned;
2614
+
2615
+ ctx->buffers[ctx->n_buffers].metal = [device newBufferWithBytesNoCopy:(void *) ((uint8_t *) data + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil];
2616
+
2617
+ if (ctx->buffers[ctx->n_buffers].metal == nil) {
2618
+ GGML_METAL_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_step_aligned / 1024.0 / 1024.0);
2619
+ return false;
2620
+ }
2621
+
2622
+ GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB, offs = %12ld", __func__, size_step_aligned / 1024.0 / 1024.0, i);
2623
+ if (i + size_step < size) {
2624
+ GGML_METAL_LOG_INFO("\n");
2625
+ }
2626
+
2627
+ ++ctx->n_buffers;
2628
+ }
2629
+ }
2630
+
2631
+ #if TARGET_OS_OSX
2632
+ GGML_METAL_LOG_INFO(", (%8.2f / %8.2f)",
2633
+ device.currentAllocatedSize / 1024.0 / 1024.0,
2634
+ device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
2635
+
2636
+ if (device.currentAllocatedSize > device.recommendedMaxWorkingSetSize) {
2637
+ GGML_METAL_LOG_WARN("%s: warning: current allocated size is greater than the recommended max working set size\n", __func__);
2638
+ } else {
2639
+ GGML_METAL_LOG_INFO("\n");
2640
+ }
2641
+ #else
2642
+ GGML_METAL_LOG_INFO(", (%8.2f)\n", device.currentAllocatedSize / 1024.0 / 1024.0);
2643
+ #endif
2644
+
2645
+ return ggml_backend_buffer_init(ggml_backend_metal_buffer_type(), ggml_backend_metal_buffer_i, ctx, size);
2646
+ }
2647
+
2648
+ // backend
2649
+
2009
2650
  static const char * ggml_backend_metal_name(ggml_backend_t backend) {
2010
2651
  return "Metal";
2011
2652
 
@@ -2018,10 +2659,6 @@ static void ggml_backend_metal_free(ggml_backend_t backend) {
2018
2659
  free(backend);
2019
2660
  }
2020
2661
 
2021
- static void ggml_backend_metal_synchronize(ggml_backend_t backend) {
2022
- UNUSED(backend);
2023
- }
2024
-
2025
2662
  static ggml_backend_buffer_type_t ggml_backend_metal_get_default_buffer_type(ggml_backend_t backend) {
2026
2663
  return ggml_backend_metal_buffer_type();
2027
2664
 
@@ -2048,25 +2685,15 @@ static struct ggml_backend_i metal_backend_i = {
2048
2685
  /* .get_tensor_async = */ NULL,
2049
2686
  /* .cpy_tensor_from_async = */ NULL,
2050
2687
  /* .cpy_tensor_to_async = */ NULL,
2051
- /* .synchronize = */ ggml_backend_metal_synchronize,
2052
- /* .graph_plan_create = */ NULL, // the metal implementation does not require creating graph plans atm
2688
+ /* .synchronize = */ NULL,
2689
+ /* .graph_plan_create = */ NULL,
2053
2690
  /* .graph_plan_free = */ NULL,
2054
2691
  /* .graph_plan_compute = */ NULL,
2055
2692
  /* .graph_compute = */ ggml_backend_metal_graph_compute,
2056
2693
  /* .supports_op = */ ggml_backend_metal_supports_op,
2057
2694
  };
2058
2695
 
2059
- // TODO: make a common log callback for all backends in ggml-backend
2060
- static void ggml_backend_log_callback(enum ggml_log_level level, const char * msg, void * user_data) {
2061
- fprintf(stderr, "%s", msg);
2062
-
2063
- UNUSED(level);
2064
- UNUSED(user_data);
2065
- }
2066
-
2067
2696
  ggml_backend_t ggml_backend_metal_init(void) {
2068
- ggml_metal_log_set_callback(ggml_backend_log_callback, NULL);
2069
-
2070
2697
  struct ggml_metal_context * ctx = ggml_metal_init(GGML_DEFAULT_N_THREADS);
2071
2698
 
2072
2699
  if (ctx == NULL) {