whisper.rn 0.5.1 → 0.5.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. package/android/src/main/jni.cpp +12 -3
  2. package/cpp/ggml-alloc.c +38 -14
  3. package/cpp/ggml-backend-impl.h +0 -3
  4. package/cpp/ggml-backend.h +2 -0
  5. package/cpp/ggml-cpu/amx/amx.cpp +1 -0
  6. package/cpp/ggml-cpu/ggml-cpu-impl.h +1 -1
  7. package/cpp/ggml-cpu/ggml-cpu.c +17 -3
  8. package/cpp/ggml-cpu/ops.cpp +33 -17
  9. package/cpp/ggml-cpu/unary-ops.cpp +135 -0
  10. package/cpp/ggml-cpu/unary-ops.h +5 -0
  11. package/cpp/ggml-cpu/vec.cpp +66 -0
  12. package/cpp/ggml-cpu/vec.h +10 -8
  13. package/cpp/ggml-impl.h +51 -2
  14. package/cpp/ggml-metal/ggml-metal-common.cpp +2 -2
  15. package/cpp/ggml-metal/ggml-metal-device.cpp +199 -10
  16. package/cpp/ggml-metal/ggml-metal-device.h +18 -0
  17. package/cpp/ggml-metal/ggml-metal-device.m +27 -14
  18. package/cpp/ggml-metal/ggml-metal-impl.h +87 -7
  19. package/cpp/ggml-metal/ggml-metal-ops.cpp +513 -88
  20. package/cpp/ggml-metal/ggml-metal-ops.h +6 -0
  21. package/cpp/ggml-metal/ggml-metal.cpp +3 -3
  22. package/cpp/ggml-metal/ggml-whisper-sim.metallib +0 -0
  23. package/cpp/ggml-metal/ggml-whisper.metallib +0 -0
  24. package/cpp/ggml.c +166 -2
  25. package/cpp/ggml.h +66 -0
  26. package/cpp/jsi/RNWhisperJSI.cpp +7 -2
  27. package/cpp/rn-whisper.h +1 -0
  28. package/cpp/whisper.cpp +4 -2
  29. package/ios/RNWhisperContext.mm +3 -1
  30. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +0 -3
  31. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend.h +2 -0
  32. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-impl.h +51 -2
  33. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml.h +66 -0
  34. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/rn-whisper.h +1 -0
  35. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
  36. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/rnwhisper +0 -0
  37. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +0 -3
  38. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +2 -0
  39. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +51 -2
  40. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +66 -0
  41. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-whisper.h +1 -0
  42. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
  43. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
  44. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +0 -3
  45. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend.h +2 -0
  46. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-impl.h +51 -2
  47. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml.h +66 -0
  48. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/rn-whisper.h +1 -0
  49. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
  50. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/rnwhisper +0 -0
  51. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +0 -3
  52. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +2 -0
  53. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +51 -2
  54. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +66 -0
  55. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-whisper.h +1 -0
  56. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
  57. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
  58. package/lib/commonjs/NativeRNWhisper.js.map +1 -1
  59. package/lib/commonjs/version.json +1 -1
  60. package/lib/module/NativeRNWhisper.js.map +1 -1
  61. package/lib/module/version.json +1 -1
  62. package/lib/typescript/NativeRNWhisper.d.ts +2 -0
  63. package/lib/typescript/NativeRNWhisper.d.ts.map +1 -1
  64. package/package.json +1 -1
  65. package/src/NativeRNWhisper.ts +2 -0
  66. package/src/version.json +1 -1
@@ -226,6 +226,10 @@ static int wsp_ggml_metal_op_encode_impl(wsp_ggml_metal_op_t ctx, int idx) {
226
226
  WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, node->src[0], nb);
227
227
  WSP_GGML_TENSOR_LOCALS( int64_t, ne1, node->src[1], ne);
228
228
  WSP_GGML_TENSOR_LOCALS(uint64_t, nb1, node->src[1], nb);
229
+ WSP_GGML_TENSOR_LOCALS( int64_t, ne2, node->src[2], ne);
230
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb2, node->src[2], nb);
231
+ WSP_GGML_TENSOR_LOCALS( int64_t, ne3, node->src[3], ne);
232
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb3, node->src[3], nb);
229
233
  WSP_GGML_TENSOR_LOCALS( int64_t, ne, node, ne);
230
234
  WSP_GGML_TENSOR_LOCALS(uint64_t, nb, node, nb);
231
235
 
@@ -237,6 +241,14 @@ static int wsp_ggml_metal_op_encode_impl(wsp_ggml_metal_op_t ctx, int idx) {
237
241
  WSP_GGML_LOG_DEBUG("%s: src1 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, wsp_ggml_type_name(node->src[1]->type), ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13,
238
242
  wsp_ggml_is_contiguous(node->src[1]), node->src[1]->name);
239
243
  }
