whisper.rn 0.5.1 → 0.5.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (85) hide show
  1. package/android/src/main/jni.cpp +12 -3
  2. package/cpp/ggml-alloc.c +49 -18
  3. package/cpp/ggml-backend-impl.h +0 -3
  4. package/cpp/ggml-backend-reg.cpp +8 -0
  5. package/cpp/ggml-backend.cpp +0 -2
  6. package/cpp/ggml-backend.h +2 -0
  7. package/cpp/ggml-cpu/amx/amx.cpp +1 -0
  8. package/cpp/ggml-cpu/arch/arm/quants.c +428 -26
  9. package/cpp/ggml-cpu/ggml-cpu-impl.h +4 -2
  10. package/cpp/ggml-cpu/ggml-cpu.c +67 -24
  11. package/cpp/ggml-cpu/ops.cpp +489 -364
  12. package/cpp/ggml-cpu/ops.h +4 -4
  13. package/cpp/ggml-cpu/repack.cpp +143 -29
  14. package/cpp/ggml-cpu/simd-mappings.h +25 -25
  15. package/cpp/ggml-cpu/unary-ops.cpp +151 -0
  16. package/cpp/ggml-cpu/unary-ops.h +7 -0
  17. package/cpp/ggml-cpu/vec.cpp +83 -0
  18. package/cpp/ggml-cpu/vec.h +20 -8
  19. package/cpp/ggml-impl.h +67 -2
  20. package/cpp/ggml-metal/ggml-metal-common.cpp +2 -2
  21. package/cpp/ggml-metal/ggml-metal-context.m +5 -6
  22. package/cpp/ggml-metal/ggml-metal-device.cpp +300 -14
  23. package/cpp/ggml-metal/ggml-metal-device.h +26 -1
  24. package/cpp/ggml-metal/ggml-metal-device.m +243 -28
  25. package/cpp/ggml-metal/ggml-metal-impl.h +177 -9
  26. package/cpp/ggml-metal/ggml-metal-ops.cpp +843 -157
  27. package/cpp/ggml-metal/ggml-metal-ops.h +8 -0
  28. package/cpp/ggml-metal/ggml-metal.cpp +8 -3
  29. package/cpp/ggml-metal/ggml-metal.metal +12436 -0
  30. package/cpp/ggml.c +317 -4
  31. package/cpp/ggml.h +139 -0
  32. package/cpp/jsi/RNWhisperJSI.cpp +7 -2
  33. package/cpp/rn-whisper.h +1 -0
  34. package/cpp/whisper.cpp +8 -2
  35. package/ios/RNWhisperContext.mm +3 -1
  36. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +0 -3
  37. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend.h +2 -0
  38. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-impl.h +67 -2
  39. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml.h +139 -0
  40. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/rn-whisper.h +1 -0
  41. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Info.plist +0 -0
  42. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/ggml-metal.metal +12436 -0
  43. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/rnwhisper +0 -0
  44. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +0 -3
  45. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +2 -0
  46. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +67 -2
  47. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +139 -0
  48. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-whisper.h +1 -0
  49. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
  50. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +1 -1
  51. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/ggml-metal.metal +12436 -0
  52. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
  53. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +0 -3
  54. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend.h +2 -0
  55. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-impl.h +67 -2
  56. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml.h +139 -0
  57. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/rn-whisper.h +1 -0
  58. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Info.plist +0 -0
  59. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/ggml-metal.metal +12436 -0
  60. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/rnwhisper +0 -0
  61. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +0 -3
  62. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +2 -0
  63. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +67 -2
  64. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +139 -0
  65. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-whisper.h +1 -0
  66. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
  67. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +1 -1
  68. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/ggml-metal.metal +12436 -0
  69. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
  70. package/lib/commonjs/NativeRNWhisper.js.map +1 -1
  71. package/lib/commonjs/version.json +1 -1
  72. package/lib/module/NativeRNWhisper.js.map +1 -1
  73. package/lib/module/version.json +1 -1
  74. package/lib/typescript/NativeRNWhisper.d.ts +2 -0
  75. package/lib/typescript/NativeRNWhisper.d.ts.map +1 -1
  76. package/package.json +1 -1
  77. package/src/NativeRNWhisper.ts +2 -0
  78. package/src/version.json +1 -1
  79. package/whisper-rn.podspec +1 -1
  80. package/cpp/ggml-metal/ggml-whisper-sim.metallib +0 -0
  81. package/cpp/ggml-metal/ggml-whisper.metallib +0 -0
  82. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
  83. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
  84. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
  85. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
@@ -10,6 +10,8 @@
10
10
 
11
11
  #include <cassert>
12
12
  #include <algorithm>
13
+ #include <limits>
14
+ #include <cmath>
13
15
 
14
16
  static wsp_ggml_metal_buffer_id wsp_ggml_metal_get_buffer_id(const wsp_ggml_tensor * t) {
15
17
  if (!t) {
@@ -226,6 +228,10 @@ static int wsp_ggml_metal_op_encode_impl(wsp_ggml_metal_op_t ctx, int idx) {
226
228
  WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, node->src[0], nb);
227
229
  WSP_GGML_TENSOR_LOCALS( int64_t, ne1, node->src[1], ne);
228
230
  WSP_GGML_TENSOR_LOCALS(uint64_t, nb1, node->src[1], nb);
231
+ WSP_GGML_TENSOR_LOCALS( int64_t, ne2, node->src[2], ne);
232
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb2, node->src[2], nb);
233
+ WSP_GGML_TENSOR_LOCALS( int64_t, ne3, node->src[3], ne);
234
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb3, node->src[3], nb);
229
235
  WSP_GGML_TENSOR_LOCALS( int64_t, ne, node, ne);
230
236
  WSP_GGML_TENSOR_LOCALS(uint64_t, nb, node, nb);
231
237
 
@@ -237,6 +243,14 @@ static int wsp_ggml_metal_op_encode_impl(wsp_ggml_metal_op_t ctx, int idx) {
237
243
  WSP_GGML_LOG_DEBUG("%s: src1 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, wsp_ggml_type_name(node->src[1]->type), ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13,
238
244
  wsp_ggml_is_contiguous(node->src[1]), node->src[1]->name);
239
245
  }
