whisper.rn 0.5.2 → 0.5.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. package/cpp/ggml-alloc.c +11 -4
  2. package/cpp/ggml-backend-reg.cpp +8 -0
  3. package/cpp/ggml-backend.cpp +0 -2
  4. package/cpp/ggml-cpu/arch/arm/quants.c +428 -26
  5. package/cpp/ggml-cpu/ggml-cpu-impl.h +3 -1
  6. package/cpp/ggml-cpu/ggml-cpu.c +50 -21
  7. package/cpp/ggml-cpu/ops.cpp +458 -349
  8. package/cpp/ggml-cpu/ops.h +4 -4
  9. package/cpp/ggml-cpu/repack.cpp +143 -29
  10. package/cpp/ggml-cpu/simd-mappings.h +25 -25
  11. package/cpp/ggml-cpu/unary-ops.cpp +16 -0
  12. package/cpp/ggml-cpu/unary-ops.h +2 -0
  13. package/cpp/ggml-cpu/vec.cpp +17 -0
  14. package/cpp/ggml-cpu/vec.h +10 -0
  15. package/cpp/ggml-impl.h +17 -1
  16. package/cpp/ggml-metal/ggml-metal-context.m +5 -6
  17. package/cpp/ggml-metal/ggml-metal-device.cpp +101 -4
  18. package/cpp/ggml-metal/ggml-metal-device.h +8 -1
  19. package/cpp/ggml-metal/ggml-metal-device.m +216 -14
  20. package/cpp/ggml-metal/ggml-metal-impl.h +90 -2
  21. package/cpp/ggml-metal/ggml-metal-ops.cpp +346 -85
  22. package/cpp/ggml-metal/ggml-metal-ops.h +2 -0
  23. package/cpp/ggml-metal/ggml-metal.cpp +5 -0
  24. package/cpp/ggml-metal/ggml-metal.metal +12436 -0
  25. package/cpp/ggml.c +154 -5
  26. package/cpp/ggml.h +73 -0
  27. package/cpp/whisper.cpp +5 -1
  28. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-impl.h +17 -1
  29. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml.h +73 -0
  30. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Info.plist +0 -0
  31. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/ggml-metal.metal +12436 -0
  32. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/rnwhisper +0 -0
  33. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +17 -1
  34. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +73 -0
  35. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
  36. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +1 -1
  37. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/ggml-metal.metal +12436 -0
  38. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
  39. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-impl.h +17 -1
  40. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml.h +73 -0
  41. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Info.plist +0 -0
  42. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/ggml-metal.metal +12436 -0
  43. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/rnwhisper +0 -0
  44. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +17 -1
  45. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +73 -0
  46. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
  47. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +1 -1
  48. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/ggml-metal.metal +12436 -0
  49. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
  50. package/package.json +1 -1
  51. package/whisper-rn.podspec +1 -1
  52. package/cpp/ggml-metal/ggml-whisper-sim.metallib +0 -0
  53. package/cpp/ggml-metal/ggml-whisper.metallib +0 -0
  54. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
  55. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
  56. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
  57. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
@@ -10,6 +10,8 @@
10
10
 
11
11
  #include <cassert>
12
12
  #include <algorithm>
13
+ #include <limits>
14
+ #include <cmath>
13
15
 