244
+ if (node->src[2]) {
245
+ WSP_GGML_LOG_DEBUG("%s: src2 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, wsp_ggml_type_name(node->src[2]->type), ne20, ne21, ne22, ne23, nb20, nb21, nb22, nb23,
246
+ wsp_ggml_is_contiguous(node->src[2]), node->src[2]->name);
247
+ }
248
+ if (node->src[3]) {
249
+ WSP_GGML_LOG_DEBUG("%s: src3 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, wsp_ggml_type_name(node->src[3]->type), ne30, ne31, ne32, ne33, nb30, nb31, nb32, nb33,
250
+ wsp_ggml_is_contiguous(node->src[3]), node->src[3]->name);
251
+ }
240
252
  if (node) {
241
253
  WSP_GGML_LOG_DEBUG("%s: node - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], 1, %s\n", __func__, wsp_ggml_type_name(node->type), ne0, ne1, ne2, ne3, nb0, nb1, nb2, nb3,
242
254
  node->name);
@@ -289,6 +301,10 @@ static int wsp_ggml_metal_op_encode_impl(wsp_ggml_metal_op_t ctx, int idx) {
289
301
  {
290
302
  n_fuse = wsp_ggml_metal_op_glu(ctx, idx);
291
303
  } break;
304
+ case WSP_GGML_OP_SUM:
305
+ {
306
+ n_fuse = wsp_ggml_metal_op_sum(ctx, idx);
307
+ } break;
292
308
  case WSP_GGML_OP_SUM_ROWS:
293
309
  case WSP_GGML_OP_MEAN:
294
310
  {
@@ -352,6 +368,10 @@ static int wsp_ggml_metal_op_encode_impl(wsp_ggml_metal_op_t ctx, int idx) {
352
368
  {
353
369
  n_fuse = wsp_ggml_metal_op_conv_transpose_1d(ctx, idx);
354
370
  } break;
371
+ case WSP_GGML_OP_CONV_TRANSPOSE_2D:
372
+ {
373
+ n_fuse = wsp_ggml_metal_op_conv_transpose_2d(ctx, idx);
374
+ } break;
355
375
  case WSP_GGML_OP_UPSCALE:
356
376
  {
357
377
  n_fuse = wsp_ggml_metal_op_upscale(ctx, idx);
@@ -398,6 +418,14 @@ static int wsp_ggml_metal_op_encode_impl(wsp_ggml_metal_op_t ctx, int idx) {
398
418
  {
399
419
  n_fuse = wsp_ggml_metal_op_argmax(ctx, idx);
400
420
  } break;
421
+ case WSP_GGML_OP_OPT_STEP_ADAMW:
422
+ {
423
+ n_fuse = wsp_ggml_metal_op_opt_step_adamw(ctx, idx);
424
+ } break;
425
+ case WSP_GGML_OP_OPT_STEP_SGD:
426
+ {
427
+ n_fuse = wsp_ggml_metal_op_opt_step_sgd(ctx, idx);
428
+ } break;
401
429
  default:
402
430
  {
403
431
  WSP_GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, wsp_ggml_op_name(node->op));
@@ -577,6 +605,7 @@ int wsp_ggml_metal_op_acc(wsp_ggml_metal_op_t ctx, int idx) {
577
605
  wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type);
578
606
 
579
607
  wsp_ggml_metal_kargs_cpy args = {
608
+ /*.nk0 =*/ ne00,
580
609
  /*.ne00 =*/ ne00,
581
610
  /*.ne01 =*/ ne01,
582
611
  /*.ne02 =*/ ne02,
@@ -827,6 +856,43 @@ int wsp_ggml_metal_op_glu(wsp_ggml_metal_op_t ctx, int idx) {
827
856
  return 1;
828
857
  }
829
858
 
859
+ int wsp_ggml_metal_op_sum(wsp_ggml_metal_op_t ctx, int idx) {
860
+ wsp_ggml_tensor * op = ctx->node(idx);
861
+
862
+ wsp_ggml_metal_library_t lib = ctx->lib;
863
+ wsp_ggml_metal_encoder_t enc = ctx->enc;
864
+
865
+ const uint64_t n = (uint64_t) wsp_ggml_nelements(op->src[0]);
866
+
867
+ wsp_ggml_metal_kargs_sum args = {
868
+ /*.np =*/ n,
869
+ };
870
+
871
+ wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_sum(lib, op);
872
+
873
+ int nth = 32; // SIMD width
874
+
875
+ while (nth < (int) n && nth < wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
876
+ nth *= 2;
877
+ }
878
+
879
+ nth = std::min(nth, wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
880
+ nth = std::min(nth, (int) n);
881
+
882
+ const int nsg = (nth + 31) / 32;
883
+
884
+ wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
885
+ wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
886
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[0]), 1);
887
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op), 2);
888
+
889
+ wsp_ggml_metal_encoder_set_threadgroup_memory_size(enc, nsg * sizeof(float), 0);
890
+
891
+ wsp_ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, nth, 1, 1);
892
+
893
+ return 1;
894
+ }
895
+
830
896
  int wsp_ggml_metal_op_sum_rows(wsp_ggml_metal_op_t ctx, int idx) {
831
897
  wsp_ggml_tensor * op = ctx->node(idx);
832
898
 
@@ -906,23 +972,31 @@ int wsp_ggml_metal_op_get_rows(wsp_ggml_metal_op_t ctx, int idx) {
906
972
  wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_get_rows(lib, op->src[0]->type);
907
973
 
908
974
  wsp_ggml_metal_kargs_get_rows args = {
909
- /*.ne00 =*/ ne00,
910
- /*.nb01 =*/ nb01,
911
- /*.nb02 =*/ nb02,
912
- /*.ne10 =*/ ne10,
913
- /*.nb10 =*/ nb10,
914
- /*.nb11 =*/ nb11,
915
- /*.nb1 =*/ nb1,
916
- /*.nb2 =*/ nb2,
975
+ /*.ne00t =*/ wsp_ggml_is_quantized(op->src[0]->type) ? ne00/16 : ne00,
976
+ /*.ne00 =*/ ne00,
977
+ /*.nb01 =*/ nb01,
978
+ /*.nb02 =*/ nb02,
979
+ /*.nb03 =*/ nb03,
980
+ /*.ne10 =*/ ne10,
981
+ /*.nb10 =*/ nb10,
982
+ /*.nb11 =*/ nb11,
983
+ /*.nb12 =*/ nb12,
984
+ /*.nb1 =*/ nb1,
985
+ /*.nb2 =*/ nb2,
986
+ /*.nb3 =*/ nb3,
917
987
  };
918
988
 
989
+ const int nth = std::min(args.ne00t, wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
990
+
991
+ const int nw0 = (args.ne00t + nth - 1)/nth;
992
+
919
993
  wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
920
994
  wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
921
995
  wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[0]), 1);
922
996
  wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[1]), 2);
923
997
  wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op), 3);
924
998
 
925
- wsp_ggml_metal_encoder_dispatch_threadgroups(enc, ne10, ne11, ne12, 32, 1, 1);
999
+ wsp_ggml_metal_encoder_dispatch_threadgroups(enc, nw0*ne10, ne11, ne12, nth, 1, 1);
926
1000
 
927
1001
  return 1;
928
1002
  }