246
+ if (node->src[2]) {
247
+ WSP_GGML_LOG_DEBUG("%s: src2 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, wsp_ggml_type_name(node->src[2]->type), ne20, ne21, ne22, ne23, nb20, nb21, nb22, nb23,
248
+ wsp_ggml_is_contiguous(node->src[2]), node->src[2]->name);
249
+ }
250
+ if (node->src[3]) {
251
+ WSP_GGML_LOG_DEBUG("%s: src3 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, wsp_ggml_type_name(node->src[3]->type), ne30, ne31, ne32, ne33, nb30, nb31, nb32, nb33,
252
+ wsp_ggml_is_contiguous(node->src[3]), node->src[3]->name);
253
+ }
240
254
  if (node) {
241
255
  WSP_GGML_LOG_DEBUG("%s: node - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], 1, %s\n", __func__, wsp_ggml_type_name(node->type), ne0, ne1, ne2, ne3, nb0, nb1, nb2, nb3,
242
256
  node->name);
@@ -289,11 +303,19 @@ static int wsp_ggml_metal_op_encode_impl(wsp_ggml_metal_op_t ctx, int idx) {
289
303
  {
290
304
  n_fuse = wsp_ggml_metal_op_glu(ctx, idx);
291
305
  } break;
306
+ case WSP_GGML_OP_SUM:
307
+ {
308
+ n_fuse = wsp_ggml_metal_op_sum(ctx, idx);
309
+ } break;
292
310
  case WSP_GGML_OP_SUM_ROWS:
293
311
  case WSP_GGML_OP_MEAN:
294
312
  {
295
313
  n_fuse = wsp_ggml_metal_op_sum_rows(ctx, idx);
296
314
  } break;
315
+ case WSP_GGML_OP_CUMSUM:
316
+ {
317
+ n_fuse = wsp_ggml_metal_op_cumsum(ctx, idx);
318
+ } break;
297
319
  case WSP_GGML_OP_SOFT_MAX:
298
320
  {
299
321
  n_fuse = wsp_ggml_metal_op_soft_max(ctx, idx);
@@ -348,10 +370,18 @@ static int wsp_ggml_metal_op_encode_impl(wsp_ggml_metal_op_t ctx, int idx) {
348
370
  {
349
371
  n_fuse = wsp_ggml_metal_op_im2col(ctx, idx);
350
372
  } break;
373
+ case WSP_GGML_OP_CONV_2D:
374
+ {
375
+ n_fuse = wsp_ggml_metal_op_conv_2d(ctx, idx);
376
+ } break;
351
377
  case WSP_GGML_OP_CONV_TRANSPOSE_1D:
352
378
  {
353
379
  n_fuse = wsp_ggml_metal_op_conv_transpose_1d(ctx, idx);
354
380
  } break;
381
+ case WSP_GGML_OP_CONV_TRANSPOSE_2D:
382
+ {
383
+ n_fuse = wsp_ggml_metal_op_conv_transpose_2d(ctx, idx);
384
+ } break;
355
385
  case WSP_GGML_OP_UPSCALE:
356
386
  {
357
387
  n_fuse = wsp_ggml_metal_op_upscale(ctx, idx);
@@ -398,6 +428,14 @@ static int wsp_ggml_metal_op_encode_impl(wsp_ggml_metal_op_t ctx, int idx) {
398
428
  {
399
429
  n_fuse = wsp_ggml_metal_op_argmax(ctx, idx);
400
430
  } break;
431
+ case WSP_GGML_OP_OPT_STEP_ADAMW:
432
+ {
433
+ n_fuse = wsp_ggml_metal_op_opt_step_adamw(ctx, idx);
434
+ } break;
435
+ case WSP_GGML_OP_OPT_STEP_SGD:
436
+ {
437
+ n_fuse = wsp_ggml_metal_op_opt_step_sgd(ctx, idx);
438
+ } break;
401
439
  default:
402
440
  {
403
441
  WSP_GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, wsp_ggml_op_name(node->op));
@@ -506,7 +544,7 @@ int wsp_ggml_metal_op_repeat(wsp_ggml_metal_op_t ctx, int idx) {
506
544
  WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
507
545
  WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
508
546
  WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
509
- WSP_GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
547
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
510
548
 
511
549
  wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_repeat(lib, op->type);
512
550
 
@@ -552,7 +590,7 @@ int wsp_ggml_metal_op_acc(wsp_ggml_metal_op_t ctx, int idx) {
552
590
  WSP_GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
553
591
  WSP_GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
554
592
  WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
555
- WSP_GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
593
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
556
594
 
557
595
  WSP_GGML_ASSERT(op->src[0]->type == WSP_GGML_TYPE_F32);
558
596
  WSP_GGML_ASSERT(op->src[1]->type == WSP_GGML_TYPE_F32);
@@ -577,6 +615,7 @@ int wsp_ggml_metal_op_acc(wsp_ggml_metal_op_t ctx, int idx) {
577
615
  wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type);
578
616
 
579
617
  wsp_ggml_metal_kargs_cpy args = {
618
+ /*.nk0 =*/ ne00,
580
619
  /*.ne00 =*/ ne00,
581
620
  /*.ne01 =*/ ne01,
582
621
  /*.ne02 =*/ ne02,
@@ -660,7 +699,7 @@ int wsp_ggml_metal_op_scale(wsp_ggml_metal_op_t ctx, int idx) {
660
699
  WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
661
700
  WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
662
701
  WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
663
- WSP_GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
702
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
664
703
 
665
704
  float scale;
666
705
  float bias;
@@ -699,7 +738,7 @@ int wsp_ggml_metal_op_clamp(wsp_ggml_metal_op_t ctx, int idx) {
699
738
  WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
700
739
  WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
701
740
  WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
702
- WSP_GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
741
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
703
742
 
704
743
  float min;
705
744
  float max;
@@ -738,7 +777,7 @@ int wsp_ggml_metal_op_unary(wsp_ggml_metal_op_t ctx, int idx) {
738
777
  WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
739
778
  WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
740
779
  WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
741
- WSP_GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
780
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
742
781
 
743
782
  int64_t n = wsp_ggml_nelements(op);
744
783
 
@@ -768,7 +807,7 @@ int wsp_ggml_metal_op_glu(wsp_ggml_metal_op_t ctx, int idx) {
768
807
  WSP_GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
769
808
  WSP_GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
770
809
  WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
771
- WSP_GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
810
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
772
811
 
773
812
  if (op->src[1]) {
774
813
  WSP_GGML_ASSERT(wsp_ggml_are_same_shape(op->src[0], op->src[1]));
@@ -800,18 +839,6 @@ int wsp_ggml_metal_op_glu(wsp_ggml_metal_op_t ctx, int idx) {
800
839
 
801
840
  const int32_t nth = std::min(wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne00/2);
802
841
 
803
- //[encoder setComputePipelineState:pipeline];
804
- //[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
805
- //if (src1) {
806
- // [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
807
- //} else {
808
- // [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
809
- //}
810
- //[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
811
- //[encoder setBytes:&args length:sizeof(args) atIndex:3];
812
-
813
- //[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
814
-
815
842
  wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
816
843
  wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
817
844
  wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[0]), 1);
@@ -827,6 +854,43 @@ int wsp_ggml_metal_op_glu(wsp_ggml_metal_op_t ctx, int idx) {
827
854
  return 1;
828
855
  }
829
856
 
857
+ int wsp_ggml_metal_op_sum(wsp_ggml_metal_op_t ctx, int idx) {
858
+ wsp_ggml_tensor * op = ctx->node(idx);
859
+
860
+ wsp_ggml_metal_library_t lib = ctx->lib;
861
+ wsp_ggml_metal_encoder_t enc = ctx->enc;
862
+
863
+ const uint64_t n = (uint64_t) wsp_ggml_nelements(op->src[0]);
864
+
865
+ wsp_ggml_metal_kargs_sum args = {
866
+ /*.np =*/ n,
867
+ };
868
+
869
+ wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_sum(lib, op);
870
+
871
+ int nth = 32; // SIMD width
872
+
873
+ while (nth < (int) n && nth < wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
874
+ nth *= 2;
875
+ }
876
+
877
+ nth = std::min(nth, wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
878
+ nth = std::min(nth, (int) n);
879
+
880
+ const int nsg = (nth + 31) / 32;
881
+
882
+ wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
883
+ wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
884
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[0]), 1);
885
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op), 2);
886
+
887
+ wsp_ggml_metal_encoder_set_threadgroup_memory_size(enc, nsg * sizeof(float), 0);
888
+
889
+ wsp_ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, nth, 1, 1);
890
+
891
+ return 1;
892
+ }
893
+
830
894
  int wsp_ggml_metal_op_sum_rows(wsp_ggml_metal_op_t ctx, int idx) {
831
895
  wsp_ggml_tensor * op = ctx->node(idx);
832
896
 
@@ -836,7 +900,7 @@ int wsp_ggml_metal_op_sum_rows(wsp_ggml_metal_op_t ctx, int idx) {
836
900
  WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
837
901
  WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
838
902
  WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
839
- WSP_GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
903
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
840
904
 
841
905
  wsp_ggml_metal_kargs_sum_rows args = {
842
906
  /*.ne00 =*/ ne00,
@@ -870,14 +934,6 @@ int wsp_ggml_metal_op_sum_rows(wsp_ggml_metal_op_t ctx, int idx) {
870
934
 
871
935
  const size_t smem = wsp_ggml_metal_pipeline_get_smem(pipeline);
872
936
 
873
- //[encoder setComputePipelineState:pipeline];
874
- //[encoder setBytes:&args length:sizeof(args) atIndex:0];
875
- //[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
876
- //[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
877
- //[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
878
-
879
- //[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
880
-
881
937
  wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
882
938
  wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
883
939
  wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[0]), 1);
@@ -890,6 +946,149 @@ int wsp_ggml_metal_op_sum_rows(wsp_ggml_metal_op_t ctx, int idx) {
890
946
  return 1;
891
947
  }
892
948
 
949
+ int wsp_ggml_metal_op_cumsum(wsp_ggml_metal_op_t ctx, int idx) {
950
+ wsp_ggml_tensor * op = ctx->node(idx);
951
+
952
+ wsp_ggml_metal_library_t lib = ctx->lib;
953
+ wsp_ggml_metal_encoder_t enc = ctx->enc;
954
+
955
+ WSP_GGML_ASSERT(wsp_ggml_is_contiguous_rows(op->src[0]));
956
+
957
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
958
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
959
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
960
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
961
+
962
+ wsp_ggml_metal_pipeline_t pipeline_blk = wsp_ggml_metal_library_get_pipeline_cumsum_blk(lib, op);
963
+
964
+ int nth = 1;
965
+ while (nth < ne00 && 2*nth <= wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline_blk)) {
966
+ nth *= 2;
967
+ }
968
+
969
+ WSP_GGML_ASSERT(ne00 <= nth*nth);
970
+
971
+ const int64_t net0 = (ne00 + nth - 1) / nth;
972
+ const int64_t net1 = ne01;
973
+ const int64_t net2 = ne02;
974
+ const int64_t net3 = ne03;
975
+
976
+ const uint64_t nbt0 = sizeof(float);
977
+ const uint64_t nbt1 = net0*nbt0;
978
+ const uint64_t nbt2 = net1*nbt1;
979
+ const uint64_t nbt3 = net2*nbt2;
980
+
981
+ const size_t smem = WSP_GGML_PAD(32*sizeof(float), 16);
982
+
983
+ wsp_ggml_metal_buffer_id bid_src0 = wsp_ggml_metal_get_buffer_id(op->src[0]);
984
+ wsp_ggml_metal_buffer_id bid_dst = wsp_ggml_metal_get_buffer_id(op);
985
+
986
+ wsp_ggml_metal_buffer_id bid_tmp = bid_dst;
987
+ bid_tmp.offs += wsp_ggml_nbytes(op);
988
+
989
+ {
990
+ wsp_ggml_metal_kargs_cumsum_blk args = {
991
+ /*.ne00 =*/ ne00,
992
+ /*.ne01 =*/ ne01,
993
+ /*.ne02 =*/ ne02,
994
+ /*.ne03 =*/ ne03,
995
+ /*.nb00 =*/ nb00,
996
+ /*.nb01 =*/ nb01,
997
+ /*.nb02 =*/ nb02,
998
+ /*.nb03 =*/ nb03,
999
+ /*.net0 =*/ net0,
1000
+ /*.net1 =*/ net1,
1001
+ /*.net2 =*/ net2,
1002
+ /*.net3 =*/ net3,
1003
+ /*.nbt0 =*/ nbt0,
1004
+ /*.nbt1 =*/ nbt1,
1005
+ /*.nbt2 =*/ nbt2,
1006
+ /*.nbt3 =*/ nbt3,
1007
+ /*.outb =*/ ne00 > nth,
1008
+ };
1009
+
1010
+ wsp_ggml_metal_encoder_set_pipeline(enc, pipeline_blk);
1011
+ wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
1012
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
1013
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_tmp, 2);
1014
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_dst, 3);
1015
+
1016
+ wsp_ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
1017
+
1018
+ wsp_ggml_metal_encoder_dispatch_threadgroups(enc, net0*ne01, ne02, ne03, nth, 1, 1);
1019
+ }
1020
+
1021
+ if (ne00 > nth) {
1022
+ wsp_ggml_metal_op_concurrency_reset(ctx);
1023
+
1024
+ {
1025
+ wsp_ggml_metal_kargs_cumsum_blk args = {
1026
+ /*.ne00 =*/ net0,
1027
+ /*.ne01 =*/ net1,
1028
+ /*.ne02 =*/ net2,
1029
+ /*.ne03 =*/ net3,
1030
+ /*.nb00 =*/ nbt0,
1031
+ /*.nb01 =*/ nbt1,
1032
+ /*.nb02 =*/ nbt2,
1033
+ /*.nb03 =*/ nbt3,
1034
+ /*.net0 =*/ net0,
1035
+ /*.net1 =*/ net1,
1036
+ /*.net2 =*/ net2,
1037
+ /*.net3 =*/ net3,
1038
+ /*.nbt0 =*/ nbt0,
1039
+ /*.nbt1 =*/ nbt1,
1040
+ /*.nbt2 =*/ nbt2,
1041
+ /*.nbt3 =*/ nbt3,
1042
+ /*.outb =*/ false,
1043
+ };
1044
+
1045
+ wsp_ggml_metal_encoder_set_pipeline(enc, pipeline_blk);
1046
+ wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
1047
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_tmp, 1);
1048
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_tmp, 2);
1049
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_tmp, 3);
1050
+
1051
+ wsp_ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
1052
+
1053
+ wsp_ggml_metal_encoder_dispatch_threadgroups(enc, net1, net2, net3, nth, 1, 1);
1054
+ }
1055
+
1056
+ wsp_ggml_metal_op_concurrency_reset(ctx);
1057
+
1058
+ {
1059
+ wsp_ggml_metal_pipeline_t pipeline_add = wsp_ggml_metal_library_get_pipeline_cumsum_add(lib, op);
1060
+
1061
+ wsp_ggml_metal_kargs_cumsum_add args = {
1062
+ /*.ne00 =*/ ne00,
1063
+ /*.ne01 =*/ ne01,
1064
+ /*.ne02 =*/ ne02,
1065
+ /*.ne03 =*/ ne03,
1066
+ /*.nb00 =*/ nb00,
1067
+ /*.nb01 =*/ nb01,
1068
+ /*.nb02 =*/ nb02,
1069
+ /*.nb03 =*/ nb03,
1070
+ /*.net0 =*/ net0,
1071
+ /*.net1 =*/ net1,
1072
+ /*.net2 =*/ net2,
1073
+ /*.net3 =*/ net3,
1074
+ /*.nbt0 =*/ nbt0,
1075
+ /*.nbt1 =*/ nbt1,
1076
+ /*.nbt2 =*/ nbt2,
1077
+ /*.nbt3 =*/ nbt3,
1078
+ };
1079
+
1080
+ wsp_ggml_metal_encoder_set_pipeline(enc, pipeline_add);
1081
+ wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
1082
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_tmp, 1);
1083
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
1084
+
1085
+ wsp_ggml_metal_encoder_dispatch_threadgroups(enc, net0*ne01, ne02, ne03, nth, 1, 1);
1086
+ }
1087
+ }
1088
+
1089
+ return 1;
1090
+ }
1091
+
893
1092
  int wsp_ggml_metal_op_get_rows(wsp_ggml_metal_op_t ctx, int idx) {
894
1093
  wsp_ggml_tensor * op = ctx->node(idx);
895
1094
 
@@ -901,28 +1100,36 @@ int wsp_ggml_metal_op_get_rows(wsp_ggml_metal_op_t ctx, int idx) {
901
1100
  WSP_GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
902
1101
  WSP_GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
903
1102
  WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
904
- WSP_GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
1103
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
905
1104
 
906
1105
  wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_get_rows(lib, op->src[0]->type);
907
1106
 
908
1107
  wsp_ggml_metal_kargs_get_rows args = {
909
- /*.ne00 =*/ ne00,
910
- /*.nb01 =*/ nb01,
911
- /*.nb02 =*/ nb02,
912
- /*.ne10 =*/ ne10,
913
- /*.nb10 =*/ nb10,
914
- /*.nb11 =*/ nb11,
915
- /*.nb1 =*/ nb1,
916
- /*.nb2 =*/ nb2,
1108
+ /*.ne00t =*/ wsp_ggml_is_quantized(op->src[0]->type) ? ne00/16 : ne00,
1109
+ /*.ne00 =*/ ne00,
1110
+ /*.nb01 =*/ nb01,
1111
+ /*.nb02 =*/ nb02,
1112
+ /*.nb03 =*/ nb03,
1113
+ /*.ne10 =*/ ne10,
1114
+ /*.nb10 =*/ nb10,
1115
+ /*.nb11 =*/ nb11,
1116
+ /*.nb12 =*/ nb12,
1117
+ /*.nb1 =*/ nb1,
1118
+ /*.nb2 =*/ nb2,
1119
+ /*.nb3 =*/ nb3,
917
1120
  };
918
1121
 
1122
+ const int nth = std::min(args.ne00t, wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
1123
+
1124
+ const int nw0 = (args.ne00t + nth - 1)/nth;
1125
+
919
1126
  wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
920
1127
  wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
921
1128
  wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[0]), 1);
922
1129
  wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[1]), 2);
923
1130
  wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op), 3);
924
1131
 
925
- wsp_ggml_metal_encoder_dispatch_threadgroups(enc, ne10, ne11, ne12, 32, 1, 1);
1132
+ wsp_ggml_metal_encoder_dispatch_threadgroups(enc, nw0*ne10, ne11, ne12, nth, 1, 1);
926
1133
 
927
1134
  return 1;
928
1135
  }
