llama_cpp 0.10.0 → 0.10.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -66,9 +66,11 @@ struct ggml_metal_context {
66
66
  GGML_METAL_DECL_KERNEL(div_row);
67
67
  GGML_METAL_DECL_KERNEL(scale);
68
68
  GGML_METAL_DECL_KERNEL(scale_4);
69
- GGML_METAL_DECL_KERNEL(silu);
69
+ GGML_METAL_DECL_KERNEL(tanh);
70
70
  GGML_METAL_DECL_KERNEL(relu);
71
71
  GGML_METAL_DECL_KERNEL(gelu);
72
+ GGML_METAL_DECL_KERNEL(gelu_quick);
73
+ GGML_METAL_DECL_KERNEL(silu);
72
74
  GGML_METAL_DECL_KERNEL(soft_max);
73
75
  GGML_METAL_DECL_KERNEL(soft_max_4);
74
76
  GGML_METAL_DECL_KERNEL(diag_mask_inf);
@@ -86,6 +88,7 @@ struct ggml_metal_context {
86
88
  GGML_METAL_DECL_KERNEL(get_rows_q5_K);
87
89
  GGML_METAL_DECL_KERNEL(get_rows_q6_K);
88
90
  GGML_METAL_DECL_KERNEL(rms_norm);
91
+ GGML_METAL_DECL_KERNEL(group_norm);
89
92
  GGML_METAL_DECL_KERNEL(norm);
90
93
  GGML_METAL_DECL_KERNEL(mul_mv_f32_f32);
91
94
  GGML_METAL_DECL_KERNEL(mul_mv_f16_f16);
@@ -102,6 +105,21 @@ struct ggml_metal_context {
102
105
  GGML_METAL_DECL_KERNEL(mul_mv_q4_K_f32);
103
106
  GGML_METAL_DECL_KERNEL(mul_mv_q5_K_f32);
104
107
  GGML_METAL_DECL_KERNEL(mul_mv_q6_K_f32);
108
+ GGML_METAL_DECL_KERNEL(mul_mv_id_f32_f32);
109
+ //GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f16);
110
+ GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32);
111
+ //GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32_1row);
112
+ //GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32_l4);
113
+ GGML_METAL_DECL_KERNEL(mul_mv_id_q4_0_f32);
114
+ GGML_METAL_DECL_KERNEL(mul_mv_id_q4_1_f32);
115
+ GGML_METAL_DECL_KERNEL(mul_mv_id_q5_0_f32);
116
+ GGML_METAL_DECL_KERNEL(mul_mv_id_q5_1_f32);
117
+ GGML_METAL_DECL_KERNEL(mul_mv_id_q8_0_f32);
118
+ GGML_METAL_DECL_KERNEL(mul_mv_id_q2_K_f32);
119
+ GGML_METAL_DECL_KERNEL(mul_mv_id_q3_K_f32);
120
+ GGML_METAL_DECL_KERNEL(mul_mv_id_q4_K_f32);
121
+ GGML_METAL_DECL_KERNEL(mul_mv_id_q5_K_f32);
122
+ GGML_METAL_DECL_KERNEL(mul_mv_id_q6_K_f32);
105
123
  GGML_METAL_DECL_KERNEL(mul_mm_f32_f32);
106
124
  GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
107
125
  GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
@@ -130,8 +148,11 @@ struct ggml_metal_context {
130
148
  GGML_METAL_DECL_KERNEL(rope_f16);
131
149
  GGML_METAL_DECL_KERNEL(alibi_f32);
132
150
  GGML_METAL_DECL_KERNEL(im2col_f16);
151
+ GGML_METAL_DECL_KERNEL(upscale_f32);
152
+ GGML_METAL_DECL_KERNEL(pad_f32);
133
153
  GGML_METAL_DECL_KERNEL(argsort_f32_i32_asc);
134
154
  GGML_METAL_DECL_KERNEL(argsort_f32_i32_desc);
155
+ GGML_METAL_DECL_KERNEL(leaky_relu_f32);
135
156
  GGML_METAL_DECL_KERNEL(cpy_f32_f16);
136
157
  GGML_METAL_DECL_KERNEL(cpy_f32_f32);
137
158
  GGML_METAL_DECL_KERNEL(cpy_f32_q8_0);
@@ -140,6 +161,7 @@ struct ggml_metal_context {
140
161
  //GGML_METAL_DECL_KERNEL(cpy_f32_q5_0);
141
162
  //GGML_METAL_DECL_KERNEL(cpy_f32_q5_1);
142
163
  GGML_METAL_DECL_KERNEL(cpy_f16_f16);
164
+ GGML_METAL_DECL_KERNEL(cpy_f16_f32);
143
165
  GGML_METAL_DECL_KERNEL(concat);
144
166
  GGML_METAL_DECL_KERNEL(sqr);
145
167
  GGML_METAL_DECL_KERNEL(sum_rows);
@@ -177,6 +199,8 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
177
199
  ggml_metal_log_callback(level, buffer, ggml_metal_log_user_data);
178
200
  } else {
179
201
  char* buffer2 = malloc(len+1);
202
+ va_end(args);
203
+ va_start(args, format);
180
204
  vsnprintf(buffer2, len+1, format, args);
181
205
  buffer2[len] = 0;
182
206
  ggml_metal_log_callback(level, buffer2, ggml_metal_log_user_data);
@@ -316,9 +340,11 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
316
340
  GGML_METAL_ADD_KERNEL(div_row);
317
341
  GGML_METAL_ADD_KERNEL(scale);
318
342
  GGML_METAL_ADD_KERNEL(scale_4);
319
- GGML_METAL_ADD_KERNEL(silu);
343
+ GGML_METAL_ADD_KERNEL(tanh);
320
344
  GGML_METAL_ADD_KERNEL(relu);
321
345
  GGML_METAL_ADD_KERNEL(gelu);
346
+ GGML_METAL_ADD_KERNEL(gelu_quick);
347
+ GGML_METAL_ADD_KERNEL(silu);
322
348
  GGML_METAL_ADD_KERNEL(soft_max);
323
349
  GGML_METAL_ADD_KERNEL(soft_max_4);
324
350
  GGML_METAL_ADD_KERNEL(diag_mask_inf);
@@ -336,6 +362,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
336
362
  GGML_METAL_ADD_KERNEL(get_rows_q5_K);
337
363
  GGML_METAL_ADD_KERNEL(get_rows_q6_K);
338
364
  GGML_METAL_ADD_KERNEL(rms_norm);
365
+ GGML_METAL_ADD_KERNEL(group_norm);
339
366
  GGML_METAL_ADD_KERNEL(norm);
340
367
  GGML_METAL_ADD_KERNEL(mul_mv_f32_f32);
341
368
  GGML_METAL_ADD_KERNEL(mul_mv_f16_f16);
@@ -352,6 +379,21 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
352
379
  GGML_METAL_ADD_KERNEL(mul_mv_q4_K_f32);
353
380
  GGML_METAL_ADD_KERNEL(mul_mv_q5_K_f32);
354
381
  GGML_METAL_ADD_KERNEL(mul_mv_q6_K_f32);
382
+ GGML_METAL_ADD_KERNEL(mul_mv_id_f32_f32);
383
+ //GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f16);
384
+ GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32);
385
+ //GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32_1row);
386
+ //GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32_l4);
387
+ GGML_METAL_ADD_KERNEL(mul_mv_id_q4_0_f32);
388
+ GGML_METAL_ADD_KERNEL(mul_mv_id_q4_1_f32);
389
+ GGML_METAL_ADD_KERNEL(mul_mv_id_q5_0_f32);
390
+ GGML_METAL_ADD_KERNEL(mul_mv_id_q5_1_f32);
391
+ GGML_METAL_ADD_KERNEL(mul_mv_id_q8_0_f32);
392
+ GGML_METAL_ADD_KERNEL(mul_mv_id_q2_K_f32);
393
+ GGML_METAL_ADD_KERNEL(mul_mv_id_q3_K_f32);
394
+ GGML_METAL_ADD_KERNEL(mul_mv_id_q4_K_f32);
395
+ GGML_METAL_ADD_KERNEL(mul_mv_id_q5_K_f32);
396
+ GGML_METAL_ADD_KERNEL(mul_mv_id_q6_K_f32);
355
397
  if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
356
398
  GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
357
399
  GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
@@ -382,8 +424,11 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
382
424
  GGML_METAL_ADD_KERNEL(rope_f16);
383
425
  GGML_METAL_ADD_KERNEL(alibi_f32);
384
426
  GGML_METAL_ADD_KERNEL(im2col_f16);
427
+ GGML_METAL_ADD_KERNEL(upscale_f32);
428
+ GGML_METAL_ADD_KERNEL(pad_f32);
385
429
  GGML_METAL_ADD_KERNEL(argsort_f32_i32_asc);
386
430
  GGML_METAL_ADD_KERNEL(argsort_f32_i32_desc);
431
+ GGML_METAL_ADD_KERNEL(leaky_relu_f32);
387
432
  GGML_METAL_ADD_KERNEL(cpy_f32_f16);
388
433
  GGML_METAL_ADD_KERNEL(cpy_f32_f32);
389
434
  GGML_METAL_ADD_KERNEL(cpy_f32_q8_0);
@@ -392,6 +437,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
392
437
  //GGML_METAL_ADD_KERNEL(cpy_f32_q5_0);
393
438
  //GGML_METAL_ADD_KERNEL(cpy_f32_q5_1);
394
439
  GGML_METAL_ADD_KERNEL(cpy_f16_f16);
440
+ GGML_METAL_ADD_KERNEL(cpy_f16_f32);
395
441
  GGML_METAL_ADD_KERNEL(concat);
396
442
  GGML_METAL_ADD_KERNEL(sqr);
397
443
  GGML_METAL_ADD_KERNEL(sum_rows);
@@ -416,9 +462,11 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
416
462
  GGML_METAL_DEL_KERNEL(div_row);
417
463
  GGML_METAL_DEL_KERNEL(scale);
418
464
  GGML_METAL_DEL_KERNEL(scale_4);
419
- GGML_METAL_DEL_KERNEL(silu);
465
+ GGML_METAL_DEL_KERNEL(tanh);
420
466
  GGML_METAL_DEL_KERNEL(relu);
421
467
  GGML_METAL_DEL_KERNEL(gelu);
468
+ GGML_METAL_DEL_KERNEL(gelu_quick);
469
+ GGML_METAL_DEL_KERNEL(silu);
422
470
  GGML_METAL_DEL_KERNEL(soft_max);
423
471
  GGML_METAL_DEL_KERNEL(soft_max_4);
424
472
  GGML_METAL_DEL_KERNEL(diag_mask_inf);
@@ -436,6 +484,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
436
484
  GGML_METAL_DEL_KERNEL(get_rows_q5_K);
437
485
  GGML_METAL_DEL_KERNEL(get_rows_q6_K);
438
486
  GGML_METAL_DEL_KERNEL(rms_norm);
487
+ GGML_METAL_DEL_KERNEL(group_norm);
439
488
  GGML_METAL_DEL_KERNEL(norm);
440
489
  GGML_METAL_DEL_KERNEL(mul_mv_f32_f32);
441
490
  GGML_METAL_DEL_KERNEL(mul_mv_f16_f16);
@@ -452,6 +501,21 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
452
501
  GGML_METAL_DEL_KERNEL(mul_mv_q4_K_f32);
453
502
  GGML_METAL_DEL_KERNEL(mul_mv_q5_K_f32);
454
503
  GGML_METAL_DEL_KERNEL(mul_mv_q6_K_f32);
504
+ GGML_METAL_DEL_KERNEL(mul_mv_id_f32_f32);
505
+ //GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f16);
506
+ GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32);
507
+ //GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32_1row);
508
+ //GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32_l4);
509
+ GGML_METAL_DEL_KERNEL(mul_mv_id_q4_0_f32);
510
+ GGML_METAL_DEL_KERNEL(mul_mv_id_q4_1_f32);
511
+ GGML_METAL_DEL_KERNEL(mul_mv_id_q5_0_f32);
512
+ GGML_METAL_DEL_KERNEL(mul_mv_id_q5_1_f32);
513
+ GGML_METAL_DEL_KERNEL(mul_mv_id_q8_0_f32);
514
+ GGML_METAL_DEL_KERNEL(mul_mv_id_q2_K_f32);
515
+ GGML_METAL_DEL_KERNEL(mul_mv_id_q3_K_f32);
516
+ GGML_METAL_DEL_KERNEL(mul_mv_id_q4_K_f32);
517
+ GGML_METAL_DEL_KERNEL(mul_mv_id_q5_K_f32);
518
+ GGML_METAL_DEL_KERNEL(mul_mv_id_q6_K_f32);
455
519
  if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
456
520
  GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
457
521
  GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
@@ -482,8 +546,11 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
482
546
  GGML_METAL_DEL_KERNEL(rope_f16);
483
547
  GGML_METAL_DEL_KERNEL(alibi_f32);
484
548
  GGML_METAL_DEL_KERNEL(im2col_f16);
549
+ GGML_METAL_DEL_KERNEL(upscale_f32);
550
+ GGML_METAL_DEL_KERNEL(pad_f32);
485
551
  GGML_METAL_DEL_KERNEL(argsort_f32_i32_asc);
486
552
  GGML_METAL_DEL_KERNEL(argsort_f32_i32_desc);
553
+ GGML_METAL_DEL_KERNEL(leaky_relu_f32);
487
554
  GGML_METAL_DEL_KERNEL(cpy_f32_f16);
488
555
  GGML_METAL_DEL_KERNEL(cpy_f32_f32);
489
556
  GGML_METAL_DEL_KERNEL(cpy_f32_q8_0);
@@ -492,6 +559,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
492
559
  //GGML_METAL_DEL_KERNEL(cpy_f32_q5_0);
493
560
  //GGML_METAL_DEL_KERNEL(cpy_f32_q5_1);
494
561
  GGML_METAL_DEL_KERNEL(cpy_f16_f16);
562
+ GGML_METAL_DEL_KERNEL(cpy_f16_f32);
495
563
  GGML_METAL_DEL_KERNEL(concat);
496
564
  GGML_METAL_DEL_KERNEL(sqr);
497
565
  GGML_METAL_DEL_KERNEL(sum_rows);
@@ -793,9 +861,11 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
793
861
  switch (op->op) {
794
862
  case GGML_OP_UNARY:
795
863
  switch (ggml_get_unary_op(op)) {
796
- case GGML_UNARY_OP_SILU:
864
+ case GGML_UNARY_OP_TANH:
797
865
  case GGML_UNARY_OP_RELU:
798
866
  case GGML_UNARY_OP_GELU:
867
+ case GGML_UNARY_OP_GELU_QUICK:
868
+ case GGML_UNARY_OP_SILU:
799
869
  return true;
800
870
  default:
801
871
  return false;
@@ -807,6 +877,7 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
807
877
  case GGML_OP_PERMUTE:
808
878
  case GGML_OP_CONCAT:
809
879
  case GGML_OP_ADD:
880
+ case GGML_OP_ACC:
810
881
  case GGML_OP_MUL:
811
882
  case GGML_OP_DIV:
812
883
  case GGML_OP_SCALE:
@@ -814,21 +885,50 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
814
885
  case GGML_OP_SUM_ROWS:
815
886
  case GGML_OP_SOFT_MAX:
816
887
  case GGML_OP_RMS_NORM:
888
+ case GGML_OP_GROUP_NORM:
817
889
  case GGML_OP_NORM:
818
890
  case GGML_OP_ALIBI:
819
891
  case GGML_OP_ROPE:
820
892
  case GGML_OP_IM2COL:
893
+ case GGML_OP_UPSCALE:
894
+ case GGML_OP_PAD:
821
895
  case GGML_OP_ARGSORT:
822
- case GGML_OP_DUP:
823
- case GGML_OP_CPY:
824
- case GGML_OP_CONT:
896
+ case GGML_OP_LEAKY_RELU:
825
897
  case GGML_OP_MUL_MAT:
826
898
  case GGML_OP_MUL_MAT_ID:
827
899
  return true;
900
+ case GGML_OP_CPY:
901
+ case GGML_OP_DUP:
902
+ case GGML_OP_CONT:
903
+ {
904
+ switch (op->src[0]->type) {
905
+ case GGML_TYPE_F32:
906
+ switch (op->type) {
907
+ case GGML_TYPE_F16:
908
+ case GGML_TYPE_F32:
909
+ case GGML_TYPE_Q8_0:
910
+ case GGML_TYPE_Q4_0:
911
+ case GGML_TYPE_Q4_1:
912
+ return true;
913
+ default:
914
+ return false;
915
+ }
916
+ case GGML_TYPE_F16:
917
+ switch (op->type) {
918
+ case GGML_TYPE_F16:
919
+ case GGML_TYPE_F32:
920
+ return true;
921
+ default:
922
+ return false;
923
+ }
924
+ default:
925
+ return false;
926
+ };
927
+ }
828
928
  case GGML_OP_DIAG_MASK_INF:
829
929
  case GGML_OP_GET_ROWS:
830
930
  {
831
- return op->ne[0] % 4 == 0;
931
+ return op->ne[3] == 1;
832
932
  }
833
933
  default:
834
934
  return false;
@@ -904,7 +1004,10 @@ void ggml_metal_graph_compute(
904
1004
  } break;
905
1005
  }
906
1006
 
907
- GGML_ASSERT(ggml_metal_supports_op(dst));
1007
+ if (!ggml_metal_supports_op(dst)) {
1008
+ GGML_METAL_LOG_ERROR("%s: error: unsupported op '%s'\n", __func__, ggml_op_desc(dst));
1009
+ GGML_ASSERT(!"unsupported op");
1010
+ }
908
1011
 
909
1012
  const int64_t ne00 = src0 ? src0->ne[0] : 0;
910
1013
  const int64_t ne01 = src0 ? src0->ne[1] : 0;
@@ -1001,34 +1104,39 @@ void ggml_metal_graph_compute(
1001
1104
  case GGML_OP_MUL:
1002
1105
  case GGML_OP_DIV:
1003
1106
  {
1004
- GGML_ASSERT(ggml_is_contiguous(src0));
1005
- GGML_ASSERT(ggml_is_contiguous(src1));
1107
+ const size_t offs = 0;
1006
1108
 
1007
1109
  bool bcast_row = false;
1008
1110
 
1009
1111
  int64_t nb = ne00;
1010
1112
 
1011
- if (ggml_nelements(src1) == ne10 && ne00 % 4 == 0) {
1113
+ id<MTLComputePipelineState> pipeline = nil;
1114
+
1115
+ if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
1116
+ GGML_ASSERT(ggml_is_contiguous(src0));
1117
+
1012
1118
  // src1 is a row
1013
1119
  GGML_ASSERT(ne11 == 1);
1014
1120
 
1015
1121
  nb = ne00 / 4;
1016
1122
  switch (dst->op) {
1017
- case GGML_OP_ADD: [encoder setComputePipelineState:ctx->pipeline_add_row]; break;
1018
- case GGML_OP_MUL: [encoder setComputePipelineState:ctx->pipeline_mul_row]; break;
1019
- case GGML_OP_DIV: [encoder setComputePipelineState:ctx->pipeline_div_row]; break;
1123
+ case GGML_OP_ADD: pipeline = ctx->pipeline_add_row; break;
1124
+ case GGML_OP_MUL: pipeline = ctx->pipeline_mul_row; break;
1125
+ case GGML_OP_DIV: pipeline = ctx->pipeline_div_row; break;
1020
1126
  default: GGML_ASSERT(false);
1021
1127
  }
1022
1128
 
1023
1129
  bcast_row = true;
1024
1130
  } else {
1025
1131
  switch (dst->op) {
1026
- case GGML_OP_ADD: [encoder setComputePipelineState:ctx->pipeline_add]; break;
1027
- case GGML_OP_MUL: [encoder setComputePipelineState:ctx->pipeline_mul]; break;
1028
- case GGML_OP_DIV: [encoder setComputePipelineState:ctx->pipeline_div]; break;
1132
+ case GGML_OP_ADD: pipeline = ctx->pipeline_add; break;
1133
+ case GGML_OP_MUL: pipeline = ctx->pipeline_mul; break;
1134
+ case GGML_OP_DIV: pipeline = ctx->pipeline_div; break;
1029
1135
  default: GGML_ASSERT(false);
1030
1136
  }
1031
1137
  }
1138
+
1139
+ [encoder setComputePipelineState:pipeline];
1032
1140
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1033
1141
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1034
1142
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
@@ -1056,18 +1164,99 @@ void ggml_metal_graph_compute(
1056
1164
  [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
1057
1165
  [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
1058
1166
  [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
1059
- [encoder setBytes:&nb length:sizeof(nb) atIndex:27];
1167
+ [encoder setBytes:&offs length:sizeof(offs) atIndex:27];
1168
+ [encoder setBytes:&nb length:sizeof(nb) atIndex:28];
1060
1169
 
1061
1170
  if (bcast_row) {
1062
1171
  const int64_t n = ggml_nelements(dst)/4;
1063
1172
 
1064
1173
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1065
1174
  } else {
1066
- const int nth = MIN(1024, ne0);
1175
+ const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
1067
1176
 
1068
1177
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1069
1178
  }
1070
1179
  } break;
1180
+ case GGML_OP_ACC:
1181
+ {
1182
+ GGML_ASSERT(src0t == GGML_TYPE_F32);
1183
+ GGML_ASSERT(src1t == GGML_TYPE_F32);
1184
+ GGML_ASSERT(dstt == GGML_TYPE_F32);
1185
+
1186
+ GGML_ASSERT(ggml_is_contiguous(src0));
1187
+ GGML_ASSERT(ggml_is_contiguous(src1));
1188
+
1189
+ const size_t pnb1 = ((int32_t *) dst->op_params)[0];
1190
+ const size_t pnb2 = ((int32_t *) dst->op_params)[1];
1191
+ const size_t pnb3 = ((int32_t *) dst->op_params)[2];
1192
+ const size_t offs = ((int32_t *) dst->op_params)[3];
1193
+
1194
+ const bool inplace = (bool) ((int32_t *) dst->op_params)[4];
1195
+
1196
+ if (!inplace) {
1197
+ // run a separete kernel to cpy src->dst
1198
+ // not sure how to avoid this
1199
+ // TODO: make a simpler cpy_bytes kernel
1200
+
1201
+ const int nth = MIN(1024, ne00);
1202
+
1203
+ [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32];
1204
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1205
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1206
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1207
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
1208
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
1209
+ [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
1210
+ [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
1211
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
1212
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
1213
+ [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
1214
+ [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
1215
+ [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
1216
+ [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
1217
+ [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
1218
+ [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
1219
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
1220
+ [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
1221
+ [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
1222
+
1223
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1224
+ }
1225
+
1226
+ [encoder setComputePipelineState:ctx->pipeline_add];
1227
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1228
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1229
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1230
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1231
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
1232
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
1233
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
1234
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
1235
+ [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:8];
1236
+ [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:9];
1237
+ [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:10];
1238
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
1239
+ [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
1240
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
1241
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
1242
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
1243
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
1244
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
1245
+ [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
1246
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
1247
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
1248
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
1249
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
1250
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
1251
+ [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:24];
1252
+ [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:25];
1253
+ [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:26];
1254
+ [encoder setBytes:&offs length:sizeof(offs) atIndex:27];
1255
+
1256
+ const int nth = MIN(1024, ne0);
1257
+
1258
+ [encoder dispatchThreadgroups:MTLSizeMake(ne11, ne12, ne13) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1259
+ } break;
1071
1260
  case GGML_OP_SCALE:
1072
1261
  {
1073
1262
  GGML_ASSERT(ggml_is_contiguous(src0));
@@ -1091,16 +1280,15 @@ void ggml_metal_graph_compute(
1091
1280
  } break;
1092
1281
  case GGML_OP_UNARY:
1093
1282
  switch (ggml_get_unary_op(gf->nodes[i])) {
1094
- case GGML_UNARY_OP_SILU:
1283
+ case GGML_UNARY_OP_TANH:
1095
1284
  {
1096
- [encoder setComputePipelineState:ctx->pipeline_silu];
1285
+ [encoder setComputePipelineState:ctx->pipeline_tanh];
1097
1286
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1098
1287
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1099
1288
 
1100
1289
  const int64_t n = ggml_nelements(dst);
1101
- GGML_ASSERT(n % 4 == 0);
1102
1290
 
1103
- [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1291
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1104
1292
  } break;
1105
1293
  case GGML_UNARY_OP_RELU:
1106
1294
  {
@@ -1121,6 +1309,28 @@ void ggml_metal_graph_compute(
1121
1309
  const int64_t n = ggml_nelements(dst);
1122
1310
  GGML_ASSERT(n % 4 == 0);
1123
1311
 
1312
+ [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1313
+ } break;
1314
+ case GGML_UNARY_OP_GELU_QUICK:
1315
+ {
1316
+ [encoder setComputePipelineState:ctx->pipeline_gelu_quick];
1317
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1318
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1319
+
1320
+ const int64_t n = ggml_nelements(dst);
1321
+ GGML_ASSERT(n % 4 == 0);
1322
+
1323
+ [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1324
+ } break;
1325
+ case GGML_UNARY_OP_SILU:
1326
+ {
1327
+ [encoder setComputePipelineState:ctx->pipeline_silu];
1328
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1329
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1330
+
1331
+ const int64_t n = ggml_nelements(dst);
1332
+ GGML_ASSERT(n % 4 == 0);
1333
+
1124
1334
  [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1125
1335
  } break;
1126
1336
  default:
@@ -1193,7 +1403,11 @@ void ggml_metal_graph_compute(
1193
1403
  const float scale = ((float *) dst->op_params)[0];
1194
1404
 
1195
1405
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1196
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1406
+ if (id_src1) {
1407
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1408
+ } else {
1409
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
1410
+ }
1197
1411
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1198
1412
  [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1199
1413
  [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
@@ -1444,7 +1658,7 @@ void ggml_metal_graph_compute(
1444
1658
  else if (src0t == GGML_TYPE_Q6_K) {
1445
1659
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1446
1660
  } else {
1447
- int64_t ny = (ne11 + nrows - 1)/nrows;
1661
+ const int64_t ny = (ne11 + nrows - 1)/nrows;
1448
1662
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1449
1663
  }
1450
1664
  }
@@ -1456,7 +1670,7 @@ void ggml_metal_graph_compute(
1456
1670
 
1457
1671
  GGML_ASSERT(src0t == GGML_TYPE_I32);
1458
1672
 
1459
- const int n_as = ne00;
1673
+ const int n_as = ((int32_t *) dst->op_params)[1];
1460
1674
 
1461
1675
  // TODO: make this more general
1462
1676
  GGML_ASSERT(n_as <= 8);
@@ -1488,14 +1702,22 @@ void ggml_metal_graph_compute(
1488
1702
 
1489
1703
  // find the break-even point where the matrix-matrix kernel becomes more efficient compared
1490
1704
  // to the matrix-vector kernel
1491
- int ne11_mm_min = 0;
1705
+ int ne11_mm_min = 1;
1492
1706
 
1493
1707
  const int idx = ((int32_t *) dst->op_params)[0];
1494
1708
 
1709
+ // batch size
1710
+ GGML_ASSERT(ne01 == ne11);
1711
+
1712
+ const int64_t _ne1 = 1; // kernel_mul_mm_impl needs a reference in constant memory
1713
+
1495
1714
  // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
1496
1715
  // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
1497
- if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
1498
- ne11 > ne11_mm_min) {
1716
+ // !!!
1717
+ // TODO: for now, always use mat-vec kernels until we figure out how to improve the
1718
+ // indirect matrix multiplication
1719
+ // !!!
1720
+ if ([ctx->device supportsFamily:MTLGPUFamilyApple7] && _ne1 > ne11_mm_min) {
1499
1721
  switch (src2->type) {
1500
1722
  case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f32_f32]; break;
1501
1723
  case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f16_f32]; break;
@@ -1514,19 +1736,22 @@ void ggml_metal_graph_compute(
1514
1736
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1515
1737
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1516
1738
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1517
- [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:3];
1518
- [encoder setBytes:&ne22 length:sizeof(ne22) atIndex:4];
1519
- [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:5];
1520
- [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:6];
1521
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
1522
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:8];
1523
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:9];
1524
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10];
1525
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11];
1526
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12];
1527
- [encoder setBytes:&r2 length:sizeof(r2) atIndex:13];
1528
- [encoder setBytes:&r3 length:sizeof(r3) atIndex:14];
1529
- [encoder setBytes:&idx length:sizeof(idx) atIndex:15];
1739
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3];
1740
+ [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
1741
+ [encoder setBytes:&ne22 length:sizeof(ne22) atIndex:5];
1742
+ [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6];
1743
+ [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:7];
1744
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:8];
1745
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:9];
1746
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:10];
1747
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:11];
1748
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:12];
1749
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
1750
+ [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:14];
1751
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
1752
+ [encoder setBytes:&r2 length:sizeof(r2) atIndex:16];
1753
+ [encoder setBytes:&r3 length:sizeof(r3) atIndex:17];
1754
+ [encoder setBytes:&idx length:sizeof(idx) atIndex:18];
1530
1755
  // TODO: how to make this an array? read Metal docs
1531
1756
  for (int j = 0; j < n_as; ++j) {
1532
1757
  struct ggml_tensor * src_cur = dst->src[2 + j];
@@ -1534,11 +1759,157 @@ void ggml_metal_graph_compute(
1534
1759
  size_t offs_src_cur = 0;
1535
1760
  id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
1536
1761
 
1537
- [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:16 + j];
1762
+ [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:19 + j];
1538
1763
  }
1539
1764
 
1540
1765
  [encoder setThreadgroupMemoryLength:8192 atIndex:0];
1541
- [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne21 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
1766
+
1767
+ // TODO: processing one row at a time (ne11 -> 1) is not efficient
1768
+ [encoder dispatchThreadgroups:MTLSizeMake( (_ne1 + 31)/32, (ne21 + 63)/64, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
1769
+ } else {
1770
+ int nth0 = 32;
1771
+ int nth1 = 1;
1772
+ int nrows = 1;
1773
+ //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
1774
+
1775
+ // use custom matrix x vector kernel
1776
+ switch (src2t) {
1777
+ case GGML_TYPE_F32:
1778
+ {
1779
+ GGML_ASSERT(src1t == GGML_TYPE_F32);
1780
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_f32_f32];
1781
+ } break;
1782
+ case GGML_TYPE_F16:
1783
+ {
1784
+ GGML_ASSERT(src1t == GGML_TYPE_F32);
1785
+ nth0 = 32;
1786
+ nth1 = 1;
1787
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_f16_f32];
1788
+ } break;
1789
+ case GGML_TYPE_Q4_0:
1790
+ {
1791
+ nth0 = 8;
1792
+ nth1 = 8;
1793
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_0_f32];
1794
+ } break;
1795
+ case GGML_TYPE_Q4_1:
1796
+ {
1797
+ nth0 = 8;
1798
+ nth1 = 8;
1799
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_1_f32];
1800
+ } break;
1801
+ case GGML_TYPE_Q5_0:
1802
+ {
1803
+ nth0 = 8;
1804
+ nth1 = 8;
1805
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_0_f32];
1806
+ } break;
1807
+ case GGML_TYPE_Q5_1:
1808
+ {
1809
+ nth0 = 8;
1810
+ nth1 = 8;
1811
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_1_f32];
1812
+ } break;
1813
+ case GGML_TYPE_Q8_0:
1814
+ {
1815
+ nth0 = 8;
1816
+ nth1 = 8;
1817
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q8_0_f32];
1818
+ } break;
1819
+ case GGML_TYPE_Q2_K:
1820
+ {
1821
+ nth0 = 2;
1822
+ nth1 = 32;
1823
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q2_K_f32];
1824
+ } break;
1825
+ case GGML_TYPE_Q3_K:
1826
+ {
1827
+ nth0 = 2;
1828
+ nth1 = 32;
1829
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q3_K_f32];
1830
+ } break;
1831
+ case GGML_TYPE_Q4_K:
1832
+ {
1833
+ nth0 = 4; //1;
1834
+ nth1 = 8; //32;
1835
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_K_f32];
1836
+ } break;
1837
+ case GGML_TYPE_Q5_K:
1838
+ {
1839
+ nth0 = 2;
1840
+ nth1 = 32;
1841
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_K_f32];
1842
+ } break;
1843
+ case GGML_TYPE_Q6_K:
1844
+ {
1845
+ nth0 = 2;
1846
+ nth1 = 32;
1847
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q6_K_f32];
1848
+ } break;
1849
+ default:
1850
+ {
1851
+ GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
1852
+ GGML_ASSERT(false && "not implemented");
1853
+ }
1854
+ };
1855
+
1856
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1857
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1858
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1859
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3];
1860
+ [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
1861
+ [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
1862
+ [encoder setBytes:&ne22 length:sizeof(ne22) atIndex:6];
1863
+ [encoder setBytes:&nb20 length:sizeof(nb20) atIndex:7];
1864
+ [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:8];
1865
+ [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:9];
1866
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
1867
+ [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:11];
1868
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
1869
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
1870
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
1871
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
1872
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
1873
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17];
1874
+ [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:18];
1875
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:19];
1876
+ [encoder setBytes:&r2 length:sizeof(r2) atIndex:20];
1877
+ [encoder setBytes:&r3 length:sizeof(r3) atIndex:21];
1878
+ [encoder setBytes:&idx length:sizeof(idx) atIndex:22];
1879
+ // TODO: how to make this an array? read Metal docs
1880
+ for (int j = 0; j < n_as; ++j) {
1881
+ struct ggml_tensor * src_cur = dst->src[2 + j];
1882
+
1883
+ size_t offs_src_cur = 0;
1884
+ id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
1885
+
1886
+ [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:23 + j];
1887
+ }
1888
+
1889
+ if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1 ||
1890
+ src2t == GGML_TYPE_Q5_0 || src2t == GGML_TYPE_Q5_1 || src2t == GGML_TYPE_Q8_0 ||
1891
+ src2t == GGML_TYPE_Q2_K) { // || src2t == GGML_TYPE_Q4_K) {
1892
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1893
+ }
1894
+ else if (src2t == GGML_TYPE_Q4_K) {
1895
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1896
+ }
1897
+ else if (src2t == GGML_TYPE_Q3_K) {
1898
+ #ifdef GGML_QKK_64
1899
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1900
+ #else
1901
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1902
+ #endif
1903
+ }
1904
+ else if (src2t == GGML_TYPE_Q5_K) {
1905
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1906
+ }
1907
+ else if (src2t == GGML_TYPE_Q6_K) {
1908
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1909
+ } else {
1910
+ const int64_t ny = (_ne1 + nrows - 1)/nrows;
1911
+ [encoder dispatchThreadgroups:MTLSizeMake(ne21, ny, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1912
+ }
1542
1913
  }
1543
1914
  } break;
1544
1915
  case GGML_OP_GET_ROWS:
@@ -1559,16 +1930,19 @@ void ggml_metal_graph_compute(
1559
1930
  default: GGML_ASSERT(false && "not implemented");
1560
1931
  }
1561
1932
 
1562
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1563
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1564
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1933
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1934
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1935
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1565
1936
  [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
1566
1937
  [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
1567
- [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:5];
1568
-
1569
- const int64_t n = ggml_nelements(src1);
1570
-
1571
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1938
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:5];
1939
+ [encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:6];
1940
+ [encoder setBytes:&nb10 length:sizeof( int64_t) atIndex:7];
1941
+ [encoder setBytes:&nb11 length:sizeof( int64_t) atIndex:8];
1942
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:9];
1943
+ [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:10];
1944
+
1945
+ [encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
1572
1946
  } break;
1573
1947
  case GGML_OP_RMS_NORM:
1574
1948
  {
@@ -1595,6 +1969,38 @@ void ggml_metal_graph_compute(
1595
1969
 
1596
1970
  [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1597
1971
  } break;
1972
+ case GGML_OP_GROUP_NORM:
1973
+ {
1974
+ GGML_ASSERT(ne00 % 4 == 0);
1975
+
1976
+ //float eps;
1977
+ //memcpy(&eps, dst->op_params, sizeof(float));
1978
+
1979
+ const float eps = 1e-6f; // TODO: temporarily hardcoded
1980
+
1981
+ const int32_t n_groups = ((int32_t *) dst->op_params)[0];
1982
+
1983
+ int nth = 32; // SIMD width
1984
+
1985
+ //while (nth < ne00/4 && nth < 1024) {
1986
+ // nth *= 2;
1987
+ //}
1988
+
1989
+ [encoder setComputePipelineState:ctx->pipeline_group_norm];
1990
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1991
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1992
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1993
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
1994
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
1995
+ [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:5];
1996
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:6];
1997
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:7];
1998
+ [encoder setBytes:&n_groups length:sizeof( int32_t) atIndex:8];
1999
+ [encoder setBytes:&eps length:sizeof( float) atIndex:9];
2000
+ [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
2001
+
2002
+ [encoder dispatchThreadgroups:MTLSizeMake(n_groups, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2003
+ } break;
1598
2004
  case GGML_OP_NORM:
1599
2005
  {
1600
2006
  float eps;
@@ -1764,6 +2170,65 @@ void ggml_metal_graph_compute(
1764
2170
 
1765
2171
  [encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
1766
2172
  } break;
2173
+ case GGML_OP_UPSCALE:
2174
+ {
2175
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
2176
+
2177
+ const int sf = dst->op_params[0];
2178
+
2179
+ [encoder setComputePipelineState:ctx->pipeline_upscale_f32];
2180
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2181
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2182
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
2183
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
2184
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
2185
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
2186
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
2187
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
2188
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
2189
+ [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
2190
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
2191
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
2192
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
2193
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
2194
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
2195
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
2196
+ [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
2197
+ [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
2198
+ [encoder setBytes:&sf length:sizeof(sf) atIndex:18];
2199
+
2200
+ const int nth = MIN(1024, ne0);
2201
+
2202
+ [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2203
+ } break;
2204
+ case GGML_OP_PAD:
2205
+ {
2206
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
2207
+
2208
+ [encoder setComputePipelineState:ctx->pipeline_pad_f32];
2209
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2210
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2211
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
2212
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
2213
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
2214
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
2215
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
2216
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
2217
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
2218
+ [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
2219
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
2220
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
2221
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
2222
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
2223
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
2224
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
2225
+ [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
2226
+ [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
2227
+
2228
+ const int nth = MIN(1024, ne0);
2229
+
2230
+ [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2231
+ } break;
1767
2232
  case GGML_OP_ARGSORT:
1768
2233
  {
1769
2234
  GGML_ASSERT(src0->type == GGML_TYPE_F32);
@@ -1785,6 +2250,22 @@ void ggml_metal_graph_compute(
1785
2250
 
1786
2251
  [encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00, 1, 1)];
1787
2252
  } break;
2253
+ case GGML_OP_LEAKY_RELU:
2254
+ {
2255
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
2256
+
2257
+ float slope;
2258
+ memcpy(&slope, dst->op_params, sizeof(float));
2259
+
2260
+ [encoder setComputePipelineState:ctx->pipeline_leaky_relu_f32];
2261
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2262
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2263
+ [encoder setBytes:&slope length:sizeof(slope) atIndex:2];
2264
+
2265
+ const int64_t n = ggml_nelements(dst);
2266
+
2267
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2268
+ } break;
1788
2269
  case GGML_OP_DUP:
1789
2270
  case GGML_OP_CPY:
1790
2271
  case GGML_OP_CONT:
@@ -1813,7 +2294,7 @@ void ggml_metal_graph_compute(
1813
2294
  {
1814
2295
  switch (dstt) {
1815
2296
  case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f16_f16]; break;
1816
- case GGML_TYPE_F32: GGML_ASSERT(false && "cpy_f16_f32 not implemented"); break;
2297
+ case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f16_f32]; break;
1817
2298
  default: GGML_ASSERT(false && "not implemented");
1818
2299
  };
1819
2300
  } break;