@@ -1117,7 +1191,7 @@ int wsp_ggml_metal_op_ssm_conv(wsp_ggml_metal_op_t ctx, int idx) {
1117
1191
  wsp_ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
1118
1192
  wsp_ggml_metal_encoder_set_buffer(enc, wsp_ggml_metal_get_buffer_id(op->src[0]), 1);
1119
1193
  wsp_ggml_metal_encoder_set_buffer(enc, wsp_ggml_metal_get_buffer_id(op->src[1]), 2);
1120
- wsp_ggml_metal_encoder_set_buffer(enc, wsp_ggml_metal_get_buffer_id(op), 3);
1194
+ wsp_ggml_metal_encoder_set_buffer(enc, wsp_ggml_metal_get_buffer_id(op), 3);
1121
1195
 
1122
1196
  wsp_ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne1, ne02, 1, 1, 1);
1123
1197
 
@@ -1172,25 +1246,36 @@ int wsp_ggml_metal_op_ssm_scan(wsp_ggml_metal_op_t ctx, int idx) {
1172
1246
  /*.n_seq_tokens =*/ n_seq_tokens,
1173
1247
  /*.n_seqs =*/ n_seqs,
1174
1248
  /*.s_off =*/ wsp_ggml_nelements(op->src[1]) * sizeof(float),
1249
+ /*.nb00 =*/ nb00,
1175
1250
  /*.nb01 =*/ nb01,
1176
1251
  /*.nb02 =*/ nb02,
1177
1252
  /*.nb03 =*/ nb03,
1253
+ /*.nb10 =*/ nb10,
1178
1254
  /*.nb11 =*/ nb11,
1179
1255
  /*.nb12 =*/ nb12,
1256
+ /*.ns12 =*/ nb12/nb10,
1180
1257
  /*.nb13 =*/ nb13,
1258
+ /*.nb20 =*/ nb20,
1181
1259
  /*.nb21 =*/ nb21,
1260
+ /*.ns21 =*/ nb21/nb20,
1182
1261
  /*.nb22 =*/ nb22,
1262
+ /*.ne30 =*/ ne30,
1183
1263
  /*.nb31 =*/ nb31,
1184
1264
  /*.nb41 =*/ nb41,
1185
1265
  /*.nb42 =*/ nb42,
1266
+ /*.ns42 =*/ nb42/nb40,
1186
1267
  /*.nb43 =*/ nb43,
1187
1268
  /*.nb51 =*/ nb51,
1188
1269
  /*.nb52 =*/ nb52,
1270
+ /*.ns52 =*/ nb52/nb50,
1189
1271
  /*.nb53 =*/ nb53,
1272
+ /*.nb0 =*/ nb0,
1190
1273
  };
1191
1274
 
1192
1275
  wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_ssm_scan(lib, op);
1193
1276
 