@@ -938,7 +1145,7 @@ int wsp_ggml_metal_op_set_rows(wsp_ggml_metal_op_t ctx, int idx) {
938
1145
  WSP_GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
939
1146
  WSP_GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
940
1147
  WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
941
- WSP_GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
1148
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
942
1149
 
943
1150
  wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_set_rows(lib, op->src[1]->type, op->type);
944
1151
 
@@ -1002,7 +1209,7 @@ int wsp_ggml_metal_op_soft_max(wsp_ggml_metal_op_t ctx, int idx) {
1002
1209
  WSP_GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
1003
1210
  WSP_GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
1004
1211
  WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
1005
- WSP_GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
1212
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
1006
1213
 
1007
1214
  float scale;
1008
1215
  float max_bias;
@@ -1090,7 +1297,7 @@ int wsp_ggml_metal_op_ssm_conv(wsp_ggml_metal_op_t ctx, int idx) {
1090
1297
  WSP_GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
1091
1298
  WSP_GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
1092
1299
  WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
1093
- WSP_GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
1300
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
1094
1301
 
1095
1302
  wsp_ggml_metal_kargs_ssm_conv args = {
1096
1303
  /*.ne00 =*/ ne00,
@@ -1117,7 +1324,7 @@ int wsp_ggml_metal_op_ssm_conv(wsp_ggml_metal_op_t ctx, int idx) {
1117
1324
  wsp_ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
1118
1325
  wsp_ggml_metal_encoder_set_buffer(enc, wsp_ggml_metal_get_buffer_id(op->src[0]), 1);
1119
1326
  wsp_ggml_metal_encoder_set_buffer(enc, wsp_ggml_metal_get_buffer_id(op->src[1]), 2);
1120
- wsp_ggml_metal_encoder_set_buffer(enc, wsp_ggml_metal_get_buffer_id(op), 3);
1327
+ wsp_ggml_metal_encoder_set_buffer(enc, wsp_ggml_metal_get_buffer_id(op), 3);
1121
1328
 
1122
1329
  wsp_ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne1, ne02, 1, 1, 1);
1123
1330
 
@@ -1145,7 +1352,7 @@ int wsp_ggml_metal_op_ssm_scan(wsp_ggml_metal_op_t ctx, int idx) {
1145
1352
  WSP_GGML_TENSOR_LOCALS( int32_t, ne6, op->src[6], ne);
1146
1353
  WSP_GGML_TENSOR_LOCALS(uint64_t, nb6, op->src[6], nb);
1147
1354
  WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
1148
- WSP_GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
1355
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
1149
1356
 
1150
1357
  const wsp_ggml_tensor * src3 = op->src[3];
1151
1358
  const wsp_ggml_tensor * src4 = op->src[4];
@@ -1172,25 +1379,36 @@ int wsp_ggml_metal_op_ssm_scan(wsp_ggml_metal_op_t ctx, int idx) {
1172
1379
  /*.n_seq_tokens =*/ n_seq_tokens,
1173
1380
  /*.n_seqs =*/ n_seqs,
1174
1381
  /*.s_off =*/ wsp_ggml_nelements(op->src[1]) * sizeof(float),
1382
+ /*.nb00 =*/ nb00,
1175
1383
  /*.nb01 =*/ nb01,
1176
1384
  /*.nb02 =*/ nb02,
1177
1385
  /*.nb03 =*/ nb03,
1386
+ /*.nb10 =*/ nb10,
1178
1387
  /*.nb11 =*/ nb11,
1179
1388
  /*.nb12 =*/ nb12,
1389
+ /*.ns12 =*/ nb12/nb10,
1180
1390
  /*.nb13 =*/ nb13,
1391
+ /*.nb20 =*/ nb20,
1181
1392
  /*.nb21 =*/ nb21,
1393
+ /*.ns21 =*/ nb21/nb20,
1182
1394
  /*.nb22 =*/ nb22,
1395
+ /*.ne30 =*/ ne30,
1183
1396
  /*.nb31 =*/ nb31,
1184
1397
  /*.nb41 =*/ nb41,
1185
1398
  /*.nb42 =*/ nb42,
1399
+ /*.ns42 =*/ nb42/nb40,
1186
1400
  /*.nb43 =*/ nb43,
1187
1401
  /*.nb51 =*/ nb51,
1188
1402
  /*.nb52 =*/ nb52,
1403
+ /*.ns52 =*/ nb52/nb50,
1189
1404
  /*.nb53 =*/ nb53,
1405
+ /*.nb0 =*/ nb0,
1190
1406
  };
1191
1407
 
1192
1408
  wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_ssm_scan(lib, op);
1193
1409
 
1410
+ WSP_GGML_ASSERT(d_state <= wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
1411
+
1194
1412
  const size_t sms = wsp_ggml_metal_pipeline_get_smem(pipeline);
1195
1413
 
1196
1414
  wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
@@ -1206,13 +1424,7 @@ int wsp_ggml_metal_op_ssm_scan(wsp_ggml_metal_op_t ctx, int idx) {
1206
1424
 
1207
1425
  wsp_ggml_metal_encoder_set_threadgroup_memory_size(enc, sms, 0);
1208
1426
 
1209
- if (ne30 == 1) {
1210
- // Mamba-2
1211
- wsp_ggml_metal_encoder_dispatch_threadgroups(enc, d_inner, n_head, n_seqs, d_state, 1, 1);
1212
- } else {
1213
- WSP_GGML_ASSERT(d_inner == 1);
1214
- wsp_ggml_metal_encoder_dispatch_threadgroups(enc, n_head, n_seqs, 1, d_state, 1, 1);
1215
- }
1427
+ wsp_ggml_metal_encoder_dispatch_threadgroups(enc, d_inner, n_head, n_seqs, d_state, 1, 1);
1216
1428
 
1217
1429
  return 1;
1218
1430
  }
@@ -1226,7 +1438,7 @@ int wsp_ggml_metal_op_rwkv(wsp_ggml_metal_op_t ctx, int idx) {
1226
1438
  WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
1227
1439
  WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
1228
1440
  WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
1229
- WSP_GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
1441
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
1230
1442
 
1231
1443
  const int64_t B = op->op == WSP_GGML_OP_RWKV_WKV6 ? op->src[5]->ne[1] : op->src[6]->ne[1];
1232
1444
  const int64_t T = op->src[0]->ne[2];
@@ -1267,32 +1479,29 @@ int wsp_ggml_metal_op_cpy(wsp_ggml_metal_op_t ctx, int idx) {
1267
1479
  WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
1268
1480
  WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
1269
1481
  WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
1270
- WSP_GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
1482
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
1271
1483
 
1272
1484
  wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type);
1273
1485
 
1274
1486
  WSP_GGML_ASSERT(ne00 % wsp_ggml_blck_size(op->src[0]->type) == 0);
1275
1487
 
1276
- // TODO: support
1277
- //const int32_t nk00 = ne00/wsp_ggml_blck_size(op->type);
1278
- const int32_t nk00 = ne00;
1279
-
1280
- int nth = 32; // SIMD width
1281
-
1282
- while (nth < nk00 && nth < wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
1283
- nth *= 2;
1488
+ int64_t nk0 = ne00;
1489
+ if (wsp_ggml_is_quantized(op->src[0]->type)) {
1490
+ nk0 = ne00/16;
1491
+ } else if (wsp_ggml_is_quantized(op->type)) {
1492
+ nk0 = ne00/wsp_ggml_blck_size(op->type);
1284
1493
  }
1285
1494
 
1286
- nth = std::min(nth, wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
1495
+ int nth = std::min<int>(nk0, wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
1287
1496
 
1288
1497
  // when rows are small, we can batch them together in a single threadgroup
1289
1498
  int nrptg = 1;
1290
1499
 
1291
1500
  // TODO: relax this constraint in the future
1292
1501
  if (wsp_ggml_blck_size(op->src[0]->type) == 1 && wsp_ggml_blck_size(op->type) == 1) {
1293
- if (nth > nk00) {
1294
- nrptg = (nth + nk00 - 1)/nk00;
1295
- nth = nk00;
1502
+ if (nth > nk0) {
1503
+ nrptg = (nth + nk0 - 1)/nk0;
1504
+ nth = nk0;
1296
1505
 
1297
1506
  if (nrptg*nth > wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
1298
1507
  nrptg--;
@@ -1300,10 +1509,11 @@ int wsp_ggml_metal_op_cpy(wsp_ggml_metal_op_t ctx, int idx) {
1300
1509
  }
1301
1510
  }
1302
1511
 
1303
- nth = std::min(nth, nk00);
1512
+ nth = std::min<int>(nth, nk0);
1304
1513
 
1305
1514
  wsp_ggml_metal_kargs_cpy args = {
1306
- /*.ne00 =*/ nk00,
1515
+ /*.nk0 =*/ nk0,
1516
+ /*.ne00 =*/ ne00,
1307
1517
  /*.ne01 =*/ ne01,
1308
1518
  /*.ne02 =*/ ne02,
1309
1519
  /*.ne03 =*/ ne03,
@@ -1321,12 +1531,14 @@ int wsp_ggml_metal_op_cpy(wsp_ggml_metal_op_t ctx, int idx) {
1321
1531
  /*.nb3 =*/ nb3,
1322
1532
  };
1323
1533
 
1534
+ const int nw0 = nrptg == 1 ? (nk0 + nth - 1)/nth : 1;
1535
+
1324
1536
  wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
1325
1537
  wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
1326
1538
  wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[0]), 1);
1327
1539
  wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op), 2);
1328
1540
 
1329
- wsp_ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, nrptg, 1);
1541
+ wsp_ggml_metal_encoder_dispatch_threadgroups(enc, nw0*(ne01 + nrptg - 1)/nrptg, ne02, ne03, nth, nrptg, 1);
1330
1542
 
1331
1543
  return 1;
1332
1544
  }
@@ -1340,7 +1552,7 @@ int wsp_ggml_metal_op_pool_2d(wsp_ggml_metal_op_t ctx, int idx) {
1340
1552
  WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
1341
1553
  WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
1342
1554
  WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
1343
- WSP_GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
1555
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
1344
1556
 
1345
1557
  const int32_t * opts = op->op_params;
1346
1558
  wsp_ggml_op_pool op_pool = (wsp_ggml_op_pool) opts[0];
@@ -1404,7 +1616,7 @@ int wsp_ggml_metal_op_mul_mat(wsp_ggml_metal_op_t ctx, int idx) {
1404
1616
  WSP_GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
1405
1617
  WSP_GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
1406
1618
  WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
1407
- WSP_GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
1619
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
1408
1620
 
1409
1621
  WSP_GGML_ASSERT(ne00 == ne10);
1410
1622
 
@@ -1520,9 +1732,8 @@ int wsp_ggml_metal_op_mul_mat(wsp_ggml_metal_op_t ctx, int idx) {
1520
1732
  !wsp_ggml_is_transposed(op->src[1]) &&
1521
1733
  // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
1522
1734
  // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
1523
- props_dev->has_simdgroup_mm && ne00 >= 64 &&
1524
- (ne11 > ne11_mm_min || (wsp_ggml_is_quantized(op->src[0]->type) && ne12 > 1))) {
1525
- //printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
1735
+ props_dev->has_simdgroup_mm && ne00 >= 64 && ne11 > ne11_mm_min) {
1736
+ //WSP_GGML_LOG_INFO("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
1526
1737
 
1527
1738
  // some Metal matrix data types require aligned pointers
1528
1739
  // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
@@ -1646,7 +1857,7 @@ int wsp_ggml_metal_op_mul_mat_id(wsp_ggml_metal_op_t ctx, int idx) {
1646
1857
  WSP_GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
1647
1858
  WSP_GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
1648
1859
  WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
1649
- WSP_GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
1860
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
1650
1861
 
1651
1862
  // src2 = ids
1652
1863
  WSP_GGML_ASSERT(op->src[2]->type == WSP_GGML_TYPE_I32);
@@ -1875,20 +2086,114 @@ bool wsp_ggml_metal_op_flash_attn_ext_use_vec(const wsp_ggml_tensor * op) {
1875
2086
  return (ne01 < 20) && (ne00 % 32 == 0);
1876
2087
  }
1877
2088
 
2089
+ size_t wsp_ggml_metal_op_flash_attn_ext_extra_pad(const wsp_ggml_tensor * op) {
2090
+ assert(op->op == WSP_GGML_OP_FLASH_ATTN_EXT);
2091
+
2092
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
2093
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
2094
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
2095
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
2096
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
2097
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
2098
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne);
2099
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb);
2100
+
2101
+ size_t res = 0;
2102
+
2103
+ const bool has_mask = op->src[3] != nullptr;
2104
+
2105
+ if (wsp_ggml_metal_op_flash_attn_ext_use_vec(op)) {
2106
+ // note: always reserve the padding space to avoid graph reallocations
2107
+ //const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_VEC_NCPSG != 0;
2108
+ const bool has_kvpad = true;
2109
+
2110
+ if (has_kvpad) {
2111
+ res += OP_FLASH_ATTN_EXT_VEC_NCPSG*(
2112
+ nb11*ne12*ne13 +
2113
+ nb21*ne22*ne23 +
2114
+ (has_mask ? wsp_ggml_type_size(WSP_GGML_TYPE_F16)*ne31*ne32*ne33 : 0));
2115
+ }
2116
+ } else {
2117
+ //const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_NCPSG != 0;
2118
+ const bool has_kvpad = true;
2119
+
2120
+ if (has_kvpad) {
2121
+ res += OP_FLASH_ATTN_EXT_NCPSG*(
2122
+ nb11*ne12*ne13 +
2123
+ nb21*ne22*ne23 +
2124
+ (has_mask ? wsp_ggml_type_size(WSP_GGML_TYPE_F16)*ne31*ne32*ne33 : 0));
2125
+ }
2126
+ }
2127
+
2128
+ return res;
2129
+ }
2130
+
2131
+ size_t wsp_ggml_metal_op_flash_attn_ext_extra_blk(const wsp_ggml_tensor * op) {
2132
+ assert(op->op == WSP_GGML_OP_FLASH_ATTN_EXT);
2133
+
2134
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
2135
+ //WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
2136
+ //WSP_GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
2137
+ //WSP_GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
2138
+ //WSP_GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
2139
+ //WSP_GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
2140
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne);
2141
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb);
2142
+
2143
+ size_t res = 0;
2144
+
2145
+ const bool has_mask = op->src[3] != nullptr;
2146
+
2147
+ if (!has_mask) {
2148
+ return res;
2149
+ }
2150
+
2151
+ const bool is_vec = wsp_ggml_metal_op_flash_attn_ext_use_vec(op);
2152
+
2153
+ // this optimization is not useful for the vector kernels
2154
+ // note: always reserve the blk buffer to avoid graph reallocations
2155
+ //if (is_vec) {
2156
+ // return res;
2157
+ //}
2158
+
2159
+ const int nqptg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NQPTG : OP_FLASH_ATTN_EXT_NQPTG;
2160
+ const int ncpsg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NCPSG : OP_FLASH_ATTN_EXT_NCPSG;
2161
+
2162
+ const int64_t ne1 = (ne01 + nqptg - 1)/nqptg;
2163
+ const int64_t ne0 = (ne30 + ncpsg - 1)/ncpsg;
2164
+
2165
+ res += WSP_GGML_PAD(wsp_ggml_type_size(WSP_GGML_TYPE_I8)*ne0*ne1*ne32*ne33, 32);
2166
+
2167
+ return res;
2168
+ }
2169
+
1878
2170
  size_t wsp_ggml_metal_op_flash_attn_ext_extra_tmp(const wsp_ggml_tensor * op) {
1879
2171
  assert(op->op == WSP_GGML_OP_FLASH_ATTN_EXT);
1880
2172
 
1881
- const int64_t nwg = 32;
2173
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
2174
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
2175
+ //WSP_GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
2176
+ //WSP_GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
2177
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
2178
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
2179
+ //WSP_GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne);
2180
+ //WSP_GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb);
2181
+
2182
+ size_t res = 0;
2183
+
2184
+ // note: always reserve the temp buffer to avoid graph reallocations
2185
+ //if (wsp_ggml_metal_op_flash_attn_ext_use_vec(op)) {
2186
+ if (true) {
2187
+ const int64_t nwg = 32;
2188
+ const int64_t ne01_max = std::min(ne01, 32);
1882
2189
 
1883
- const int64_t ne01 = op->src[0]->ne[1];
1884
- const int64_t ne02 = op->src[0]->ne[2];
1885
- const int64_t ne03 = op->src[0]->ne[3];
1886
- const int64_t ne20 = op->src[2]->ne[0];
2190
+ // temp buffer for writing the results from each workgroup
2191
+ // - ne20: the size of the Value head
2192
+ // - + 2: the S and M values for each intermediate result
2193
+ res += wsp_ggml_type_size(WSP_GGML_TYPE_F32)*(ne01_max*ne02*ne03*nwg*(ne20 + 2));
2194
+ }
1887
2195
 
1888
- // temp buffer for writing the results from each workgroup
1889
- // - ne20: the size of the Value head
1890
- // - + 2: the S and M values for each intermediate result
1891
- return wsp_ggml_type_size(WSP_GGML_TYPE_F32)*(ne01*ne02*ne03*nwg*(ne20 + 2));
2196
+ return res;
1892
2197
  }
1893
2198
 
1894
2199
  int wsp_ggml_metal_op_flash_attn_ext(wsp_ggml_metal_op_t ctx, int idx) {
@@ -1910,8 +2215,7 @@ int wsp_ggml_metal_op_flash_attn_ext(wsp_ggml_metal_op_t ctx, int idx) {
1910
2215
  WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
1911
2216
  WSP_GGML_TENSOR_LOCALS( int32_t, nb, op, nb);
1912
2217
 
1913
- WSP_GGML_ASSERT(ne00 % 4 == 0);
1914
- WSP_GGML_ASSERT(ne11 % 32 == 0);
2218
+ WSP_GGML_ASSERT(ne00 % 4 == 0);
1915
2219
 
1916
2220
  WSP_GGML_ASSERT(op->src[0]->type == WSP_GGML_TYPE_F32);
1917
2221
  WSP_GGML_ASSERT(op->src[1]->type == op->src[2]->type);
@@ -1921,8 +2225,8 @@ int wsp_ggml_metal_op_flash_attn_ext(wsp_ggml_metal_op_t ctx, int idx) {
1921
2225
  WSP_GGML_ASSERT(ne12 == ne22);
1922
2226
 
1923
2227
  WSP_GGML_ASSERT(!op->src[3] || op->src[3]->type == WSP_GGML_TYPE_F16);
1924
- WSP_GGML_ASSERT(!op->src[3] || op->src[3]->ne[1] >= WSP_GGML_PAD(op->src[0]->ne[1], 8) &&
1925
- "the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big");
2228
+ WSP_GGML_ASSERT(!op->src[3] || op->src[3]->ne[1] >= op->src[0]->ne[1] &&
2229
+ "the Flash-Attention Metal kernel requires the mask to be at least n_queries big");
1926
2230
 
1927
2231
  float scale;
1928
2232
  float max_bias;
@@ -1949,15 +2253,107 @@ int wsp_ggml_metal_op_flash_attn_ext(wsp_ggml_metal_op_t ctx, int idx) {
1949
2253
 
1950
2254
  WSP_GGML_ASSERT(ne01 < 65536);
1951
2255
 
2256
+ wsp_ggml_metal_buffer_id bid_src0 = wsp_ggml_metal_get_buffer_id(op->src[0]);
2257
+ wsp_ggml_metal_buffer_id bid_src1 = wsp_ggml_metal_get_buffer_id(op->src[1]);
2258
+ wsp_ggml_metal_buffer_id bid_src2 = wsp_ggml_metal_get_buffer_id(op->src[2]);
2259
+ wsp_ggml_metal_buffer_id bid_src3 = has_mask ? wsp_ggml_metal_get_buffer_id(op->src[3]) : bid_src0;
2260
+ wsp_ggml_metal_buffer_id bid_src4 = has_sinks ? wsp_ggml_metal_get_buffer_id(op->src[4]) : bid_src0;
2261
+
2262
+ wsp_ggml_metal_buffer_id bid_dst = wsp_ggml_metal_get_buffer_id(op);
2263
+
2264
+ wsp_ggml_metal_buffer_id bid_pad = bid_dst;
2265
+ bid_pad.offs += wsp_ggml_nbytes(op);
2266
+
2267
+ wsp_ggml_metal_buffer_id bid_blk = bid_pad;
2268
+ bid_blk.offs += wsp_ggml_metal_op_flash_attn_ext_extra_pad(op);
2269
+
2270
+ wsp_ggml_metal_buffer_id bid_tmp = bid_blk;
2271
+ bid_tmp.offs += wsp_ggml_metal_op_flash_attn_ext_extra_blk(op);
2272
+
1952
2273
  if (!wsp_ggml_metal_op_flash_attn_ext_use_vec(op)) {
1953
2274
  // half8x8 kernel
1954
- const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !!
1955
- const int64_t ncpsg = 64; // cache values per simdgroup !! sync with kernel template arguments !!
2275
+ const int nqptg = OP_FLASH_ATTN_EXT_NQPTG; // queries per threadgroup
2276
+ const int ncpsg = OP_FLASH_ATTN_EXT_NCPSG; // cache values per simdgroup
1956
2277
 
1957
2278
  WSP_GGML_ASSERT(nqptg <= 32);
1958
2279
  WSP_GGML_ASSERT(nqptg % 8 == 0);
1959
2280
  WSP_GGML_ASSERT(ncpsg % 32 == 0);
1960
2281
 
2282
+ bool need_sync = false;
2283
+
2284
+ const bool has_kvpad = ne11 % ncpsg != 0;
2285
+
2286
+ if (has_kvpad) {
2287
+ assert(wsp_ggml_metal_op_flash_attn_ext_extra_pad(op) != 0);
2288
+
2289
+ wsp_ggml_metal_kargs_flash_attn_ext_pad args0 = {
2290
+ /*.ne11 =*/ne11,
2291
+ /*.ne_12_2 =*/ne12,
2292
+ /*.ne_12_3 =*/ne13,
2293
+ /*.nb11 =*/nb11,
2294
+ /*.nb12 =*/nb12,
2295
+ /*.nb13 =*/nb13,
2296
+ /*.nb21 =*/nb21,
2297
+ /*.nb22 =*/nb22,
2298
+ /*.nb23 =*/nb23,
2299
+ /*.ne31 =*/ne31,
2300
+ /*.ne32 =*/ne32,
2301
+ /*.ne33 =*/ne33,
2302
+ /*.nb31 =*/nb31,
2303
+ /*.nb32 =*/nb32,
2304
+ /*.nb33 =*/nb33,
2305
+ };
2306
+
2307
+ wsp_ggml_metal_pipeline_t pipeline0 = wsp_ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg);
2308
+
2309
+ wsp_ggml_metal_encoder_set_pipeline(enc, pipeline0);
2310
+ wsp_ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0);
2311
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_src1, 1);
2312
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_src2, 2);
2313
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_src3, 3);
2314
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_pad, 4);
2315
+
2316
+ assert(ne12 == ne22);
2317
+ assert(ne13 == ne23);
2318
+
2319
+ wsp_ggml_metal_encoder_dispatch_threadgroups(enc, ncpsg, std::max(ne12, ne32), std::max(ne13, ne33), 32, 1, 1);
2320
+
2321
+ need_sync = true;
2322
+ }
2323
+
2324
+ if (has_mask) {
2325
+ assert(wsp_ggml_metal_op_flash_attn_ext_extra_blk(op) != 0);
2326
+
2327
+ wsp_ggml_metal_kargs_flash_attn_ext_blk args0 = {
2328
+ /*.ne01 =*/ ne01,
2329
+ /*.ne30 =*/ ne30,
2330
+ /*.ne31 =*/ ne31,
2331
+ /*.ne32 =*/ ne32,
2332
+ /*.ne33 =*/ ne33,
2333
+ /*.nb31 =*/ nb31,
2334
+ /*.nb32 =*/ nb32,
2335
+ /*.nb33 =*/ nb33,
2336
+ };
2337
+
2338
+ wsp_ggml_metal_pipeline_t pipeline0 = wsp_ggml_metal_library_get_pipeline_flash_attn_ext_blk(lib, op, nqptg, ncpsg);
2339
+
2340
+ wsp_ggml_metal_encoder_set_pipeline(enc, pipeline0);
2341
+ wsp_ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0);
2342
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_src3, 1);
2343
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_blk, 2);
2344
+
2345
+ const int32_t nblk1 = ((ne01 + nqptg - 1)/nqptg);
2346
+ const int32_t nblk0 = ((ne30 + ncpsg - 1)/ncpsg);
2347
+
2348
+ wsp_ggml_metal_encoder_dispatch_threadgroups(enc, nblk0, nblk1, ne32*ne33, 32, 1, 1);
2349
+
2350
+ need_sync = true;
2351
+ }
2352
+
2353
+ if (need_sync) {
2354
+ wsp_ggml_metal_op_concurrency_reset(ctx);
2355
+ }
2356
+
1961
2357
  const int is_q = wsp_ggml_is_quantized(op->src[1]->type) ? 1 : 0;
1962
2358
 
1963
2359
  // 2*(2*ncpsg)
@@ -2007,6 +2403,7 @@ int wsp_ggml_metal_op_flash_attn_ext(wsp_ggml_metal_op_t ctx, int idx) {
2007
2403
  /*.nb21 =*/ nb21,
2008
2404
  /*.nb22 =*/ nb22,
2009
2405
  /*.nb23 =*/ nb23,
2406
+ /*.ne31 =*/ ne31,
2010
2407
  /*.ne32 =*/ ne32,
2011
2408
  /*.ne33 =*/ ne33,
2012
2409
  /*.nb31 =*/ nb31,
@@ -2023,24 +2420,18 @@ int wsp_ggml_metal_op_flash_attn_ext(wsp_ggml_metal_op_t ctx, int idx) {
2023
2420
  /*.logit_softcap =*/ logit_softcap,
2024
2421
  };
2025
2422
 
2026
- wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_flash_attn_ext(lib, op, has_mask, has_sinks, has_bias, has_scap, nsg);
2423
+ wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_flash_attn_ext(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg);
2027
2424
 
2028
2425
  wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
2029
2426
  wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
2030
- wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[0]), 1);
2031
- wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[1]), 2);
2032
- wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[2]), 3);
2033
- if (op->src[3]) {
2034
- wsp_ggml_metal_encoder_set_buffer(enc, wsp_ggml_metal_get_buffer_id(op->src[3]), 4);
2035
- } else {
2036
- wsp_ggml_metal_encoder_set_buffer(enc, wsp_ggml_metal_get_buffer_id(op->src[0]), 4);
2037
- }
2038
- if (op->src[4]) {
2039
- wsp_ggml_metal_encoder_set_buffer(enc, wsp_ggml_metal_get_buffer_id(op->src[4]), 5);
2040
- } else {
2041
- wsp_ggml_metal_encoder_set_buffer(enc, wsp_ggml_metal_get_buffer_id(op->src[0]), 5);
2042
- }
2043
- wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op), 6);
2427
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
2428
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_src1, 2);
2429
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_src2, 3);
2430
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_src3, 4);
2431
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_src4, 5);
2432
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_pad, 6);
2433
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_blk, 7);
2434
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_dst, 8);
2044
2435
 
