whisper.rn 0.5.4 → 0.5.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (91) hide show
  1. package/android/src/main/java/com/rnwhisper/WhisperContext.java +5 -0
  2. package/android/src/main/jni.cpp +13 -0
  3. package/cpp/ggml-alloc.c +78 -26
  4. package/cpp/ggml-alloc.h +9 -0
  5. package/cpp/ggml-backend-impl.h +1 -1
  6. package/cpp/ggml-backend-reg.cpp +19 -3
  7. package/cpp/ggml-backend.cpp +72 -20
  8. package/cpp/ggml-backend.h +2 -1
  9. package/cpp/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
  10. package/cpp/ggml-cpu/arch/arm/repack.cpp +1004 -0
  11. package/cpp/ggml-cpu/arch/x86/repack.cpp +6 -6
  12. package/cpp/ggml-cpu/arch-fallback.h +50 -2
  13. package/cpp/ggml-cpu/ggml-cpu-impl.h +1 -1
  14. package/cpp/ggml-cpu/ggml-cpu.c +139 -58
  15. package/cpp/ggml-cpu/ggml-cpu.cpp +4 -0
  16. package/cpp/ggml-cpu/ops.cpp +170 -18
  17. package/cpp/ggml-cpu/ops.h +1 -0
  18. package/cpp/ggml-cpu/repack.cpp +531 -5
  19. package/cpp/ggml-cpu/repack.h +14 -0
  20. package/cpp/ggml-cpu/simd-mappings.h +16 -18
  21. package/cpp/ggml-cpu/vec.cpp +41 -1
  22. package/cpp/ggml-cpu/vec.h +241 -138
  23. package/cpp/ggml-cpu.h +1 -0
  24. package/cpp/ggml-impl.h +0 -4
  25. package/cpp/ggml-metal/ggml-metal-context.m +26 -16
  26. package/cpp/ggml-metal/ggml-metal-device.cpp +452 -371
  27. package/cpp/ggml-metal/ggml-metal-device.h +87 -65
  28. package/cpp/ggml-metal/ggml-metal-device.m +263 -104
  29. package/cpp/ggml-metal/ggml-metal-impl.h +58 -4
  30. package/cpp/ggml-metal/ggml-metal-ops.cpp +415 -98
  31. package/cpp/ggml-metal/ggml-metal-ops.h +4 -0
  32. package/cpp/ggml-metal/ggml-metal.cpp +6 -5
  33. package/cpp/ggml-metal/ggml-metal.metal +404 -34
  34. package/cpp/ggml.c +110 -31
  35. package/cpp/ggml.h +51 -12
  36. package/cpp/jsi/RNWhisperJSI.cpp +1 -0
  37. package/cpp/whisper.cpp +16 -3
  38. package/ios/CMakeLists.txt +21 -1
  39. package/ios/RNWhisperContext.mm +5 -0
  40. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-alloc.h +9 -0
  41. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +1 -1
  42. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend.h +2 -1
  43. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-cpu.h +1 -0
  44. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-impl.h +0 -4
  45. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml.h +51 -12
  46. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Info.plist +0 -0
  47. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/ggml-metal.metal +404 -34
  48. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/rnwhisper +0 -0
  49. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-alloc.h +9 -0
  50. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +1 -1
  51. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +2 -1
  52. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpu.h +1 -0
  53. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +0 -4
  54. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +51 -12
  55. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
  56. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +1 -1
  57. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/ggml-metal.metal +404 -34
  58. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
  59. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-alloc.h +9 -0
  60. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +1 -1
  61. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend.h +2 -1
  62. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-cpu.h +1 -0
  63. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-impl.h +0 -4
  64. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml.h +51 -12
  65. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Info.plist +0 -0
  66. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/ggml-metal.metal +404 -34
  67. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/rnwhisper +0 -0
  68. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-alloc.h +9 -0
  69. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +1 -1
  70. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +2 -1
  71. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpu.h +1 -0
  72. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +0 -4
  73. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +51 -12
  74. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
  75. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +1 -1
  76. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/ggml-metal.metal +404 -34
  77. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
  78. package/lib/commonjs/NativeRNWhisper.js.map +1 -1
  79. package/lib/commonjs/jest-mock.js +2 -0
  80. package/lib/commonjs/jest-mock.js.map +1 -1
  81. package/lib/commonjs/version.json +1 -1
  82. package/lib/module/NativeRNWhisper.js.map +1 -1
  83. package/lib/module/jest-mock.js +2 -0
  84. package/lib/module/jest-mock.js.map +1 -1
  85. package/lib/module/version.json +1 -1
  86. package/lib/typescript/NativeRNWhisper.d.ts +1 -0
  87. package/lib/typescript/NativeRNWhisper.d.ts.map +1 -1
  88. package/package.json +1 -1
  89. package/src/NativeRNWhisper.ts +1 -0
  90. package/src/jest-mock.ts +2 -0
  91. package/src/version.json +1 -1
@@ -221,7 +221,7 @@ static int wsp_ggml_metal_op_encode_impl(wsp_ggml_metal_op_t ctx, int idx) {
221
221
  }
222
222
 
223
223
  if (ctx->debug_graph > 0) {
224
- WSP_GGML_LOG_DEBUG("%s: node[%5d] - %-12s %s\n", __func__, idx, wsp_ggml_op_name(node->op), is_concurrent ? "(concurrent)" : "");
224
+ WSP_GGML_LOG_DEBUG("%s: node[%5d] - %-12s %-12s %s\n", __func__, idx, wsp_ggml_op_name(node->op), wsp_ggml_get_name(node), is_concurrent ? "(concurrent)" : "");
225
225
  }