1277
+ WSP_GGML_ASSERT(d_state <= wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
1278
+
1194
1279
  const size_t sms = wsp_ggml_metal_pipeline_get_smem(pipeline);
1195
1280
 
1196
1281
  wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
@@ -1206,13 +1291,7 @@ int wsp_ggml_metal_op_ssm_scan(wsp_ggml_metal_op_t ctx, int idx) {
1206
1291
 
1207
1292
  wsp_ggml_metal_encoder_set_threadgroup_memory_size(enc, sms, 0);
1208
1293
 
1209
- if (ne30 == 1) {
1210
- // Mamba-2
1211
- wsp_ggml_metal_encoder_dispatch_threadgroups(enc, d_inner, n_head, n_seqs, d_state, 1, 1);
1212
- } else {
1213
- WSP_GGML_ASSERT(d_inner == 1);
1214
- wsp_ggml_metal_encoder_dispatch_threadgroups(enc, n_head, n_seqs, 1, d_state, 1, 1);
1215
- }
1294
+ wsp_ggml_metal_encoder_dispatch_threadgroups(enc, d_inner, n_head, n_seqs, d_state, 1, 1);
1216
1295
 
1217
1296
  return 1;
1218
1297
  }
@@ -1273,26 +1352,23 @@ int wsp_ggml_metal_op_cpy(wsp_ggml_metal_op_t ctx, int idx) {
1273
1352
 
1274
1353
  WSP_GGML_ASSERT(ne00 % wsp_ggml_blck_size(op->src[0]->type) == 0);
1275
1354
 
1276
- // TODO: support
1277
- //const int32_t nk00 = ne00/wsp_ggml_blck_size(op->type);
1278
- const int32_t nk00 = ne00;
1279
-
1280
- int nth = 32; // SIMD width
1281
-
1282
- while (nth < nk00 && nth < wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
1283
- nth *= 2;
1355
+ int64_t nk0 = ne00;
1356
+ if (wsp_ggml_is_quantized(op->src[0]->type)) {
1357
+ nk0 = ne00/16;
1358
+ } else if (wsp_ggml_is_quantized(op->type)) {
1359
+ nk0 = ne00/wsp_ggml_blck_size(op->type);
1284
1360
  }
1285
1361
 
1286
- nth = std::min(nth, wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
1362
+ int nth = std::min<int>(nk0, wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
1287
1363
 
1288
1364
  // when rows are small, we can batch them together in a single threadgroup
1289
1365
  int nrptg = 1;
1290
1366
 
1291
1367
  // TODO: relax this constraint in the future
1292
1368
  if (wsp_ggml_blck_size(op->src[0]->type) == 1 && wsp_ggml_blck_size(op->type) == 1) {
1293
- if (nth > nk00) {
1294
- nrptg = (nth + nk00 - 1)/nk00;
1295
- nth = nk00;
1369
+ if (nth > nk0) {
1370
+ nrptg = (nth + nk0 - 1)/nk0;
1371
+ nth = nk0;
1296
1372
 
1297
1373
  if (nrptg*nth > wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
1298
1374
  nrptg--;
@@ -1300,10 +1376,11 @@ int wsp_ggml_metal_op_cpy(wsp_ggml_metal_op_t ctx, int idx) {
1300
1376
  }
1301
1377
  }
1302
1378
 
1303
- nth = std::min(nth, nk00);
1379
+ nth = std::min<int>(nth, nk0);
1304
1380
 
1305
1381
  wsp_ggml_metal_kargs_cpy args = {
1306
- /*.ne00 =*/ nk00,
1382
+ /*.nk0 =*/ nk0,
1383
+ /*.ne00 =*/ ne00,
1307
1384
  /*.ne01 =*/ ne01,
1308
1385
  /*.ne02 =*/ ne02,
1309
1386
  /*.ne03 =*/ ne03,
@@ -1321,12 +1398,14 @@ int wsp_ggml_metal_op_cpy(wsp_ggml_metal_op_t ctx, int idx) {
1321
1398
  /*.nb3 =*/ nb3,
1322
1399
  };
1323
1400
 
1401
+ const int nw0 = nrptg == 1 ? (nk0 + nth - 1)/nth : 1;
1402
+
1324
1403
  wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
1325
1404
  wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
1326
1405
  wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[0]), 1);
1327
1406
  wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op), 2);
1328
1407
 
1329
- wsp_ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, nrptg, 1);
1408
+ wsp_ggml_metal_encoder_dispatch_threadgroups(enc, nw0*(ne01 + nrptg - 1)/nrptg, ne02, ne03, nth, nrptg, 1);
1330
1409
 
1331
1410
  return 1;
1332
1411
  }
@@ -1520,9 +1599,8 @@ int wsp_ggml_metal_op_mul_mat(wsp_ggml_metal_op_t ctx, int idx) {
1520
1599
  !wsp_ggml_is_transposed(op->src[1]) &&
1521
1600
  // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
1522
1601
  // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
1523
- props_dev->has_simdgroup_mm && ne00 >= 64 &&
1524
- (ne11 > ne11_mm_min || (wsp_ggml_is_quantized(op->src[0]->type) && ne12 > 1))) {
1525
- //printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
1602
+ props_dev->has_simdgroup_mm && ne00 >= 64 && ne11 > ne11_mm_min) {
1603
+ //WSP_GGML_LOG_INFO("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
1526
1604
 
1527
1605
  // some Metal matrix data types require aligned pointers
1528
1606
  // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
@@ -1875,20 +1953,107 @@ bool wsp_ggml_metal_op_flash_attn_ext_use_vec(const wsp_ggml_tensor * op) {
1875
1953
  return (ne01 < 20) && (ne00 % 32 == 0);
1876
1954
  }
1877
1955
 
1956
+ size_t wsp_ggml_metal_op_flash_attn_ext_extra_pad(const wsp_ggml_tensor * op) {
1957
+ assert(op->op == WSP_GGML_OP_FLASH_ATTN_EXT);
1958
+
1959
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
1960
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
1961
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
1962
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
1963
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
1964
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
1965
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne);
1966
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb);
1967
+
1968
+ size_t res = 0;
1969
+
1970
+ const bool has_mask = op->src[3] != nullptr;
1971
+
1972
+ if (wsp_ggml_metal_op_flash_attn_ext_use_vec(op)) {
1973
+ const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_VEC_NCPSG != 0;
1974
+
1975
+ if (has_kvpad) {
1976
+ res += OP_FLASH_ATTN_EXT_VEC_NCPSG*(
1977
+ nb11*ne12*ne13 +
1978
+ nb21*ne22*ne23 +
1979
+ (has_mask ? wsp_ggml_type_size(WSP_GGML_TYPE_F16)*ne31*ne32*ne33 : 0));
1980
+ }
1981
+ } else {
1982
+ const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_NCPSG != 0;
1983
+
1984
+ if (has_kvpad) {
1985
+ res += OP_FLASH_ATTN_EXT_NCPSG*(
1986
+ nb11*ne12*ne13 +
1987
+ nb21*ne22*ne23 +
1988
+ (has_mask ? wsp_ggml_type_size(WSP_GGML_TYPE_F16)*ne31*ne32*ne33 : 0));
1989
+ }
1990
+ }
1991
+
1992
+ return res;
1993
+ }
1994
+
1995
+ size_t wsp_ggml_metal_op_flash_attn_ext_extra_blk(const wsp_ggml_tensor * op) {
1996
+ assert(op->op == WSP_GGML_OP_FLASH_ATTN_EXT);
1997
+
1998
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
1999
+ //WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
2000
+ //WSP_GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
2001
+ //WSP_GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
2002
+ //WSP_GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
2003
+ //WSP_GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
2004
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne);
2005
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb);
2006
+
2007
+ size_t res = 0;
2008
+
2009
+ const bool has_mask = op->src[3] != nullptr;
2010
+
2011
+ if (!has_mask) {
2012
+ return res;
2013
+ }
2014
+
2015
+ const bool is_vec = wsp_ggml_metal_op_flash_attn_ext_use_vec(op);
2016
+
2017
+ // this optimization is not useful for the vector kernels
2018
+ if (is_vec) {
2019
+ return res;
2020
+ }
2021
+
2022
+ const int nqptg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NQPTG : OP_FLASH_ATTN_EXT_NQPTG;
2023
+ const int ncpsg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NCPSG : OP_FLASH_ATTN_EXT_NCPSG;
2024
+
2025
+ const int64_t ne1 = (ne01 + nqptg - 1)/nqptg;
2026
+ const int64_t ne0 = (ne30 + ncpsg - 1)/ncpsg;
2027
+
2028
+ res += WSP_GGML_PAD(wsp_ggml_type_size(WSP_GGML_TYPE_I8)*ne0*ne1*ne32*ne33, 32);
2029
+
2030
+ return res;
2031
+ }
2032
+
1878
2033
  size_t wsp_ggml_metal_op_flash_attn_ext_extra_tmp(const wsp_ggml_tensor * op) {
1879
2034
  assert(op->op == WSP_GGML_OP_FLASH_ATTN_EXT);
1880
2035
 
1881
- const int64_t nwg = 32;
2036
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
2037
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
2038
+ //WSP_GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
2039
+ //WSP_GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
2040
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
2041
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
2042
+ //WSP_GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne);
2043
+ //WSP_GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb);
2044
+
2045
+ size_t res = 0;
1882
2046
 
1883
- const int64_t ne01 = op->src[0]->ne[1];
1884
- const int64_t ne02 = op->src[0]->ne[2];
1885
- const int64_t ne03 = op->src[0]->ne[3];
1886
- const int64_t ne20 = op->src[2]->ne[0];
2047
+ if (wsp_ggml_metal_op_flash_attn_ext_use_vec(op)) {
2048
+ const int64_t nwg = 32;
1887
2049
 
1888
- // temp buffer for writing the results from each workgroup
1889
- // - ne20: the size of the Value head
1890
- // - + 2: the S and M values for each intermediate result
1891
- return wsp_ggml_type_size(WSP_GGML_TYPE_F32)*(ne01*ne02*ne03*nwg*(ne20 + 2));
2050
+ // temp buffer for writing the results from each workgroup
2051
+ // - ne20: the size of the Value head
2052
+ // - + 2: the S and M values for each intermediate result
2053
+ res += wsp_ggml_type_size(WSP_GGML_TYPE_F32)*(ne01*ne02*ne03*nwg*(ne20 + 2));
2054
+ }
2055
+
2056
+ return res;
1892
2057
  }