2045
2436
  wsp_ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
2046
2437
 
@@ -2048,14 +2439,60 @@ int wsp_ggml_metal_op_flash_attn_ext(wsp_ggml_metal_op_t ctx, int idx) {
2048
2439
  #undef FATTN_SMEM
2049
2440
  } else {
2050
2441
  // half4x4 kernel
2051
- const int64_t nqptg = 1; // queries per threadgroup !! sync with kernel template arguments !!
2052
- const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
2053
- const int64_t nkpsg = 1*ncpsg;
2442
+ const int nqptg = OP_FLASH_ATTN_EXT_VEC_NQPTG; // queries per threadgroup
2443
+ const int ncpsg = OP_FLASH_ATTN_EXT_VEC_NCPSG; // cache values per simdgroup !! sync with kernel template arguments !!
2444
+ const int nkpsg = 1*ncpsg;
2054
2445
 
2055
2446
  WSP_GGML_ASSERT(nqptg <= 32);
2056
2447
  WSP_GGML_ASSERT(nqptg % 1 == 0);
2057
2448
  WSP_GGML_ASSERT(ncpsg % 32 == 0);
2058
2449
 
2450
+ bool need_sync = false;
2451
+
2452
+ const bool has_kvpad = ne11 % ncpsg != 0;
2453
+
2454
+ if (has_kvpad) {
2455
+ assert(wsp_ggml_metal_op_flash_attn_ext_extra_pad(op) != 0);
2456
+
2457
+ wsp_ggml_metal_kargs_flash_attn_ext_pad args0 = {
2458
+ /*.ne11 =*/ne11,
2459
+ /*.ne_12_2 =*/ne12,
2460
+ /*.ne_12_3 =*/ne13,
2461
+ /*.nb11 =*/nb11,
2462
+ /*.nb12 =*/nb12,
2463
+ /*.nb13 =*/nb13,
2464
+ /*.nb21 =*/nb21,
2465
+ /*.nb22 =*/nb22,
2466
+ /*.nb23 =*/nb23,
2467
+ /*.ne31 =*/ne31,
2468
+ /*.ne32 =*/ne32,
2469
+ /*.ne33 =*/ne33,
2470
+ /*.nb31 =*/nb31,
2471
+ /*.nb32 =*/nb32,
2472
+ /*.nb33 =*/nb33,
2473
+ };
2474
+
2475
+ wsp_ggml_metal_pipeline_t pipeline0 = wsp_ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg);
2476
+
2477
+ wsp_ggml_metal_encoder_set_pipeline(enc, pipeline0);
2478
+ wsp_ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0);
2479
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_src1, 1);
2480
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_src2, 2);
2481
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_src3, 3);
2482
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_pad, 4);
2483
+
2484
+ assert(ne12 == ne22);
2485
+ assert(ne13 == ne23);
2486
+
2487
+ wsp_ggml_metal_encoder_dispatch_threadgroups(enc, ncpsg, std::max(ne12, ne32), std::max(ne13, ne33), 32, 1, 1);
2488
+
2489
+ need_sync = true;
2490
+ }
2491
+
2492
+ if (need_sync) {
2493
+ wsp_ggml_metal_op_concurrency_reset(ctx);
2494
+ }
2495
+
2059
2496
  // ne00 + 2*ncpsg*(nsg)