14
16
  static wsp_ggml_metal_buffer_id wsp_ggml_metal_get_buffer_id(const wsp_ggml_tensor * t) {
15
17
  if (!t) {
@@ -310,6 +312,10 @@ static int wsp_ggml_metal_op_encode_impl(wsp_ggml_metal_op_t ctx, int idx) {
310
312
  {
311
313
  n_fuse = wsp_ggml_metal_op_sum_rows(ctx, idx);
312
314
  } break;
315
+ case WSP_GGML_OP_CUMSUM:
316
+ {
317
+ n_fuse = wsp_ggml_metal_op_cumsum(ctx, idx);
318
+ } break;
313
319
  case WSP_GGML_OP_SOFT_MAX:
314
320
  {
315
321
  n_fuse = wsp_ggml_metal_op_soft_max(ctx, idx);
@@ -364,6 +370,10 @@ static int wsp_ggml_metal_op_encode_impl(wsp_ggml_metal_op_t ctx, int idx) {
364
370
  {
365
371
  n_fuse = wsp_ggml_metal_op_im2col(ctx, idx);
366
372
  } break;
373
+ case WSP_GGML_OP_CONV_2D:
374
+ {
375
+ n_fuse = wsp_ggml_metal_op_conv_2d(ctx, idx);
376
+ } break;
367
377
  case WSP_GGML_OP_CONV_TRANSPOSE_1D:
368
378
  {
369
379
  n_fuse = wsp_ggml_metal_op_conv_transpose_1d(ctx, idx);
@@ -534,7 +544,7 @@ int wsp_ggml_metal_op_repeat(wsp_ggml_metal_op_t ctx, int idx) {
534
544
  WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
535
545
  WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
536
546
  WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
537
- WSP_GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
547
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
538
548
 
539
549
  wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_repeat(lib, op->type);
540
550
 
@@ -580,7 +590,7 @@ int wsp_ggml_metal_op_acc(wsp_ggml_metal_op_t ctx, int idx) {
580
590
  WSP_GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
581
591
  WSP_GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
582
592
  WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
583
- WSP_GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
593
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
584
594
 
585
595
  WSP_GGML_ASSERT(op->src[0]->type == WSP_GGML_TYPE_F32);
586
596
  WSP_GGML_ASSERT(op->src[1]->type == WSP_GGML_TYPE_F32);
@@ -689,7 +699,7 @@ int wsp_ggml_metal_op_scale(wsp_ggml_metal_op_t ctx, int idx) {
689
699
  WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
690
700
  WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
691
701
  WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
692
- WSP_GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
702
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
693
703
 
694
704
  float scale;
695
705
  float bias;
@@ -728,7 +738,7 @@ int wsp_ggml_metal_op_clamp(wsp_ggml_metal_op_t ctx, int idx) {
728
738
  WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
729
739
  WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
730
740
  WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
731
- WSP_GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
741
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
732
742
 
733
743
  float min;
734
744
  float max;
@@ -767,7 +777,7 @@ int wsp_ggml_metal_op_unary(wsp_ggml_metal_op_t ctx, int idx) {
767
777
  WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
768
778
  WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
769
779
  WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
770
- WSP_GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
780
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
771
781
 
772
782
  int64_t n = wsp_ggml_nelements(op);
773
783
 
@@ -797,7 +807,7 @@ int wsp_ggml_metal_op_glu(wsp_ggml_metal_op_t ctx, int idx) {
797
807
  WSP_GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
798
808
  WSP_GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
799
809
  WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
800
- WSP_GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
810
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
801
811
 
802
812
  if (op->src[1]) {
803
813
  WSP_GGML_ASSERT(wsp_ggml_are_same_shape(op->src[0], op->src[1]));
@@ -829,18 +839,6 @@ int wsp_ggml_metal_op_glu(wsp_ggml_metal_op_t ctx, int idx) {
829
839
 
830
840
  const int32_t nth = std::min(wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne00/2);
831
841
 
832
- //[encoder setComputePipelineState:pipeline];
833
- //[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
834
- //if (src1) {
835
- // [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
836
- //} else {
837
- // [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
838
- //}
839
- //[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
840
- //[encoder setBytes:&args length:sizeof(args) atIndex:3];
841
-
842
- //[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
843
-
844
842
  wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
845
843
  wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
846
844
  wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[0]), 1);
@@ -902,7 +900,7 @@ int wsp_ggml_metal_op_sum_rows(wsp_ggml_metal_op_t ctx, int idx) {
902
900
  WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
903
901
  WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
904
902
  WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
905
- WSP_GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
903
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
906
904
 
907
905
  wsp_ggml_metal_kargs_sum_rows args = {
908
906
  /*.ne00 =*/ ne00,
@@ -936,14 +934,6 @@ int wsp_ggml_metal_op_sum_rows(wsp_ggml_metal_op_t ctx, int idx) {
936
934
 
937
935
  const size_t smem = wsp_ggml_metal_pipeline_get_smem(pipeline);
938
936
 
939
- //[encoder setComputePipelineState:pipeline];
940
- //[encoder setBytes:&args length:sizeof(args) atIndex:0];
941
- //[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
942
- //[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
943
- //[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
944
-
945
- //[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
946
-
947
937
  wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
948
938
  wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
949
939
  wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[0]), 1);
@@ -956,6 +946,149 @@ int wsp_ggml_metal_op_sum_rows(wsp_ggml_metal_op_t ctx, int idx) {
956
946
  return 1;
957
947
  }
958
948
 
949
+ int wsp_ggml_metal_op_cumsum(wsp_ggml_metal_op_t ctx, int idx) {
950
+ wsp_ggml_tensor * op = ctx->node(idx);
951
+
952
+ wsp_ggml_metal_library_t lib = ctx->lib;
953
+ wsp_ggml_metal_encoder_t enc = ctx->enc;
954
+
955
+ WSP_GGML_ASSERT(wsp_ggml_is_contiguous_rows(op->src[0]));
956
+
957
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
958
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
959
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
960
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
961
+
962
+ wsp_ggml_metal_pipeline_t pipeline_blk = wsp_ggml_metal_library_get_pipeline_cumsum_blk(lib, op);
963
+
964
+ int nth = 1;
965
+ while (nth < ne00 && 2*nth <= wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline_blk)) {
966
+ nth *= 2;
967
+ }
968
+
969
+ WSP_GGML_ASSERT(ne00 <= nth*nth);
970
+
971
+ const int64_t net0 = (ne00 + nth - 1) / nth;
972
+ const int64_t net1 = ne01;
973
+ const int64_t net2 = ne02;
974
+ const int64_t net3 = ne03;
975
+
976
+ const uint64_t nbt0 = sizeof(float);
977
+ const uint64_t nbt1 = net0*nbt0;
978
+ const uint64_t nbt2 = net1*nbt1;
979
+ const uint64_t nbt3 = net2*nbt2;
980
+
981
+ const size_t smem = WSP_GGML_PAD(32*sizeof(float), 16);
982
+
983
+ wsp_ggml_metal_buffer_id bid_src0 = wsp_ggml_metal_get_buffer_id(op->src[0]);
984
+ wsp_ggml_metal_buffer_id bid_dst = wsp_ggml_metal_get_buffer_id(op);
985
+
986
+ wsp_ggml_metal_buffer_id bid_tmp = bid_dst;
987
+ bid_tmp.offs += wsp_ggml_nbytes(op);
988
+
989
+ {
990
+ wsp_ggml_metal_kargs_cumsum_blk args = {
991
+ /*.ne00 =*/ ne00,
992
+ /*.ne01 =*/ ne01,
993
+ /*.ne02 =*/ ne02,
994
+ /*.ne03 =*/ ne03,
995
+ /*.nb00 =*/ nb00,
996
+ /*.nb01 =*/ nb01,
997
+ /*.nb02 =*/ nb02,
998
+ /*.nb03 =*/ nb03,
999
+ /*.net0 =*/ net0,
1000
+ /*.net1 =*/ net1,
1001
+ /*.net2 =*/ net2,
1002
+ /*.net3 =*/ net3,
1003
+ /*.nbt0 =*/ nbt0,
1004
+ /*.nbt1 =*/ nbt1,
1005
+ /*.nbt2 =*/ nbt2,
1006
+ /*.nbt3 =*/ nbt3,
1007
+ /*.outb =*/ ne00 > nth,
1008
+ };
1009
+
1010
+ wsp_ggml_metal_encoder_set_pipeline(enc, pipeline_blk);
1011
+ wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
1012
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
1013
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_tmp, 2);
1014
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_dst, 3);
1015
+
1016
+ wsp_ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
1017
+
1018
+ wsp_ggml_metal_encoder_dispatch_threadgroups(enc, net0*ne01, ne02, ne03, nth, 1, 1);
1019
+ }
1020
+
1021
+ if (ne00 > nth) {
1022
+ wsp_ggml_metal_op_concurrency_reset(ctx);
1023
+
1024
+ {
1025
+ wsp_ggml_metal_kargs_cumsum_blk args = {
1026
+ /*.ne00 =*/ net0,
1027
+ /*.ne01 =*/ net1,
1028
+ /*.ne02 =*/ net2,
1029
+ /*.ne03 =*/ net3,
1030
+ /*.nb00 =*/ nbt0,
1031
+ /*.nb01 =*/ nbt1,
1032
+ /*.nb02 =*/ nbt2,
1033
+ /*.nb03 =*/ nbt3,
1034
+ /*.net0 =*/ net0,
1035
+ /*.net1 =*/ net1,
1036
+ /*.net2 =*/ net2,
1037
+ /*.net3 =*/ net3,
1038
+ /*.nbt0 =*/ nbt0,
1039
+ /*.nbt1 =*/ nbt1,
1040
+ /*.nbt2 =*/ nbt2,
1041
+ /*.nbt3 =*/ nbt3,
1042
+ /*.outb =*/ false,
1043
+ };
1044
+
1045
+ wsp_ggml_metal_encoder_set_pipeline(enc, pipeline_blk);
1046
+ wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
1047
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_tmp, 1);
1048
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_tmp, 2);
1049
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_tmp, 3);
1050
+
1051
+ wsp_ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
1052
+
1053
+ wsp_ggml_metal_encoder_dispatch_threadgroups(enc, net1, net2, net3, nth, 1, 1);
1054
+ }
1055
+
1056
+ wsp_ggml_metal_op_concurrency_reset(ctx);
1057
+
1058
+ {
1059
+ wsp_ggml_metal_pipeline_t pipeline_add = wsp_ggml_metal_library_get_pipeline_cumsum_add(lib, op);
1060
+
1061
+ wsp_ggml_metal_kargs_cumsum_add args = {
1062
+ /*.ne00 =*/ ne00,
1063
+ /*.ne01 =*/ ne01,
1064
+ /*.ne02 =*/ ne02,
1065
+ /*.ne03 =*/ ne03,
1066
+ /*.nb00 =*/ nb00,
1067
+ /*.nb01 =*/ nb01,
1068
+ /*.nb02 =*/ nb02,
1069
+ /*.nb03 =*/ nb03,
1070
+ /*.net0 =*/ net0,
1071
+ /*.net1 =*/ net1,
1072
+ /*.net2 =*/ net2,
1073
+ /*.net3 =*/ net3,
1074
+ /*.nbt0 =*/ nbt0,
1075
+ /*.nbt1 =*/ nbt1,
1076
+ /*.nbt2 =*/ nbt2,
1077
+ /*.nbt3 =*/ nbt3,
1078
+ };
1079
+
1080
+ wsp_ggml_metal_encoder_set_pipeline(enc, pipeline_add);
1081
+ wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
1082
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_tmp, 1);
1083
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
1084
+
1085
+ wsp_ggml_metal_encoder_dispatch_threadgroups(enc, net0*ne01, ne02, ne03, nth, 1, 1);
1086
+ }
1087
+ }
1088
+
1089
+ return 1;
1090
+ }
1091
+
959
1092
  int wsp_ggml_metal_op_get_rows(wsp_ggml_metal_op_t ctx, int idx) {
960
1093
  wsp_ggml_tensor * op = ctx->node(idx);
961
1094
 
@@ -967,7 +1100,7 @@ int wsp_ggml_metal_op_get_rows(wsp_ggml_metal_op_t ctx, int idx) {
967
1100
  WSP_GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
968
1101
  WSP_GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
969
1102
  WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
970
- WSP_GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
1103
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
971
1104
 
972
1105
  wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_get_rows(lib, op->src[0]->type);
973
1106
 
@@ -1012,7 +1145,7 @@ int wsp_ggml_metal_op_set_rows(wsp_ggml_metal_op_t ctx, int idx) {
1012
1145
  WSP_GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
1013
1146
  WSP_GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
1014
1147
  WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
1015
- WSP_GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
1148
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
1016
1149
 
1017
1150
  wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_set_rows(lib, op->src[1]->type, op->type);
1018
1151
 
@@ -1076,7 +1209,7 @@ int wsp_ggml_metal_op_soft_max(wsp_ggml_metal_op_t ctx, int idx) {
1076
1209
  WSP_GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
1077
1210
  WSP_GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
1078
1211
  WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
1079
- WSP_GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
1212
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
1080
1213
 
1081
1214
  float scale;
1082
1215
  float max_bias;
@@ -1164,7 +1297,7 @@ int wsp_ggml_metal_op_ssm_conv(wsp_ggml_metal_op_t ctx, int idx) {
1164
1297
  WSP_GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
1165
1298
  WSP_GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
1166
1299
  WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
1167
- WSP_GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
1300
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
1168
1301
 
1169
1302
  wsp_ggml_metal_kargs_ssm_conv args = {
1170
1303
  /*.ne00 =*/ ne00,
@@ -1219,7 +1352,7 @@ int wsp_ggml_metal_op_ssm_scan(wsp_ggml_metal_op_t ctx, int idx) {
1219
1352
  WSP_GGML_TENSOR_LOCALS( int32_t, ne6, op->src[6], ne);
1220
1353
  WSP_GGML_TENSOR_LOCALS(uint64_t, nb6, op->src[6], nb);
1221
1354
  WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
1222
- WSP_GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
1355
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
1223
1356
 
1224
1357
  const wsp_ggml_tensor * src3 = op->src[3];
1225
1358
  const wsp_ggml_tensor * src4 = op->src[4];
@@ -1305,7 +1438,7 @@ int wsp_ggml_metal_op_rwkv(wsp_ggml_metal_op_t ctx, int idx) {
1305
1438
  WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
1306
1439
  WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
1307
1440
  WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
1308
- WSP_GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
1441
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
1309
1442
 
1310
1443
  const int64_t B = op->op == WSP_GGML_OP_RWKV_WKV6 ? op->src[5]->ne[1] : op->src[6]->ne[1];
1311
1444
  const int64_t T = op->src[0]->ne[2];
@@ -1346,7 +1479,7 @@ int wsp_ggml_metal_op_cpy(wsp_ggml_metal_op_t ctx, int idx) {
1346
1479
  WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
1347
1480
  WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
1348
1481
  WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
1349
- WSP_GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
1482
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
1350
1483
 
1351
1484
  wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type);
1352
1485
 
@@ -1419,7 +1552,7 @@ int wsp_ggml_metal_op_pool_2d(wsp_ggml_metal_op_t ctx, int idx) {
1419
1552
  WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
1420
1553
  WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
1421
1554
  WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
1422
- WSP_GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
1555
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
1423
1556
 
1424
1557
  const int32_t * opts = op->op_params;
1425
1558
  wsp_ggml_op_pool op_pool = (wsp_ggml_op_pool) opts[0];
@@ -1483,7 +1616,7 @@ int wsp_ggml_metal_op_mul_mat(wsp_ggml_metal_op_t ctx, int idx) {
1483
1616
  WSP_GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
1484
1617
  WSP_GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
1485
1618
  WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
1486
- WSP_GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
1619
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
1487
1620
 
1488
1621
  WSP_GGML_ASSERT(ne00 == ne10);
1489
1622
 
@@ -1724,7 +1857,7 @@ int wsp_ggml_metal_op_mul_mat_id(wsp_ggml_metal_op_t ctx, int idx) {
1724
1857
  WSP_GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
1725
1858
  WSP_GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
1726
1859
  WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
1727
- WSP_GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
1860
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
1728
1861
 
1729
1862
  // src2 = ids
1730
1863
  WSP_GGML_ASSERT(op->src[2]->type == WSP_GGML_TYPE_I32);
@@ -1970,7 +2103,9 @@ size_t wsp_ggml_metal_op_flash_attn_ext_extra_pad(const wsp_ggml_tensor * op) {
1970
2103
  const bool has_mask = op->src[3] != nullptr;
1971
2104
 
1972
2105
  if (wsp_ggml_metal_op_flash_attn_ext_use_vec(op)) {
1973
- const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_VEC_NCPSG != 0;
2106
+ // note: always reserve the padding space to avoid graph reallocations
2107
+ //const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_VEC_NCPSG != 0;
2108
+ const bool has_kvpad = true;
1974
2109
 
1975
2110
  if (has_kvpad) {
1976
2111
  res += OP_FLASH_ATTN_EXT_VEC_NCPSG*(
@@ -1979,7 +2114,8 @@ size_t wsp_ggml_metal_op_flash_attn_ext_extra_pad(const wsp_ggml_tensor * op) {
1979
2114
  (has_mask ? wsp_ggml_type_size(WSP_GGML_TYPE_F16)*ne31*ne32*ne33 : 0));
1980
2115
  }
1981
2116
  } else {
1982
- const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_NCPSG != 0;
2117
+ //const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_NCPSG != 0;
2118
+ const bool has_kvpad = true;
1983
2119
 
1984
2120
  if (has_kvpad) {
1985
2121
  res += OP_FLASH_ATTN_EXT_NCPSG*(
@@ -2015,9 +2151,10 @@ size_t wsp_ggml_metal_op_flash_attn_ext_extra_blk(const wsp_ggml_tensor * op) {
2015
2151
  const bool is_vec = wsp_ggml_metal_op_flash_attn_ext_use_vec(op);
2016
2152
 
2017
2153
  // this optimization is not useful for the vector kernels
2018
- if (is_vec) {
2019
- return res;
2020
- }
2154
+ // note: always reserve the blk buffer to avoid graph reallocations
2155
+ //if (is_vec) {
2156
+ // return res;
2157
+ //}
2021
2158
 
2022
2159
  const int nqptg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NQPTG : OP_FLASH_ATTN_EXT_NQPTG;
2023
2160
  const int ncpsg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NCPSG : OP_FLASH_ATTN_EXT_NCPSG;
@@ -2044,13 +2181,16 @@ size_t wsp_ggml_metal_op_flash_attn_ext_extra_tmp(const wsp_ggml_tensor * op) {
2044
2181
 
2045
2182
  size_t res = 0;
2046
2183
 
2047
- if (wsp_ggml_metal_op_flash_attn_ext_use_vec(op)) {
2184
+ // note: always reserve the temp buffer to avoid graph reallocations
2185
+ //if (wsp_ggml_metal_op_flash_attn_ext_use_vec(op)) {
2186
+ if (true) {
2048
2187
  const int64_t nwg = 32;
2188
+ const int64_t ne01_max = std::min(ne01, 32);
2049
2189
 
2050
2190
  // temp buffer for writing the results from each workgroup
2051
2191
  // - ne20: the size of the Value head
2052
2192
  // - + 2: the S and M values for each intermediate result
2053
- res += wsp_ggml_type_size(WSP_GGML_TYPE_F32)*(ne01*ne02*ne03*nwg*(ne20 + 2));
2193
+ res += wsp_ggml_type_size(WSP_GGML_TYPE_F32)*(ne01_max*ne02*ne03*nwg*(ne20 + 2));
2054
2194
  }
2055
2195
 
2056
2196
  return res;
@@ -2179,8 +2319,6 @@ int wsp_ggml_metal_op_flash_attn_ext(wsp_ggml_metal_op_t ctx, int idx) {
2179
2319
  wsp_ggml_metal_encoder_dispatch_threadgroups(enc, ncpsg, std::max(ne12, ne32), std::max(ne13, ne33), 32, 1, 1);
2180
2320
 
2181
2321
  need_sync = true;
2182
- } else {
2183
- assert(wsp_ggml_metal_op_flash_attn_ext_extra_pad(op) == 0);
2184
2322
  }
2185
2323
 
2186
2324
  if (has_mask) {
@@ -2210,8 +2348,6 @@ int wsp_ggml_metal_op_flash_attn_ext(wsp_ggml_metal_op_t ctx, int idx) {
2210
2348
  wsp_ggml_metal_encoder_dispatch_threadgroups(enc, nblk0, nblk1, ne32*ne33, 32, 1, 1);
2211
2349
 
2212
2350
  need_sync = true;
2213
- } else {
2214
- assert(wsp_ggml_metal_op_flash_attn_ext_extra_blk(op) == 0);
2215
2351
  }
2216
2352
 
2217
2353
  if (need_sync) {
@@ -2351,8 +2487,6 @@ int wsp_ggml_metal_op_flash_attn_ext(wsp_ggml_metal_op_t ctx, int idx) {
2351
2487
  wsp_ggml_metal_encoder_dispatch_threadgroups(enc, ncpsg, std::max(ne12, ne32), std::max(ne13, ne33), 32, 1, 1);
2352
2488
 
2353
2489
  need_sync = true;
2354
- } else {
2355
- assert(wsp_ggml_metal_op_flash_attn_ext_extra_pad(op) == 0);
2356
2490
  }
2357
2491
 
2358
2492
  if (need_sync) {
@@ -2683,7 +2817,7 @@ int wsp_ggml_metal_op_l2_norm(wsp_ggml_metal_op_t ctx, int idx) {
2683
2817
  WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
2684
2818
  WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
2685
2819
  WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
2686
- WSP_GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
2820
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
2687
2821
 
2688
2822
  float eps;
2689
2823
  memcpy(&eps, op->op_params, sizeof(float));
@@ -2731,7 +2865,7 @@ int wsp_ggml_metal_op_group_norm(wsp_ggml_metal_op_t ctx, int idx) {
2731
2865
  WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
2732
2866
  WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
2733
2867
  WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
2734
- WSP_GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
2868
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
2735
2869
 
2736
2870
  const int32_t ngrp = ((const int32_t *) op->op_params)[0];
2737
2871
 
@@ -2786,7 +2920,7 @@ int wsp_ggml_metal_op_norm(wsp_ggml_metal_op_t ctx, int idx) {
2786
2920
  WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
2787
2921
  WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
2788
2922
  WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
2789
- WSP_GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
2923
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
2790
2924
 
2791
2925
  float eps;
2792
2926
  memcpy(&eps, op->op_params, sizeof(float));
@@ -2922,7 +3056,7 @@ int wsp_ggml_metal_op_rope(wsp_ggml_metal_op_t ctx, int idx) {
2922
3056
  WSP_GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
2923
3057
  WSP_GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
2924
3058
  WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
2925
- WSP_GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
3059
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
2926
3060
 
2927
3061
  // make sure we have one or more position id(ne10) per token(ne02)
2928
3062
  WSP_GGML_ASSERT(ne10 % ne02 == 0);
@@ -3016,7 +3150,7 @@ int wsp_ggml_metal_op_im2col(wsp_ggml_metal_op_t ctx, int idx) {
3016
3150
  WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3017
3151
  WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3018
3152
  WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3019
- WSP_GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
3153
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
3020
3154
 
3021
3155
  const int32_t s0 = ((const int32_t *)(op->op_params))[0];
3022
3156
  const int32_t s1 = ((const int32_t *)(op->op_params))[1];
@@ -3077,6 +3211,84 @@ int wsp_ggml_metal_op_im2col(wsp_ggml_metal_op_t ctx, int idx) {
3077
3211
  return 1;
3078
3212
  }
3079
3213
 
3214
+ int wsp_ggml_metal_op_conv_2d(wsp_ggml_metal_op_t ctx, int idx) {
3215
+ wsp_ggml_tensor * op = ctx->node(idx);
3216
+
3217
+ wsp_ggml_metal_library_t lib = ctx->lib;
3218
+ wsp_ggml_metal_encoder_t enc = ctx->enc;
3219
+
3220
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3221
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3222
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
3223
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
3224
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3225
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
3226
+
3227
+ WSP_GGML_ASSERT(wsp_ggml_is_contiguous(op->src[0]));
3228
+ WSP_GGML_ASSERT(op->src[1]->type == WSP_GGML_TYPE_F32);
3229
+ WSP_GGML_ASSERT(op->type == WSP_GGML_TYPE_F32);
3230
+ WSP_GGML_ASSERT(op->src[0]->type == WSP_GGML_TYPE_F16 || op->src[0]->type == WSP_GGML_TYPE_F32);
3231
+
3232
+ const int32_t s0 = ((const int32_t *) op->op_params)[0];
3233
+ const int32_t s1 = ((const int32_t *) op->op_params)[1];
3234
+ const int32_t p0 = ((const int32_t *) op->op_params)[2];
3235
+ const int32_t p1 = ((const int32_t *) op->op_params)[3];
3236
+ const int32_t d0 = ((const int32_t *) op->op_params)[4];
3237
+ const int32_t d1 = ((const int32_t *) op->op_params)[5];
3238
+
3239
+ wsp_ggml_metal_kargs_conv_2d args = {
3240
+ /*.nb00 =*/ nb00,
3241
+ /*.nb01 =*/ nb01,
3242
+ /*.nb02 =*/ nb02,
3243
+ /*.nb03 =*/ nb03,
3244
+ /*.nb10 =*/ nb10,
3245
+ /*.nb11 =*/ nb11,
3246
+ /*.nb12 =*/ nb12,
3247
+ /*.nb13 =*/ nb13,
3248
+ /*.nb0 =*/ nb0,
3249
+ /*.nb1 =*/ nb1,
3250
+ /*.nb2 =*/ nb2,
3251
+ /*.nb3 =*/ nb3,
3252
+ /*.IW =*/ ne10,
3253
+ /*.IH =*/ ne11,
3254
+ /*.KW =*/ ne00,
3255
+ /*.KH =*/ ne01,
3256
+ /*.IC =*/ ne02,
3257
+ /*.OC =*/ ne03,
3258
+ /*.OW =*/ ne0,
3259
+ /*.OH =*/ ne1,
3260
+ /*.N =*/ ne3,
3261
+ /*.s0 =*/ s0,
3262
+ /*.s1 =*/ s1,
3263
+ /*.p0 =*/ p0,
3264
+ /*.p1 =*/ p1,
3265
+ /*.d0 =*/ d0,
3266
+ /*.d1 =*/ d1,
3267
+ };
3268
+
3269
+ wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_conv_2d(lib, op);
3270
+
3271
+ int nth = wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline);
3272
+ nth = std::min(nth, 256);
3273
+ nth = std::max(nth, 1);
3274
+
3275
+ const uint64_t n_out = wsp_ggml_nelements(op);
3276
+
3277
+ uint64_t tg = (n_out + nth - 1)/nth;
3278
+ tg = std::max<uint64_t>(tg, 1);
3279
+ tg = std::min<uint64_t>(tg, (uint64_t) std::numeric_limits<int>::max());
3280
+
3281
+ wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
3282
+ wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
3283
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[0]), 1);
3284
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[1]), 2);
3285
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op), 3);
3286
+
3287
+ wsp_ggml_metal_encoder_dispatch_threadgroups(enc, tg, 1, 1, nth, 1, 1);
3288
+
3289
+ return 1;
3290
+ }
3291
+
3080
3292
  int wsp_ggml_metal_op_conv_transpose_1d(wsp_ggml_metal_op_t ctx, int idx) {
3081
3293
  wsp_ggml_tensor * op = ctx->node(idx);
3082
3294
 
@@ -3088,7 +3300,7 @@ int wsp_ggml_metal_op_conv_transpose_1d(wsp_ggml_metal_op_t ctx, int idx) {
3088
3300
  WSP_GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
3089
3301
  WSP_GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
3090
3302
  WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3091
- WSP_GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
3303
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
3092
3304
 
3093
3305
  const int32_t s0 = ((const int32_t *)(op->op_params))[0];
3094
3306
 
@@ -3133,7 +3345,7 @@ int wsp_ggml_metal_op_conv_transpose_2d(wsp_ggml_metal_op_t ctx, int idx) {
3133
3345
  WSP_GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
3134
3346
  WSP_GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
3135
3347
  WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3136
- WSP_GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
3348
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
3137
3349
 
3138
3350
  const int32_t s0 = ((const int32_t *)(op->op_params))[0];
3139
3351
 
@@ -3187,7 +3399,7 @@ int wsp_ggml_metal_op_upscale(wsp_ggml_metal_op_t ctx, int idx) {
3187
3399
  WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3188
3400
  WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3189
3401
  WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3190
- WSP_GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
3402
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
3191
3403
 
3192
3404
  const float sf0 = (float)ne0/op->src[0]->ne[0];
3193
3405
  const float sf1 = (float)ne1/op->src[0]->ne[1];
@@ -3240,7 +3452,7 @@ int wsp_ggml_metal_op_pad(wsp_ggml_metal_op_t ctx, int idx) {
3240
3452
  WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3241
3453
  WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3242
3454
  WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3243
- WSP_GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
3455
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
3244
3456
 
3245
3457
  wsp_ggml_metal_kargs_pad args = {
3246
3458
  /*.ne00 =*/ ne00,
@@ -3284,7 +3496,7 @@ int wsp_ggml_metal_op_pad_reflect_1d(wsp_ggml_metal_op_t ctx, int idx) {
3284
3496
  WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3285
3497
  WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3286
3498
  WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3287
- WSP_GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
3499
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
3288
3500
 
3289
3501
  wsp_ggml_metal_kargs_pad_reflect_1d args = {
3290
3502
  /*.ne00 =*/ ne00,
@@ -3328,7 +3540,7 @@ int wsp_ggml_metal_op_arange(wsp_ggml_metal_op_t ctx, int idx) {
3328
3540
  wsp_ggml_metal_encoder_t enc = ctx->enc;
3329
3541
 
3330
3542
  WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3331
- WSP_GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
3543
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
3332
3544
 
3333
3545
  float start;
3334
3546
  float step;
@@ -3346,12 +3558,6 @@ int wsp_ggml_metal_op_arange(wsp_ggml_metal_op_t ctx, int idx) {
3346
3558
 
3347
3559
  wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_arange(lib, op);
3348
3560
 
3349
- //[encoder setComputePipelineState:pipeline];
3350
- //[encoder setBuffer:id_dst offset:offs_dst atIndex:0];
3351
- //[encoder setBytes:&args length:sizeof(args) atIndex:1];
3352
-
3353
- //[encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
3354
-
3355
3561
  wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
3356
3562
  wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
3357
3563
  wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op), 1);
@@ -3370,7 +3576,7 @@ int wsp_ggml_metal_op_timestep_embedding(wsp_ggml_metal_op_t ctx, int idx) {
3370
3576
  WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3371
3577
  WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3372
3578
  WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3373
- WSP_GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
3579
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
3374
3580
 
3375
3581
  const int dim = op->op_params[0];
3376
3582
  const int max_period = op->op_params[1];
@@ -3404,7 +3610,7 @@ int wsp_ggml_metal_op_argmax(wsp_ggml_metal_op_t ctx, int idx) {
3404
3610
  WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3405
3611
  WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3406
3612
  WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3407
- WSP_GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
3613
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
3408
3614
 
3409
3615
  wsp_ggml_metal_kargs_argmax args = {
3410
3616
  /*.ne00 = */ ne00,
@@ -3440,38 +3646,93 @@ int wsp_ggml_metal_op_argsort(wsp_ggml_metal_op_t ctx, int idx) {
3440
3646
  wsp_ggml_metal_library_t lib = ctx->lib;
3441
3647
  wsp_ggml_metal_encoder_t enc = ctx->enc;
3442
3648
 
3649
+ WSP_GGML_ASSERT(wsp_ggml_is_contiguous_rows(op->src[0]));
3650
+
3443
3651
  WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3444
3652
  WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3445
3653
  WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3446
- WSP_GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
3654
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
3655
+
3656
+ wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_argsort(lib, op);
3447
3657
 
3448
3658
  // bitonic sort requires the number of elements to be power of 2
3449
- int64_t ne00_padded = 1;
3450
- while (ne00_padded < ne00) {
3451
- ne00_padded *= 2;
3659
+ int nth = 1;
3660
+ while (nth < ne00 && 2*nth <= wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
3661
+ nth *= 2;
3452
3662
  }
3453
3663
 
3454
- wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_argsort(lib, op);
3455
-
3456
- const int64_t nrows = wsp_ggml_nrows(op->src[0]);
3664
+ const int npr = (ne00 + nth - 1)/nth;
3457
3665
 
3458
3666
  // Metal kernels require the buffer size to be multiple of 16 bytes
3459
3667
  // https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/1443142-setthreadgroupmemorylength
3460
- const size_t smem = WSP_GGML_PAD(ne00_padded*sizeof(int32_t), 16);
3668
+ const size_t smem = WSP_GGML_PAD(nth*sizeof(int32_t), 16);
3669
+
3670
+ wsp_ggml_metal_buffer_id bid_src0 = wsp_ggml_metal_get_buffer_id(op->src[0]);
3671
+ wsp_ggml_metal_buffer_id bid_dst = wsp_ggml_metal_get_buffer_id(op);
3672
+
3673
+ wsp_ggml_metal_buffer_id bid_tmp = bid_dst;
3674
+ bid_tmp.offs += wsp_ggml_nbytes(op);
3675
+
3676
+ if ((int) ceil(std::log(npr) / std::log(2)) % 2 == 1) {
3677
+ std::swap(bid_dst, bid_tmp);
3678
+ }
3461
3679
 
3462
3680
  wsp_ggml_metal_kargs_argsort args = {
3463
- /*.ncols =*/ ne00,
3464
- /*.ncols_pad =*/ ne00_padded
3681
+ /*.ne00 =*/ ne00,
3682
+ /*.ne01 =*/ ne01,
3683
+ /*.ne02 =*/ ne02,
3684
+ /*.ne03 =*/ ne03,
3685
+ /*.nb00 =*/ nb00,
3686
+ /*.nb01 =*/ nb01,
3687
+ /*.nb02 =*/ nb02,
3688
+ /*.nb03 =*/ nb03,
3465
3689
  };
3466
3690
 
3467
3691
  wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
3468
3692
  wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
3469
- wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[0]), 1);
3470
- wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op), 2);
3693
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
3694
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
3471
3695
 
3472
3696
  wsp_ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
3473
3697
 
3474
- wsp_ggml_metal_encoder_dispatch_threadgroups(enc, 1, nrows, 1, ne00_padded, 1, 1);
3698
+ wsp_ggml_metal_encoder_dispatch_threadgroups(enc, npr*ne01, ne02, ne03, nth, 1, 1);
3699
+
3700
+ wsp_ggml_metal_pipeline_t pipeline_merge = wsp_ggml_metal_library_get_pipeline_argsort_merge(lib, op);
3701
+
3702
+ int len = nth;
3703
+
3704
+ while (len < ne00) {
3705
+ wsp_ggml_metal_op_concurrency_reset(ctx);
3706
+
3707
+ wsp_ggml_metal_kargs_argsort_merge args_merge = {
3708
+ .ne00 = ne00,
3709
+ .ne01 = ne01,
3710
+ .ne02 = ne02,
3711
+ .ne03 = ne03,
3712
+ .nb00 = nb00,
3713
+ .nb01 = nb01,
3714
+ .nb02 = nb02,
3715
+ .nb03 = nb03,
3716
+ .len = len,
3717
+ };
3718
+
3719
+ // merges per row
3720
+ const int nm = (ne00 + 2*len - 1) / (2*len);
3721
+
3722
+ const int nth = std::min(512, wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline_merge));
3723
+
3724
+ wsp_ggml_metal_encoder_set_pipeline(enc, pipeline_merge);
3725
+ wsp_ggml_metal_encoder_set_bytes (enc, &args_merge, sizeof(args_merge), 0);
3726
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
3727
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
3728
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_tmp, 3);
3729
+
3730
+ wsp_ggml_metal_encoder_dispatch_threadgroups(enc, nm*ne01, ne02, ne03, nth, 1, 1);
3731
+
3732
+ std::swap(bid_dst, bid_tmp);
3733
+
3734
+ len <<= 1;
3735
+ }
3475
3736
 
3476
3737
  return 1;
3477
3738
  }
@@ -3485,7 +3746,7 @@ int wsp_ggml_metal_op_leaky_relu(wsp_ggml_metal_op_t ctx, int idx) {
3485
3746
  WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3486
3747
  WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3487
3748
  WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3488
- WSP_GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
3749
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
3489
3750
 
3490
3751
  float slope;
3491
3752
  memcpy(&slope, op->op_params, sizeof(float));
@@ -3521,7 +3782,7 @@ int wsp_ggml_metal_op_opt_step_adamw(wsp_ggml_metal_op_t ctx, int idx) {
3521
3782
  WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3522
3783
  WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3523
3784
  WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3524
- WSP_GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
3785
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
3525
3786
 
3526
3787
  wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_opt_step_adamw(lib, op);
3527
3788
 
@@ -3557,7 +3818,7 @@ int wsp_ggml_metal_op_opt_step_sgd(wsp_ggml_metal_op_t ctx, int idx) {
3557
3818
  WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3558
3819
  WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3559
3820
  WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3560
- WSP_GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
3821
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
3561
3822
 
3562
3823
  wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_opt_step_sgd(lib, op);
3563
3824