1893
2058
 
1894
2059
  int wsp_ggml_metal_op_flash_attn_ext(wsp_ggml_metal_op_t ctx, int idx) {
@@ -1910,8 +2075,7 @@ int wsp_ggml_metal_op_flash_attn_ext(wsp_ggml_metal_op_t ctx, int idx) {
1910
2075
  WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
1911
2076
  WSP_GGML_TENSOR_LOCALS( int32_t, nb, op, nb);
1912
2077
 
1913
- WSP_GGML_ASSERT(ne00 % 4 == 0);
1914
- WSP_GGML_ASSERT(ne11 % 32 == 0);
2078
+ WSP_GGML_ASSERT(ne00 % 4 == 0);
1915
2079
 
1916
2080
  WSP_GGML_ASSERT(op->src[0]->type == WSP_GGML_TYPE_F32);
1917
2081
  WSP_GGML_ASSERT(op->src[1]->type == op->src[2]->type);
@@ -1921,8 +2085,8 @@ int wsp_ggml_metal_op_flash_attn_ext(wsp_ggml_metal_op_t ctx, int idx) {
1921
2085
  WSP_GGML_ASSERT(ne12 == ne22);
1922
2086
 
1923
2087
  WSP_GGML_ASSERT(!op->src[3] || op->src[3]->type == WSP_GGML_TYPE_F16);
1924
- WSP_GGML_ASSERT(!op->src[3] || op->src[3]->ne[1] >= WSP_GGML_PAD(op->src[0]->ne[1], 8) &&
1925
- "the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big");
2088
+ WSP_GGML_ASSERT(!op->src[3] || op->src[3]->ne[1] >= op->src[0]->ne[1] &&
2089
+ "the Flash-Attention Metal kernel requires the mask to be at least n_queries big");
1926
2090
 
1927
2091
  float scale;
1928
2092
  float max_bias;
@@ -1949,15 +2113,111 @@ int wsp_ggml_metal_op_flash_attn_ext(wsp_ggml_metal_op_t ctx, int idx) {
1949
2113
 
1950
2114
  WSP_GGML_ASSERT(ne01 < 65536);
1951
2115
 
2116
+ wsp_ggml_metal_buffer_id bid_src0 = wsp_ggml_metal_get_buffer_id(op->src[0]);
2117
+ wsp_ggml_metal_buffer_id bid_src1 = wsp_ggml_metal_get_buffer_id(op->src[1]);
2118
+ wsp_ggml_metal_buffer_id bid_src2 = wsp_ggml_metal_get_buffer_id(op->src[2]);
2119
+ wsp_ggml_metal_buffer_id bid_src3 = has_mask ? wsp_ggml_metal_get_buffer_id(op->src[3]) : bid_src0;
2120
+ wsp_ggml_metal_buffer_id bid_src4 = has_sinks ? wsp_ggml_metal_get_buffer_id(op->src[4]) : bid_src0;
2121
+
2122
+ wsp_ggml_metal_buffer_id bid_dst = wsp_ggml_metal_get_buffer_id(op);
2123
+
2124
+ wsp_ggml_metal_buffer_id bid_pad = bid_dst;
2125
+ bid_pad.offs += wsp_ggml_nbytes(op);
2126
+
2127
+ wsp_ggml_metal_buffer_id bid_blk = bid_pad;
2128
+ bid_blk.offs += wsp_ggml_metal_op_flash_attn_ext_extra_pad(op);
2129
+
2130
+ wsp_ggml_metal_buffer_id bid_tmp = bid_blk;
2131
+ bid_tmp.offs += wsp_ggml_metal_op_flash_attn_ext_extra_blk(op);
2132
+
1952
2133
  if (!wsp_ggml_metal_op_flash_attn_ext_use_vec(op)) {
1953
2134
  // half8x8 kernel
1954
- const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !!
1955
- const int64_t ncpsg = 64; // cache values per simdgroup !! sync with kernel template arguments !!
2135
+ const int nqptg = OP_FLASH_ATTN_EXT_NQPTG; // queries per threadgroup
2136
+ const int ncpsg = OP_FLASH_ATTN_EXT_NCPSG; // cache values per simdgroup
1956
2137
 
1957
2138
  WSP_GGML_ASSERT(nqptg <= 32);
1958
2139
  WSP_GGML_ASSERT(nqptg % 8 == 0);
1959
2140
  WSP_GGML_ASSERT(ncpsg % 32 == 0);
1960
2141
 
2142
+ bool need_sync = false;
2143
+
2144
+ const bool has_kvpad = ne11 % ncpsg != 0;
2145
+
2146
+ if (has_kvpad) {
2147
+ assert(wsp_ggml_metal_op_flash_attn_ext_extra_pad(op) != 0);
2148
+
2149
+ wsp_ggml_metal_kargs_flash_attn_ext_pad args0 = {
2150
+ /*.ne11 =*/ne11,
2151
+ /*.ne_12_2 =*/ne12,
2152
+ /*.ne_12_3 =*/ne13,
2153
+ /*.nb11 =*/nb11,
2154
+ /*.nb12 =*/nb12,
2155
+ /*.nb13 =*/nb13,
2156
+ /*.nb21 =*/nb21,
2157
+ /*.nb22 =*/nb22,
2158
+ /*.nb23 =*/nb23,
2159
+ /*.ne31 =*/ne31,
2160
+ /*.ne32 =*/ne32,
2161
+ /*.ne33 =*/ne33,
2162
+ /*.nb31 =*/nb31,
2163
+ /*.nb32 =*/nb32,
2164
+ /*.nb33 =*/nb33,
2165
+ };
2166
+
2167
+ wsp_ggml_metal_pipeline_t pipeline0 = wsp_ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg);
2168
+
2169
+ wsp_ggml_metal_encoder_set_pipeline(enc, pipeline0);
2170
+ wsp_ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0);
2171
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_src1, 1);
2172
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_src2, 2);
2173
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_src3, 3);
2174
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_pad, 4);
2175
+
2176
+ assert(ne12 == ne22);
2177
+ assert(ne13 == ne23);
2178
+
2179
+ wsp_ggml_metal_encoder_dispatch_threadgroups(enc, ncpsg, std::max(ne12, ne32), std::max(ne13, ne33), 32, 1, 1);
2180
+
2181
+ need_sync = true;
2182
+ } else {
2183
+ assert(wsp_ggml_metal_op_flash_attn_ext_extra_pad(op) == 0);
2184
+ }
2185
+
2186
+ if (has_mask) {
2187
+ assert(wsp_ggml_metal_op_flash_attn_ext_extra_blk(op) != 0);
2188
+
2189
+ wsp_ggml_metal_kargs_flash_attn_ext_blk args0 = {
2190
+ /*.ne01 =*/ ne01,
2191
+ /*.ne30 =*/ ne30,
2192
+ /*.ne31 =*/ ne31,
2193
+ /*.ne32 =*/ ne32,
2194
+ /*.ne33 =*/ ne33,
2195
+ /*.nb31 =*/ nb31,
2196
+ /*.nb32 =*/ nb32,
2197
+ /*.nb33 =*/ nb33,
2198
+ };
2199
+
2200
+ wsp_ggml_metal_pipeline_t pipeline0 = wsp_ggml_metal_library_get_pipeline_flash_attn_ext_blk(lib, op, nqptg, ncpsg);
2201
+
2202
+ wsp_ggml_metal_encoder_set_pipeline(enc, pipeline0);
2203
+ wsp_ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0);
2204
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_src3, 1);
2205
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_blk, 2);
2206
+
2207
+ const int32_t nblk1 = ((ne01 + nqptg - 1)/nqptg);
2208
+ const int32_t nblk0 = ((ne30 + ncpsg - 1)/ncpsg);
2209
+
2210
+ wsp_ggml_metal_encoder_dispatch_threadgroups(enc, nblk0, nblk1, ne32*ne33, 32, 1, 1);
2211
+
2212
+ need_sync = true;
2213
+ } else {
2214
+ assert(wsp_ggml_metal_op_flash_attn_ext_extra_blk(op) == 0);
2215
+ }
2216
+
2217
+ if (need_sync) {
2218
+ wsp_ggml_metal_op_concurrency_reset(ctx);
2219
+ }
2220
+
1961
2221
  const int is_q = wsp_ggml_is_quantized(op->src[1]->type) ? 1 : 0;