2060
2497
  // for each query, we load it as f16 in shared memory (ne00)
2061
2498
  // and store the soft_max values and the mask
@@ -2120,6 +2557,7 @@ int wsp_ggml_metal_op_flash_attn_ext(wsp_ggml_metal_op_t ctx, int idx) {
2120
2557
  /*.nb21 =*/ nb21,
2121
2558
  /*.nb22 =*/ nb22,
2122
2559
  /*.nb23 =*/ nb23,
2560
+ /*.ne31 =*/ ne31,
2123
2561
  /*.ne32 =*/ ne32,
2124
2562
  /*.ne33 =*/ ne33,
2125
2563
  /*.nb31 =*/ nb31,
@@ -2136,25 +2574,17 @@ int wsp_ggml_metal_op_flash_attn_ext(wsp_ggml_metal_op_t ctx, int idx) {
2136
2574
  /*.logit_softcap =*/ logit_softcap,
2137
2575
  };
2138
2576
 
2139
- wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_flash_attn_ext_vec(lib, op, has_mask, has_sinks, has_bias, has_scap, nsg, nwg);
2577
+ wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_flash_attn_ext_vec(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg, nwg);
2140
2578
 
2141
2579
  WSP_GGML_ASSERT(nsg*32 <= wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
2142
2580
 
2143
2581
  wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
2144
2582
  wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
2145
- wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[0]), 1);
2146
- wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[1]), 2);
2147
- wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[2]), 3);
2148
- if (op->src[3]) {
2149
- wsp_ggml_metal_encoder_set_buffer(enc, wsp_ggml_metal_get_buffer_id(op->src[3]), 4);
2150
- } else {
2151
- wsp_ggml_metal_encoder_set_buffer(enc, wsp_ggml_metal_get_buffer_id(op->src[0]), 4);
2152
- }
2153
- if (op->src[4]) {
2154
- wsp_ggml_metal_encoder_set_buffer(enc, wsp_ggml_metal_get_buffer_id(op->src[4]), 5);
2155
- } else {
2156
- wsp_ggml_metal_encoder_set_buffer(enc, wsp_ggml_metal_get_buffer_id(op->src[0]), 5);
2157
- }
2583
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
2584
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_src1, 2);
2585
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_src2, 3);
2586
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_src3, 4);
2587
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_src4, 5);
2158
2588
 
