llama_cpp 0.10.0 → 0.10.1

Sign up to get free protection for your applications and to get access to all the features.
@@ -66,9 +66,11 @@ struct ggml_metal_context {
66
66
  GGML_METAL_DECL_KERNEL(div_row);
67
67
  GGML_METAL_DECL_KERNEL(scale);
68
68
  GGML_METAL_DECL_KERNEL(scale_4);
69
- GGML_METAL_DECL_KERNEL(silu);
69
+ GGML_METAL_DECL_KERNEL(tanh);
70
70
  GGML_METAL_DECL_KERNEL(relu);
71
71
  GGML_METAL_DECL_KERNEL(gelu);
72
+ GGML_METAL_DECL_KERNEL(gelu_quick);
73
+ GGML_METAL_DECL_KERNEL(silu);
72
74
  GGML_METAL_DECL_KERNEL(soft_max);
73
75
  GGML_METAL_DECL_KERNEL(soft_max_4);
74
76
  GGML_METAL_DECL_KERNEL(diag_mask_inf);
@@ -86,6 +88,7 @@ struct ggml_metal_context {
86
88
  GGML_METAL_DECL_KERNEL(get_rows_q5_K);
87
89
  GGML_METAL_DECL_KERNEL(get_rows_q6_K);
88
90
  GGML_METAL_DECL_KERNEL(rms_norm);
91
+ GGML_METAL_DECL_KERNEL(group_norm);
89
92
  GGML_METAL_DECL_KERNEL(norm);
90
93
  GGML_METAL_DECL_KERNEL(mul_mv_f32_f32);
91
94
  GGML_METAL_DECL_KERNEL(mul_mv_f16_f16);
@@ -102,6 +105,21 @@ struct ggml_metal_context {
102
105
  GGML_METAL_DECL_KERNEL(mul_mv_q4_K_f32);
103
106
  GGML_METAL_DECL_KERNEL(mul_mv_q5_K_f32);
104
107
  GGML_METAL_DECL_KERNEL(mul_mv_q6_K_f32);
108
+ GGML_METAL_DECL_KERNEL(mul_mv_id_f32_f32);
109
+ //GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f16);
110
+ GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32);
111
+ //GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32_1row);
112
+ //GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32_l4);
113
+ GGML_METAL_DECL_KERNEL(mul_mv_id_q4_0_f32);
114
+ GGML_METAL_DECL_KERNEL(mul_mv_id_q4_1_f32);
115
+ GGML_METAL_DECL_KERNEL(mul_mv_id_q5_0_f32);
116
+ GGML_METAL_DECL_KERNEL(mul_mv_id_q5_1_f32);
117
+ GGML_METAL_DECL_KERNEL(mul_mv_id_q8_0_f32);
118
+ GGML_METAL_DECL_KERNEL(mul_mv_id_q2_K_f32);
119
+ GGML_METAL_DECL_KERNEL(mul_mv_id_q3_K_f32);
120
+ GGML_METAL_DECL_KERNEL(mul_mv_id_q4_K_f32);
121
+ GGML_METAL_DECL_KERNEL(mul_mv_id_q5_K_f32);
122
+ GGML_METAL_DECL_KERNEL(mul_mv_id_q6_K_f32);
105
123
  GGML_METAL_DECL_KERNEL(mul_mm_f32_f32);
106
124
  GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
107
125
  GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
@@ -130,8 +148,11 @@ struct ggml_metal_context {
130
148
  GGML_METAL_DECL_KERNEL(rope_f16);
131
149
  GGML_METAL_DECL_KERNEL(alibi_f32);
132
150
  GGML_METAL_DECL_KERNEL(im2col_f16);
151
+ GGML_METAL_DECL_KERNEL(upscale_f32);
152
+ GGML_METAL_DECL_KERNEL(pad_f32);
133
153
  GGML_METAL_DECL_KERNEL(argsort_f32_i32_asc);
134
154
  GGML_METAL_DECL_KERNEL(argsort_f32_i32_desc);
155
+ GGML_METAL_DECL_KERNEL(leaky_relu_f32);
135
156
  GGML_METAL_DECL_KERNEL(cpy_f32_f16);
136
157
  GGML_METAL_DECL_KERNEL(cpy_f32_f32);
137
158
  GGML_METAL_DECL_KERNEL(cpy_f32_q8_0);
@@ -140,6 +161,7 @@ struct ggml_metal_context {
140
161
  //GGML_METAL_DECL_KERNEL(cpy_f32_q5_0);
141
162
  //GGML_METAL_DECL_KERNEL(cpy_f32_q5_1);
142
163
  GGML_METAL_DECL_KERNEL(cpy_f16_f16);
164
+ GGML_METAL_DECL_KERNEL(cpy_f16_f32);
143
165
  GGML_METAL_DECL_KERNEL(concat);
144
166
  GGML_METAL_DECL_KERNEL(sqr);
145
167
  GGML_METAL_DECL_KERNEL(sum_rows);
@@ -177,6 +199,8 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
177
199
  ggml_metal_log_callback(level, buffer, ggml_metal_log_user_data);
178
200
  } else {
179
201
  char* buffer2 = malloc(len+1);
202
+ va_end(args);
203
+ va_start(args, format);
180
204
  vsnprintf(buffer2, len+1, format, args);
181
205
  buffer2[len] = 0;
182
206
  ggml_metal_log_callback(level, buffer2, ggml_metal_log_user_data);
@@ -316,9 +340,11 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
316
340
  GGML_METAL_ADD_KERNEL(div_row);
317
341
  GGML_METAL_ADD_KERNEL(scale);
318
342
  GGML_METAL_ADD_KERNEL(scale_4);
319
- GGML_METAL_ADD_KERNEL(silu);
343
+ GGML_METAL_ADD_KERNEL(tanh);
320
344
  GGML_METAL_ADD_KERNEL(relu);
321
345
  GGML_METAL_ADD_KERNEL(gelu);
346
+ GGML_METAL_ADD_KERNEL(gelu_quick);
347
+ GGML_METAL_ADD_KERNEL(silu);
322
348
  GGML_METAL_ADD_KERNEL(soft_max);
323
349
  GGML_METAL_ADD_KERNEL(soft_max_4);
324
350
  GGML_METAL_ADD_KERNEL(diag_mask_inf);
@@ -336,6 +362,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
336
362
  GGML_METAL_ADD_KERNEL(get_rows_q5_K);
337
363
  GGML_METAL_ADD_KERNEL(get_rows_q6_K);
338
364
  GGML_METAL_ADD_KERNEL(rms_norm);
365
+ GGML_METAL_ADD_KERNEL(group_norm);
339
366
  GGML_METAL_ADD_KERNEL(norm);
340
367
  GGML_METAL_ADD_KERNEL(mul_mv_f32_f32);
341
368
  GGML_METAL_ADD_KERNEL(mul_mv_f16_f16);
@@ -352,6 +379,21 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
352
379
  GGML_METAL_ADD_KERNEL(mul_mv_q4_K_f32);
353
380
  GGML_METAL_ADD_KERNEL(mul_mv_q5_K_f32);
354
381
  GGML_METAL_ADD_KERNEL(mul_mv_q6_K_f32);
382
+ GGML_METAL_ADD_KERNEL(mul_mv_id_f32_f32);
383
+ //GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f16);
384
+ GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32);
385
+ //GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32_1row);
386
+ //GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32_l4);
387
+ GGML_METAL_ADD_KERNEL(mul_mv_id_q4_0_f32);
388
+ GGML_METAL_ADD_KERNEL(mul_mv_id_q4_1_f32);
389
+ GGML_METAL_ADD_KERNEL(mul_mv_id_q5_0_f32);
390
+ GGML_METAL_ADD_KERNEL(mul_mv_id_q5_1_f32);
391
+ GGML_METAL_ADD_KERNEL(mul_mv_id_q8_0_f32);
392
+ GGML_METAL_ADD_KERNEL(mul_mv_id_q2_K_f32);
393
+ GGML_METAL_ADD_KERNEL(mul_mv_id_q3_K_f32);
394
+ GGML_METAL_ADD_KERNEL(mul_mv_id_q4_K_f32);
395
+ GGML_METAL_ADD_KERNEL(mul_mv_id_q5_K_f32);
396
+ GGML_METAL_ADD_KERNEL(mul_mv_id_q6_K_f32);
355
397
  if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
356
398
  GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
357
399
  GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
@@ -382,8 +424,11 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
382
424
  GGML_METAL_ADD_KERNEL(rope_f16);
383
425
  GGML_METAL_ADD_KERNEL(alibi_f32);
384
426
  GGML_METAL_ADD_KERNEL(im2col_f16);
427
+ GGML_METAL_ADD_KERNEL(upscale_f32);
428
+ GGML_METAL_ADD_KERNEL(pad_f32);
385
429
  GGML_METAL_ADD_KERNEL(argsort_f32_i32_asc);
386
430
  GGML_METAL_ADD_KERNEL(argsort_f32_i32_desc);
431
+ GGML_METAL_ADD_KERNEL(leaky_relu_f32);
387
432
  GGML_METAL_ADD_KERNEL(cpy_f32_f16);
388
433
  GGML_METAL_ADD_KERNEL(cpy_f32_f32);
389
434
  GGML_METAL_ADD_KERNEL(cpy_f32_q8_0);
@@ -392,6 +437,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
392
437
  //GGML_METAL_ADD_KERNEL(cpy_f32_q5_0);
393
438
  //GGML_METAL_ADD_KERNEL(cpy_f32_q5_1);
394
439
  GGML_METAL_ADD_KERNEL(cpy_f16_f16);
440
+ GGML_METAL_ADD_KERNEL(cpy_f16_f32);
395
441
  GGML_METAL_ADD_KERNEL(concat);
396
442
  GGML_METAL_ADD_KERNEL(sqr);
397
443
  GGML_METAL_ADD_KERNEL(sum_rows);
@@ -416,9 +462,11 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
416
462
  GGML_METAL_DEL_KERNEL(div_row);
417
463
  GGML_METAL_DEL_KERNEL(scale);
418
464
  GGML_METAL_DEL_KERNEL(scale_4);
419
- GGML_METAL_DEL_KERNEL(silu);
465
+ GGML_METAL_DEL_KERNEL(tanh);
420
466
  GGML_METAL_DEL_KERNEL(relu);
421
467
  GGML_METAL_DEL_KERNEL(gelu);
468
+ GGML_METAL_DEL_KERNEL(gelu_quick);
469
+ GGML_METAL_DEL_KERNEL(silu);
422
470
  GGML_METAL_DEL_KERNEL(soft_max);
423
471
  GGML_METAL_DEL_KERNEL(soft_max_4);
424
472
  GGML_METAL_DEL_KERNEL(diag_mask_inf);
@@ -436,6 +484,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
436
484
  GGML_METAL_DEL_KERNEL(get_rows_q5_K);
437
485
  GGML_METAL_DEL_KERNEL(get_rows_q6_K);
438
486
  GGML_METAL_DEL_KERNEL(rms_norm);
487
+ GGML_METAL_DEL_KERNEL(group_norm);
439
488
  GGML_METAL_DEL_KERNEL(norm);
440
489
  GGML_METAL_DEL_KERNEL(mul_mv_f32_f32);
441
490
  GGML_METAL_DEL_KERNEL(mul_mv_f16_f16);
@@ -452,6 +501,21 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
452
501
  GGML_METAL_DEL_KERNEL(mul_mv_q4_K_f32);
453
502
  GGML_METAL_DEL_KERNEL(mul_mv_q5_K_f32);
454
503
  GGML_METAL_DEL_KERNEL(mul_mv_q6_K_f32);
504
+ GGML_METAL_DEL_KERNEL(mul_mv_id_f32_f32);
505
+ //GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f16);
506
+ GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32);
507
+ //GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32_1row);
508
+ //GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32_l4);
509
+ GGML_METAL_DEL_KERNEL(mul_mv_id_q4_0_f32);
510
+ GGML_METAL_DEL_KERNEL(mul_mv_id_q4_1_f32);
511
+ GGML_METAL_DEL_KERNEL(mul_mv_id_q5_0_f32);
512
+ GGML_METAL_DEL_KERNEL(mul_mv_id_q5_1_f32);
513
+ GGML_METAL_DEL_KERNEL(mul_mv_id_q8_0_f32);
514
+ GGML_METAL_DEL_KERNEL(mul_mv_id_q2_K_f32);
515
+ GGML_METAL_DEL_KERNEL(mul_mv_id_q3_K_f32);
516
+ GGML_METAL_DEL_KERNEL(mul_mv_id_q4_K_f32);
517
+ GGML_METAL_DEL_KERNEL(mul_mv_id_q5_K_f32);
518
+ GGML_METAL_DEL_KERNEL(mul_mv_id_q6_K_f32);
455
519
  if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
456
520
  GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
457
521
  GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
@@ -482,8 +546,11 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
482
546
  GGML_METAL_DEL_KERNEL(rope_f16);
483
547
  GGML_METAL_DEL_KERNEL(alibi_f32);
484
548
  GGML_METAL_DEL_KERNEL(im2col_f16);
549
+ GGML_METAL_DEL_KERNEL(upscale_f32);
550
+ GGML_METAL_DEL_KERNEL(pad_f32);
485
551
  GGML_METAL_DEL_KERNEL(argsort_f32_i32_asc);
486
552
  GGML_METAL_DEL_KERNEL(argsort_f32_i32_desc);
553
+ GGML_METAL_DEL_KERNEL(leaky_relu_f32);
487
554
  GGML_METAL_DEL_KERNEL(cpy_f32_f16);
488
555
  GGML_METAL_DEL_KERNEL(cpy_f32_f32);
489
556
  GGML_METAL_DEL_KERNEL(cpy_f32_q8_0);
@@ -492,6 +559,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
492
559
  //GGML_METAL_DEL_KERNEL(cpy_f32_q5_0);
493
560
  //GGML_METAL_DEL_KERNEL(cpy_f32_q5_1);
494
561
  GGML_METAL_DEL_KERNEL(cpy_f16_f16);
562
+ GGML_METAL_DEL_KERNEL(cpy_f16_f32);
495
563
  GGML_METAL_DEL_KERNEL(concat);
496
564
  GGML_METAL_DEL_KERNEL(sqr);
497
565
  GGML_METAL_DEL_KERNEL(sum_rows);
@@ -793,9 +861,11 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
793
861
  switch (op->op) {
794
862
  case GGML_OP_UNARY:
795
863
  switch (ggml_get_unary_op(op)) {
796
- case GGML_UNARY_OP_SILU:
864
+ case GGML_UNARY_OP_TANH:
797
865
  case GGML_UNARY_OP_RELU:
798
866
  case GGML_UNARY_OP_GELU:
867
+ case GGML_UNARY_OP_GELU_QUICK:
868
+ case GGML_UNARY_OP_SILU:
799
869
  return true;
800
870
  default:
801
871
  return false;
@@ -807,6 +877,7 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
807
877
  case GGML_OP_PERMUTE:
808
878
  case GGML_OP_CONCAT:
809
879
  case GGML_OP_ADD:
880
+ case GGML_OP_ACC:
810
881
  case GGML_OP_MUL:
811
882
  case GGML_OP_DIV:
812
883
  case GGML_OP_SCALE:
@@ -814,21 +885,50 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
814
885
  case GGML_OP_SUM_ROWS:
815
886
  case GGML_OP_SOFT_MAX:
816
887
  case GGML_OP_RMS_NORM:
888
+ case GGML_OP_GROUP_NORM:
817
889
  case GGML_OP_NORM:
818
890
  case GGML_OP_ALIBI:
819
891
  case GGML_OP_ROPE:
820
892
  case GGML_OP_IM2COL:
893
+ case GGML_OP_UPSCALE:
894
+ case GGML_OP_PAD:
821
895
  case GGML_OP_ARGSORT:
822
- case GGML_OP_DUP:
823
- case GGML_OP_CPY:
824
- case GGML_OP_CONT:
896
+ case GGML_OP_LEAKY_RELU:
825
897
  case GGML_OP_MUL_MAT:
826
898
  case GGML_OP_MUL_MAT_ID:
827
899
  return true;
900
+ case GGML_OP_CPY:
901
+ case GGML_OP_DUP:
902
+ case GGML_OP_CONT:
903
+ {
904
+ switch (op->src[0]->type) {
905
+ case GGML_TYPE_F32:
906
+ switch (op->type) {
907
+ case GGML_TYPE_F16:
908
+ case GGML_TYPE_F32:
909
+ case GGML_TYPE_Q8_0:
910
+ case GGML_TYPE_Q4_0:
911
+ case GGML_TYPE_Q4_1:
912
+ return true;
913
+ default:
914
+ return false;
915
+ }
916
+ case GGML_TYPE_F16:
917
+ switch (op->type) {
918
+ case GGML_TYPE_F16:
919
+ case GGML_TYPE_F32:
920
+ return true;
921
+ default:
922
+ return false;
923
+ }
924
+ default:
925
+ return false;
926
+ };
927
+ }
828
928
  case GGML_OP_DIAG_MASK_INF:
829
929
  case GGML_OP_GET_ROWS:
830
930
  {
831
- return op->ne[0] % 4 == 0;
931
+ return op->ne[3] == 1;
832
932
  }
833
933
  default:
834
934
  return false;
@@ -904,7 +1004,10 @@ void ggml_metal_graph_compute(
904
1004
  } break;
905
1005
  }
906
1006
 
907
- GGML_ASSERT(ggml_metal_supports_op(dst));
1007
+ if (!ggml_metal_supports_op(dst)) {
1008
+ GGML_METAL_LOG_ERROR("%s: error: unsupported op '%s'\n", __func__, ggml_op_desc(dst));
1009
+ GGML_ASSERT(!"unsupported op");
1010
+ }
908
1011
 
909
1012
  const int64_t ne00 = src0 ? src0->ne[0] : 0;
910
1013
  const int64_t ne01 = src0 ? src0->ne[1] : 0;
@@ -1001,34 +1104,39 @@ void ggml_metal_graph_compute(
1001
1104
  case GGML_OP_MUL:
1002
1105
  case GGML_OP_DIV:
1003
1106
  {
1004
- GGML_ASSERT(ggml_is_contiguous(src0));
1005
- GGML_ASSERT(ggml_is_contiguous(src1));
1107
+ const size_t offs = 0;
1006
1108
 
1007
1109
  bool bcast_row = false;
1008
1110
 
1009
1111
  int64_t nb = ne00;
1010
1112
 
1011
- if (ggml_nelements(src1) == ne10 && ne00 % 4 == 0) {
1113
+ id<MTLComputePipelineState> pipeline = nil;
1114
+
1115
+ if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
1116
+ GGML_ASSERT(ggml_is_contiguous(src0));
1117
+
1012
1118
  // src1 is a row
1013
1119
  GGML_ASSERT(ne11 == 1);
1014
1120
 
1015
1121
  nb = ne00 / 4;
1016
1122
  switch (dst->op) {
1017
- case GGML_OP_ADD: [encoder setComputePipelineState:ctx->pipeline_add_row]; break;
1018
- case GGML_OP_MUL: [encoder setComputePipelineState:ctx->pipeline_mul_row]; break;
1019
- case GGML_OP_DIV: [encoder setComputePipelineState:ctx->pipeline_div_row]; break;
1123
+ case GGML_OP_ADD: pipeline = ctx->pipeline_add_row; break;
1124
+ case GGML_OP_MUL: pipeline = ctx->pipeline_mul_row; break;
1125
+ case GGML_OP_DIV: pipeline = ctx->pipeline_div_row; break;
1020
1126
  default: GGML_ASSERT(false);
1021
1127
  }
1022
1128
 
1023
1129
  bcast_row = true;
1024
1130
  } else {
1025
1131
  switch (dst->op) {
1026
- case GGML_OP_ADD: [encoder setComputePipelineState:ctx->pipeline_add]; break;
1027
- case GGML_OP_MUL: [encoder setComputePipelineState:ctx->pipeline_mul]; break;
1028
- case GGML_OP_DIV: [encoder setComputePipelineState:ctx->pipeline_div]; break;
1132
+ case GGML_OP_ADD: pipeline = ctx->pipeline_add; break;
1133
+ case GGML_OP_MUL: pipeline = ctx->pipeline_mul; break;
1134
+ case GGML_OP_DIV: pipeline = ctx->pipeline_div; break;
1029
1135
  default: GGML_ASSERT(false);
1030
1136
  }
1031
1137
  }
1138
+
1139
+ [encoder setComputePipelineState:pipeline];
1032
1140
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1033
1141
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1034
1142
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
@@ -1056,18 +1164,99 @@ void ggml_metal_graph_compute(
1056
1164
  [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
1057
1165
  [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
1058
1166
  [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
1059
- [encoder setBytes:&nb length:sizeof(nb) atIndex:27];
1167
+ [encoder setBytes:&offs length:sizeof(offs) atIndex:27];
1168
+ [encoder setBytes:&nb length:sizeof(nb) atIndex:28];
1060
1169
 
1061
1170
  if (bcast_row) {
1062
1171
  const int64_t n = ggml_nelements(dst)/4;
1063
1172
 
1064
1173
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1065
1174
  } else {
1066
- const int nth = MIN(1024, ne0);
1175
+ const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
1067
1176
 
1068
1177
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1069
1178
  }
1070
1179
  } break;
1180
+ case GGML_OP_ACC:
1181
+ {
1182
+ GGML_ASSERT(src0t == GGML_TYPE_F32);
1183
+ GGML_ASSERT(src1t == GGML_TYPE_F32);
1184
+ GGML_ASSERT(dstt == GGML_TYPE_F32);
1185
+
1186
+ GGML_ASSERT(ggml_is_contiguous(src0));
1187
+ GGML_ASSERT(ggml_is_contiguous(src1));
1188
+
1189
+ const size_t pnb1 = ((int32_t *) dst->op_params)[0];
1190
+ const size_t pnb2 = ((int32_t *) dst->op_params)[1];
1191
+ const size_t pnb3 = ((int32_t *) dst->op_params)[2];
1192
+ const size_t offs = ((int32_t *) dst->op_params)[3];
1193
+
1194
+ const bool inplace = (bool) ((int32_t *) dst->op_params)[4];
1195
+
1196
+ if (!inplace) {
1197
+ // run a separete kernel to cpy src->dst
1198
+ // not sure how to avoid this
1199
+ // TODO: make a simpler cpy_bytes kernel
1200
+
1201
+ const int nth = MIN(1024, ne00);
1202
+
1203
+ [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32];
1204
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1205
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1206
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1207
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
1208
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
1209
+ [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
1210
+ [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
1211
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
1212
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
1213
+ [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
1214
+ [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
1215
+ [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
1216
+ [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
1217
+ [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
1218
+ [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
1219
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
1220
+ [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
1221
+ [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
1222
+
1223
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1224
+ }
1225
+
1226
+ [encoder setComputePipelineState:ctx->pipeline_add];
1227
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1228
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1229
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1230
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1231
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
1232
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
1233
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
1234
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
1235
+ [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:8];
1236
+ [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:9];
1237
+ [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:10];
1238
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
1239
+ [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
1240
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
1241
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
1242
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
1243
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
1244
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
1245
+ [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
1246
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
1247
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
1248
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
1249
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
1250
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
1251
+ [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:24];
1252
+ [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:25];
1253
+ [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:26];
1254
+ [encoder setBytes:&offs length:sizeof(offs) atIndex:27];
1255
+
1256
+ const int nth = MIN(1024, ne0);
1257
+
1258
+ [encoder dispatchThreadgroups:MTLSizeMake(ne11, ne12, ne13) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1259
+ } break;
1071
1260
  case GGML_OP_SCALE:
1072
1261
  {
1073
1262
  GGML_ASSERT(ggml_is_contiguous(src0));
@@ -1091,16 +1280,15 @@ void ggml_metal_graph_compute(
1091
1280
  } break;
1092
1281
  case GGML_OP_UNARY:
1093
1282
  switch (ggml_get_unary_op(gf->nodes[i])) {
1094
- case GGML_UNARY_OP_SILU:
1283
+ case GGML_UNARY_OP_TANH:
1095
1284
  {
1096
- [encoder setComputePipelineState:ctx->pipeline_silu];
1285
+ [encoder setComputePipelineState:ctx->pipeline_tanh];
1097
1286
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1098
1287
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1099
1288
 
1100
1289
  const int64_t n = ggml_nelements(dst);
1101
- GGML_ASSERT(n % 4 == 0);
1102
1290
 
1103
- [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1291
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1104
1292
  } break;
1105
1293
  case GGML_UNARY_OP_RELU:
1106
1294
  {
@@ -1121,6 +1309,28 @@ void ggml_metal_graph_compute(
1121
1309
  const int64_t n = ggml_nelements(dst);
1122
1310
  GGML_ASSERT(n % 4 == 0);
1123
1311
 
1312
+ [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1313
+ } break;
1314
+ case GGML_UNARY_OP_GELU_QUICK:
1315
+ {
1316
+ [encoder setComputePipelineState:ctx->pipeline_gelu_quick];
1317
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1318
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1319
+
1320
+ const int64_t n = ggml_nelements(dst);
1321
+ GGML_ASSERT(n % 4 == 0);
1322
+
1323
+ [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1324
+ } break;
1325
+ case GGML_UNARY_OP_SILU:
1326
+ {
1327
+ [encoder setComputePipelineState:ctx->pipeline_silu];
1328
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1329
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1330
+
1331
+ const int64_t n = ggml_nelements(dst);
1332
+ GGML_ASSERT(n % 4 == 0);
1333
+
1124
1334
  [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1125
1335
  } break;
1126
1336
  default:
@@ -1193,7 +1403,11 @@ void ggml_metal_graph_compute(
1193
1403
  const float scale = ((float *) dst->op_params)[0];
1194
1404
 
1195
1405
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1196
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1406
+ if (id_src1) {
1407
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1408
+ } else {
1409
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
1410
+ }
1197
1411
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1198
1412
  [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1199
1413
  [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
@@ -1444,7 +1658,7 @@ void ggml_metal_graph_compute(
1444
1658
  else if (src0t == GGML_TYPE_Q6_K) {
1445
1659
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1446
1660
  } else {
1447
- int64_t ny = (ne11 + nrows - 1)/nrows;
1661
+ const int64_t ny = (ne11 + nrows - 1)/nrows;
1448
1662
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1449
1663
  }
1450
1664
  }
@@ -1456,7 +1670,7 @@ void ggml_metal_graph_compute(
1456
1670
 
1457
1671
  GGML_ASSERT(src0t == GGML_TYPE_I32);
1458
1672
 
1459
- const int n_as = ne00;
1673
+ const int n_as = ((int32_t *) dst->op_params)[1];
1460
1674
 
1461
1675
  // TODO: make this more general
1462
1676
  GGML_ASSERT(n_as <= 8);
@@ -1488,14 +1702,22 @@ void ggml_metal_graph_compute(
1488
1702
 
1489
1703
  // find the break-even point where the matrix-matrix kernel becomes more efficient compared
1490
1704
  // to the matrix-vector kernel
1491
- int ne11_mm_min = 0;
1705
+ int ne11_mm_min = 1;
1492
1706
 
1493
1707
  const int idx = ((int32_t *) dst->op_params)[0];
1494
1708
 
1709
+ // batch size
1710
+ GGML_ASSERT(ne01 == ne11);
1711
+
1712
+ const int64_t _ne1 = 1; // kernel_mul_mm_impl needs a reference in constant memory
1713
+
1495
1714
  // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
1496
1715
  // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
1497
- if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
1498
- ne11 > ne11_mm_min) {
1716
+ // !!!
1717
+ // TODO: for now, always use mat-vec kernels until we figure out how to improve the
1718
+ // indirect matrix multiplication
1719
+ // !!!
1720
+ if ([ctx->device supportsFamily:MTLGPUFamilyApple7] && _ne1 > ne11_mm_min) {
1499
1721
  switch (src2->type) {
1500
1722
  case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f32_f32]; break;
1501
1723
  case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f16_f32]; break;
@@ -1514,19 +1736,22 @@ void ggml_metal_graph_compute(
1514
1736
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1515
1737
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1516
1738
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1517
- [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:3];
1518
- [encoder setBytes:&ne22 length:sizeof(ne22) atIndex:4];
1519
- [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:5];
1520
- [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:6];
1521
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
1522
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:8];
1523
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:9];
1524
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10];
1525
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11];
1526
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12];
1527
- [encoder setBytes:&r2 length:sizeof(r2) atIndex:13];
1528
- [encoder setBytes:&r3 length:sizeof(r3) atIndex:14];
1529
- [encoder setBytes:&idx length:sizeof(idx) atIndex:15];
1739
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3];
1740
+ [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
1741
+ [encoder setBytes:&ne22 length:sizeof(ne22) atIndex:5];
1742
+ [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6];
1743
+ [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:7];
1744
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:8];
1745
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:9];
1746
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:10];
1747
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:11];
1748
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:12];
1749
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
1750
+ [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:14];
1751
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
1752
+ [encoder setBytes:&r2 length:sizeof(r2) atIndex:16];
1753
+ [encoder setBytes:&r3 length:sizeof(r3) atIndex:17];
1754
+ [encoder setBytes:&idx length:sizeof(idx) atIndex:18];
1530
1755
  // TODO: how to make this an array? read Metal docs
1531
1756
  for (int j = 0; j < n_as; ++j) {
1532
1757
  struct ggml_tensor * src_cur = dst->src[2 + j];
@@ -1534,11 +1759,157 @@ void ggml_metal_graph_compute(
1534
1759
  size_t offs_src_cur = 0;
1535
1760
  id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
1536
1761
 
1537
- [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:16 + j];
1762
+ [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:19 + j];
1538
1763
  }
1539
1764
 
1540
1765
  [encoder setThreadgroupMemoryLength:8192 atIndex:0];
1541
- [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne21 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
1766
+
1767
+ // TODO: processing one row at a time (ne11 -> 1) is not efficient
1768
+ [encoder dispatchThreadgroups:MTLSizeMake( (_ne1 + 31)/32, (ne21 + 63)/64, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
1769
+ } else {
1770
+ int nth0 = 32;
1771
+ int nth1 = 1;
1772
+ int nrows = 1;
1773
+ //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
1774
+
1775
+ // use custom matrix x vector kernel
1776
+ switch (src2t) {
1777
+ case GGML_TYPE_F32:
1778
+ {
1779
+ GGML_ASSERT(src1t == GGML_TYPE_F32);
1780
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_f32_f32];
1781
+ } break;
1782
+ case GGML_TYPE_F16:
1783
+ {
1784
+ GGML_ASSERT(src1t == GGML_TYPE_F32);
1785
+ nth0 = 32;
1786
+ nth1 = 1;
1787
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_f16_f32];
1788
+ } break;
1789
+ case GGML_TYPE_Q4_0:
1790
+ {
1791
+ nth0 = 8;
1792
+ nth1 = 8;
1793
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_0_f32];
1794
+ } break;
1795
+ case GGML_TYPE_Q4_1:
1796
+ {
1797
+ nth0 = 8;
1798
+ nth1 = 8;
1799
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_1_f32];
1800
+ } break;
1801
+ case GGML_TYPE_Q5_0:
1802
+ {
1803
+ nth0 = 8;
1804
+ nth1 = 8;
1805
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_0_f32];
1806
+ } break;
1807
+ case GGML_TYPE_Q5_1:
1808
+ {
1809
+ nth0 = 8;
1810
+ nth1 = 8;
1811
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_1_f32];
1812
+ } break;
1813
+ case GGML_TYPE_Q8_0:
1814
+ {
1815
+ nth0 = 8;
1816
+ nth1 = 8;
1817
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q8_0_f32];
1818
+ } break;
1819
+ case GGML_TYPE_Q2_K:
1820
+ {
1821
+ nth0 = 2;
1822
+ nth1 = 32;
1823
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q2_K_f32];
1824
+ } break;
1825
+ case GGML_TYPE_Q3_K:
1826
+ {
1827
+ nth0 = 2;
1828
+ nth1 = 32;
1829
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q3_K_f32];
1830
+ } break;
1831
+ case GGML_TYPE_Q4_K:
1832
+ {
1833
+ nth0 = 4; //1;
1834
+ nth1 = 8; //32;
1835
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_K_f32];
1836
+ } break;
1837
+ case GGML_TYPE_Q5_K:
1838
+ {
1839
+ nth0 = 2;
1840
+ nth1 = 32;
1841
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_K_f32];
1842
+ } break;
1843
+ case GGML_TYPE_Q6_K:
1844
+ {
1845
+ nth0 = 2;
1846
+ nth1 = 32;
1847
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q6_K_f32];
1848
+ } break;
1849
+ default:
1850
+ {
1851
+ GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
1852
+ GGML_ASSERT(false && "not implemented");
1853
+ }
1854
+ };
1855
+
1856
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1857
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1858
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1859
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3];
1860
+ [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
1861
+ [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
1862
+ [encoder setBytes:&ne22 length:sizeof(ne22) atIndex:6];
1863
+ [encoder setBytes:&nb20 length:sizeof(nb20) atIndex:7];
1864
+ [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:8];
1865
+ [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:9];
1866
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
1867
+ [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:11];
1868
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
1869
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
1870
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
1871
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
1872
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
1873
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17];
1874
+ [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:18];
1875
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:19];
1876
+ [encoder setBytes:&r2 length:sizeof(r2) atIndex:20];
1877
+ [encoder setBytes:&r3 length:sizeof(r3) atIndex:21];
1878
+ [encoder setBytes:&idx length:sizeof(idx) atIndex:22];
1879
+ // TODO: how to make this an array? read Metal docs
1880
+ for (int j = 0; j < n_as; ++j) {
1881
+ struct ggml_tensor * src_cur = dst->src[2 + j];
1882
+
1883
+ size_t offs_src_cur = 0;
1884
+ id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
1885
+
1886
+ [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:23 + j];
1887
+ }
1888
+
1889
+ if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1 ||
1890
+ src2t == GGML_TYPE_Q5_0 || src2t == GGML_TYPE_Q5_1 || src2t == GGML_TYPE_Q8_0 ||
1891
+ src2t == GGML_TYPE_Q2_K) { // || src2t == GGML_TYPE_Q4_K) {
1892
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1893
+ }
1894
+ else if (src2t == GGML_TYPE_Q4_K) {
1895
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1896
+ }
1897
+ else if (src2t == GGML_TYPE_Q3_K) {
1898
+ #ifdef GGML_QKK_64
1899
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1900
+ #else
1901
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1902
+ #endif
1903
+ }
1904
+ else if (src2t == GGML_TYPE_Q5_K) {
1905
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1906
+ }
1907
+ else if (src2t == GGML_TYPE_Q6_K) {
1908
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1909
+ } else {
1910
+ const int64_t ny = (_ne1 + nrows - 1)/nrows;
1911
+ [encoder dispatchThreadgroups:MTLSizeMake(ne21, ny, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1912
+ }
1542
1913
  }
1543
1914
  } break;
1544
1915
  case GGML_OP_GET_ROWS:
@@ -1559,16 +1930,19 @@ void ggml_metal_graph_compute(
1559
1930
  default: GGML_ASSERT(false && "not implemented");
1560
1931
  }
1561
1932
 
1562
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1563
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1564
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1933
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1934
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1935
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1565
1936
  [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
1566
1937
  [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
1567
- [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:5];
1568
-
1569
- const int64_t n = ggml_nelements(src1);
1570
-
1571
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1938
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:5];
1939
+ [encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:6];
1940
+ [encoder setBytes:&nb10 length:sizeof( int64_t) atIndex:7];
1941
+ [encoder setBytes:&nb11 length:sizeof( int64_t) atIndex:8];
1942
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:9];
1943
+ [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:10];
1944
+
1945
+ [encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
1572
1946
  } break;
1573
1947
  case GGML_OP_RMS_NORM:
1574
1948
  {
@@ -1595,6 +1969,38 @@ void ggml_metal_graph_compute(
1595
1969
 
1596
1970
  [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1597
1971
  } break;
1972
+ case GGML_OP_GROUP_NORM:
1973
+ {
1974
+ GGML_ASSERT(ne00 % 4 == 0);
1975
+
1976
+ //float eps;
1977
+ //memcpy(&eps, dst->op_params, sizeof(float));
1978
+
1979
+ const float eps = 1e-6f; // TODO: temporarily hardcoded
1980
+
1981
+ const int32_t n_groups = ((int32_t *) dst->op_params)[0];
1982
+
1983
+ int nth = 32; // SIMD width
1984
+
1985
+ //while (nth < ne00/4 && nth < 1024) {
1986
+ // nth *= 2;
1987
+ //}
1988
+
1989
+ [encoder setComputePipelineState:ctx->pipeline_group_norm];
1990
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1991
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1992
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1993
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
1994
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
1995
+ [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:5];
1996
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:6];
1997
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:7];
1998
+ [encoder setBytes:&n_groups length:sizeof( int32_t) atIndex:8];
1999
+ [encoder setBytes:&eps length:sizeof( float) atIndex:9];
2000
+ [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
2001
+
2002
+ [encoder dispatchThreadgroups:MTLSizeMake(n_groups, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2003
+ } break;
1598
2004
  case GGML_OP_NORM:
1599
2005
  {
1600
2006
  float eps;
@@ -1764,6 +2170,65 @@ void ggml_metal_graph_compute(
1764
2170
 
1765
2171
  [encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
1766
2172
  } break;
2173
+ case GGML_OP_UPSCALE:
2174
+ {
2175
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
2176
+
2177
+ const int sf = dst->op_params[0];
2178
+
2179
+ [encoder setComputePipelineState:ctx->pipeline_upscale_f32];
2180
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2181
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2182
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
2183
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
2184
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
2185
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
2186
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
2187
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
2188
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
2189
+ [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
2190
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
2191
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
2192
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
2193
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
2194
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
2195
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
2196
+ [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
2197
+ [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
2198
+ [encoder setBytes:&sf length:sizeof(sf) atIndex:18];
2199
+
2200
+ const int nth = MIN(1024, ne0);
2201
+
2202
+ [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2203
+ } break;
2204
+ case GGML_OP_PAD:
2205
+ {
2206
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
2207
+
2208
+ [encoder setComputePipelineState:ctx->pipeline_pad_f32];
2209
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2210
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2211
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
2212
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
2213
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
2214
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
2215
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
2216
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
2217
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
2218
+ [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
2219
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
2220
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
2221
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
2222
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
2223
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
2224
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
2225
+ [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
2226
+ [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
2227
+
2228
+ const int nth = MIN(1024, ne0);
2229
+
2230
+ [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2231
+ } break;
1767
2232
  case GGML_OP_ARGSORT:
1768
2233
  {
1769
2234
  GGML_ASSERT(src0->type == GGML_TYPE_F32);
@@ -1785,6 +2250,22 @@ void ggml_metal_graph_compute(
1785
2250
 
1786
2251
  [encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00, 1, 1)];
1787
2252
  } break;
2253
+ case GGML_OP_LEAKY_RELU:
2254
+ {
2255
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
2256
+
2257
+ float slope;
2258
+ memcpy(&slope, dst->op_params, sizeof(float));
2259
+
2260
+ [encoder setComputePipelineState:ctx->pipeline_leaky_relu_f32];
2261
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2262
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2263
+ [encoder setBytes:&slope length:sizeof(slope) atIndex:2];
2264
+
2265
+ const int64_t n = ggml_nelements(dst);
2266
+
2267
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2268
+ } break;
1788
2269
  case GGML_OP_DUP:
1789
2270
  case GGML_OP_CPY:
1790
2271
  case GGML_OP_CONT:
@@ -1813,7 +2294,7 @@ void ggml_metal_graph_compute(
1813
2294
  {
1814
2295
  switch (dstt) {
1815
2296
  case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f16_f16]; break;
1816
- case GGML_TYPE_F32: GGML_ASSERT(false && "cpy_f16_f32 not implemented"); break;
2297
+ case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f16_f32]; break;
1817
2298
  default: GGML_ASSERT(false && "not implemented");
1818
2299
  };
1819
2300
  } break;