1962
2222
 
1963
2223
  // 2*(2*ncpsg)
@@ -2007,6 +2267,7 @@ int wsp_ggml_metal_op_flash_attn_ext(wsp_ggml_metal_op_t ctx, int idx) {
2007
2267
  /*.nb21 =*/ nb21,
2008
2268
  /*.nb22 =*/ nb22,
2009
2269
  /*.nb23 =*/ nb23,
2270
+ /*.ne31 =*/ ne31,
2010
2271
  /*.ne32 =*/ ne32,
2011
2272
  /*.ne33 =*/ ne33,
2012
2273
  /*.nb31 =*/ nb31,
@@ -2023,24 +2284,18 @@ int wsp_ggml_metal_op_flash_attn_ext(wsp_ggml_metal_op_t ctx, int idx) {
2023
2284
  /*.logit_softcap =*/ logit_softcap,
2024
2285
  };
2025
2286
 
2026
- wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_flash_attn_ext(lib, op, has_mask, has_sinks, has_bias, has_scap, nsg);
2287
+ wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_flash_attn_ext(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg);
2027
2288
 
2028
2289
  wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
2029
2290
  wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
2030
- wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[0]), 1);
2031
- wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[1]), 2);
2032
- wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[2]), 3);
2033
- if (op->src[3]) {
2034
- wsp_ggml_metal_encoder_set_buffer(enc, wsp_ggml_metal_get_buffer_id(op->src[3]), 4);
2035
- } else {
2036
- wsp_ggml_metal_encoder_set_buffer(enc, wsp_ggml_metal_get_buffer_id(op->src[0]), 4);
2037
- }
2038
- if (op->src[4]) {
2039
- wsp_ggml_metal_encoder_set_buffer(enc, wsp_ggml_metal_get_buffer_id(op->src[4]), 5);
2040
- } else {
2041
- wsp_ggml_metal_encoder_set_buffer(enc, wsp_ggml_metal_get_buffer_id(op->src[0]), 5);
2042
- }
2043
- wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op), 6);
2291
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
2292
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_src1, 2);
2293
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_src2, 3);
2294
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_src3, 4);
2295
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_src4, 5);
2296
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_pad, 6);
2297
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_blk, 7);
2298
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_dst, 8);
2044
2299
 
2045
2300
  wsp_ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
2046
2301
 
