whisper.rn 0.4.0-rc.5 → 0.4.0-rc.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/cpp/ggml-metal.m CHANGED
@@ -66,9 +66,11 @@ struct wsp_ggml_metal_context {
66
66
  WSP_GGML_METAL_DECL_KERNEL(div_row);
67
67
  WSP_GGML_METAL_DECL_KERNEL(scale);
68
68
  WSP_GGML_METAL_DECL_KERNEL(scale_4);
69
- WSP_GGML_METAL_DECL_KERNEL(silu);
69
+ WSP_GGML_METAL_DECL_KERNEL(tanh);
70
70
  WSP_GGML_METAL_DECL_KERNEL(relu);
71
71
  WSP_GGML_METAL_DECL_KERNEL(gelu);
72
+ WSP_GGML_METAL_DECL_KERNEL(gelu_quick);
73
+ WSP_GGML_METAL_DECL_KERNEL(silu);
72
74
  WSP_GGML_METAL_DECL_KERNEL(soft_max);
73
75
  WSP_GGML_METAL_DECL_KERNEL(soft_max_4);
74
76
  WSP_GGML_METAL_DECL_KERNEL(diag_mask_inf);
@@ -86,6 +88,7 @@ struct wsp_ggml_metal_context {
86
88
  WSP_GGML_METAL_DECL_KERNEL(get_rows_q5_K);
87
89
  WSP_GGML_METAL_DECL_KERNEL(get_rows_q6_K);
88
90
  WSP_GGML_METAL_DECL_KERNEL(rms_norm);
91
+ WSP_GGML_METAL_DECL_KERNEL(group_norm);
89
92
  WSP_GGML_METAL_DECL_KERNEL(norm);
90
93
  WSP_GGML_METAL_DECL_KERNEL(mul_mv_f32_f32);
91
94
  WSP_GGML_METAL_DECL_KERNEL(mul_mv_f16_f16);
@@ -102,6 +105,21 @@ struct wsp_ggml_metal_context {
102
105
  WSP_GGML_METAL_DECL_KERNEL(mul_mv_q4_K_f32);
103
106
  WSP_GGML_METAL_DECL_KERNEL(mul_mv_q5_K_f32);
104
107
  WSP_GGML_METAL_DECL_KERNEL(mul_mv_q6_K_f32);
108
+ WSP_GGML_METAL_DECL_KERNEL(mul_mv_id_f32_f32);
109
+ //WSP_GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f16);
110
+ WSP_GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32);
111
+ //WSP_GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32_1row);
112
+ //WSP_GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32_l4);
113
+ WSP_GGML_METAL_DECL_KERNEL(mul_mv_id_q4_0_f32);
114
+ WSP_GGML_METAL_DECL_KERNEL(mul_mv_id_q4_1_f32);
115
+ WSP_GGML_METAL_DECL_KERNEL(mul_mv_id_q5_0_f32);
116
+ WSP_GGML_METAL_DECL_KERNEL(mul_mv_id_q5_1_f32);
117
+ WSP_GGML_METAL_DECL_KERNEL(mul_mv_id_q8_0_f32);
118
+ WSP_GGML_METAL_DECL_KERNEL(mul_mv_id_q2_K_f32);
119
+ WSP_GGML_METAL_DECL_KERNEL(mul_mv_id_q3_K_f32);
120
+ WSP_GGML_METAL_DECL_KERNEL(mul_mv_id_q4_K_f32);
121
+ WSP_GGML_METAL_DECL_KERNEL(mul_mv_id_q5_K_f32);
122
+ WSP_GGML_METAL_DECL_KERNEL(mul_mv_id_q6_K_f32);
105
123
  WSP_GGML_METAL_DECL_KERNEL(mul_mm_f32_f32);
106
124
  WSP_GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
107
125
  WSP_GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
@@ -130,8 +148,11 @@ struct wsp_ggml_metal_context {
130
148
  WSP_GGML_METAL_DECL_KERNEL(rope_f16);
131
149
  WSP_GGML_METAL_DECL_KERNEL(alibi_f32);
132
150
  WSP_GGML_METAL_DECL_KERNEL(im2col_f16);
151
+ WSP_GGML_METAL_DECL_KERNEL(upscale_f32);
152
+ WSP_GGML_METAL_DECL_KERNEL(pad_f32);
133
153
  WSP_GGML_METAL_DECL_KERNEL(argsort_f32_i32_asc);
134
154
  WSP_GGML_METAL_DECL_KERNEL(argsort_f32_i32_desc);
155
+ WSP_GGML_METAL_DECL_KERNEL(leaky_relu_f32);
135
156
  WSP_GGML_METAL_DECL_KERNEL(cpy_f32_f16);
136
157
  WSP_GGML_METAL_DECL_KERNEL(cpy_f32_f32);
137
158
  WSP_GGML_METAL_DECL_KERNEL(cpy_f32_q8_0);
@@ -140,6 +161,7 @@ struct wsp_ggml_metal_context {
140
161
  //WSP_GGML_METAL_DECL_KERNEL(cpy_f32_q5_0);
141
162
  //WSP_GGML_METAL_DECL_KERNEL(cpy_f32_q5_1);
142
163
  WSP_GGML_METAL_DECL_KERNEL(cpy_f16_f16);
164
+ WSP_GGML_METAL_DECL_KERNEL(cpy_f16_f32);
143
165
  WSP_GGML_METAL_DECL_KERNEL(concat);
144
166
  WSP_GGML_METAL_DECL_KERNEL(sqr);
145
167
  WSP_GGML_METAL_DECL_KERNEL(sum_rows);
@@ -318,9 +340,11 @@ struct wsp_ggml_metal_context * wsp_ggml_metal_init(int n_cb) {
318
340
  WSP_GGML_METAL_ADD_KERNEL(div_row);
319
341
  WSP_GGML_METAL_ADD_KERNEL(scale);
320
342
  WSP_GGML_METAL_ADD_KERNEL(scale_4);
321
- WSP_GGML_METAL_ADD_KERNEL(silu);
343
+ WSP_GGML_METAL_ADD_KERNEL(tanh);
322
344
  WSP_GGML_METAL_ADD_KERNEL(relu);
323
345
  WSP_GGML_METAL_ADD_KERNEL(gelu);
346
+ WSP_GGML_METAL_ADD_KERNEL(gelu_quick);
347
+ WSP_GGML_METAL_ADD_KERNEL(silu);
324
348
  WSP_GGML_METAL_ADD_KERNEL(soft_max);
325
349
  WSP_GGML_METAL_ADD_KERNEL(soft_max_4);
326
350
  WSP_GGML_METAL_ADD_KERNEL(diag_mask_inf);
@@ -338,6 +362,7 @@ struct wsp_ggml_metal_context * wsp_ggml_metal_init(int n_cb) {
338
362
  WSP_GGML_METAL_ADD_KERNEL(get_rows_q5_K);
339
363
  WSP_GGML_METAL_ADD_KERNEL(get_rows_q6_K);
340
364
  WSP_GGML_METAL_ADD_KERNEL(rms_norm);
365
+ WSP_GGML_METAL_ADD_KERNEL(group_norm);
341
366
  WSP_GGML_METAL_ADD_KERNEL(norm);
342
367
  WSP_GGML_METAL_ADD_KERNEL(mul_mv_f32_f32);
343
368
  WSP_GGML_METAL_ADD_KERNEL(mul_mv_f16_f16);
@@ -354,6 +379,21 @@ struct wsp_ggml_metal_context * wsp_ggml_metal_init(int n_cb) {
354
379
  WSP_GGML_METAL_ADD_KERNEL(mul_mv_q4_K_f32);
355
380
  WSP_GGML_METAL_ADD_KERNEL(mul_mv_q5_K_f32);
356
381
  WSP_GGML_METAL_ADD_KERNEL(mul_mv_q6_K_f32);
382
+ WSP_GGML_METAL_ADD_KERNEL(mul_mv_id_f32_f32);
383
+ //WSP_GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f16);
384
+ WSP_GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32);
385
+ //WSP_GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32_1row);
386
+ //WSP_GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32_l4);
387
+ WSP_GGML_METAL_ADD_KERNEL(mul_mv_id_q4_0_f32);
388
+ WSP_GGML_METAL_ADD_KERNEL(mul_mv_id_q4_1_f32);
389
+ WSP_GGML_METAL_ADD_KERNEL(mul_mv_id_q5_0_f32);
390
+ WSP_GGML_METAL_ADD_KERNEL(mul_mv_id_q5_1_f32);
391
+ WSP_GGML_METAL_ADD_KERNEL(mul_mv_id_q8_0_f32);
392
+ WSP_GGML_METAL_ADD_KERNEL(mul_mv_id_q2_K_f32);
393
+ WSP_GGML_METAL_ADD_KERNEL(mul_mv_id_q3_K_f32);
394
+ WSP_GGML_METAL_ADD_KERNEL(mul_mv_id_q4_K_f32);
395
+ WSP_GGML_METAL_ADD_KERNEL(mul_mv_id_q5_K_f32);
396
+ WSP_GGML_METAL_ADD_KERNEL(mul_mv_id_q6_K_f32);
357
397
  if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
358
398
  WSP_GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
359
399
  WSP_GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
@@ -384,8 +424,11 @@ struct wsp_ggml_metal_context * wsp_ggml_metal_init(int n_cb) {
384
424
  WSP_GGML_METAL_ADD_KERNEL(rope_f16);
385
425
  WSP_GGML_METAL_ADD_KERNEL(alibi_f32);
386
426
  WSP_GGML_METAL_ADD_KERNEL(im2col_f16);
427
+ WSP_GGML_METAL_ADD_KERNEL(upscale_f32);
428
+ WSP_GGML_METAL_ADD_KERNEL(pad_f32);
387
429
  WSP_GGML_METAL_ADD_KERNEL(argsort_f32_i32_asc);
388
430
  WSP_GGML_METAL_ADD_KERNEL(argsort_f32_i32_desc);
431
+ WSP_GGML_METAL_ADD_KERNEL(leaky_relu_f32);
389
432
  WSP_GGML_METAL_ADD_KERNEL(cpy_f32_f16);
390
433
  WSP_GGML_METAL_ADD_KERNEL(cpy_f32_f32);
391
434
  WSP_GGML_METAL_ADD_KERNEL(cpy_f32_q8_0);
@@ -394,6 +437,7 @@ struct wsp_ggml_metal_context * wsp_ggml_metal_init(int n_cb) {
394
437
  //WSP_GGML_METAL_ADD_KERNEL(cpy_f32_q5_0);
395
438
  //WSP_GGML_METAL_ADD_KERNEL(cpy_f32_q5_1);
396
439
  WSP_GGML_METAL_ADD_KERNEL(cpy_f16_f16);
440
+ WSP_GGML_METAL_ADD_KERNEL(cpy_f16_f32);
397
441
  WSP_GGML_METAL_ADD_KERNEL(concat);
398
442
  WSP_GGML_METAL_ADD_KERNEL(sqr);
399
443
  WSP_GGML_METAL_ADD_KERNEL(sum_rows);
@@ -416,9 +460,11 @@ void wsp_ggml_metal_free(struct wsp_ggml_metal_context * ctx) {
416
460
  WSP_GGML_METAL_DEL_KERNEL(div_row);
417
461
  WSP_GGML_METAL_DEL_KERNEL(scale);
418
462
  WSP_GGML_METAL_DEL_KERNEL(scale_4);
419
- WSP_GGML_METAL_DEL_KERNEL(silu);
463
+ WSP_GGML_METAL_DEL_KERNEL(tanh);
420
464
  WSP_GGML_METAL_DEL_KERNEL(relu);
421
465
  WSP_GGML_METAL_DEL_KERNEL(gelu);
466
+ WSP_GGML_METAL_DEL_KERNEL(gelu_quick);
467
+ WSP_GGML_METAL_DEL_KERNEL(silu);
422
468
  WSP_GGML_METAL_DEL_KERNEL(soft_max);
423
469
  WSP_GGML_METAL_DEL_KERNEL(soft_max_4);
424
470
  WSP_GGML_METAL_DEL_KERNEL(diag_mask_inf);
@@ -436,6 +482,7 @@ void wsp_ggml_metal_free(struct wsp_ggml_metal_context * ctx) {
436
482
  WSP_GGML_METAL_DEL_KERNEL(get_rows_q5_K);
437
483
  WSP_GGML_METAL_DEL_KERNEL(get_rows_q6_K);
438
484
  WSP_GGML_METAL_DEL_KERNEL(rms_norm);
485
+ WSP_GGML_METAL_DEL_KERNEL(group_norm);
439
486
  WSP_GGML_METAL_DEL_KERNEL(norm);
440
487
  WSP_GGML_METAL_DEL_KERNEL(mul_mv_f32_f32);
441
488
  WSP_GGML_METAL_DEL_KERNEL(mul_mv_f16_f16);
@@ -452,6 +499,21 @@ void wsp_ggml_metal_free(struct wsp_ggml_metal_context * ctx) {
452
499
  WSP_GGML_METAL_DEL_KERNEL(mul_mv_q4_K_f32);
453
500
  WSP_GGML_METAL_DEL_KERNEL(mul_mv_q5_K_f32);
454
501
  WSP_GGML_METAL_DEL_KERNEL(mul_mv_q6_K_f32);
502
+ WSP_GGML_METAL_DEL_KERNEL(mul_mv_id_f32_f32);
503
+ //WSP_GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f16);
504
+ WSP_GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32);
505
+ //WSP_GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32_1row);
506
+ //WSP_GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32_l4);
507
+ WSP_GGML_METAL_DEL_KERNEL(mul_mv_id_q4_0_f32);
508
+ WSP_GGML_METAL_DEL_KERNEL(mul_mv_id_q4_1_f32);
509
+ WSP_GGML_METAL_DEL_KERNEL(mul_mv_id_q5_0_f32);
510
+ WSP_GGML_METAL_DEL_KERNEL(mul_mv_id_q5_1_f32);
511
+ WSP_GGML_METAL_DEL_KERNEL(mul_mv_id_q8_0_f32);
512
+ WSP_GGML_METAL_DEL_KERNEL(mul_mv_id_q2_K_f32);
513
+ WSP_GGML_METAL_DEL_KERNEL(mul_mv_id_q3_K_f32);
514
+ WSP_GGML_METAL_DEL_KERNEL(mul_mv_id_q4_K_f32);
515
+ WSP_GGML_METAL_DEL_KERNEL(mul_mv_id_q5_K_f32);
516
+ WSP_GGML_METAL_DEL_KERNEL(mul_mv_id_q6_K_f32);
455
517
  if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
456
518
  WSP_GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
457
519
  WSP_GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
@@ -482,8 +544,11 @@ void wsp_ggml_metal_free(struct wsp_ggml_metal_context * ctx) {
482
544
  WSP_GGML_METAL_DEL_KERNEL(rope_f16);
483
545
  WSP_GGML_METAL_DEL_KERNEL(alibi_f32);
484
546
  WSP_GGML_METAL_DEL_KERNEL(im2col_f16);
547
+ WSP_GGML_METAL_DEL_KERNEL(upscale_f32);
548
+ WSP_GGML_METAL_DEL_KERNEL(pad_f32);
485
549
  WSP_GGML_METAL_DEL_KERNEL(argsort_f32_i32_asc);
486
550
  WSP_GGML_METAL_DEL_KERNEL(argsort_f32_i32_desc);
551
+ WSP_GGML_METAL_DEL_KERNEL(leaky_relu_f32);
487
552
  WSP_GGML_METAL_DEL_KERNEL(cpy_f32_f16);
488
553
  WSP_GGML_METAL_DEL_KERNEL(cpy_f32_f32);
489
554
  WSP_GGML_METAL_DEL_KERNEL(cpy_f32_q8_0);
@@ -492,6 +557,7 @@ void wsp_ggml_metal_free(struct wsp_ggml_metal_context * ctx) {
492
557
  //WSP_GGML_METAL_DEL_KERNEL(cpy_f32_q5_0);
493
558
  //WSP_GGML_METAL_DEL_KERNEL(cpy_f32_q5_1);
494
559
  WSP_GGML_METAL_DEL_KERNEL(cpy_f16_f16);
560
+ WSP_GGML_METAL_DEL_KERNEL(cpy_f16_f32);
495
561
  WSP_GGML_METAL_DEL_KERNEL(concat);
496
562
  WSP_GGML_METAL_DEL_KERNEL(sqr);
497
563
  WSP_GGML_METAL_DEL_KERNEL(sum_rows);
@@ -783,9 +849,11 @@ static bool wsp_ggml_metal_supports_op(const struct wsp_ggml_tensor * op) {
783
849
  switch (op->op) {
784
850
  case WSP_GGML_OP_UNARY:
785
851
  switch (wsp_ggml_get_unary_op(op)) {
786
- case WSP_GGML_UNARY_OP_SILU:
852
+ case WSP_GGML_UNARY_OP_TANH:
787
853
  case WSP_GGML_UNARY_OP_RELU:
788
854
  case WSP_GGML_UNARY_OP_GELU:
855
+ case WSP_GGML_UNARY_OP_GELU_QUICK:
856
+ case WSP_GGML_UNARY_OP_SILU:
789
857
  return true;
790
858
  default:
791
859
  return false;
@@ -797,6 +865,7 @@ static bool wsp_ggml_metal_supports_op(const struct wsp_ggml_tensor * op) {
797
865
  case WSP_GGML_OP_PERMUTE:
798
866
  case WSP_GGML_OP_CONCAT:
799
867
  case WSP_GGML_OP_ADD:
868
+ case WSP_GGML_OP_ACC:
800
869
  case WSP_GGML_OP_MUL:
801
870
  case WSP_GGML_OP_DIV:
802
871
  case WSP_GGML_OP_SCALE:
@@ -804,21 +873,50 @@ static bool wsp_ggml_metal_supports_op(const struct wsp_ggml_tensor * op) {
804
873
  case WSP_GGML_OP_SUM_ROWS:
805
874
  case WSP_GGML_OP_SOFT_MAX:
806
875
  case WSP_GGML_OP_RMS_NORM:
876
+ case WSP_GGML_OP_GROUP_NORM:
807
877
  case WSP_GGML_OP_NORM:
808
878
  case WSP_GGML_OP_ALIBI:
809
879
  case WSP_GGML_OP_ROPE:
810
880
  case WSP_GGML_OP_IM2COL:
881
+ case WSP_GGML_OP_UPSCALE:
882
+ case WSP_GGML_OP_PAD:
811
883
  case WSP_GGML_OP_ARGSORT:
812
- case WSP_GGML_OP_DUP:
813
- case WSP_GGML_OP_CPY:
814
- case WSP_GGML_OP_CONT:
884
+ case WSP_GGML_OP_LEAKY_RELU:
815
885
  case WSP_GGML_OP_MUL_MAT:
816
886
  case WSP_GGML_OP_MUL_MAT_ID:
817
887
  return true;
888
+ case WSP_GGML_OP_CPY:
889
+ case WSP_GGML_OP_DUP:
890
+ case WSP_GGML_OP_CONT:
891
+ {
892
+ switch (op->src[0]->type) {
893
+ case WSP_GGML_TYPE_F32:
894
+ switch (op->type) {
895
+ case WSP_GGML_TYPE_F16:
896
+ case WSP_GGML_TYPE_F32:
897
+ case WSP_GGML_TYPE_Q8_0:
898
+ case WSP_GGML_TYPE_Q4_0:
899
+ case WSP_GGML_TYPE_Q4_1:
900
+ return true;
901
+ default:
902
+ return false;
903
+ }
904
+ case WSP_GGML_TYPE_F16:
905
+ switch (op->type) {
906
+ case WSP_GGML_TYPE_F16:
907
+ case WSP_GGML_TYPE_F32:
908
+ return true;
909
+ default:
910
+ return false;
911
+ }
912
+ default:
913
+ return false;
914
+ };
915
+ }
818
916
  case WSP_GGML_OP_DIAG_MASK_INF:
819
917
  case WSP_GGML_OP_GET_ROWS:
820
918
  {
821
- return op->ne[0] % 4 == 0;
919
+ return op->ne[3] == 1;
822
920
  }
823
921
  default:
824
922
  return false;
@@ -894,7 +992,10 @@ void wsp_ggml_metal_graph_compute(
894
992
  } break;
895
993
  }
896
994
 
897
- WSP_GGML_ASSERT(wsp_ggml_metal_supports_op(dst));
995
+ if (!wsp_ggml_metal_supports_op(dst)) {
996
+ WSP_GGML_METAL_LOG_ERROR("%s: error: unsupported op '%s'\n", __func__, wsp_ggml_op_desc(dst));
997
+ WSP_GGML_ASSERT(!"unsupported op");
998
+ }
898
999
 
899
1000
  const int64_t ne00 = src0 ? src0->ne[0] : 0;
900
1001
  const int64_t ne01 = src0 ? src0->ne[1] : 0;
@@ -991,34 +1092,39 @@ void wsp_ggml_metal_graph_compute(
991
1092
  case WSP_GGML_OP_MUL:
992
1093
  case WSP_GGML_OP_DIV:
993
1094
  {
994
- WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src0));
995
- WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src1));
1095
+ const size_t offs = 0;
996
1096
 
997
1097
  bool bcast_row = false;
998
1098
 
999
1099
  int64_t nb = ne00;
1000
1100
 
1001
- if (wsp_ggml_nelements(src1) == ne10 && ne00 % 4 == 0) {
1101
+ id<MTLComputePipelineState> pipeline = nil;
1102
+
1103
+ if (wsp_ggml_nelements(src1) == ne10 && wsp_ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
1104
+ WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src0));
1105
+
1002
1106
  // src1 is a row
1003
1107
  WSP_GGML_ASSERT(ne11 == 1);
1004
1108
 
1005
1109
  nb = ne00 / 4;
1006
1110
  switch (dst->op) {
1007
- case WSP_GGML_OP_ADD: [encoder setComputePipelineState:ctx->pipeline_add_row]; break;
1008
- case WSP_GGML_OP_MUL: [encoder setComputePipelineState:ctx->pipeline_mul_row]; break;
1009
- case WSP_GGML_OP_DIV: [encoder setComputePipelineState:ctx->pipeline_div_row]; break;
1111
+ case WSP_GGML_OP_ADD: pipeline = ctx->pipeline_add_row; break;
1112
+ case WSP_GGML_OP_MUL: pipeline = ctx->pipeline_mul_row; break;
1113
+ case WSP_GGML_OP_DIV: pipeline = ctx->pipeline_div_row; break;
1010
1114
  default: WSP_GGML_ASSERT(false);
1011
1115
  }
1012
1116
 
1013
1117
  bcast_row = true;
1014
1118
  } else {
1015
1119
  switch (dst->op) {
1016
- case WSP_GGML_OP_ADD: [encoder setComputePipelineState:ctx->pipeline_add]; break;
1017
- case WSP_GGML_OP_MUL: [encoder setComputePipelineState:ctx->pipeline_mul]; break;
1018
- case WSP_GGML_OP_DIV: [encoder setComputePipelineState:ctx->pipeline_div]; break;
1120
+ case WSP_GGML_OP_ADD: pipeline = ctx->pipeline_add; break;
1121
+ case WSP_GGML_OP_MUL: pipeline = ctx->pipeline_mul; break;
1122
+ case WSP_GGML_OP_DIV: pipeline = ctx->pipeline_div; break;
1019
1123
  default: WSP_GGML_ASSERT(false);
1020
1124
  }
1021
1125
  }
1126
+
1127
+ [encoder setComputePipelineState:pipeline];
1022
1128
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1023
1129
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1024
1130
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
@@ -1046,18 +1152,99 @@ void wsp_ggml_metal_graph_compute(
1046
1152
  [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
1047
1153
  [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
1048
1154
  [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
1049
- [encoder setBytes:&nb length:sizeof(nb) atIndex:27];
1155
+ [encoder setBytes:&offs length:sizeof(offs) atIndex:27];
1156
+ [encoder setBytes:&nb length:sizeof(nb) atIndex:28];
1050
1157
 
1051
1158
  if (bcast_row) {
1052
1159
  const int64_t n = wsp_ggml_nelements(dst)/4;
1053
1160
 
1054
1161
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1055
1162
  } else {
1056
- const int nth = MIN(1024, ne0);
1163
+ const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
1057
1164
 
1058
1165
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1059
1166
  }
1060
1167
  } break;
1168
+ case WSP_GGML_OP_ACC:
1169
+ {
1170
+ WSP_GGML_ASSERT(src0t == WSP_GGML_TYPE_F32);
1171
+ WSP_GGML_ASSERT(src1t == WSP_GGML_TYPE_F32);
1172
+ WSP_GGML_ASSERT(dstt == WSP_GGML_TYPE_F32);
1173
+
1174
+ WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src0));
1175
+ WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src1));
1176
+
1177
+ const size_t pnb1 = ((int32_t *) dst->op_params)[0];
1178
+ const size_t pnb2 = ((int32_t *) dst->op_params)[1];
1179
+ const size_t pnb3 = ((int32_t *) dst->op_params)[2];
1180
+ const size_t offs = ((int32_t *) dst->op_params)[3];
1181
+
1182
+ const bool inplace = (bool) ((int32_t *) dst->op_params)[4];
1183
+
1184
+ if (!inplace) {
1185
+ // run a separete kernel to cpy src->dst
1186
+ // not sure how to avoid this
1187
+ // TODO: make a simpler cpy_bytes kernel
1188
+
1189
+ const int nth = MIN(1024, ne00);
1190
+
1191
+ [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32];
1192
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1193
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1194
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1195
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
1196
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
1197
+ [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
1198
+ [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
1199
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
1200
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
1201
+ [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
1202
+ [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
1203
+ [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
1204
+ [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
1205
+ [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
1206
+ [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
1207
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
1208
+ [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
1209
+ [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
1210
+
1211
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1212
+ }
1213
+
1214
+ [encoder setComputePipelineState:ctx->pipeline_add];
1215
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1216
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1217
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1218
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1219
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
1220
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
1221
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
1222
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
1223
+ [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:8];
1224
+ [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:9];
1225
+ [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:10];
1226
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
1227
+ [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
1228
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
1229
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
1230
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
1231
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
1232
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
1233
+ [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
1234
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
1235
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
1236
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
1237
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
1238
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
1239
+ [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:24];
1240
+ [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:25];
1241
+ [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:26];
1242
+ [encoder setBytes:&offs length:sizeof(offs) atIndex:27];
1243
+
1244
+ const int nth = MIN(1024, ne0);
1245
+
1246
+ [encoder dispatchThreadgroups:MTLSizeMake(ne11, ne12, ne13) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1247
+ } break;
1061
1248
  case WSP_GGML_OP_SCALE:
1062
1249
  {
1063
1250
  WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src0));
@@ -1081,16 +1268,15 @@ void wsp_ggml_metal_graph_compute(
1081
1268
  } break;
1082
1269
  case WSP_GGML_OP_UNARY:
1083
1270
  switch (wsp_ggml_get_unary_op(gf->nodes[i])) {
1084
- case WSP_GGML_UNARY_OP_SILU:
1271
+ case WSP_GGML_UNARY_OP_TANH:
1085
1272
  {
1086
- [encoder setComputePipelineState:ctx->pipeline_silu];
1273
+ [encoder setComputePipelineState:ctx->pipeline_tanh];
1087
1274
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1088
1275
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1089
1276
 
1090
1277
  const int64_t n = wsp_ggml_nelements(dst);
1091
- WSP_GGML_ASSERT(n % 4 == 0);
1092
1278
 
1093
- [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1279
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1094
1280
  } break;
1095
1281
  case WSP_GGML_UNARY_OP_RELU:
1096
1282
  {
@@ -1111,6 +1297,28 @@ void wsp_ggml_metal_graph_compute(
1111
1297
  const int64_t n = wsp_ggml_nelements(dst);
1112
1298
  WSP_GGML_ASSERT(n % 4 == 0);
1113
1299
 
1300
+ [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1301
+ } break;
1302
+ case WSP_GGML_UNARY_OP_GELU_QUICK:
1303
+ {
1304
+ [encoder setComputePipelineState:ctx->pipeline_gelu_quick];
1305
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1306
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1307
+
1308
+ const int64_t n = wsp_ggml_nelements(dst);
1309
+ WSP_GGML_ASSERT(n % 4 == 0);
1310
+
1311
+ [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1312
+ } break;
1313
+ case WSP_GGML_UNARY_OP_SILU:
1314
+ {
1315
+ [encoder setComputePipelineState:ctx->pipeline_silu];
1316
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1317
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1318
+
1319
+ const int64_t n = wsp_ggml_nelements(dst);
1320
+ WSP_GGML_ASSERT(n % 4 == 0);
1321
+
1114
1322
  [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1115
1323
  } break;
1116
1324
  default:
@@ -1185,6 +1393,8 @@ void wsp_ggml_metal_graph_compute(
1185
1393
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1186
1394
  if (id_src1) {
1187
1395
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1396
+ } else {
1397
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
1188
1398
  }
1189
1399
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1190
1400
  [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
@@ -1436,7 +1646,7 @@ void wsp_ggml_metal_graph_compute(
1436
1646
  else if (src0t == WSP_GGML_TYPE_Q6_K) {
1437
1647
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1438
1648
  } else {
1439
- int64_t ny = (ne11 + nrows - 1)/nrows;
1649
+ const int64_t ny = (ne11 + nrows - 1)/nrows;
1440
1650
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1441
1651
  }
1442
1652
  }
@@ -1448,7 +1658,7 @@ void wsp_ggml_metal_graph_compute(
1448
1658
 
1449
1659
  WSP_GGML_ASSERT(src0t == WSP_GGML_TYPE_I32);
1450
1660
 
1451
- const int n_as = ne00;
1661
+ const int n_as = ((int32_t *) dst->op_params)[1];
1452
1662
 
1453
1663
  // TODO: make this more general
1454
1664
  WSP_GGML_ASSERT(n_as <= 8);
@@ -1480,14 +1690,22 @@ void wsp_ggml_metal_graph_compute(
1480
1690
 
1481
1691
  // find the break-even point where the matrix-matrix kernel becomes more efficient compared
1482
1692
  // to the matrix-vector kernel
1483
- int ne11_mm_min = 0;
1693
+ int ne11_mm_min = 1;
1484
1694
 
1485
1695
  const int idx = ((int32_t *) dst->op_params)[0];
1486
1696
 
1697
+ // batch size
1698
+ WSP_GGML_ASSERT(ne01 == ne11);
1699
+
1700
+ const int64_t _ne1 = 1; // kernel_mul_mm_impl needs a reference in constant memory
1701
+
1487
1702
  // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
1488
1703
  // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
1489
- if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
1490
- ne11 > ne11_mm_min) {
1704
+ // !!!
1705
+ // TODO: for now, always use mat-vec kernels until we figure out how to improve the
1706
+ // indirect matrix multiplication
1707
+ // !!!
1708
+ if ([ctx->device supportsFamily:MTLGPUFamilyApple7] && _ne1 > ne11_mm_min) {
1491
1709
  switch (src2->type) {
1492
1710
  case WSP_GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f32_f32]; break;
1493
1711
  case WSP_GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f16_f32]; break;
@@ -1506,19 +1724,22 @@ void wsp_ggml_metal_graph_compute(
1506
1724
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1507
1725
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1508
1726
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1509
- [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:3];
1510
- [encoder setBytes:&ne22 length:sizeof(ne22) atIndex:4];
1511
- [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:5];
1512
- [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:6];
1513
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
1514
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:8];
1515
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:9];
1516
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10];
1517
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11];
1518
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12];
1519
- [encoder setBytes:&r2 length:sizeof(r2) atIndex:13];
1520
- [encoder setBytes:&r3 length:sizeof(r3) atIndex:14];
1521
- [encoder setBytes:&idx length:sizeof(idx) atIndex:15];
1727
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3];
1728
+ [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
1729
+ [encoder setBytes:&ne22 length:sizeof(ne22) atIndex:5];
1730
+ [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6];
1731
+ [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:7];
1732
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:8];
1733
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:9];
1734
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:10];
1735
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:11];
1736
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:12];
1737
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
1738
+ [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:14];
1739
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
1740
+ [encoder setBytes:&r2 length:sizeof(r2) atIndex:16];
1741
+ [encoder setBytes:&r3 length:sizeof(r3) atIndex:17];
1742
+ [encoder setBytes:&idx length:sizeof(idx) atIndex:18];
1522
1743
  // TODO: how to make this an array? read Metal docs
1523
1744
  for (int j = 0; j < n_as; ++j) {
1524
1745
  struct wsp_ggml_tensor * src_cur = dst->src[2 + j];
@@ -1526,11 +1747,157 @@ void wsp_ggml_metal_graph_compute(
1526
1747
  size_t offs_src_cur = 0;
1527
1748
  id<MTLBuffer> id_src_cur = wsp_ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
1528
1749
 
1529
- [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:16 + j];
1750
+ [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:19 + j];
1530
1751
  }
1531
1752
 
1532
1753
  [encoder setThreadgroupMemoryLength:8192 atIndex:0];
1533
- [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne21 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
1754
+
1755
+ // TODO: processing one row at a time (ne11 -> 1) is not efficient
1756
+ [encoder dispatchThreadgroups:MTLSizeMake( (_ne1 + 31)/32, (ne21 + 63)/64, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
1757
+ } else {
1758
+ int nth0 = 32;
1759
+ int nth1 = 1;
1760
+ int nrows = 1;
1761
+ //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
1762
+
1763
+ // use custom matrix x vector kernel
1764
+ switch (src2t) {
1765
+ case WSP_GGML_TYPE_F32:
1766
+ {
1767
+ WSP_GGML_ASSERT(src1t == WSP_GGML_TYPE_F32);
1768
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_f32_f32];
1769
+ } break;
1770
+ case WSP_GGML_TYPE_F16:
1771
+ {
1772
+ WSP_GGML_ASSERT(src1t == WSP_GGML_TYPE_F32);
1773
+ nth0 = 32;
1774
+ nth1 = 1;
1775
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_f16_f32];
1776
+ } break;
1777
+ case WSP_GGML_TYPE_Q4_0:
1778
+ {
1779
+ nth0 = 8;
1780
+ nth1 = 8;
1781
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_0_f32];
1782
+ } break;
1783
+ case WSP_GGML_TYPE_Q4_1:
1784
+ {
1785
+ nth0 = 8;
1786
+ nth1 = 8;
1787
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_1_f32];
1788
+ } break;
1789
+ case WSP_GGML_TYPE_Q5_0:
1790
+ {
1791
+ nth0 = 8;
1792
+ nth1 = 8;
1793
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_0_f32];
1794
+ } break;
1795
+ case WSP_GGML_TYPE_Q5_1:
1796
+ {
1797
+ nth0 = 8;
1798
+ nth1 = 8;
1799
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_1_f32];
1800
+ } break;
1801
+ case WSP_GGML_TYPE_Q8_0:
1802
+ {
1803
+ nth0 = 8;
1804
+ nth1 = 8;
1805
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q8_0_f32];
1806
+ } break;
1807
+ case WSP_GGML_TYPE_Q2_K:
1808
+ {
1809
+ nth0 = 2;
1810
+ nth1 = 32;
1811
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q2_K_f32];
1812
+ } break;
1813
+ case WSP_GGML_TYPE_Q3_K:
1814
+ {
1815
+ nth0 = 2;
1816
+ nth1 = 32;
1817
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q3_K_f32];
1818
+ } break;
1819
+ case WSP_GGML_TYPE_Q4_K:
1820
+ {
1821
+ nth0 = 4; //1;
1822
+ nth1 = 8; //32;
1823
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_K_f32];
1824
+ } break;
1825
+ case WSP_GGML_TYPE_Q5_K:
1826
+ {
1827
+ nth0 = 2;
1828
+ nth1 = 32;
1829
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_K_f32];
1830
+ } break;
1831
+ case WSP_GGML_TYPE_Q6_K:
1832
+ {
1833
+ nth0 = 2;
1834
+ nth1 = 32;
1835
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q6_K_f32];
1836
+ } break;
1837
+ default:
1838
+ {
1839
+ WSP_GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
1840
+ WSP_GGML_ASSERT(false && "not implemented");
1841
+ }
1842
+ };
1843
+
1844
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1845
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1846
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1847
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3];
1848
+ [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
1849
+ [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
1850
+ [encoder setBytes:&ne22 length:sizeof(ne22) atIndex:6];
1851
+ [encoder setBytes:&nb20 length:sizeof(nb20) atIndex:7];
1852
+ [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:8];
1853
+ [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:9];
1854
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
1855
+ [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:11];
1856
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
1857
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
1858
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
1859
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
1860
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
1861
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17];
1862
+ [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:18];
1863
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:19];
1864
+ [encoder setBytes:&r2 length:sizeof(r2) atIndex:20];
1865
+ [encoder setBytes:&r3 length:sizeof(r3) atIndex:21];
1866
+ [encoder setBytes:&idx length:sizeof(idx) atIndex:22];
1867
+ // TODO: how to make this an array? read Metal docs
1868
+ for (int j = 0; j < n_as; ++j) {
1869
+ struct wsp_ggml_tensor * src_cur = dst->src[2 + j];
1870
+
1871
+ size_t offs_src_cur = 0;
1872
+ id<MTLBuffer> id_src_cur = wsp_ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
1873
+
1874
+ [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:23 + j];
1875
+ }
1876
+
1877
+ if (src2t == WSP_GGML_TYPE_Q4_0 || src2t == WSP_GGML_TYPE_Q4_1 ||
1878
+ src2t == WSP_GGML_TYPE_Q5_0 || src2t == WSP_GGML_TYPE_Q5_1 || src2t == WSP_GGML_TYPE_Q8_0 ||
1879
+ src2t == WSP_GGML_TYPE_Q2_K) { // || src2t == WSP_GGML_TYPE_Q4_K) {
1880
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1881
+ }
1882
+ else if (src2t == WSP_GGML_TYPE_Q4_K) {
1883
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1884
+ }
1885
+ else if (src2t == WSP_GGML_TYPE_Q3_K) {
1886
+ #ifdef WSP_GGML_QKK_64
1887
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1888
+ #else
1889
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1890
+ #endif
1891
+ }
1892
+ else if (src2t == WSP_GGML_TYPE_Q5_K) {
1893
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1894
+ }
1895
+ else if (src2t == WSP_GGML_TYPE_Q6_K) {
1896
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1897
+ } else {
1898
+ const int64_t ny = (_ne1 + nrows - 1)/nrows;
1899
+ [encoder dispatchThreadgroups:MTLSizeMake(ne21, ny, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1900
+ }
1534
1901
  }
1535
1902
  } break;
1536
1903
  case WSP_GGML_OP_GET_ROWS:
@@ -1551,16 +1918,19 @@ void wsp_ggml_metal_graph_compute(
1551
1918
  default: WSP_GGML_ASSERT(false && "not implemented");
1552
1919
  }
1553
1920
 
1554
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1555
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1556
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1921
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1922
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1923
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1557
1924
  [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
1558
1925
  [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
1559
- [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:5];
1560
-
1561
- const int64_t n = wsp_ggml_nelements(src1);
1562
-
1563
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1926
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:5];
1927
+ [encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:6];
1928
+ [encoder setBytes:&nb10 length:sizeof( int64_t) atIndex:7];
1929
+ [encoder setBytes:&nb11 length:sizeof( int64_t) atIndex:8];
1930
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:9];
1931
+ [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:10];
1932
+
1933
+ [encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
1564
1934
  } break;
1565
1935
  case WSP_GGML_OP_RMS_NORM:
1566
1936
  {
@@ -1587,6 +1957,38 @@ void wsp_ggml_metal_graph_compute(
1587
1957
 
1588
1958
  [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1589
1959
  } break;
1960
+ case WSP_GGML_OP_GROUP_NORM:
1961
+ {
1962
+ WSP_GGML_ASSERT(ne00 % 4 == 0);
1963
+
1964
+ //float eps;
1965
+ //memcpy(&eps, dst->op_params, sizeof(float));
1966
+
1967
+ const float eps = 1e-6f; // TODO: temporarily hardcoded
1968
+
1969
+ const int32_t n_groups = ((int32_t *) dst->op_params)[0];
1970
+
1971
+ int nth = 32; // SIMD width
1972
+
1973
+ //while (nth < ne00/4 && nth < 1024) {
1974
+ // nth *= 2;
1975
+ //}
1976
+
1977
+ [encoder setComputePipelineState:ctx->pipeline_group_norm];
1978
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1979
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1980
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1981
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
1982
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
1983
+ [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:5];
1984
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:6];
1985
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:7];
1986
+ [encoder setBytes:&n_groups length:sizeof( int32_t) atIndex:8];
1987
+ [encoder setBytes:&eps length:sizeof( float) atIndex:9];
1988
+ [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
1989
+
1990
+ [encoder dispatchThreadgroups:MTLSizeMake(n_groups, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1991
+ } break;
1590
1992
  case WSP_GGML_OP_NORM:
1591
1993
  {
1592
1994
  float eps;
@@ -1756,6 +2158,65 @@ void wsp_ggml_metal_graph_compute(
1756
2158
 
1757
2159
  [encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
1758
2160
  } break;
2161
+ case WSP_GGML_OP_UPSCALE:
2162
+ {
2163
+ WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F32);
2164
+
2165
+ const int sf = dst->op_params[0];
2166
+
2167
+ [encoder setComputePipelineState:ctx->pipeline_upscale_f32];
2168
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2169
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2170
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
2171
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
2172
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
2173
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
2174
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
2175
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
2176
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
2177
+ [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
2178
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
2179
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
2180
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
2181
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
2182
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
2183
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
2184
+ [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
2185
+ [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
2186
+ [encoder setBytes:&sf length:sizeof(sf) atIndex:18];
2187
+
2188
+ const int nth = MIN(1024, ne0);
2189
+
2190
+ [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2191
+ } break;
2192
+ case WSP_GGML_OP_PAD:
2193
+ {
2194
+ WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F32);
2195
+
2196
+ [encoder setComputePipelineState:ctx->pipeline_pad_f32];
2197
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2198
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2199
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
2200
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
2201
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
2202
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
2203
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
2204
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
2205
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
2206
+ [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
2207
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
2208
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
2209
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
2210
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
2211
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
2212
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
2213
+ [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
2214
+ [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
2215
+
2216
+ const int nth = MIN(1024, ne0);
2217
+
2218
+ [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2219
+ } break;
1759
2220
  case WSP_GGML_OP_ARGSORT:
1760
2221
  {
1761
2222
  WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F32);
@@ -1777,6 +2238,22 @@ void wsp_ggml_metal_graph_compute(
1777
2238
 
1778
2239
  [encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00, 1, 1)];
1779
2240
  } break;
2241
+ case WSP_GGML_OP_LEAKY_RELU:
2242
+ {
2243
+ WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F32);
2244
+
2245
+ float slope;
2246
+ memcpy(&slope, dst->op_params, sizeof(float));
2247
+
2248
+ [encoder setComputePipelineState:ctx->pipeline_leaky_relu_f32];
2249
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2250
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2251
+ [encoder setBytes:&slope length:sizeof(slope) atIndex:2];
2252
+
2253
+ const int64_t n = wsp_ggml_nelements(dst);
2254
+
2255
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2256
+ } break;
1780
2257
  case WSP_GGML_OP_DUP:
1781
2258
  case WSP_GGML_OP_CPY:
1782
2259
  case WSP_GGML_OP_CONT:
@@ -1805,7 +2282,7 @@ void wsp_ggml_metal_graph_compute(
1805
2282
  {
1806
2283
  switch (dstt) {
1807
2284
  case WSP_GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f16_f16]; break;
1808
- case WSP_GGML_TYPE_F32: WSP_GGML_ASSERT(false && "cpy_f16_f32 not implemented"); break;
2285
+ case WSP_GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f16_f32]; break;
1809
2286
  default: WSP_GGML_ASSERT(false && "not implemented");
1810
2287
  };
1811
2288
  } break;