2159
2589
  const size_t smem = FATTN_SMEM(nsg);
2160
2590
 
@@ -2162,23 +2592,25 @@ int wsp_ggml_metal_op_flash_attn_ext(wsp_ggml_metal_op_t ctx, int idx) {
2162
2592
  WSP_GGML_ASSERT(smem <= props_dev->max_theadgroup_memory_size);
2163
2593
 
2164
2594
  if (nwg == 1) {
2595
+ assert(wsp_ggml_metal_op_flash_attn_ext_extra_tmp(op) == 0);
2596
+
2165
2597
  // using 1 workgroup -> write the result directly into dst
2166
- wsp_ggml_metal_encoder_set_buffer(enc, wsp_ggml_metal_get_buffer_id(op), 6);
2598
+ wsp_ggml_metal_encoder_set_buffer(enc, bid_pad, 6);
2599
+ wsp_ggml_metal_encoder_set_buffer(enc, bid_dst, 7);
2167
2600
 
2168
2601
  wsp_ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
2169
2602
 
2170
2603
  wsp_ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg, 32, nsg, 1);
2171
2604
  } else {
2172
2605
  // sanity checks
2606
+ assert(wsp_ggml_metal_op_flash_attn_ext_extra_tmp(op) != 0);
2607
+
2173
2608
  WSP_GGML_ASSERT(ne01*ne02*ne03 == ne1*ne2*ne3);
2174
2609
  WSP_GGML_ASSERT((uint64_t)ne1*ne2*ne3 <= (1u << 31));
2175
2610
 
2176
- wsp_ggml_metal_buffer_id bid_dst = wsp_ggml_metal_get_buffer_id(op);
2177
-
2178
2611
  // write the results from each workgroup into a temp buffer
2179
- wsp_ggml_metal_buffer_id bid_tmp = bid_dst;
2180
- bid_tmp.offs += wsp_ggml_nbytes(op);
2181
- wsp_ggml_metal_encoder_set_buffer(enc, bid_tmp, 6);
2612
+ wsp_ggml_metal_encoder_set_buffer(enc, bid_pad, 6);
2613
+ wsp_ggml_metal_encoder_set_buffer(enc, bid_tmp, 7);
2182
2614
 
2183
2615
  wsp_ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
2184
2616
  wsp_ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg, 32, nsg, 1);