@@ -2048,14 +2303,62 @@ int wsp_ggml_metal_op_flash_attn_ext(wsp_ggml_metal_op_t ctx, int idx) {
2048
2303
  #undef FATTN_SMEM
2049
2304
  } else {
2050
2305
  // half4x4 kernel
2051
- const int64_t nqptg = 1; // queries per threadgroup !! sync with kernel template arguments !!
2052
- const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
2053
- const int64_t nkpsg = 1*ncpsg;
2306
+ const int nqptg = OP_FLASH_ATTN_EXT_VEC_NQPTG; // queries per threadgroup
2307
+ const int ncpsg = OP_FLASH_ATTN_EXT_VEC_NCPSG; // cache values per simdgroup !! sync with kernel template arguments !!
2308
+ const int nkpsg = 1*ncpsg;
2054
2309
 
2055
2310
  WSP_GGML_ASSERT(nqptg <= 32);
2056
2311
  WSP_GGML_ASSERT(nqptg % 1 == 0);
2057
2312
  WSP_GGML_ASSERT(ncpsg % 32 == 0);
2058
2313
 
2314
+ bool need_sync = false;
2315
+
2316
+ const bool has_kvpad = ne11 % ncpsg != 0;
2317
+
2318
+ if (has_kvpad) {
2319
+ assert(wsp_ggml_metal_op_flash_attn_ext_extra_pad(op) != 0);
2320
+
2321
+ wsp_ggml_metal_kargs_flash_attn_ext_pad args0 = {
2322
+ /*.ne11 =*/ne11,
2323
+ /*.ne_12_2 =*/ne12,
2324
+ /*.ne_12_3 =*/ne13,
2325
+ /*.nb11 =*/nb11,
2326
+ /*.nb12 =*/nb12,
2327
+ /*.nb13 =*/nb13,
2328
+ /*.nb21 =*/nb21,
2329
+ /*.nb22 =*/nb22,
2330
+ /*.nb23 =*/nb23,
2331
+ /*.ne31 =*/ne31,
2332
+ /*.ne32 =*/ne32,
2333
+ /*.ne33 =*/ne33,
2334
+ /*.nb31 =*/nb31,
2335
+ /*.nb32 =*/nb32,
2336
+ /*.nb33 =*/nb33,
2337
+ };
2338
+
2339
+ wsp_ggml_metal_pipeline_t pipeline0 = wsp_ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg);
2340
+
2341
+ wsp_ggml_metal_encoder_set_pipeline(enc, pipeline0);
2342
+ wsp_ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0);
2343
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_src1, 1);
2344
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_src2, 2);
2345
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_src3, 3);
2346
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_pad, 4);
2347
+
2348
+ assert(ne12 == ne22);
2349
+ assert(ne13 == ne23);
2350
+
2351
+ wsp_ggml_metal_encoder_dispatch_threadgroups(enc, ncpsg, std::max(ne12, ne32), std::max(ne13, ne33), 32, 1, 1);
2352
+
2353
+ need_sync = true;
2354
+ } else {
2355
+ assert(wsp_ggml_metal_op_flash_attn_ext_extra_pad(op) == 0);
2356
+ }
2357
+
2358
+ if (need_sync) {
2359
+ wsp_ggml_metal_op_concurrency_reset(ctx);
2360
+ }
2361
+
2059
2362
  // ne00 + 2*ncpsg*(nsg)
2060
2363
  // for each query, we load it as f16 in shared memory (ne00)
2061
2364
  // and store the soft_max values and the mask
@@ -2120,6 +2423,7 @@ int wsp_ggml_metal_op_flash_attn_ext(wsp_ggml_metal_op_t ctx, int idx) {
2120
2423
  /*.nb21 =*/ nb21,
2121
2424
  /*.nb22 =*/ nb22,
2122
2425
  /*.nb23 =*/ nb23,
2426
+ /*.ne31 =*/ ne31,
2123
2427
  /*.ne32 =*/ ne32,
2124
2428
  /*.ne33 =*/ ne33,
2125
2429
  /*.nb31 =*/ nb31,
@@ -2136,25 +2440,17 @@ int wsp_ggml_metal_op_flash_attn_ext(wsp_ggml_metal_op_t ctx, int idx) {
2136
2440
  /*.logit_softcap =*/ logit_softcap,
2137
2441
  };
2138
2442
 
2139
- wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_flash_attn_ext_vec(lib, op, has_mask, has_sinks, has_bias, has_scap, nsg, nwg);
2443
+ wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_flash_attn_ext_vec(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg, nwg);
2140
2444
 
2141
2445
  WSP_GGML_ASSERT(nsg*32 <= wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
2142
2446
 
2143
2447
  wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
2144
2448
  wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
2145
- wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[0]), 1);
2146
- wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[1]), 2);
2147
- wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[2]), 3);
2148
- if (op->src[3]) {
2149
- wsp_ggml_metal_encoder_set_buffer(enc, wsp_ggml_metal_get_buffer_id(op->src[3]), 4);
2150
- } else {
2151
- wsp_ggml_metal_encoder_set_buffer(enc, wsp_ggml_metal_get_buffer_id(op->src[0]), 4);
2152
- }
2153
- if (op->src[4]) {
2154
- wsp_ggml_metal_encoder_set_buffer(enc, wsp_ggml_metal_get_buffer_id(op->src[4]), 5);
2155
- } else {
2156
- wsp_ggml_metal_encoder_set_buffer(enc, wsp_ggml_metal_get_buffer_id(op->src[0]), 5);
2157
- }
2449
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
2450
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_src1, 2);
2451
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_src2, 3);
2452
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_src3, 4);
2453
+ wsp_ggml_metal_encoder_set_buffer (enc, bid_src4, 5);
2158
2454
 
2159
2455
  const size_t smem = FATTN_SMEM(nsg);
2160
2456
 
@@ -2162,23 +2458,25 @@ int wsp_ggml_metal_op_flash_attn_ext(wsp_ggml_metal_op_t ctx, int idx) {
2162
2458
  WSP_GGML_ASSERT(smem <= props_dev->max_theadgroup_memory_size);
2163
2459
 
2164
2460
  if (nwg == 1) {
2461
+ assert(wsp_ggml_metal_op_flash_attn_ext_extra_tmp(op) == 0);
2462
+
2165
2463
  // using 1 workgroup -> write the result directly into dst
2166
- wsp_ggml_metal_encoder_set_buffer(enc, wsp_ggml_metal_get_buffer_id(op), 6);
2464
+ wsp_ggml_metal_encoder_set_buffer(enc, bid_pad, 6);
2465
+ wsp_ggml_metal_encoder_set_buffer(enc, bid_dst, 7);
2167
2466
 
2168
2467
  wsp_ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
2169
2468
 
2170
2469
  wsp_ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg, 32, nsg, 1);
2171
2470
  } else {
2172
2471
  // sanity checks
2472
+ assert(wsp_ggml_metal_op_flash_attn_ext_extra_tmp(op) != 0);
2473
+
2173
2474
  WSP_GGML_ASSERT(ne01*ne02*ne03 == ne1*ne2*ne3);
2174
2475
  WSP_GGML_ASSERT((uint64_t)ne1*ne2*ne3 <= (1u << 31));
2175
2476
 
2176
- wsp_ggml_metal_buffer_id bid_dst = wsp_ggml_metal_get_buffer_id(op);
2177
-
2178
2477
  // write the results from each workgroup into a temp buffer
2179
- wsp_ggml_metal_buffer_id bid_tmp = bid_dst;
2180
- bid_tmp.offs += wsp_ggml_nbytes(op);
2181
- wsp_ggml_metal_encoder_set_buffer(enc, bid_tmp, 6);
2478
+ wsp_ggml_metal_encoder_set_buffer(enc, bid_pad, 6);
2479
+ wsp_ggml_metal_encoder_set_buffer(enc, bid_tmp, 7);
2182
2480
 
2183
2481
  wsp_ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
2184
2482
  wsp_ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg, 32, nsg, 1);