226
226
  if (ctx->debug_graph > 1) {
227
227
  WSP_GGML_TENSOR_LOCALS( int64_t, ne0, node->src[0], ne);
@@ -286,6 +286,10 @@ static int wsp_ggml_metal_op_encode_impl(wsp_ggml_metal_op_t ctx, int idx) {
286
286
  {
287
287
  n_fuse = wsp_ggml_metal_op_scale(ctx, idx);
288
288
  } break;
289
+ case WSP_GGML_OP_FILL:
290
+ {
291
+ n_fuse = wsp_ggml_metal_op_fill(ctx, idx);
292
+ } break;
289
293
  case WSP_GGML_OP_CLAMP:
290
294
  {
291
295
  n_fuse = wsp_ggml_metal_op_clamp(ctx, idx);
@@ -406,10 +410,18 @@ static int wsp_ggml_metal_op_encode_impl(wsp_ggml_metal_op_t ctx, int idx) {
406
410
  {
407
411
  n_fuse = wsp_ggml_metal_op_argsort(ctx, idx);
408
412
  } break;
413
+ case WSP_GGML_OP_TOP_K:
414
+ {
415
+ n_fuse = wsp_ggml_metal_op_top_k(ctx, idx);
416
+ } break;
409
417
  case WSP_GGML_OP_LEAKY_RELU:
410
418
  {
411
419
  n_fuse = wsp_ggml_metal_op_leaky_relu(ctx, idx);
412
420
  } break;
421
+ case WSP_GGML_OP_TRI:
422
+ {
423
+ n_fuse = wsp_ggml_metal_op_tri(ctx, idx);
424
+ } break;
413
425
  case WSP_GGML_OP_FLASH_ATTN_EXT:
414
426
  {
415
427
  n_fuse = wsp_ggml_metal_op_flash_attn_ext(ctx, idx);
@@ -436,7 +448,11 @@ static int wsp_ggml_metal_op_encode_impl(wsp_ggml_metal_op_t ctx, int idx) {
436
448
  {
437
449
  n_fuse = wsp_ggml_metal_op_opt_step_sgd(ctx, idx);
438
450
  } break;
439
- default:
451
+ case WSP_GGML_OP_COUNT_EQUAL:
452
+ {
453
+ n_fuse = wsp_ggml_metal_op_count_equal(ctx, idx);
454
+ } break;
455
+ default:
440
456
  {
441
457
  WSP_GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, wsp_ggml_op_name(node->op));
442
458
  WSP_GGML_ABORT("fatal error");
@@ -520,7 +536,7 @@ int wsp_ggml_metal_op_concat(wsp_ggml_metal_op_t ctx, int idx) {
520
536
  /*.dim =*/ dim,
521
537
  };
522
538
 
523
- wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_base(lib, WSP_GGML_OP_CONCAT);
539
+ auto pipeline = wsp_ggml_metal_library_get_pipeline_base(lib, WSP_GGML_OP_CONCAT);
524
540
 
525
541
  wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
526
542
  wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
@@ -546,7 +562,7 @@ int wsp_ggml_metal_op_repeat(wsp_ggml_metal_op_t ctx, int idx) {
546
562
  WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
547
563
  WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
548
564
 
549
- wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_repeat(lib, op->type);
565
+ auto pipeline = wsp_ggml_metal_library_get_pipeline_repeat(lib, op->type);
550
566
 
551
567
  wsp_ggml_metal_kargs_repeat args = {
552
568
  /*.ne00 =*/ ne00,
@@ -612,7 +628,7 @@ int wsp_ggml_metal_op_acc(wsp_ggml_metal_op_t ctx, int idx) {
612
628
  // TODO: make a simpler cpy_bytes kernel
613
629
 
614
630
  //const id<MTLComputePipelineState> pipeline = ctx->pipelines[WSP_GGML_METAL_PIPELINE_TYPE_CPY_F32_F32].obj;
615
- wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type);
631
+ auto pipeline = wsp_ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type);
616
632
 
617
633
  wsp_ggml_metal_kargs_cpy args = {
618
634
  /*.nk0 =*/ ne00,
@@ -675,7 +691,7 @@ int wsp_ggml_metal_op_acc(wsp_ggml_metal_op_t ctx, int idx) {
675
691
  /*.o1 =*/ { 0 },
676
692
  };
677
693
 
678
- wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_bin(lib, WSP_GGML_OP_ADD, 1, false);
694
+ auto pipeline = wsp_ggml_metal_library_get_pipeline_bin(lib, WSP_GGML_OP_ADD, 1, false);
679
695
 
680
696
  wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
681
697
  wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
@@ -717,7 +733,42 @@ int wsp_ggml_metal_op_scale(wsp_ggml_metal_op_t ctx, int idx) {
717
733
  n /= 4;
718
734
  }
719
735
 
720
- wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_unary(lib, op);
736
+ auto pipeline = wsp_ggml_metal_library_get_pipeline_unary(lib, op);
737
+
738
+ wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
739
+ wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
740
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[0]), 1);
741
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op), 2);
742
+
743
+ wsp_ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
744
+
745
+ return 1;
746
+ }
747
+
748
+ int wsp_ggml_metal_op_fill(wsp_ggml_metal_op_t ctx, int idx) {
749
+ wsp_ggml_tensor * op = ctx->node(idx);
750
+
751
+ wsp_ggml_metal_library_t lib = ctx->lib;
752
+ wsp_ggml_metal_encoder_t enc = ctx->enc;
753
+
754
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
755
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
756
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
757
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
758
+
759
+ const float val = wsp_ggml_get_op_params_f32(op, 0);
760
+
761
+ wsp_ggml_metal_kargs_fill args = {
762
+ /*.val =*/ val
763
+ };
764
+
765
+ int64_t n = wsp_ggml_nelements(op);
766
+
767
+ if (n % 4 == 0) {
768
+ n /= 4;
769
+ }
770
+
771
+ auto pipeline = wsp_ggml_metal_library_get_pipeline_unary(lib, op);
721
772
 
722
773
  wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
723
774
  wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
@@ -756,7 +807,7 @@ int wsp_ggml_metal_op_clamp(wsp_ggml_metal_op_t ctx, int idx) {
756
807
  n /= 4;
757
808
  }
758
809
 
759
- wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_unary(lib, op);
810
+ auto pipeline = wsp_ggml_metal_library_get_pipeline_unary(lib, op);
760
811
 
761
812
  wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
762
813
  wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
@@ -785,7 +836,7 @@ int wsp_ggml_metal_op_unary(wsp_ggml_metal_op_t ctx, int idx) {
785
836
  n /= 4;
786
837
  }
787
838
 
788
- wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_unary(lib, op);
839
+ auto pipeline = wsp_ggml_metal_library_get_pipeline_unary(lib, op);
789
840
 
790
841
  wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
791
842
  wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[0]), 0);
@@ -813,7 +864,7 @@ int wsp_ggml_metal_op_glu(wsp_ggml_metal_op_t ctx, int idx) {
813
864
  WSP_GGML_ASSERT(wsp_ggml_are_same_shape(op->src[0], op->src[1]));
814
865
  }
815
866
 
816
- wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_glu(lib, op);
867
+ auto pipeline = wsp_ggml_metal_library_get_pipeline_glu(lib, op);
817
868
 
818
869
  const int32_t swp = wsp_ggml_get_op_params_i32(op, 1);
819
870
  const float alpha = wsp_ggml_get_op_params_f32(op, 2);
@@ -866,7 +917,7 @@ int wsp_ggml_metal_op_sum(wsp_ggml_metal_op_t ctx, int idx) {
866
917
  /*.np =*/ n,
867
918
  };
868
919
 
869
- wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_sum(lib, op);
920
+ auto pipeline = wsp_ggml_metal_library_get_pipeline_sum(lib, op);
870
921
 
871
922
  int nth = 32; // SIMD width
872
923
 
@@ -921,7 +972,7 @@ int wsp_ggml_metal_op_sum_rows(wsp_ggml_metal_op_t ctx, int idx) {
921
972
  /*.nb3 =*/ nb3,
922
973
  };
923
974
 
924
- wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_sum_rows(lib, op);
975
+ auto pipeline = wsp_ggml_metal_library_get_pipeline_sum_rows(lib, op);
925
976
 
926
977
  int nth = 32; // SIMD width
927
978
 
@@ -932,7 +983,7 @@ int wsp_ggml_metal_op_sum_rows(wsp_ggml_metal_op_t ctx, int idx) {
932
983
  nth = std::min(nth, wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
933
984
  nth = std::min(nth, ne00);
934
985
 
935
- const size_t smem = wsp_ggml_metal_pipeline_get_smem(pipeline);
986
+ const size_t smem = pipeline.smem;
936
987
 
937
988
  wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
938
989
  wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
@@ -959,7 +1010,7 @@ int wsp_ggml_metal_op_cumsum(wsp_ggml_metal_op_t ctx, int idx) {
959
1010
  WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
960
1011
  WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
961
1012
 
962
- wsp_ggml_metal_pipeline_t pipeline_blk = wsp_ggml_metal_library_get_pipeline_cumsum_blk(lib, op);
1013
+ auto pipeline_blk = wsp_ggml_metal_library_get_pipeline_cumsum_blk(lib, op);
963
1014
 
964
1015
  int nth = 1;
965
1016
  while (nth < ne00 && 2*nth <= wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline_blk)) {
@@ -1056,7 +1107,7 @@ int wsp_ggml_metal_op_cumsum(wsp_ggml_metal_op_t ctx, int idx) {
1056
1107
  wsp_ggml_metal_op_concurrency_reset(ctx);
1057
1108
 
1058
1109
  {
1059
- wsp_ggml_metal_pipeline_t pipeline_add = wsp_ggml_metal_library_get_pipeline_cumsum_add(lib, op);
1110
+ auto pipeline_add = wsp_ggml_metal_library_get_pipeline_cumsum_add(lib, op);
1060
1111
 
1061
1112
  wsp_ggml_metal_kargs_cumsum_add args = {
1062
1113
  /*.ne00 =*/ ne00,
@@ -1102,7 +1153,7 @@ int wsp_ggml_metal_op_get_rows(wsp_ggml_metal_op_t ctx, int idx) {
1102
1153
  WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
1103
1154
  WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
1104
1155
 
1105
- wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_get_rows(lib, op->src[0]->type);
1156
+ auto pipeline = wsp_ggml_metal_library_get_pipeline_get_rows(lib, op->src[0]->type);
1106
1157
 
1107
1158
  wsp_ggml_metal_kargs_get_rows args = {
1108
1159
  /*.ne00t =*/ wsp_ggml_is_quantized(op->src[0]->type) ? ne00/16 : ne00,
@@ -1147,7 +1198,7 @@ int wsp_ggml_metal_op_set_rows(wsp_ggml_metal_op_t ctx, int idx) {
1147
1198
  WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
1148
1199
  WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
1149
1200
 
1150
- wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_set_rows(lib, op->src[1]->type, op->type);
1201
+ auto pipeline = wsp_ggml_metal_library_get_pipeline_set_rows(lib, op->src[1]->type, op->type);
1151
1202
 
1152
1203
  const int32_t nk0 = ne0/wsp_ggml_blck_size(op->type);
1153
1204
 
@@ -1248,7 +1299,7 @@ int wsp_ggml_metal_op_soft_max(wsp_ggml_metal_op_t ctx, int idx) {
1248
1299
  /*.n_head_log2 =*/ n_head_log2,
1249
1300
  };
1250
1301
 
1251
- wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_soft_max(lib, op);
1302
+ auto pipeline = wsp_ggml_metal_library_get_pipeline_soft_max(lib, op);
1252
1303
 
1253
1304
  int nth = 32; // SIMD width
1254
1305
 
@@ -1262,7 +1313,7 @@ int wsp_ggml_metal_op_soft_max(wsp_ggml_metal_op_t ctx, int idx) {
1262
1313
  }
1263
1314
  }
1264
1315
 
1265
- const size_t smem = wsp_ggml_metal_pipeline_get_smem(pipeline);
1316
+ const size_t smem = pipeline.smem;
1266
1317
 
1267
1318
  wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
1268
1319
  wsp_ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
@@ -1318,15 +1369,43 @@ int wsp_ggml_metal_op_ssm_conv(wsp_ggml_metal_op_t ctx, int idx) {
1318
1369
  /*.nb2 =*/ nb2,
1319
1370
  };
1320
1371
 
1321
- wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_ssm_conv(lib, op);
1372
+ // Use batched kernel for prefill (ne1 > 1) to reduce threadgroup dispatch overhead
1373
+ const bool use_batched = (ne1 > 1);
1322
1374
 
1323
- wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
1324
- wsp_ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
1325
- wsp_ggml_metal_encoder_set_buffer(enc, wsp_ggml_metal_get_buffer_id(op->src[0]), 1);
1326
- wsp_ggml_metal_encoder_set_buffer(enc, wsp_ggml_metal_get_buffer_id(op->src[1]), 2);
1327
- wsp_ggml_metal_encoder_set_buffer(enc, wsp_ggml_metal_get_buffer_id(op), 3);
1375
+ if (use_batched) {
1376
+ // Determine the smallest power of 2 that's >= ne1, but <= 256
1377
+ int BATCH_SIZE;
1378
+ if (ne1 > 128) BATCH_SIZE = 256;
1379
+ else if (ne1 > 64 ) BATCH_SIZE = 128;
1380
+ else if (ne1 > 32 ) BATCH_SIZE = 64;
1381
+ else if (ne1 > 16 ) BATCH_SIZE = 32;
1382
+ else if (ne1 > 8 ) BATCH_SIZE = 16;
1383
+ else if (ne1 > 4 ) BATCH_SIZE = 8;
1384
+ else BATCH_SIZE = 2;
1385
+
1386
+ auto pipeline = wsp_ggml_metal_library_get_pipeline_ssm_conv_batched(lib, op, BATCH_SIZE);
1387
+
1388
+ wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
1389
+ wsp_ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
1390
+ wsp_ggml_metal_encoder_set_buffer(enc, wsp_ggml_metal_get_buffer_id(op->src[0]), 1);
1391
+ wsp_ggml_metal_encoder_set_buffer(enc, wsp_ggml_metal_get_buffer_id(op->src[1]), 2);
1392
+ wsp_ggml_metal_encoder_set_buffer(enc, wsp_ggml_metal_get_buffer_id(op), 3);
1328
1393
 
1329
- wsp_ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne1, ne02, 1, 1, 1);
1394
+ // Dispatch: ne01 rows, ceil(ne1/BATCH_SIZE) token batches, ne02 sequences
1395
+ // Each threadgroup has BATCH_SIZE threads, each handling one token
1396
+ const int n_token_batches = (ne1 + BATCH_SIZE - 1) / BATCH_SIZE;
1397
+ wsp_ggml_metal_encoder_dispatch_threadgroups(enc, ne01, n_token_batches, ne02, BATCH_SIZE, 1, 1);
1398
+ } else {
1399
+ auto pipeline = wsp_ggml_metal_library_get_pipeline_ssm_conv(lib, op);
1400
+
1401
+ wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
1402
+ wsp_ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
1403
+ wsp_ggml_metal_encoder_set_buffer(enc, wsp_ggml_metal_get_buffer_id(op->src[0]), 1);
1404
+ wsp_ggml_metal_encoder_set_buffer(enc, wsp_ggml_metal_get_buffer_id(op->src[1]), 2);
1405
+ wsp_ggml_metal_encoder_set_buffer(enc, wsp_ggml_metal_get_buffer_id(op), 3);
1406
+
1407
+ wsp_ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne1, ne02, 1, 1, 1);
1408
+ }
1330
1409
 
1331
1410
  return 1;
1332
1411
  }
@@ -1405,11 +1484,11 @@ int wsp_ggml_metal_op_ssm_scan(wsp_ggml_metal_op_t ctx, int idx) {
1405
1484
  /*.nb0 =*/ nb0,
1406
1485
  };
1407
1486
 
1408
- wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_ssm_scan(lib, op);
1487
+ auto pipeline = wsp_ggml_metal_library_get_pipeline_ssm_scan(lib, op);
1409
1488
 
1410
1489
  WSP_GGML_ASSERT(d_state <= wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
1411
1490
 
1412
- const size_t sms = wsp_ggml_metal_pipeline_get_smem(pipeline);
1491
+ const size_t smem = pipeline.smem;
1413
1492
 
1414
1493
  wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
1415
1494
  wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
@@ -1422,7 +1501,7 @@ int wsp_ggml_metal_op_ssm_scan(wsp_ggml_metal_op_t ctx, int idx) {
1422
1501
  wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[6]), 7);
1423
1502
  wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op), 8);
1424
1503
 
1425
- wsp_ggml_metal_encoder_set_threadgroup_memory_size(enc, sms, 0);
1504
+ wsp_ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
1426
1505
 
1427
1506
  wsp_ggml_metal_encoder_dispatch_threadgroups(enc, d_inner, n_head, n_seqs, d_state, 1, 1);
1428
1507
 
@@ -1445,7 +1524,7 @@ int wsp_ggml_metal_op_rwkv(wsp_ggml_metal_op_t ctx, int idx) {
1445
1524
  const int64_t C = op->ne[0];
1446
1525
  const int64_t H = op->src[0]->ne[1];
1447
1526
 
1448
- wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_rwkv(lib, op);
1527
+ auto pipeline = wsp_ggml_metal_library_get_pipeline_rwkv(lib, op);
1449
1528
 
1450
1529
  int ida = 0;
1451
1530
 
@@ -1481,7 +1560,7 @@ int wsp_ggml_metal_op_cpy(wsp_ggml_metal_op_t ctx, int idx) {
1481
1560
  WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
1482
1561
  WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
1483
1562
 
1484
- wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type);
1563
+ auto pipeline = wsp_ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type);
1485
1564
 
1486
1565
  WSP_GGML_ASSERT(ne00 % wsp_ggml_blck_size(op->src[0]->type) == 0);
1487
1566
 
@@ -1588,7 +1667,7 @@ int wsp_ggml_metal_op_pool_2d(wsp_ggml_metal_op_t ctx, int idx) {
1588
1667
  /* .np = */ np
1589
1668
  };
1590
1669
 
1591
- wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_pool_2d(lib, op, op_pool);
1670
+ auto pipeline = wsp_ggml_metal_library_get_pipeline_pool_2d(lib, op, op_pool);
1592
1671
 
1593
1672
  const int nth = std::min(wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), (int) np);
1594
1673
  const int ntg = (np + nth - 1) / nth;
@@ -1697,7 +1776,7 @@ int wsp_ggml_metal_op_mul_mat(wsp_ggml_metal_op_t ctx, int idx) {
1697
1776
  WSP_GGML_ABORT("unsupported ne11");
1698
1777
  };
1699
1778
 
1700
- wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_mul_mv_ext(lib, op->src[0]->type, op->src[1]->type, nsg, nxpsg, r1ptg);
1779
+ auto pipeline = wsp_ggml_metal_library_get_pipeline_mul_mv_ext(lib, op->src[0]->type, op->src[1]->type, nsg, nxpsg, r1ptg);
1701
1780
 
1702
1781
  wsp_ggml_metal_kargs_mul_mv_ext args = {
1703
1782
  /*.ne00 =*/ ne00,
@@ -1744,7 +1823,7 @@ int wsp_ggml_metal_op_mul_mat(wsp_ggml_metal_op_t ctx, int idx) {
1744
1823
  // default: break;
1745
1824
  //}
1746
1825
 
1747
- wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_mul_mm(lib, op);
1826
+ auto pipeline = wsp_ggml_metal_library_get_pipeline_mul_mm(lib, op);
1748
1827
 
1749
1828
  wsp_ggml_metal_kargs_mul_mm args = {
1750
1829
  /*.ne00 =*/ ne00,
@@ -1769,18 +1848,18 @@ int wsp_ggml_metal_op_mul_mat(wsp_ggml_metal_op_t ctx, int idx) {
1769
1848
  wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[1]), 2);
1770
1849
  wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op), 3);
1771
1850
 
1772
- const size_t smem = wsp_ggml_metal_pipeline_get_smem(pipeline);
1851
+ const size_t smem = pipeline.smem;
1773
1852
 
1774
1853
  wsp_ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
1775
1854
  wsp_ggml_metal_encoder_dispatch_threadgroups(enc, ((ne11 + 31)/32), ((ne01 + 63)/64), ne12*ne13, 128, 1, 1);
1776
1855
  } else {
1777
- wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_mul_mv(lib, op);
1856
+ auto pipeline = wsp_ggml_metal_library_get_pipeline_mul_mv(lib, op);
1778
1857
 
1779
- const int nr0 = wsp_ggml_metal_pipeline_get_nr0(pipeline);
1780
- const int nr1 = wsp_ggml_metal_pipeline_get_nr1(pipeline);
1781
- const int nsg = wsp_ggml_metal_pipeline_get_nsg(pipeline);
1858
+ const int nr0 = pipeline.nr0;
1859
+ const int nr1 = pipeline.nr1;
1860
+ const int nsg = pipeline.nsg;
1782
1861
 
1783
- const size_t smem = wsp_ggml_metal_pipeline_get_smem(pipeline);
1862
+ const size_t smem = pipeline.smem;
1784
1863
 
1785
1864
  wsp_ggml_metal_kargs_mul_mv args = {
1786
1865
  /*.ne00 =*/ ne00,
@@ -1911,9 +1990,9 @@ int wsp_ggml_metal_op_mul_mat_id(wsp_ggml_metal_op_t ctx, int idx) {
1911
1990
  nb21,
1912
1991
  };
1913
1992
 
1914
- wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_mul_mm_id_map0(lib, ne02, ne20);
1993
+ auto pipeline = wsp_ggml_metal_library_get_pipeline_mul_mm_id_map0(lib, ne02, ne20);
1915
1994
 
1916
- const size_t smem = wsp_ggml_metal_pipeline_get_smem(pipeline);
1995
+ const size_t smem = pipeline.smem;
1917
1996
 
1918
1997
  WSP_GGML_ASSERT(ne02 <= wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
1919
1998
 
@@ -1934,7 +2013,7 @@ int wsp_ggml_metal_op_mul_mat_id(wsp_ggml_metal_op_t ctx, int idx) {
1934
2013
  wsp_ggml_metal_op_concurrency_reset(ctx);
1935
2014
 
1936
2015
  {
1937
- wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_mul_mm_id(lib, op);
2016
+ auto pipeline = wsp_ggml_metal_library_get_pipeline_mul_mm_id(lib, op);
1938
2017
 
1939
2018
  wsp_ggml_metal_kargs_mul_mm_id args = {
1940
2019
  /*.ne00 =*/ ne00,
@@ -1963,20 +2042,20 @@ int wsp_ggml_metal_op_mul_mat_id(wsp_ggml_metal_op_t ctx, int idx) {
1963
2042
  wsp_ggml_metal_encoder_set_buffer (enc, bid_ids, 4);
1964
2043
  wsp_ggml_metal_encoder_set_buffer (enc, bid_dst, 5);
1965
2044
 
1966
- const size_t smem = wsp_ggml_metal_pipeline_get_smem(pipeline);
2045
+ const size_t smem = pipeline.smem;
1967
2046
 
1968
2047
  wsp_ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
1969
2048
 
1970
2049
  wsp_ggml_metal_encoder_dispatch_threadgroups(enc, (ne21 + 31)/32, (ne01 + 63)/64, ne02, 128, 1, 1);
1971
2050
  }
1972
2051
  } else {
1973
- wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_mul_mv_id(lib, op);
2052
+ auto pipeline = wsp_ggml_metal_library_get_pipeline_mul_mv_id(lib, op);
1974
2053
 
1975
- const int nr0 = wsp_ggml_metal_pipeline_get_nr0(pipeline);
1976
- const int nr1 = wsp_ggml_metal_pipeline_get_nr1(pipeline);
1977
- const int nsg = wsp_ggml_metal_pipeline_get_nsg(pipeline);
2054
+ const int nr0 = pipeline.nr0;
2055
+ const int nr1 = pipeline.nr1;
2056
+ const int nsg = pipeline.nsg;
1978
2057
 
1979
- const size_t smem = wsp_ggml_metal_pipeline_get_smem(pipeline);
2058
+ const size_t smem = pipeline.smem;
1980
2059
 
1981
2060
  wsp_ggml_metal_kargs_mul_mv_id args = {
1982
2061
  /*.nei0 =*/ ne20,
@@ -2060,7 +2139,7 @@ int wsp_ggml_metal_op_add_id(wsp_ggml_metal_op_t ctx, int idx) {
2060
2139
  /*.nb21 =*/ nb21,
2061
2140
  };
2062
2141
 
2063
- wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_base(lib, WSP_GGML_OP_ADD_ID);
2142
+ auto pipeline = wsp_ggml_metal_library_get_pipeline_base(lib, WSP_GGML_OP_ADD_ID);
2064
2143
 
2065
2144
  wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
2066
2145
  wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
@@ -2102,7 +2181,11 @@ size_t wsp_ggml_metal_op_flash_attn_ext_extra_pad(const wsp_ggml_tensor * op) {
2102
2181
 
2103
2182
  const bool has_mask = op->src[3] != nullptr;
2104
2183
 
2105
- if (wsp_ggml_metal_op_flash_attn_ext_use_vec(op)) {
2184
+ // note: the non-vec kernel requires more extra memory, so always reserve for it
2185
+ WSP_GGML_ASSERT(OP_FLASH_ATTN_EXT_NCPSG >= OP_FLASH_ATTN_EXT_VEC_NCPSG);
2186
+
2187
+ //if (wsp_ggml_metal_op_flash_attn_ext_use_vec(op)) {
2188
+ if (false) {
2106
2189
  // note: always reserve the padding space to avoid graph reallocations
2107
2190
  //const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_VEC_NCPSG != 0;
2108
2191
  const bool has_kvpad = true;
@@ -2304,7 +2387,7 @@ int wsp_ggml_metal_op_flash_attn_ext(wsp_ggml_metal_op_t ctx, int idx) {
2304
2387
  /*.nb33 =*/nb33,
2305
2388
  };
2306
2389
 
2307
- wsp_ggml_metal_pipeline_t pipeline0 = wsp_ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg);
2390
+ auto pipeline0 = wsp_ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg);
2308
2391
 
2309
2392
  wsp_ggml_metal_encoder_set_pipeline(enc, pipeline0);
2310
2393
  wsp_ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0);
@@ -2335,7 +2418,7 @@ int wsp_ggml_metal_op_flash_attn_ext(wsp_ggml_metal_op_t ctx, int idx) {
2335
2418
  /*.nb33 =*/ nb33,
2336
2419
  };
2337
2420
 
2338
- wsp_ggml_metal_pipeline_t pipeline0 = wsp_ggml_metal_library_get_pipeline_flash_attn_ext_blk(lib, op, nqptg, ncpsg);
2421
+ auto pipeline0 = wsp_ggml_metal_library_get_pipeline_flash_attn_ext_blk(lib, op, nqptg, ncpsg);
2339
2422
 
2340
2423
  wsp_ggml_metal_encoder_set_pipeline(enc, pipeline0);
2341
2424
  wsp_ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0);
@@ -2420,7 +2503,7 @@ int wsp_ggml_metal_op_flash_attn_ext(wsp_ggml_metal_op_t ctx, int idx) {
2420
2503
  /*.logit_softcap =*/ logit_softcap,
2421
2504
  };
2422
2505
 
2423
- wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_flash_attn_ext(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg);
2506
+ auto pipeline = wsp_ggml_metal_library_get_pipeline_flash_attn_ext(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg);
2424
2507
 
2425
2508
  wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
2426
2509
  wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
@@ -2472,7 +2555,7 @@ int wsp_ggml_metal_op_flash_attn_ext(wsp_ggml_metal_op_t ctx, int idx) {
2472
2555
  /*.nb33 =*/nb33,
2473
2556
  };
2474
2557
 
2475
- wsp_ggml_metal_pipeline_t pipeline0 = wsp_ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg);
2558
+ auto pipeline0 = wsp_ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg);
2476
2559
 
2477
2560
  wsp_ggml_metal_encoder_set_pipeline(enc, pipeline0);
2478
2561
  wsp_ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0);
@@ -2574,7 +2657,7 @@ int wsp_ggml_metal_op_flash_attn_ext(wsp_ggml_metal_op_t ctx, int idx) {
2574
2657
  /*.logit_softcap =*/ logit_softcap,
2575
2658
  };
2576
2659
 
2577
- wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_flash_attn_ext_vec(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg, nwg);
2660
+ auto pipeline = wsp_ggml_metal_library_get_pipeline_flash_attn_ext_vec(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg, nwg);
2578
2661
 
2579
2662
  WSP_GGML_ASSERT(nsg*32 <= wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
2580
2663
 
@@ -2626,7 +2709,7 @@ int wsp_ggml_metal_op_flash_attn_ext(wsp_ggml_metal_op_t ctx, int idx) {
2626
2709
  nrows,
2627
2710
  };
2628
2711
 
2629
- wsp_ggml_metal_pipeline_t pipeline0 = wsp_ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce(lib, op, ne20, nwg);
2712
+ auto pipeline0 = wsp_ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce(lib, op, ne20, nwg);
2630
2713
 
2631
2714
  wsp_ggml_metal_encoder_set_pipeline(enc, pipeline0);
2632
2715
  wsp_ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0);
@@ -2758,7 +2841,7 @@ int wsp_ggml_metal_op_bin(wsp_ggml_metal_op_t ctx, int idx) {
2758
2841
  // the offsets of src1 and all fused buffers are relative to the start of the src1 buffer
2759
2842
  bid_src1.offs = 0;
2760
2843
 
2761
- wsp_ggml_metal_pipeline_t pipeline = nullptr;
2844
+ struct wsp_ggml_metal_pipeline_with_params pipeline;
2762
2845
 
2763
2846
  if (wsp_ggml_nelements(op->src[1]) == ne10 && wsp_ggml_is_contiguous(op->src[1]) && ne00 % 4 == 0 && ne10 % 4 == 0) {
2764
2847
  WSP_GGML_ASSERT(wsp_ggml_is_contiguous(op->src[0]));
@@ -2831,7 +2914,7 @@ int wsp_ggml_metal_op_l2_norm(wsp_ggml_metal_op_t ctx, int idx) {
2831
2914
  /*.eps =*/ eps,
2832
2915
  };
2833
2916
 
2834
- wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_l2_norm(lib, op);
2917
+ auto pipeline = wsp_ggml_metal_library_get_pipeline_l2_norm(lib, op);
2835
2918
 
2836
2919
  while (nth < ne00/4 && nth < wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
2837
2920
  nth *= 2;
@@ -2840,7 +2923,7 @@ int wsp_ggml_metal_op_l2_norm(wsp_ggml_metal_op_t ctx, int idx) {
2840
2923
  nth = std::min(nth, wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
2841
2924
  nth = std::min(nth, ne00/4);
2842
2925
 
2843
- const size_t smem = wsp_ggml_metal_pipeline_get_smem(pipeline);
2926
+ const size_t smem = pipeline.smem;
2844
2927
 
2845
2928
  const int64_t nrows = wsp_ggml_nrows(op->src[0]);
2846
2929
 
@@ -2883,7 +2966,7 @@ int wsp_ggml_metal_op_group_norm(wsp_ggml_metal_op_t ctx, int idx) {
2883
2966
  /*.eps =*/ eps,
2884
2967
  };
2885
2968
 
2886
- wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_group_norm(lib, op);
2969
+ auto pipeline = wsp_ggml_metal_library_get_pipeline_group_norm(lib, op);
2887
2970
 
2888
2971
  int nth = 32; // SIMD width
2889
2972
  //while (nth < ne00/4 && nth < wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
@@ -2893,7 +2976,7 @@ int wsp_ggml_metal_op_group_norm(wsp_ggml_metal_op_t ctx, int idx) {
2893
2976
  //nth = std::min(nth, wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
2894
2977
  //nth = std::min(nth, ne00/4);
2895
2978
 
2896
- const size_t smem = wsp_ggml_metal_pipeline_get_smem(pipeline);
2979
+ const size_t smem = pipeline.smem;
2897
2980
 
2898
2981
  wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
2899
2982
  wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
@@ -3018,7 +3101,7 @@ int wsp_ggml_metal_op_norm(wsp_ggml_metal_op_t ctx, int idx) {
3018
3101
  }
3019
3102
  }
3020
3103
 
3021
- wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_norm(lib, op, n_fuse);
3104
+ auto pipeline = wsp_ggml_metal_library_get_pipeline_norm(lib, op, n_fuse);
3022
3105
 
3023
3106
  int nth = 32; // SIMD width
3024
3107
 
@@ -3029,7 +3112,7 @@ int wsp_ggml_metal_op_norm(wsp_ggml_metal_op_t ctx, int idx) {
3029
3112
  nth = std::min(nth, wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
3030
3113
  nth = std::min(nth, args.ne00_t);
3031
3114
 
3032
- const size_t smem = wsp_ggml_metal_pipeline_get_smem(pipeline);
3115
+ const size_t smem = pipeline.smem;
3033
3116
 
3034
3117
  wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
3035
3118
  wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
@@ -3123,7 +3206,7 @@ int wsp_ggml_metal_op_rope(wsp_ggml_metal_op_t ctx, int idx) {
3123
3206
  /* src2 =*/ op->src[2] != nullptr,
3124
3207
  };
3125
3208
 
3126
- wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_rope(lib, op);
3209
+ auto pipeline = wsp_ggml_metal_library_get_pipeline_rope(lib, op);
3127
3210
 
3128
3211
  wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
3129
3212
  wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
@@ -3195,7 +3278,7 @@ int wsp_ggml_metal_op_im2col(wsp_ggml_metal_op_t ctx, int idx) {
3195
3278
  /*.KHW =*/ KH * KW,
3196
3279
  };
3197
3280
 
3198
- wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_im2col(lib, op);
3281
+ auto pipeline = wsp_ggml_metal_library_get_pipeline_im2col(lib, op);
3199
3282
 
3200
3283
  WSP_GGML_ASSERT(KH*KW <= wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
3201
3284
 
@@ -3266,7 +3349,7 @@ int wsp_ggml_metal_op_conv_2d(wsp_ggml_metal_op_t ctx, int idx) {
3266
3349
  /*.d1 =*/ d1,
3267
3350
  };
3268
3351
 
3269
- wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_conv_2d(lib, op);
3352
+ auto pipeline = wsp_ggml_metal_library_get_pipeline_conv_2d(lib, op);
3270
3353
 
3271
3354
  int nth = wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline);
3272
3355
  nth = std::min(nth, 256);
@@ -3321,7 +3404,7 @@ int wsp_ggml_metal_op_conv_transpose_1d(wsp_ggml_metal_op_t ctx, int idx) {
3321
3404
  /*.nb1 =*/ nb1,
3322
3405
  };
3323
3406
 
3324
- wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_conv_transpose_1d(lib, op);
3407
+ auto pipeline = wsp_ggml_metal_library_get_pipeline_conv_transpose_1d(lib, op);
3325
3408
 
3326
3409
  wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
3327
3410
  wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
@@ -3373,7 +3456,7 @@ int wsp_ggml_metal_op_conv_transpose_2d(wsp_ggml_metal_op_t ctx, int idx) {
3373
3456
  /*.nb2 =*/ nb2,
3374
3457
  };
3375
3458
 
3376
- wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_conv_transpose_2d(lib, op);
3459
+ auto pipeline = wsp_ggml_metal_library_get_pipeline_conv_transpose_2d(lib, op);
3377
3460
 
3378
3461
  wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
3379
3462
  wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
@@ -3429,7 +3512,7 @@ int wsp_ggml_metal_op_upscale(wsp_ggml_metal_op_t ctx, int idx) {
3429
3512
  /*.sf3 =*/ sf3
3430
3513
  };
3431
3514
 
3432
- wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_upscale(lib, op);
3515
+ auto pipeline = wsp_ggml_metal_library_get_pipeline_upscale(lib, op);
3433
3516
 
3434
3517
  const int nth = std::min(wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0);
3435
3518
 
@@ -3473,7 +3556,7 @@ int wsp_ggml_metal_op_pad(wsp_ggml_metal_op_t ctx, int idx) {
3473
3556
  /*.nb3 =*/ nb3
3474
3557
  };
3475
3558
 
3476
- wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_pad(lib, op);
3559
+ auto pipeline = wsp_ggml_metal_library_get_pipeline_pad(lib, op);
3477
3560
 
3478
3561
  const int nth = std::min(1024, ne0);
3479
3562
 
@@ -3519,7 +3602,7 @@ int wsp_ggml_metal_op_pad_reflect_1d(wsp_ggml_metal_op_t ctx, int idx) {
3519
3602
  /*.p1 =*/ ((const int32_t *)(op->op_params))[1]
3520
3603
  };
3521
3604
 
3522
- wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_pad_reflect_1d(lib, op);
3605
+ auto pipeline = wsp_ggml_metal_library_get_pipeline_pad_reflect_1d(lib, op);
3523
3606
 
3524
3607
  const int nth = std::min(1024, ne0);
3525
3608
 
@@ -3556,7 +3639,7 @@ int wsp_ggml_metal_op_arange(wsp_ggml_metal_op_t ctx, int idx) {
3556
3639
 
3557
3640
  const int nth = std::min(1024, ne0);
3558
3641
 
3559
- wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_arange(lib, op);
3642
+ auto pipeline = wsp_ggml_metal_library_get_pipeline_arange(lib, op);
3560
3643
 
3561
3644
  wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
3562
3645
  wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
@@ -3587,7 +3670,7 @@ int wsp_ggml_metal_op_timestep_embedding(wsp_ggml_metal_op_t ctx, int idx) {
3587
3670
  /*.max_period =*/ max_period,
3588
3671
  };
3589
3672
 
3590
- wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_timestep_embedding(lib, op);
3673
+ auto pipeline = wsp_ggml_metal_library_get_pipeline_timestep_embedding(lib, op);
3591
3674
 
3592
3675
  const int nth = std::max(1, std::min(1024, dim/2));
3593
3676
 
@@ -3617,7 +3700,7 @@ int wsp_ggml_metal_op_argmax(wsp_ggml_metal_op_t ctx, int idx) {
3617
3700
  /*.nb01 = */ nb01,
3618
3701
  };
3619
3702
 
3620
- wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_argmax(lib, op);
3703
+ auto pipeline = wsp_ggml_metal_library_get_pipeline_argmax(lib, op);
3621
3704
 
3622
3705
  const int64_t nrows = wsp_ggml_nrows(op->src[0]);
3623
3706
 
@@ -3626,7 +3709,7 @@ int wsp_ggml_metal_op_argmax(wsp_ggml_metal_op_t ctx, int idx) {
3626
3709
  nth *= 2;
3627
3710
  }
3628
3711
 
3629
- const size_t smem = wsp_ggml_metal_pipeline_get_smem(pipeline);
3712
+ const size_t smem = pipeline.smem;
3630
3713
 
3631
3714
  wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
3632
3715
  wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
@@ -3653,7 +3736,7 @@ int wsp_ggml_metal_op_argsort(wsp_ggml_metal_op_t ctx, int idx) {
3653
3736
  WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3654
3737
  WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
3655
3738
 
3656
- wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_argsort(lib, op);
3739
+ auto pipeline = wsp_ggml_metal_library_get_pipeline_argsort(lib, op);
3657
3740
 
3658
3741
  // bitonic sort requires the number of elements to be power of 2
3659
3742
  int nth = 1;
@@ -3678,14 +3761,19 @@ int wsp_ggml_metal_op_argsort(wsp_ggml_metal_op_t ctx, int idx) {
3678
3761
  }
3679
3762
 
3680
3763
  wsp_ggml_metal_kargs_argsort args = {
3681
- /*.ne00 =*/ ne00,
3682
- /*.ne01 =*/ ne01,
3683
- /*.ne02 =*/ ne02,
3684
- /*.ne03 =*/ ne03,
3685
- /*.nb00 =*/ nb00,
3686
- /*.nb01 =*/ nb01,
3687
- /*.nb02 =*/ nb02,
3688
- /*.nb03 =*/ nb03,
3764
+ /*.ne00 =*/ ne00,
3765
+ /*.ne01 =*/ ne01,
3766
+ /*.ne02 =*/ ne02,
3767
+ /*.ne03 =*/ ne03,
3768
+ /*.nb00 =*/ nb00,
3769
+ /*.nb01 =*/ nb01,
3770
+ /*.nb02 =*/ nb02,
3771
+ /*.nb03 =*/ nb03,
3772
+ /*.ne0 =*/ ne0,
3773
+ /*.ne1 =*/ ne1,
3774
+ /*.ne2 =*/ ne2,
3775
+ /*.ne3 =*/ ne3,
3776
+ /*.top_k =*/ nth,
3689
3777
  };
3690
3778
 
3691
3779
  wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
@@ -3697,7 +3785,7 @@ int wsp_ggml_metal_op_argsort(wsp_ggml_metal_op_t ctx, int idx) {
3697
3785
 
3698
3786
  wsp_ggml_metal_encoder_dispatch_threadgroups(enc, npr*ne01, ne02, ne03, nth, 1, 1);
3699
3787
 
3700
- wsp_ggml_metal_pipeline_t pipeline_merge = wsp_ggml_metal_library_get_pipeline_argsort_merge(lib, op);
3788
+ auto pipeline_merge = wsp_ggml_metal_library_get_pipeline_argsort_merge(lib, op);
3701
3789
 
3702
3790
  int len = nth;
3703
3791
 
@@ -3705,15 +3793,20 @@ int wsp_ggml_metal_op_argsort(wsp_ggml_metal_op_t ctx, int idx) {
3705
3793
  wsp_ggml_metal_op_concurrency_reset(ctx);
3706
3794
 
3707
3795
  wsp_ggml_metal_kargs_argsort_merge args_merge = {
3708
- .ne00 = ne00,
3709
- .ne01 = ne01,
3710
- .ne02 = ne02,
3711
- .ne03 = ne03,
3712
- .nb00 = nb00,
3713
- .nb01 = nb01,
3714
- .nb02 = nb02,
3715
- .nb03 = nb03,
3716
- .len = len,
3796
+ /*.ne00 =*/ ne00,
3797
+ /*.ne01 =*/ ne01,
3798
+ /*.ne02 =*/ ne02,
3799
+ /*.ne03 =*/ ne03,
3800
+ /*.nb00 =*/ nb00,
3801
+ /*.nb01 =*/ nb01,
3802
+ /*.nb02 =*/ nb02,
3803
+ /*.nb03 =*/ nb03,
3804
+ /*.ne0 =*/ ne0,
3805
+ /*.ne1 =*/ ne1,
3806
+ /*.ne2 =*/ ne2,
3807
+ /*.ne3 =*/ ne3,
3808
+ /*.top_k =*/ ne00,
3809
+ /*.len =*/ len,
3717
3810
  };
3718
3811
 
3719
3812
  // merges per row
@@ -3737,6 +3830,118 @@ int wsp_ggml_metal_op_argsort(wsp_ggml_metal_op_t ctx, int idx) {
3737
3830
  return 1;
3738
3831
  }
3739
3832
 
3833
+ int wsp_ggml_metal_op_top_k(wsp_ggml_metal_op_t ctx, int idx) {
3834
+ wsp_ggml_tensor * op = ctx->node(idx);
3835
+
3836
+ wsp_ggml_metal_library_t lib = ctx->lib;
3837
+ wsp_ggml_metal_encoder_t enc = ctx->enc;
3838
+
3839
+ WSP_GGML_ASSERT(wsp_ggml_is_contiguous_rows(op->src[0]));
3840
+
3841
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3842
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3843
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3844
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
3845
+
3846
+ auto pipeline = wsp_ggml_metal_library_get_pipeline_top_k(lib, op);
3847
+
3848
+ // bitonic sort requires the number of elements to be power of 2
3849
+ int nth = 1;
3850
+ while (nth < ne00 && 2*nth <= wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
3851
+ nth *= 2;
3852
+ }
3853
+
3854
+ // blocks per row
3855
+ const int npr = (ne00 + nth - 1)/nth;
3856
+
3857
+ const size_t smem = WSP_GGML_PAD(nth*sizeof(int32_t), 16);
3858
+
3859
+ wsp_ggml_metal_buffer_id bid_src0 = wsp_ggml_metal_get_buffer_id(op->src[0]);
3860
+ wsp_ggml_metal_buffer_id bid_dst = wsp_ggml_metal_get_buffer_id(op);
3861
+
3862
+ wsp_ggml_metal_buffer_id bid_tmp = bid_dst;
3863
+ bid_tmp.offs += sizeof(int32_t)*wsp_ggml_nelements(op->src[0]);
3864
+
3865
+ if ((int) ceil(std::log(npr) / std::log(2)) % 2 == 1) {
3866
+ std::swap(bid_dst, bid_tmp);
3867
+ }
3868
+
3869
+ const int top_k = ne0;
3870
+
3871
+ wsp_ggml_metal_kargs_argsort args = {
3872
+ /*.ne00 =*/ ne00,
3873
+ /*.ne01 =*/ ne01,
3874
+ /*.ne02 =*/ ne02,
3875
+ /*.ne03 =*/ ne03,
3876
+ /*.nb00 =*/ nb00,
3877
+ /*.nb01 =*/ nb01,
3878
+ /*.nb02 =*/ nb02,
3879
+ /*.nb03 =*/ nb03,
3880
+ /*.ne0 =*/ ne0,
3881
+ /*.ne1 =*/ ne1,
3882
+ /*.ne2 =*/ ne2,
3883
+ /*.ne3 =*/ ne3,
3884
+ /*.top_k =*/ std::min(nth, top_k), // for each block, keep just the top_k indices
3885
+ };
3886
+
3887
+ if (npr > 1) {
3888
+ args.ne0 = (npr - 1)*args.top_k + std::min(ne00 - (npr - 1)*nth, args.top_k);
3889
+ }
3890
+
3891
+ wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
3892
+ wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
3893
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
3894
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
3895
+
3896
+ wsp_ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
3897
+
3898
+ wsp_ggml_metal_encoder_dispatch_threadgroups(enc, npr*ne01, ne02, ne03, nth, 1, 1);
3899
+
3900
+ auto pipeline_merge = wsp_ggml_metal_library_get_pipeline_top_k_merge(lib, op);
3901
+
3902
+ int len = args.top_k;
3903
+
3904
+ while (len < args.ne0) {
3905
+ wsp_ggml_metal_op_concurrency_reset(ctx);
3906
+
3907
+ // merges per row
3908
+ const int nm = (args.ne0 + 2*len - 1) / (2*len);
3909
+
3910
+ const int nth = std::min(512, std::min(len, wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline_merge)));
3911
+
3912
+ wsp_ggml_metal_kargs_argsort_merge args_merge = {
3913
+ /*.ne00 =*/ ne00,
3914
+ /*.ne01 =*/ ne01,
3915
+ /*.ne02 =*/ ne02,
3916
+ /*.ne03 =*/ ne03,
3917
+ /*.nb00 =*/ nb00,
3918
+ /*.nb01 =*/ nb01,
3919
+ /*.nb02 =*/ nb02,
3920
+ /*.nb03 =*/ nb03,
3921
+ /*.ne0 =*/ args.ne0,
3922
+ /*.ne1 =*/ ne1,
3923
+ /*.ne2 =*/ ne2,
3924
+ /*.ne3 =*/ ne3,
3925
+ /*.top_k =*/ nm == 1 ? top_k : args.ne0, // the final merge outputs top_k elements
3926
+ /*.len =*/ len,
3927
+ };
3928
+
3929
+ wsp_ggml_metal_encoder_set_pipeline(enc, pipeline_merge);
3930
+ wsp_ggml_metal_encoder_set_bytes (enc, &args_merge, sizeof(args_merge), 0);
3931
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
3932
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
3933
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_tmp, 3);
3934
+
3935
+ wsp_ggml_metal_encoder_dispatch_threadgroups(enc, nm*ne01, ne02, ne03, nth, 1, 1);
3936
+
3937
+ std::swap(bid_dst, bid_tmp);
3938
+
3939
+ len <<= 1;
3940
+ }
3941
+
3942
+ return 1;
3943
+ }
3944
+
3740
3945
  int wsp_ggml_metal_op_leaky_relu(wsp_ggml_metal_op_t ctx, int idx) {
3741
3946
  wsp_ggml_tensor * op = ctx->node(idx);
3742
3947
 
@@ -3755,7 +3960,7 @@ int wsp_ggml_metal_op_leaky_relu(wsp_ggml_metal_op_t ctx, int idx) {
3755
3960
  /*.slope =*/ slope
3756
3961
  };
3757
3962
 
3758
- wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_unary(lib, op);
3963
+ auto pipeline = wsp_ggml_metal_library_get_pipeline_unary(lib, op);
3759
3964
 
3760
3965
  int64_t n = wsp_ggml_nelements(op);
3761
3966
 
@@ -3773,6 +3978,57 @@ int wsp_ggml_metal_op_leaky_relu(wsp_ggml_metal_op_t ctx, int idx) {
3773
3978
  return 1;
3774
3979
  }
3775
3980
 
3981
+ int wsp_ggml_metal_op_tri(wsp_ggml_metal_op_t ctx, int idx) {
3982
+ wsp_ggml_tensor * op = ctx->node(idx);
3983
+
3984
+ wsp_ggml_metal_library_t lib = ctx->lib;
3985
+ wsp_ggml_metal_encoder_t enc = ctx->enc;
3986
+
3987
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3988
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3989
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3990
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
3991
+
3992
+ wsp_ggml_metal_kargs_tri args = {
3993
+ /*.ne00 =*/ ne00,
3994
+ /*.ne01 =*/ ne01,
3995
+ /*.ne02 =*/ ne02,
3996
+ /*.ne03 =*/ ne03,
3997
+ /*.nb00 =*/ nb00,
3998
+ /*.nb01 =*/ nb01,
3999
+ /*.nb02 =*/ nb02,
4000
+ /*.nb03 =*/ nb03,
4001
+ /*.ne0 =*/ ne0,
4002
+ /*.ne1 =*/ ne1,
4003
+ /*.ne2 =*/ ne2,
4004
+ /*.ne3 =*/ ne3,
4005
+ /*.nb0 =*/ nb0,
4006
+ /*.nb1 =*/ nb1,
4007
+ /*.nb2 =*/ nb2,
4008
+ /*.nb3 =*/ nb3,
4009
+ };
4010
+
4011
+ auto pipeline = wsp_ggml_metal_library_get_pipeline_tri(lib, op);
4012
+
4013
+ int nth = 32; // SIMD width
4014
+
4015
+ while (nth < ne00 && nth < wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
4016
+ nth *= 2;
4017
+ }
4018
+
4019
+ nth = std::min(nth, wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
4020
+ nth = std::min(nth, ne00);
4021
+
4022
+ wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
4023
+ wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
4024
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[0]), 1);
4025
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op), 2);
4026
+
4027
+ wsp_ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
4028
+
4029
+ return 1;
4030
+ }
4031
+
3776
4032
  int wsp_ggml_metal_op_opt_step_adamw(wsp_ggml_metal_op_t ctx, int idx) {
3777
4033
  wsp_ggml_tensor * op = ctx->node(idx);
3778
4034
 
@@ -3784,7 +4040,7 @@ int wsp_ggml_metal_op_opt_step_adamw(wsp_ggml_metal_op_t ctx, int idx) {
3784
4040
  WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3785
4041
  WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
3786
4042
 
3787
- wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_opt_step_adamw(lib, op);
4043
+ auto pipeline = wsp_ggml_metal_library_get_pipeline_opt_step_adamw(lib, op);
3788
4044
 
3789
4045
  const int64_t np = wsp_ggml_nelements(op->src[0]);
3790
4046
  wsp_ggml_metal_kargs_opt_step_adamw args = {
@@ -3820,7 +4076,7 @@ int wsp_ggml_metal_op_opt_step_sgd(wsp_ggml_metal_op_t ctx, int idx) {
3820
4076
  WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3821
4077
  WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
3822
4078
 
3823
- wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_opt_step_sgd(lib, op);
4079
+ auto pipeline = wsp_ggml_metal_library_get_pipeline_opt_step_sgd(lib, op);
3824
4080
 
3825
4081
  const int64_t np = wsp_ggml_nelements(op->src[0]);
3826
4082
  wsp_ggml_metal_kargs_opt_step_sgd args = {
@@ -3842,3 +4098,64 @@ int wsp_ggml_metal_op_opt_step_sgd(wsp_ggml_metal_op_t ctx, int idx) {
3842
4098
 
3843
4099
  return 1;
3844
4100
  }
4101
+
4102
+ int wsp_ggml_metal_op_count_equal(wsp_ggml_metal_op_t ctx, int idx) {
4103
+ wsp_ggml_tensor * op = ctx->node(idx);
4104
+
4105
+ wsp_ggml_metal_library_t lib = ctx->lib;
4106
+ wsp_ggml_metal_encoder_t enc = ctx->enc;
4107
+
4108
+ WSP_GGML_TENSOR_LOCALS(int32_t, ne0, op->src[0], ne);
4109
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
4110
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
4111
+
4112
+ {
4113
+ wsp_ggml_metal_kargs_memset args = { /*.val =*/ 0 };
4114
+
4115
+ auto pipeline = wsp_ggml_metal_library_get_pipeline_memset(lib, op);
4116
+
4117
+ wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
4118
+ wsp_ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
4119
+ wsp_ggml_metal_encoder_set_buffer(enc, wsp_ggml_metal_get_buffer_id(op), 1);
4120
+
4121
+ wsp_ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, 1, 1, 1);
4122
+ }
4123
+
4124
+ wsp_ggml_metal_op_concurrency_reset(ctx);
4125
+
4126
+ {
4127
+ wsp_ggml_metal_kargs_count_equal args = {
4128
+ /*.ne00 =*/ ne00,
4129
+ /*.ne01 =*/ ne01,
4130
+ /*.ne02 =*/ ne02,
4131
+ /*.ne03 =*/ ne03,
4132
+ /*.nb00 =*/ nb00,
4133
+ /*.nb01 =*/ nb01,
4134
+ /*.nb02 =*/ nb02,
4135
+ /*.nb03 =*/ nb03,
4136
+ /*.nb10 =*/ nb10,
4137
+ /*.nb11 =*/ nb11,
4138
+ /*.nb12 =*/ nb12,
4139
+ /*.nb13 =*/ nb13,
4140
+ };
4141
+
4142
+ auto pipeline = wsp_ggml_metal_library_get_pipeline_count_equal(lib, op);
4143
+
4144
+ const size_t smem = pipeline.smem;
4145
+
4146
+ const int nth = 32*pipeline.nsg;
4147
+
4148
+ WSP_GGML_ASSERT(nth <= wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
4149
+
4150
+ wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
4151
+ wsp_ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
4152
+ wsp_ggml_metal_encoder_set_buffer(enc, wsp_ggml_metal_get_buffer_id(op->src[0]), 1);
4153
+ wsp_ggml_metal_encoder_set_buffer(enc, wsp_ggml_metal_get_buffer_id(op->src[1]), 2);
4154
+ wsp_ggml_metal_encoder_set_buffer(enc, wsp_ggml_metal_get_buffer_id(op), 3);
4155
+
4156
+ wsp_ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
4157
+ wsp_ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
4158
+ }
4159
+
4160
+ return 1;
4161
+ }