@@ -2385,7 +2817,7 @@ int wsp_ggml_metal_op_l2_norm(wsp_ggml_metal_op_t ctx, int idx) {
2385
2817
  WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
2386
2818
  WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
2387
2819
  WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
2388
- WSP_GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
2820
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
2389
2821
 
2390
2822
  float eps;
2391
2823
  memcpy(&eps, op->op_params, sizeof(float));
@@ -2433,7 +2865,7 @@ int wsp_ggml_metal_op_group_norm(wsp_ggml_metal_op_t ctx, int idx) {
2433
2865
  WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
2434
2866
  WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
2435
2867
  WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
2436
- WSP_GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
2868
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
2437
2869
 
2438
2870
  const int32_t ngrp = ((const int32_t *) op->op_params)[0];
2439
2871
 
@@ -2488,7 +2920,7 @@ int wsp_ggml_metal_op_norm(wsp_ggml_metal_op_t ctx, int idx) {
2488
2920
  WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
2489
2921
  WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
2490
2922
  WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
2491
- WSP_GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
2923
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
2492
2924
 
2493
2925
  float eps;
2494
2926
  memcpy(&eps, op->op_params, sizeof(float));
@@ -2624,7 +3056,7 @@ int wsp_ggml_metal_op_rope(wsp_ggml_metal_op_t ctx, int idx) {
2624
3056
  WSP_GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
2625
3057
  WSP_GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
2626
3058
  WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
2627
- WSP_GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
3059
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
2628
3060
 
2629
3061
  // make sure we have one or more position id(ne10) per token(ne02)
2630
3062
  WSP_GGML_ASSERT(ne10 % ne02 == 0);
@@ -2688,6 +3120,7 @@ int wsp_ggml_metal_op_rope(wsp_ggml_metal_op_t ctx, int idx) {
2688
3120
  /* sect_1 =*/ sect_1,
2689
3121
  /* sect_2 =*/ sect_2,
2690
3122
  /* sect_3 =*/ sect_3,
3123
+ /* src2 =*/ op->src[2] != nullptr,
2691
3124
  };
2692
3125
 
2693
3126
  wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_rope(lib, op);
@@ -2717,7 +3150,7 @@ int wsp_ggml_metal_op_im2col(wsp_ggml_metal_op_t ctx, int idx) {
2717
3150
  WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
2718
3151
  WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
2719
3152
  WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
2720
- WSP_GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
3153
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
2721
3154
 
2722
3155
  const int32_t s0 = ((const int32_t *)(op->op_params))[0];
2723
3156
  const int32_t s1 = ((const int32_t *)(op->op_params))[1];
@@ -2778,6 +3211,84 @@ int wsp_ggml_metal_op_im2col(wsp_ggml_metal_op_t ctx, int idx) {
2778
3211
  return 1;
2779
3212
  }
2780
3213
 
3214
+ int wsp_ggml_metal_op_conv_2d(wsp_ggml_metal_op_t ctx, int idx) {
3215
+ wsp_ggml_tensor * op = ctx->node(idx);
3216
+
3217
+ wsp_ggml_metal_library_t lib = ctx->lib;
3218
+ wsp_ggml_metal_encoder_t enc = ctx->enc;
3219
+
3220
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3221
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3222
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
3223
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
3224
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3225
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
3226
+
3227
+ WSP_GGML_ASSERT(wsp_ggml_is_contiguous(op->src[0]));
3228
+ WSP_GGML_ASSERT(op->src[1]->type == WSP_GGML_TYPE_F32);
3229
+ WSP_GGML_ASSERT(op->type == WSP_GGML_TYPE_F32);
3230
+ WSP_GGML_ASSERT(op->src[0]->type == WSP_GGML_TYPE_F16 || op->src[0]->type == WSP_GGML_TYPE_F32);
3231
+
3232
+ const int32_t s0 = ((const int32_t *) op->op_params)[0];
3233
+ const int32_t s1 = ((const int32_t *) op->op_params)[1];
3234
+ const int32_t p0 = ((const int32_t *) op->op_params)[2];
3235
+ const int32_t p1 = ((const int32_t *) op->op_params)[3];
3236
+ const int32_t d0 = ((const int32_t *) op->op_params)[4];
3237
+ const int32_t d1 = ((const int32_t *) op->op_params)[5];
3238
+
3239
+ wsp_ggml_metal_kargs_conv_2d args = {
3240
+ /*.nb00 =*/ nb00,
3241
+ /*.nb01 =*/ nb01,
3242
+ /*.nb02 =*/ nb02,
3243
+ /*.nb03 =*/ nb03,
3244
+ /*.nb10 =*/ nb10,
3245
+ /*.nb11 =*/ nb11,
3246
+ /*.nb12 =*/ nb12,
3247
+ /*.nb13 =*/ nb13,
3248
+ /*.nb0 =*/ nb0,
3249
+ /*.nb1 =*/ nb1,
3250
+ /*.nb2 =*/ nb2,
3251
+ /*.nb3 =*/ nb3,
3252
+ /*.IW =*/ ne10,
3253
+ /*.IH =*/ ne11,
3254
+ /*.KW =*/ ne00,
3255
+ /*.KH =*/ ne01,
3256
+ /*.IC =*/ ne02,
3257
+ /*.OC =*/ ne03,
3258
+ /*.OW =*/ ne0,
3259
+ /*.OH =*/ ne1,
3260
+ /*.N =*/ ne3,
3261
+ /*.s0 =*/ s0,
3262
+ /*.s1 =*/ s1,
3263
+ /*.p0 =*/ p0,
3264
+ /*.p1 =*/ p1,
3265
+ /*.d0 =*/ d0,
3266
+ /*.d1 =*/ d1,
3267
+ };
3268
+
3269
+ wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_conv_2d(lib, op);
3270
+
3271
+ int nth = wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline);
3272
+ nth = std::min(nth, 256);
3273
+ nth = std::max(nth, 1);
3274
+
3275
+ const uint64_t n_out = wsp_ggml_nelements(op);
3276
+
3277
+ uint64_t tg = (n_out + nth - 1)/nth;
3278
+ tg = std::max<uint64_t>(tg, 1);
3279
+ tg = std::min<uint64_t>(tg, (uint64_t) std::numeric_limits<int>::max());
3280
+
3281
+ wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
3282
+ wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
3283
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[0]), 1);
3284
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[1]), 2);
3285
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op), 3);
3286
+
3287
+ wsp_ggml_metal_encoder_dispatch_threadgroups(enc, tg, 1, 1, nth, 1, 1);
3288
+
3289
+ return 1;
3290
+ }
3291
+
2781
3292
  int wsp_ggml_metal_op_conv_transpose_1d(wsp_ggml_metal_op_t ctx, int idx) {
2782
3293
  wsp_ggml_tensor * op = ctx->node(idx);
2783
3294
 
@@ -2789,7 +3300,7 @@ int wsp_ggml_metal_op_conv_transpose_1d(wsp_ggml_metal_op_t ctx, int idx) {
2789
3300
  WSP_GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
2790
3301
  WSP_GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
2791
3302
  WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
2792
- WSP_GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
3303
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
2793
3304
 
2794
3305
  const int32_t s0 = ((const int32_t *)(op->op_params))[0];
2795
3306
 
@@ -2823,6 +3334,62 @@ int wsp_ggml_metal_op_conv_transpose_1d(wsp_ggml_metal_op_t ctx, int idx) {
2823
3334
  return 1;
2824
3335
  }
2825
3336
 
3337
+ int wsp_ggml_metal_op_conv_transpose_2d(wsp_ggml_metal_op_t ctx, int idx) {
3338
+ wsp_ggml_tensor * op = ctx->node(idx);
3339
+
3340
+ wsp_ggml_metal_library_t lib = ctx->lib;
3341
+ wsp_ggml_metal_encoder_t enc = ctx->enc;
3342
+
3343
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3344
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3345
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
3346
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
3347
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3348
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
3349
+
3350
+ const int32_t s0 = ((const int32_t *)(op->op_params))[0];
3351
+
3352
+ const int32_t IC = op->src[1]->ne[2];
3353
+ const int32_t IH = op->src[1]->ne[1];
3354
+ const int32_t IW = op->src[1]->ne[0];
3355
+
3356
+ const int32_t KH = op->src[0]->ne[1];
3357
+ const int32_t KW = op->src[0]->ne[0];
3358
+
3359
+ const int32_t OW = op->ne[0];
3360
+ const int32_t OH = op->ne[1];
3361
+ const int32_t OC = op->ne[2];
3362
+
3363
+ wsp_ggml_metal_kargs_conv_transpose_2d args = {
3364
+ /*.IC =*/ IC,
3365
+ /*.IH =*/ IH,
3366
+ /*.IW =*/ IW,
3367
+ /*.KH =*/ KH,
3368
+ /*.KW =*/ KW,
3369
+ /*.OC =*/ OC,
3370
+ /*.s0 =*/ s0,
3371
+ /*.nb0 =*/ nb0,
3372
+ /*.nb1 =*/ nb1,
3373
+ /*.nb2 =*/ nb2,
3374
+ };
3375
+
3376
+ wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_conv_transpose_2d(lib, op);
3377
+
3378
+ wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
3379
+ wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
3380
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[0]), 1);
3381
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[1]), 2);
3382
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op), 3);
3383
+
3384
+ // Metal requires buffer size to be multiple of 16 bytes
3385
+ const size_t smem = WSP_GGML_PAD(KW * KH * sizeof(float), 16);
3386
+ wsp_ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
3387
+
3388
+ wsp_ggml_metal_encoder_dispatch_threadgroups(enc, OW, OH, OC, KW, KH, 1);
3389
+
3390
+ return 1;
3391
+ }
3392
+
2826
3393
  int wsp_ggml_metal_op_upscale(wsp_ggml_metal_op_t ctx, int idx) {
2827
3394
  wsp_ggml_tensor * op = ctx->node(idx);
2828
3395
 
@@ -2832,7 +3399,7 @@ int wsp_ggml_metal_op_upscale(wsp_ggml_metal_op_t ctx, int idx) {
2832
3399
  WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
2833
3400
  WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
2834
3401
  WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
2835
- WSP_GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
3402
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
2836
3403
 
2837
3404
  const float sf0 = (float)ne0/op->src[0]->ne[0];
2838
3405
  const float sf1 = (float)ne1/op->src[0]->ne[1];
@@ -2885,7 +3452,7 @@ int wsp_ggml_metal_op_pad(wsp_ggml_metal_op_t ctx, int idx) {
2885
3452
  WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
2886
3453
  WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
2887
3454
  WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
2888
- WSP_GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
3455
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
2889
3456
 
2890
3457
  wsp_ggml_metal_kargs_pad args = {
2891
3458
  /*.ne00 =*/ ne00,
@@ -2929,7 +3496,7 @@ int wsp_ggml_metal_op_pad_reflect_1d(wsp_ggml_metal_op_t ctx, int idx) {
2929
3496
  WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
2930
3497
  WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
2931
3498
  WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
2932
- WSP_GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
3499
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
2933
3500
 
2934
3501
  wsp_ggml_metal_kargs_pad_reflect_1d args = {
2935
3502
  /*.ne00 =*/ ne00,
@@ -2973,7 +3540,7 @@ int wsp_ggml_metal_op_arange(wsp_ggml_metal_op_t ctx, int idx) {
2973
3540
  wsp_ggml_metal_encoder_t enc = ctx->enc;
2974
3541
 
2975
3542
  WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
2976
- WSP_GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
3543
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
2977
3544
 
2978
3545
  float start;
2979
3546
  float step;
@@ -2991,12 +3558,6 @@ int wsp_ggml_metal_op_arange(wsp_ggml_metal_op_t ctx, int idx) {
2991
3558
 
2992
3559
  wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_arange(lib, op);
2993
3560
 
2994
- //[encoder setComputePipelineState:pipeline];
2995
- //[encoder setBuffer:id_dst offset:offs_dst atIndex:0];
2996
- //[encoder setBytes:&args length:sizeof(args) atIndex:1];
2997
-
2998
- //[encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2999
-
3000
3561
  wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
3001
3562
  wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
3002
3563
  wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op), 1);
@@ -3015,7 +3576,7 @@ int wsp_ggml_metal_op_timestep_embedding(wsp_ggml_metal_op_t ctx, int idx) {
3015
3576
  WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3016
3577
  WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3017
3578
  WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3018
- WSP_GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
3579
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
3019
3580
 
3020
3581
  const int dim = op->op_params[0];
3021
3582
  const int max_period = op->op_params[1];
@@ -3049,7 +3610,7 @@ int wsp_ggml_metal_op_argmax(wsp_ggml_metal_op_t ctx, int idx) {
3049
3610
  WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3050
3611
  WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3051
3612
  WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3052
- WSP_GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
3613
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
3053
3614
 
3054
3615
  wsp_ggml_metal_kargs_argmax args = {
3055
3616
  /*.ne00 = */ ne00,
@@ -3085,38 +3646,93 @@ int wsp_ggml_metal_op_argsort(wsp_ggml_metal_op_t ctx, int idx) {
3085
3646
  wsp_ggml_metal_library_t lib = ctx->lib;
3086
3647
  wsp_ggml_metal_encoder_t enc = ctx->enc;
3087
3648
 
3649
+ WSP_GGML_ASSERT(wsp_ggml_is_contiguous_rows(op->src[0]));
3650
+
3088
3651
  WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3089
3652
  WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3090
3653
  WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3091
- WSP_GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
3654
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
3655
+
3656
+ wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_argsort(lib, op);
3092
3657
 
3093
3658
  // bitonic sort requires the number of elements to be power of 2
3094
- int64_t ne00_padded = 1;
3095
- while (ne00_padded < ne00) {
3096
- ne00_padded *= 2;
3659
+ int nth = 1;
3660
+ while (nth < ne00 && 2*nth <= wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
3661
+ nth *= 2;
3097
3662
  }
3098
3663
 
3099
- wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_argsort(lib, op);
3100
-
3101
- const int64_t nrows = wsp_ggml_nrows(op->src[0]);
3664
+ const int npr = (ne00 + nth - 1)/nth;
3102
3665
 
3103
3666
  // Metal kernels require the buffer size to be multiple of 16 bytes
3104
3667
  // https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/1443142-setthreadgroupmemorylength
3105
- const size_t smem = WSP_GGML_PAD(ne00_padded*sizeof(int32_t), 16);
3668
+ const size_t smem = WSP_GGML_PAD(nth*sizeof(int32_t), 16);
3669
+
3670
+ wsp_ggml_metal_buffer_id bid_src0 = wsp_ggml_metal_get_buffer_id(op->src[0]);
3671
+ wsp_ggml_metal_buffer_id bid_dst = wsp_ggml_metal_get_buffer_id(op);
3672
+
3673
+ wsp_ggml_metal_buffer_id bid_tmp = bid_dst;
3674
+ bid_tmp.offs += wsp_ggml_nbytes(op);
3675
+
3676
+ if ((int) ceil(std::log(npr) / std::log(2)) % 2 == 1) {
3677
+ std::swap(bid_dst, bid_tmp);
3678
+ }
3106
3679
 
3107
3680
  wsp_ggml_metal_kargs_argsort args = {
3108
- /*.ncols =*/ ne00,
3109
- /*.ncols_pad =*/ ne00_padded
3681
+ /*.ne00 =*/ ne00,
3682
+ /*.ne01 =*/ ne01,
3683
+ /*.ne02 =*/ ne02,
3684
+ /*.ne03 =*/ ne03,
3685
+ /*.nb00 =*/ nb00,
3686
+ /*.nb01 =*/ nb01,
3687
+ /*.nb02 =*/ nb02,
3688
+ /*.nb03 =*/ nb03,
3110
3689
  };
3111
3690
 
3112
3691
  wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
3113
3692
  wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
3114
- wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[0]), 1);
3115
- wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op), 2);
3693
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
3694
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
3116
3695
 
3117
3696
  wsp_ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
3118
3697
 
3119
- wsp_ggml_metal_encoder_dispatch_threadgroups(enc, 1, nrows, 1, ne00_padded, 1, 1);
3698
+ wsp_ggml_metal_encoder_dispatch_threadgroups(enc, npr*ne01, ne02, ne03, nth, 1, 1);
3699
+
3700
+ wsp_ggml_metal_pipeline_t pipeline_merge = wsp_ggml_metal_library_get_pipeline_argsort_merge(lib, op);
3701
+
3702
+ int len = nth;
3703
+
3704
+ while (len < ne00) {
3705
+ wsp_ggml_metal_op_concurrency_reset(ctx);
3706
+
3707
+ wsp_ggml_metal_kargs_argsort_merge args_merge = {
3708
+ .ne00 = ne00,
3709
+ .ne01 = ne01,
3710
+ .ne02 = ne02,
3711
+ .ne03 = ne03,
3712
+ .nb00 = nb00,
3713
+ .nb01 = nb01,
3714
+ .nb02 = nb02,
3715
+ .nb03 = nb03,
3716
+ .len = len,
3717
+ };
3718
+
3719
+ // merges per row
3720
+ const int nm = (ne00 + 2*len - 1) / (2*len);
3721
+
3722
+ const int nth = std::min(512, wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline_merge));
3723
+
3724
+ wsp_ggml_metal_encoder_set_pipeline(enc, pipeline_merge);
3725
+ wsp_ggml_metal_encoder_set_bytes (enc, &args_merge, sizeof(args_merge), 0);
3726
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
3727
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
3728
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_tmp, 3);
3729
+
3730
+ wsp_ggml_metal_encoder_dispatch_threadgroups(enc, nm*ne01, ne02, ne03, nth, 1, 1);
3731
+
3732
+ std::swap(bid_dst, bid_tmp);
3733
+
3734
+ len <<= 1;
3735
+ }
3120
3736
 
3121
3737
  return 1;
3122
3738
  }
@@ -3130,7 +3746,7 @@ int wsp_ggml_metal_op_leaky_relu(wsp_ggml_metal_op_t ctx, int idx) {
3130
3746
  WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3131
3747
  WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3132
3748
  WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3133
- WSP_GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
3749
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
3134
3750
 
3135
3751
  float slope;
3136
3752
  memcpy(&slope, op->op_params, sizeof(float));
@@ -3156,3 +3772,73 @@ int wsp_ggml_metal_op_leaky_relu(wsp_ggml_metal_op_t ctx, int idx) {
3156
3772
 
3157
3773
  return 1;
3158
3774
  }
3775
+
3776
+ int wsp_ggml_metal_op_opt_step_adamw(wsp_ggml_metal_op_t ctx, int idx) {
3777
+ wsp_ggml_tensor * op = ctx->node(idx);
3778
+
3779
+ wsp_ggml_metal_library_t lib = ctx->lib;
3780
+ wsp_ggml_metal_encoder_t enc = ctx->enc;
3781
+
3782
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3783
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3784
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3785
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
3786
+
3787
+ wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_opt_step_adamw(lib, op);
3788
+
3789
+ const int64_t np = wsp_ggml_nelements(op->src[0]);
3790
+ wsp_ggml_metal_kargs_opt_step_adamw args = {
3791
+ /*.np =*/ np,
3792
+ };
3793
+
3794
+ int ida = 0;
3795
+
3796
+ wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
3797
+ wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), ida++);
3798
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[0]), ida++);
3799
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[1]), ida++);
3800
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[2]), ida++);
3801
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[3]), ida++);
3802
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[4]), ida++);
3803
+
3804
+ const int nth = std::min(wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0);
3805
+ const int64_t n = (np + nth - 1) / nth;
3806
+
3807
+ wsp_ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, nth, 1, 1);
3808
+
3809
+ return 1;
3810
+ }
3811
+
3812
+ int wsp_ggml_metal_op_opt_step_sgd(wsp_ggml_metal_op_t ctx, int idx) {
3813
+ wsp_ggml_tensor * op = ctx->node(idx);
3814
+
3815
+ wsp_ggml_metal_library_t lib = ctx->lib;
3816
+ wsp_ggml_metal_encoder_t enc = ctx->enc;
3817
+
3818
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3819
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3820
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3821
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
3822
+
3823
+ wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_opt_step_sgd(lib, op);
3824
+
3825
+ const int64_t np = wsp_ggml_nelements(op->src[0]);
3826
+ wsp_ggml_metal_kargs_opt_step_sgd args = {
3827
+ /*.np =*/ np,
3828
+ };
3829
+
3830
+ int ida = 0;
3831
+
3832
+ wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
3833
+ wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), ida++);
3834
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[0]), ida++);
3835
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[1]), ida++);
3836
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[2]), ida++);
3837
+
3838
+ const int nth = std::min(wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0);
3839
+ const int64_t n = (np + nth - 1) / nth;
3840
+
3841
+ wsp_ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, nth, 1, 1);
3842
+
3843
+ return 1;
3844
+ }