@@ -2688,6 +2986,7 @@ int wsp_ggml_metal_op_rope(wsp_ggml_metal_op_t ctx, int idx) {
2688
2986
  /* sect_1 =*/ sect_1,
2689
2987
  /* sect_2 =*/ sect_2,
2690
2988
  /* sect_3 =*/ sect_3,
2989
+ /* src2 =*/ op->src[2] != nullptr,
2691
2990
  };
2692
2991
 
2693
2992
  wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_rope(lib, op);
@@ -2823,6 +3122,62 @@ int wsp_ggml_metal_op_conv_transpose_1d(wsp_ggml_metal_op_t ctx, int idx) {
2823
3122
  return 1;
2824
3123
  }
2825
3124
 
3125
+ int wsp_ggml_metal_op_conv_transpose_2d(wsp_ggml_metal_op_t ctx, int idx) {
3126
+ wsp_ggml_tensor * op = ctx->node(idx);
3127
+
3128
+ wsp_ggml_metal_library_t lib = ctx->lib;
3129
+ wsp_ggml_metal_encoder_t enc = ctx->enc;
3130
+
3131
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3132
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3133
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
3134
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
3135
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3136
+ WSP_GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
3137
+
3138
+ const int32_t s0 = ((const int32_t *)(op->op_params))[0];
3139
+
3140
+ const int32_t IC = op->src[1]->ne[2];
3141
+ const int32_t IH = op->src[1]->ne[1];
3142
+ const int32_t IW = op->src[1]->ne[0];
3143
+
3144
+ const int32_t KH = op->src[0]->ne[1];
3145
+ const int32_t KW = op->src[0]->ne[0];
3146
+
3147
+ const int32_t OW = op->ne[0];
3148
+ const int32_t OH = op->ne[1];
3149
+ const int32_t OC = op->ne[2];
3150
+
3151
+ wsp_ggml_metal_kargs_conv_transpose_2d args = {
3152
+ /*.IC =*/ IC,
3153
+ /*.IH =*/ IH,
3154
+ /*.IW =*/ IW,
3155
+ /*.KH =*/ KH,
3156
+ /*.KW =*/ KW,
3157
+ /*.OC =*/ OC,
3158
+ /*.s0 =*/ s0,
3159
+ /*.nb0 =*/ nb0,
3160
+ /*.nb1 =*/ nb1,
3161
+ /*.nb2 =*/ nb2,
3162
+ };
3163
+
3164
+ wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_conv_transpose_2d(lib, op);
3165
+
3166
+ wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
3167
+ wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
3168
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[0]), 1);
3169
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[1]), 2);
3170
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op), 3);
3171
+
3172
+ // Metal requires buffer size to be multiple of 16 bytes
3173
+ const size_t smem = WSP_GGML_PAD(KW * KH * sizeof(float), 16);
3174
+ wsp_ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
3175
+
3176
+ wsp_ggml_metal_encoder_dispatch_threadgroups(enc, OW, OH, OC, KW, KH, 1);
3177
+
3178
+ return 1;
3179
+ }
3180
+
2826
3181
  int wsp_ggml_metal_op_upscale(wsp_ggml_metal_op_t ctx, int idx) {
2827
3182
  wsp_ggml_tensor * op = ctx->node(idx);
2828
3183
 
@@ -3156,3 +3511,73 @@ int wsp_ggml_metal_op_leaky_relu(wsp_ggml_metal_op_t ctx, int idx) {
3156
3511
 
3157
3512
  return 1;
3158
3513
  }
3514
+
3515
+ int wsp_ggml_metal_op_opt_step_adamw(wsp_ggml_metal_op_t ctx, int idx) {
3516
+ wsp_ggml_tensor * op = ctx->node(idx);
3517
+
3518
+ wsp_ggml_metal_library_t lib = ctx->lib;
3519
+ wsp_ggml_metal_encoder_t enc = ctx->enc;
3520
+
3521
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3522
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3523
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3524
+ WSP_GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
3525
+
3526
+ wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_opt_step_adamw(lib, op);
3527
+
3528
+ const int64_t np = wsp_ggml_nelements(op->src[0]);
3529
+ wsp_ggml_metal_kargs_opt_step_adamw args = {
3530
+ /*.np =*/ np,
3531
+ };
3532
+
3533
+ int ida = 0;
3534
+
3535
+ wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
3536
+ wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), ida++);
3537
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[0]), ida++);
3538
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[1]), ida++);
3539
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[2]), ida++);
3540
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[3]), ida++);
3541
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[4]), ida++);
3542
+
3543
+ const int nth = std::min(wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0);
3544
+ const int64_t n = (np + nth - 1) / nth;
3545
+
3546
+ wsp_ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, nth, 1, 1);
3547
+
3548
+ return 1;
3549
+ }
3550
+
3551
+ int wsp_ggml_metal_op_opt_step_sgd(wsp_ggml_metal_op_t ctx, int idx) {
3552
+ wsp_ggml_tensor * op = ctx->node(idx);
3553
+
3554
+ wsp_ggml_metal_library_t lib = ctx->lib;
3555
+ wsp_ggml_metal_encoder_t enc = ctx->enc;
3556
+
3557
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3558
+ WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3559
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3560
+ WSP_GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
3561
+
3562
+ wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_opt_step_sgd(lib, op);
3563
+
3564
+ const int64_t np = wsp_ggml_nelements(op->src[0]);
3565
+ wsp_ggml_metal_kargs_opt_step_sgd args = {
3566
+ /*.np =*/ np,
3567
+ };
3568
+
3569
+ int ida = 0;
3570
+
3571
+ wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
3572
+ wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), ida++);
3573
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[0]), ida++);
3574
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[1]), ida++);
3575
+ wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[2]), ida++);
3576
+
3577
+ const int nth = std::min(wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0);
3578
+ const int64_t n = (np + nth - 1) / nth;
3579
+
3580
+ wsp_ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, nth, 1, 1);
3581
+
3582
+ return 1;
3583
+ }