whisper.rn 0.5.1 → 0.5.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/src/main/jni.cpp +12 -3
- package/cpp/ggml-alloc.c +38 -14
- package/cpp/ggml-backend-impl.h +0 -3
- package/cpp/ggml-backend.h +2 -0
- package/cpp/ggml-cpu/amx/amx.cpp +1 -0
- package/cpp/ggml-cpu/ggml-cpu-impl.h +1 -1
- package/cpp/ggml-cpu/ggml-cpu.c +17 -3
- package/cpp/ggml-cpu/ops.cpp +33 -17
- package/cpp/ggml-cpu/unary-ops.cpp +135 -0
- package/cpp/ggml-cpu/unary-ops.h +5 -0
- package/cpp/ggml-cpu/vec.cpp +66 -0
- package/cpp/ggml-cpu/vec.h +10 -8
- package/cpp/ggml-impl.h +51 -2
- package/cpp/ggml-metal/ggml-metal-common.cpp +2 -2
- package/cpp/ggml-metal/ggml-metal-device.cpp +199 -10
- package/cpp/ggml-metal/ggml-metal-device.h +18 -0
- package/cpp/ggml-metal/ggml-metal-device.m +27 -14
- package/cpp/ggml-metal/ggml-metal-impl.h +87 -7
- package/cpp/ggml-metal/ggml-metal-ops.cpp +513 -88
- package/cpp/ggml-metal/ggml-metal-ops.h +6 -0
- package/cpp/ggml-metal/ggml-metal.cpp +3 -3
- package/cpp/ggml-metal/ggml-whisper-sim.metallib +0 -0
- package/cpp/ggml-metal/ggml-whisper.metallib +0 -0
- package/cpp/ggml.c +166 -2
- package/cpp/ggml.h +66 -0
- package/cpp/jsi/RNWhisperJSI.cpp +7 -2
- package/cpp/rn-whisper.h +1 -0
- package/cpp/whisper.cpp +4 -2
- package/ios/RNWhisperContext.mm +3 -1
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +0 -3
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend.h +2 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-impl.h +51 -2
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml.h +66 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/rn-whisper.h +1 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/rnwhisper +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +0 -3
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +2 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +51 -2
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +66 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-whisper.h +1 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +0 -3
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend.h +2 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-impl.h +51 -2
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml.h +66 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/rn-whisper.h +1 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/rnwhisper +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +0 -3
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +2 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +51 -2
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +66 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-whisper.h +1 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
- package/lib/commonjs/NativeRNWhisper.js.map +1 -1
- package/lib/commonjs/version.json +1 -1
- package/lib/module/NativeRNWhisper.js.map +1 -1
- package/lib/module/version.json +1 -1
- package/lib/typescript/NativeRNWhisper.d.ts +2 -0
- package/lib/typescript/NativeRNWhisper.d.ts.map +1 -1
- package/package.json +1 -1
- package/src/NativeRNWhisper.ts +2 -0
- package/src/version.json +1 -1
|
@@ -226,6 +226,10 @@ static int wsp_ggml_metal_op_encode_impl(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
226
226
|
WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, node->src[0], nb);
|
|
227
227
|
WSP_GGML_TENSOR_LOCALS( int64_t, ne1, node->src[1], ne);
|
|
228
228
|
WSP_GGML_TENSOR_LOCALS(uint64_t, nb1, node->src[1], nb);
|
|
229
|
+
WSP_GGML_TENSOR_LOCALS( int64_t, ne2, node->src[2], ne);
|
|
230
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb2, node->src[2], nb);
|
|
231
|
+
WSP_GGML_TENSOR_LOCALS( int64_t, ne3, node->src[3], ne);
|
|
232
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb3, node->src[3], nb);
|
|
229
233
|
WSP_GGML_TENSOR_LOCALS( int64_t, ne, node, ne);
|
|
230
234
|
WSP_GGML_TENSOR_LOCALS(uint64_t, nb, node, nb);
|
|
231
235
|
|
|
@@ -237,6 +241,14 @@ static int wsp_ggml_metal_op_encode_impl(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
237
241
|
WSP_GGML_LOG_DEBUG("%s: src1 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, wsp_ggml_type_name(node->src[1]->type), ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13,
|
|
238
242
|
wsp_ggml_is_contiguous(node->src[1]), node->src[1]->name);
|
|
239
243
|
}
|
|
244
|
+
if (node->src[2]) {
|
|
245
|
+
WSP_GGML_LOG_DEBUG("%s: src2 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, wsp_ggml_type_name(node->src[2]->type), ne20, ne21, ne22, ne23, nb20, nb21, nb22, nb23,
|
|
246
|
+
wsp_ggml_is_contiguous(node->src[2]), node->src[2]->name);
|
|
247
|
+
}
|
|
248
|
+
if (node->src[3]) {
|
|
249
|
+
WSP_GGML_LOG_DEBUG("%s: src3 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, wsp_ggml_type_name(node->src[3]->type), ne30, ne31, ne32, ne33, nb30, nb31, nb32, nb33,
|
|
250
|
+
wsp_ggml_is_contiguous(node->src[3]), node->src[3]->name);
|
|
251
|
+
}
|
|
240
252
|
if (node) {
|
|
241
253
|
WSP_GGML_LOG_DEBUG("%s: node - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], 1, %s\n", __func__, wsp_ggml_type_name(node->type), ne0, ne1, ne2, ne3, nb0, nb1, nb2, nb3,
|
|
242
254
|
node->name);
|
|
@@ -289,6 +301,10 @@ static int wsp_ggml_metal_op_encode_impl(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
289
301
|
{
|
|
290
302
|
n_fuse = wsp_ggml_metal_op_glu(ctx, idx);
|
|
291
303
|
} break;
|
|
304
|
+
case WSP_GGML_OP_SUM:
|
|
305
|
+
{
|
|
306
|
+
n_fuse = wsp_ggml_metal_op_sum(ctx, idx);
|
|
307
|
+
} break;
|
|
292
308
|
case WSP_GGML_OP_SUM_ROWS:
|
|
293
309
|
case WSP_GGML_OP_MEAN:
|
|
294
310
|
{
|
|
@@ -352,6 +368,10 @@ static int wsp_ggml_metal_op_encode_impl(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
352
368
|
{
|
|
353
369
|
n_fuse = wsp_ggml_metal_op_conv_transpose_1d(ctx, idx);
|
|
354
370
|
} break;
|
|
371
|
+
case WSP_GGML_OP_CONV_TRANSPOSE_2D:
|
|
372
|
+
{
|
|
373
|
+
n_fuse = wsp_ggml_metal_op_conv_transpose_2d(ctx, idx);
|
|
374
|
+
} break;
|
|
355
375
|
case WSP_GGML_OP_UPSCALE:
|
|
356
376
|
{
|
|
357
377
|
n_fuse = wsp_ggml_metal_op_upscale(ctx, idx);
|
|
@@ -398,6 +418,14 @@ static int wsp_ggml_metal_op_encode_impl(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
398
418
|
{
|
|
399
419
|
n_fuse = wsp_ggml_metal_op_argmax(ctx, idx);
|
|
400
420
|
} break;
|
|
421
|
+
case WSP_GGML_OP_OPT_STEP_ADAMW:
|
|
422
|
+
{
|
|
423
|
+
n_fuse = wsp_ggml_metal_op_opt_step_adamw(ctx, idx);
|
|
424
|
+
} break;
|
|
425
|
+
case WSP_GGML_OP_OPT_STEP_SGD:
|
|
426
|
+
{
|
|
427
|
+
n_fuse = wsp_ggml_metal_op_opt_step_sgd(ctx, idx);
|
|
428
|
+
} break;
|
|
401
429
|
default:
|
|
402
430
|
{
|
|
403
431
|
WSP_GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, wsp_ggml_op_name(node->op));
|
|
@@ -577,6 +605,7 @@ int wsp_ggml_metal_op_acc(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
577
605
|
wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type);
|
|
578
606
|
|
|
579
607
|
wsp_ggml_metal_kargs_cpy args = {
|
|
608
|
+
/*.nk0 =*/ ne00,
|
|
580
609
|
/*.ne00 =*/ ne00,
|
|
581
610
|
/*.ne01 =*/ ne01,
|
|
582
611
|
/*.ne02 =*/ ne02,
|
|
@@ -827,6 +856,43 @@ int wsp_ggml_metal_op_glu(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
827
856
|
return 1;
|
|
828
857
|
}
|
|
829
858
|
|
|
859
|
+
int wsp_ggml_metal_op_sum(wsp_ggml_metal_op_t ctx, int idx) {
|
|
860
|
+
wsp_ggml_tensor * op = ctx->node(idx);
|
|
861
|
+
|
|
862
|
+
wsp_ggml_metal_library_t lib = ctx->lib;
|
|
863
|
+
wsp_ggml_metal_encoder_t enc = ctx->enc;
|
|
864
|
+
|
|
865
|
+
const uint64_t n = (uint64_t) wsp_ggml_nelements(op->src[0]);
|
|
866
|
+
|
|
867
|
+
wsp_ggml_metal_kargs_sum args = {
|
|
868
|
+
/*.np =*/ n,
|
|
869
|
+
};
|
|
870
|
+
|
|
871
|
+
wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_sum(lib, op);
|
|
872
|
+
|
|
873
|
+
int nth = 32; // SIMD width
|
|
874
|
+
|
|
875
|
+
while (nth < (int) n && nth < wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
|
|
876
|
+
nth *= 2;
|
|
877
|
+
}
|
|
878
|
+
|
|
879
|
+
nth = std::min(nth, wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
|
880
|
+
nth = std::min(nth, (int) n);
|
|
881
|
+
|
|
882
|
+
const int nsg = (nth + 31) / 32;
|
|
883
|
+
|
|
884
|
+
wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
885
|
+
wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
886
|
+
wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[0]), 1);
|
|
887
|
+
wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op), 2);
|
|
888
|
+
|
|
889
|
+
wsp_ggml_metal_encoder_set_threadgroup_memory_size(enc, nsg * sizeof(float), 0);
|
|
890
|
+
|
|
891
|
+
wsp_ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, nth, 1, 1);
|
|
892
|
+
|
|
893
|
+
return 1;
|
|
894
|
+
}
|
|
895
|
+
|
|
830
896
|
int wsp_ggml_metal_op_sum_rows(wsp_ggml_metal_op_t ctx, int idx) {
|
|
831
897
|
wsp_ggml_tensor * op = ctx->node(idx);
|
|
832
898
|
|
|
@@ -906,23 +972,31 @@ int wsp_ggml_metal_op_get_rows(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
906
972
|
wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_get_rows(lib, op->src[0]->type);
|
|
907
973
|
|
|
908
974
|
wsp_ggml_metal_kargs_get_rows args = {
|
|
909
|
-
/*.
|
|
910
|
-
/*.
|
|
911
|
-
/*.
|
|
912
|
-
/*.
|
|
913
|
-
/*.
|
|
914
|
-
/*.
|
|
915
|
-
/*.
|
|
916
|
-
/*.
|
|
975
|
+
/*.ne00t =*/ wsp_ggml_is_quantized(op->src[0]->type) ? ne00/16 : ne00,
|
|
976
|
+
/*.ne00 =*/ ne00,
|
|
977
|
+
/*.nb01 =*/ nb01,
|
|
978
|
+
/*.nb02 =*/ nb02,
|
|
979
|
+
/*.nb03 =*/ nb03,
|
|
980
|
+
/*.ne10 =*/ ne10,
|
|
981
|
+
/*.nb10 =*/ nb10,
|
|
982
|
+
/*.nb11 =*/ nb11,
|
|
983
|
+
/*.nb12 =*/ nb12,
|
|
984
|
+
/*.nb1 =*/ nb1,
|
|
985
|
+
/*.nb2 =*/ nb2,
|
|
986
|
+
/*.nb3 =*/ nb3,
|
|
917
987
|
};
|
|
918
988
|
|
|
989
|
+
const int nth = std::min(args.ne00t, wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
|
990
|
+
|
|
991
|
+
const int nw0 = (args.ne00t + nth - 1)/nth;
|
|
992
|
+
|
|
919
993
|
wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
920
994
|
wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
921
995
|
wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[0]), 1);
|
|
922
996
|
wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[1]), 2);
|
|
923
997
|
wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op), 3);
|
|
924
998
|
|
|
925
|
-
wsp_ggml_metal_encoder_dispatch_threadgroups(enc, ne10, ne11, ne12,
|
|
999
|
+
wsp_ggml_metal_encoder_dispatch_threadgroups(enc, nw0*ne10, ne11, ne12, nth, 1, 1);
|
|
926
1000
|
|
|
927
1001
|
return 1;
|
|
928
1002
|
}
|
|
@@ -1117,7 +1191,7 @@ int wsp_ggml_metal_op_ssm_conv(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
1117
1191
|
wsp_ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
|
|
1118
1192
|
wsp_ggml_metal_encoder_set_buffer(enc, wsp_ggml_metal_get_buffer_id(op->src[0]), 1);
|
|
1119
1193
|
wsp_ggml_metal_encoder_set_buffer(enc, wsp_ggml_metal_get_buffer_id(op->src[1]), 2);
|
|
1120
|
-
wsp_ggml_metal_encoder_set_buffer(enc, wsp_ggml_metal_get_buffer_id(op),
|
|
1194
|
+
wsp_ggml_metal_encoder_set_buffer(enc, wsp_ggml_metal_get_buffer_id(op), 3);
|
|
1121
1195
|
|
|
1122
1196
|
wsp_ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne1, ne02, 1, 1, 1);
|
|
1123
1197
|
|
|
@@ -1172,25 +1246,36 @@ int wsp_ggml_metal_op_ssm_scan(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
1172
1246
|
/*.n_seq_tokens =*/ n_seq_tokens,
|
|
1173
1247
|
/*.n_seqs =*/ n_seqs,
|
|
1174
1248
|
/*.s_off =*/ wsp_ggml_nelements(op->src[1]) * sizeof(float),
|
|
1249
|
+
/*.nb00 =*/ nb00,
|
|
1175
1250
|
/*.nb01 =*/ nb01,
|
|
1176
1251
|
/*.nb02 =*/ nb02,
|
|
1177
1252
|
/*.nb03 =*/ nb03,
|
|
1253
|
+
/*.nb10 =*/ nb10,
|
|
1178
1254
|
/*.nb11 =*/ nb11,
|
|
1179
1255
|
/*.nb12 =*/ nb12,
|
|
1256
|
+
/*.ns12 =*/ nb12/nb10,
|
|
1180
1257
|
/*.nb13 =*/ nb13,
|
|
1258
|
+
/*.nb20 =*/ nb20,
|
|
1181
1259
|
/*.nb21 =*/ nb21,
|
|
1260
|
+
/*.ns21 =*/ nb21/nb20,
|
|
1182
1261
|
/*.nb22 =*/ nb22,
|
|
1262
|
+
/*.ne30 =*/ ne30,
|
|
1183
1263
|
/*.nb31 =*/ nb31,
|
|
1184
1264
|
/*.nb41 =*/ nb41,
|
|
1185
1265
|
/*.nb42 =*/ nb42,
|
|
1266
|
+
/*.ns42 =*/ nb42/nb40,
|
|
1186
1267
|
/*.nb43 =*/ nb43,
|
|
1187
1268
|
/*.nb51 =*/ nb51,
|
|
1188
1269
|
/*.nb52 =*/ nb52,
|
|
1270
|
+
/*.ns52 =*/ nb52/nb50,
|
|
1189
1271
|
/*.nb53 =*/ nb53,
|
|
1272
|
+
/*.nb0 =*/ nb0,
|
|
1190
1273
|
};
|
|
1191
1274
|
|
|
1192
1275
|
wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_ssm_scan(lib, op);
|
|
1193
1276
|
|
|
1277
|
+
WSP_GGML_ASSERT(d_state <= wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
|
1278
|
+
|
|
1194
1279
|
const size_t sms = wsp_ggml_metal_pipeline_get_smem(pipeline);
|
|
1195
1280
|
|
|
1196
1281
|
wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
@@ -1206,13 +1291,7 @@ int wsp_ggml_metal_op_ssm_scan(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
1206
1291
|
|
|
1207
1292
|
wsp_ggml_metal_encoder_set_threadgroup_memory_size(enc, sms, 0);
|
|
1208
1293
|
|
|
1209
|
-
|
|
1210
|
-
// Mamba-2
|
|
1211
|
-
wsp_ggml_metal_encoder_dispatch_threadgroups(enc, d_inner, n_head, n_seqs, d_state, 1, 1);
|
|
1212
|
-
} else {
|
|
1213
|
-
WSP_GGML_ASSERT(d_inner == 1);
|
|
1214
|
-
wsp_ggml_metal_encoder_dispatch_threadgroups(enc, n_head, n_seqs, 1, d_state, 1, 1);
|
|
1215
|
-
}
|
|
1294
|
+
wsp_ggml_metal_encoder_dispatch_threadgroups(enc, d_inner, n_head, n_seqs, d_state, 1, 1);
|
|
1216
1295
|
|
|
1217
1296
|
return 1;
|
|
1218
1297
|
}
|
|
@@ -1273,26 +1352,23 @@ int wsp_ggml_metal_op_cpy(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
1273
1352
|
|
|
1274
1353
|
WSP_GGML_ASSERT(ne00 % wsp_ggml_blck_size(op->src[0]->type) == 0);
|
|
1275
1354
|
|
|
1276
|
-
|
|
1277
|
-
|
|
1278
|
-
|
|
1279
|
-
|
|
1280
|
-
|
|
1281
|
-
|
|
1282
|
-
while (nth < nk00 && nth < wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
|
|
1283
|
-
nth *= 2;
|
|
1355
|
+
int64_t nk0 = ne00;
|
|
1356
|
+
if (wsp_ggml_is_quantized(op->src[0]->type)) {
|
|
1357
|
+
nk0 = ne00/16;
|
|
1358
|
+
} else if (wsp_ggml_is_quantized(op->type)) {
|
|
1359
|
+
nk0 = ne00/wsp_ggml_blck_size(op->type);
|
|
1284
1360
|
}
|
|
1285
1361
|
|
|
1286
|
-
nth = std::min(
|
|
1362
|
+
int nth = std::min<int>(nk0, wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
|
1287
1363
|
|
|
1288
1364
|
// when rows are small, we can batch them together in a single threadgroup
|
|
1289
1365
|
int nrptg = 1;
|
|
1290
1366
|
|
|
1291
1367
|
// TODO: relax this constraint in the future
|
|
1292
1368
|
if (wsp_ggml_blck_size(op->src[0]->type) == 1 && wsp_ggml_blck_size(op->type) == 1) {
|
|
1293
|
-
if (nth >
|
|
1294
|
-
nrptg = (nth +
|
|
1295
|
-
nth =
|
|
1369
|
+
if (nth > nk0) {
|
|
1370
|
+
nrptg = (nth + nk0 - 1)/nk0;
|
|
1371
|
+
nth = nk0;
|
|
1296
1372
|
|
|
1297
1373
|
if (nrptg*nth > wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
|
|
1298
1374
|
nrptg--;
|
|
@@ -1300,10 +1376,11 @@ int wsp_ggml_metal_op_cpy(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
1300
1376
|
}
|
|
1301
1377
|
}
|
|
1302
1378
|
|
|
1303
|
-
nth = std::min(nth,
|
|
1379
|
+
nth = std::min<int>(nth, nk0);
|
|
1304
1380
|
|
|
1305
1381
|
wsp_ggml_metal_kargs_cpy args = {
|
|
1306
|
-
/*.
|
|
1382
|
+
/*.nk0 =*/ nk0,
|
|
1383
|
+
/*.ne00 =*/ ne00,
|
|
1307
1384
|
/*.ne01 =*/ ne01,
|
|
1308
1385
|
/*.ne02 =*/ ne02,
|
|
1309
1386
|
/*.ne03 =*/ ne03,
|
|
@@ -1321,12 +1398,14 @@ int wsp_ggml_metal_op_cpy(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
1321
1398
|
/*.nb3 =*/ nb3,
|
|
1322
1399
|
};
|
|
1323
1400
|
|
|
1401
|
+
const int nw0 = nrptg == 1 ? (nk0 + nth - 1)/nth : 1;
|
|
1402
|
+
|
|
1324
1403
|
wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
1325
1404
|
wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
1326
1405
|
wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[0]), 1);
|
|
1327
1406
|
wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op), 2);
|
|
1328
1407
|
|
|
1329
|
-
wsp_ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, nrptg, 1);
|
|
1408
|
+
wsp_ggml_metal_encoder_dispatch_threadgroups(enc, nw0*(ne01 + nrptg - 1)/nrptg, ne02, ne03, nth, nrptg, 1);
|
|
1330
1409
|
|
|
1331
1410
|
return 1;
|
|
1332
1411
|
}
|
|
@@ -1520,9 +1599,8 @@ int wsp_ggml_metal_op_mul_mat(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
1520
1599
|
!wsp_ggml_is_transposed(op->src[1]) &&
|
|
1521
1600
|
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
|
|
1522
1601
|
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
|
|
1523
|
-
props_dev->has_simdgroup_mm && ne00 >= 64 &&
|
|
1524
|
-
(
|
|
1525
|
-
//printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
|
|
1602
|
+
props_dev->has_simdgroup_mm && ne00 >= 64 && ne11 > ne11_mm_min) {
|
|
1603
|
+
//WSP_GGML_LOG_INFO("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
|
|
1526
1604
|
|
|
1527
1605
|
// some Metal matrix data types require aligned pointers
|
|
1528
1606
|
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
|
|
@@ -1875,20 +1953,107 @@ bool wsp_ggml_metal_op_flash_attn_ext_use_vec(const wsp_ggml_tensor * op) {
|
|
|
1875
1953
|
return (ne01 < 20) && (ne00 % 32 == 0);
|
|
1876
1954
|
}
|
|
1877
1955
|
|
|
1956
|
+
size_t wsp_ggml_metal_op_flash_attn_ext_extra_pad(const wsp_ggml_tensor * op) {
|
|
1957
|
+
assert(op->op == WSP_GGML_OP_FLASH_ATTN_EXT);
|
|
1958
|
+
|
|
1959
|
+
WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
1960
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
1961
|
+
WSP_GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
|
1962
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
|
1963
|
+
WSP_GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
|
|
1964
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
|
|
1965
|
+
WSP_GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne);
|
|
1966
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb);
|
|
1967
|
+
|
|
1968
|
+
size_t res = 0;
|
|
1969
|
+
|
|
1970
|
+
const bool has_mask = op->src[3] != nullptr;
|
|
1971
|
+
|
|
1972
|
+
if (wsp_ggml_metal_op_flash_attn_ext_use_vec(op)) {
|
|
1973
|
+
const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_VEC_NCPSG != 0;
|
|
1974
|
+
|
|
1975
|
+
if (has_kvpad) {
|
|
1976
|
+
res += OP_FLASH_ATTN_EXT_VEC_NCPSG*(
|
|
1977
|
+
nb11*ne12*ne13 +
|
|
1978
|
+
nb21*ne22*ne23 +
|
|
1979
|
+
(has_mask ? wsp_ggml_type_size(WSP_GGML_TYPE_F16)*ne31*ne32*ne33 : 0));
|
|
1980
|
+
}
|
|
1981
|
+
} else {
|
|
1982
|
+
const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_NCPSG != 0;
|
|
1983
|
+
|
|
1984
|
+
if (has_kvpad) {
|
|
1985
|
+
res += OP_FLASH_ATTN_EXT_NCPSG*(
|
|
1986
|
+
nb11*ne12*ne13 +
|
|
1987
|
+
nb21*ne22*ne23 +
|
|
1988
|
+
(has_mask ? wsp_ggml_type_size(WSP_GGML_TYPE_F16)*ne31*ne32*ne33 : 0));
|
|
1989
|
+
}
|
|
1990
|
+
}
|
|
1991
|
+
|
|
1992
|
+
return res;
|
|
1993
|
+
}
|
|
1994
|
+
|
|
1995
|
+
size_t wsp_ggml_metal_op_flash_attn_ext_extra_blk(const wsp_ggml_tensor * op) {
|
|
1996
|
+
assert(op->op == WSP_GGML_OP_FLASH_ATTN_EXT);
|
|
1997
|
+
|
|
1998
|
+
WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
1999
|
+
//WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
2000
|
+
//WSP_GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
|
2001
|
+
//WSP_GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
|
2002
|
+
//WSP_GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
|
|
2003
|
+
//WSP_GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
|
|
2004
|
+
WSP_GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne);
|
|
2005
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb);
|
|
2006
|
+
|
|
2007
|
+
size_t res = 0;
|
|
2008
|
+
|
|
2009
|
+
const bool has_mask = op->src[3] != nullptr;
|
|
2010
|
+
|
|
2011
|
+
if (!has_mask) {
|
|
2012
|
+
return res;
|
|
2013
|
+
}
|
|
2014
|
+
|
|
2015
|
+
const bool is_vec = wsp_ggml_metal_op_flash_attn_ext_use_vec(op);
|
|
2016
|
+
|
|
2017
|
+
// this optimization is not useful for the vector kernels
|
|
2018
|
+
if (is_vec) {
|
|
2019
|
+
return res;
|
|
2020
|
+
}
|
|
2021
|
+
|
|
2022
|
+
const int nqptg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NQPTG : OP_FLASH_ATTN_EXT_NQPTG;
|
|
2023
|
+
const int ncpsg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NCPSG : OP_FLASH_ATTN_EXT_NCPSG;
|
|
2024
|
+
|
|
2025
|
+
const int64_t ne1 = (ne01 + nqptg - 1)/nqptg;
|
|
2026
|
+
const int64_t ne0 = (ne30 + ncpsg - 1)/ncpsg;
|
|
2027
|
+
|
|
2028
|
+
res += WSP_GGML_PAD(wsp_ggml_type_size(WSP_GGML_TYPE_I8)*ne0*ne1*ne32*ne33, 32);
|
|
2029
|
+
|
|
2030
|
+
return res;
|
|
2031
|
+
}
|
|
2032
|
+
|
|
1878
2033
|
size_t wsp_ggml_metal_op_flash_attn_ext_extra_tmp(const wsp_ggml_tensor * op) {
|
|
1879
2034
|
assert(op->op == WSP_GGML_OP_FLASH_ATTN_EXT);
|
|
1880
2035
|
|
|
1881
|
-
|
|
2036
|
+
WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
2037
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
2038
|
+
//WSP_GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
|
2039
|
+
//WSP_GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
|
2040
|
+
WSP_GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
|
|
2041
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
|
|
2042
|
+
//WSP_GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne);
|
|
2043
|
+
//WSP_GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb);
|
|
2044
|
+
|
|
2045
|
+
size_t res = 0;
|
|
1882
2046
|
|
|
1883
|
-
|
|
1884
|
-
|
|
1885
|
-
const int64_t ne03 = op->src[0]->ne[3];
|
|
1886
|
-
const int64_t ne20 = op->src[2]->ne[0];
|
|
2047
|
+
if (wsp_ggml_metal_op_flash_attn_ext_use_vec(op)) {
|
|
2048
|
+
const int64_t nwg = 32;
|
|
1887
2049
|
|
|
1888
|
-
|
|
1889
|
-
|
|
1890
|
-
|
|
1891
|
-
|
|
2050
|
+
// temp buffer for writing the results from each workgroup
|
|
2051
|
+
// - ne20: the size of the Value head
|
|
2052
|
+
// - + 2: the S and M values for each intermediate result
|
|
2053
|
+
res += wsp_ggml_type_size(WSP_GGML_TYPE_F32)*(ne01*ne02*ne03*nwg*(ne20 + 2));
|
|
2054
|
+
}
|
|
2055
|
+
|
|
2056
|
+
return res;
|
|
1892
2057
|
}
|
|
1893
2058
|
|
|
1894
2059
|
int wsp_ggml_metal_op_flash_attn_ext(wsp_ggml_metal_op_t ctx, int idx) {
|
|
@@ -1910,8 +2075,7 @@ int wsp_ggml_metal_op_flash_attn_ext(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
1910
2075
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
1911
2076
|
WSP_GGML_TENSOR_LOCALS( int32_t, nb, op, nb);
|
|
1912
2077
|
|
|
1913
|
-
WSP_GGML_ASSERT(ne00 % 4
|
|
1914
|
-
WSP_GGML_ASSERT(ne11 % 32 == 0);
|
|
2078
|
+
WSP_GGML_ASSERT(ne00 % 4 == 0);
|
|
1915
2079
|
|
|
1916
2080
|
WSP_GGML_ASSERT(op->src[0]->type == WSP_GGML_TYPE_F32);
|
|
1917
2081
|
WSP_GGML_ASSERT(op->src[1]->type == op->src[2]->type);
|
|
@@ -1921,8 +2085,8 @@ int wsp_ggml_metal_op_flash_attn_ext(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
1921
2085
|
WSP_GGML_ASSERT(ne12 == ne22);
|
|
1922
2086
|
|
|
1923
2087
|
WSP_GGML_ASSERT(!op->src[3] || op->src[3]->type == WSP_GGML_TYPE_F16);
|
|
1924
|
-
WSP_GGML_ASSERT(!op->src[3] || op->src[3]->ne[1] >=
|
|
1925
|
-
"the Flash-Attention Metal kernel requires the mask to be
|
|
2088
|
+
WSP_GGML_ASSERT(!op->src[3] || op->src[3]->ne[1] >= op->src[0]->ne[1] &&
|
|
2089
|
+
"the Flash-Attention Metal kernel requires the mask to be at least n_queries big");
|
|
1926
2090
|
|
|
1927
2091
|
float scale;
|
|
1928
2092
|
float max_bias;
|
|
@@ -1949,15 +2113,111 @@ int wsp_ggml_metal_op_flash_attn_ext(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
1949
2113
|
|
|
1950
2114
|
WSP_GGML_ASSERT(ne01 < 65536);
|
|
1951
2115
|
|
|
2116
|
+
wsp_ggml_metal_buffer_id bid_src0 = wsp_ggml_metal_get_buffer_id(op->src[0]);
|
|
2117
|
+
wsp_ggml_metal_buffer_id bid_src1 = wsp_ggml_metal_get_buffer_id(op->src[1]);
|
|
2118
|
+
wsp_ggml_metal_buffer_id bid_src2 = wsp_ggml_metal_get_buffer_id(op->src[2]);
|
|
2119
|
+
wsp_ggml_metal_buffer_id bid_src3 = has_mask ? wsp_ggml_metal_get_buffer_id(op->src[3]) : bid_src0;
|
|
2120
|
+
wsp_ggml_metal_buffer_id bid_src4 = has_sinks ? wsp_ggml_metal_get_buffer_id(op->src[4]) : bid_src0;
|
|
2121
|
+
|
|
2122
|
+
wsp_ggml_metal_buffer_id bid_dst = wsp_ggml_metal_get_buffer_id(op);
|
|
2123
|
+
|
|
2124
|
+
wsp_ggml_metal_buffer_id bid_pad = bid_dst;
|
|
2125
|
+
bid_pad.offs += wsp_ggml_nbytes(op);
|
|
2126
|
+
|
|
2127
|
+
wsp_ggml_metal_buffer_id bid_blk = bid_pad;
|
|
2128
|
+
bid_blk.offs += wsp_ggml_metal_op_flash_attn_ext_extra_pad(op);
|
|
2129
|
+
|
|
2130
|
+
wsp_ggml_metal_buffer_id bid_tmp = bid_blk;
|
|
2131
|
+
bid_tmp.offs += wsp_ggml_metal_op_flash_attn_ext_extra_blk(op);
|
|
2132
|
+
|
|
1952
2133
|
if (!wsp_ggml_metal_op_flash_attn_ext_use_vec(op)) {
|
|
1953
2134
|
// half8x8 kernel
|
|
1954
|
-
const
|
|
1955
|
-
const
|
|
2135
|
+
const int nqptg = OP_FLASH_ATTN_EXT_NQPTG; // queries per threadgroup
|
|
2136
|
+
const int ncpsg = OP_FLASH_ATTN_EXT_NCPSG; // cache values per simdgroup
|
|
1956
2137
|
|
|
1957
2138
|
WSP_GGML_ASSERT(nqptg <= 32);
|
|
1958
2139
|
WSP_GGML_ASSERT(nqptg % 8 == 0);
|
|
1959
2140
|
WSP_GGML_ASSERT(ncpsg % 32 == 0);
|
|
1960
2141
|
|
|
2142
|
+
bool need_sync = false;
|
|
2143
|
+
|
|
2144
|
+
const bool has_kvpad = ne11 % ncpsg != 0;
|
|
2145
|
+
|
|
2146
|
+
if (has_kvpad) {
|
|
2147
|
+
assert(wsp_ggml_metal_op_flash_attn_ext_extra_pad(op) != 0);
|
|
2148
|
+
|
|
2149
|
+
wsp_ggml_metal_kargs_flash_attn_ext_pad args0 = {
|
|
2150
|
+
/*.ne11 =*/ne11,
|
|
2151
|
+
/*.ne_12_2 =*/ne12,
|
|
2152
|
+
/*.ne_12_3 =*/ne13,
|
|
2153
|
+
/*.nb11 =*/nb11,
|
|
2154
|
+
/*.nb12 =*/nb12,
|
|
2155
|
+
/*.nb13 =*/nb13,
|
|
2156
|
+
/*.nb21 =*/nb21,
|
|
2157
|
+
/*.nb22 =*/nb22,
|
|
2158
|
+
/*.nb23 =*/nb23,
|
|
2159
|
+
/*.ne31 =*/ne31,
|
|
2160
|
+
/*.ne32 =*/ne32,
|
|
2161
|
+
/*.ne33 =*/ne33,
|
|
2162
|
+
/*.nb31 =*/nb31,
|
|
2163
|
+
/*.nb32 =*/nb32,
|
|
2164
|
+
/*.nb33 =*/nb33,
|
|
2165
|
+
};
|
|
2166
|
+
|
|
2167
|
+
wsp_ggml_metal_pipeline_t pipeline0 = wsp_ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg);
|
|
2168
|
+
|
|
2169
|
+
wsp_ggml_metal_encoder_set_pipeline(enc, pipeline0);
|
|
2170
|
+
wsp_ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0);
|
|
2171
|
+
wsp_ggml_metal_encoder_set_buffer (enc, bid_src1, 1);
|
|
2172
|
+
wsp_ggml_metal_encoder_set_buffer (enc, bid_src2, 2);
|
|
2173
|
+
wsp_ggml_metal_encoder_set_buffer (enc, bid_src3, 3);
|
|
2174
|
+
wsp_ggml_metal_encoder_set_buffer (enc, bid_pad, 4);
|
|
2175
|
+
|
|
2176
|
+
assert(ne12 == ne22);
|
|
2177
|
+
assert(ne13 == ne23);
|
|
2178
|
+
|
|
2179
|
+
wsp_ggml_metal_encoder_dispatch_threadgroups(enc, ncpsg, std::max(ne12, ne32), std::max(ne13, ne33), 32, 1, 1);
|
|
2180
|
+
|
|
2181
|
+
need_sync = true;
|
|
2182
|
+
} else {
|
|
2183
|
+
assert(wsp_ggml_metal_op_flash_attn_ext_extra_pad(op) == 0);
|
|
2184
|
+
}
|
|
2185
|
+
|
|
2186
|
+
if (has_mask) {
|
|
2187
|
+
assert(wsp_ggml_metal_op_flash_attn_ext_extra_blk(op) != 0);
|
|
2188
|
+
|
|
2189
|
+
wsp_ggml_metal_kargs_flash_attn_ext_blk args0 = {
|
|
2190
|
+
/*.ne01 =*/ ne01,
|
|
2191
|
+
/*.ne30 =*/ ne30,
|
|
2192
|
+
/*.ne31 =*/ ne31,
|
|
2193
|
+
/*.ne32 =*/ ne32,
|
|
2194
|
+
/*.ne33 =*/ ne33,
|
|
2195
|
+
/*.nb31 =*/ nb31,
|
|
2196
|
+
/*.nb32 =*/ nb32,
|
|
2197
|
+
/*.nb33 =*/ nb33,
|
|
2198
|
+
};
|
|
2199
|
+
|
|
2200
|
+
wsp_ggml_metal_pipeline_t pipeline0 = wsp_ggml_metal_library_get_pipeline_flash_attn_ext_blk(lib, op, nqptg, ncpsg);
|
|
2201
|
+
|
|
2202
|
+
wsp_ggml_metal_encoder_set_pipeline(enc, pipeline0);
|
|
2203
|
+
wsp_ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0);
|
|
2204
|
+
wsp_ggml_metal_encoder_set_buffer (enc, bid_src3, 1);
|
|
2205
|
+
wsp_ggml_metal_encoder_set_buffer (enc, bid_blk, 2);
|
|
2206
|
+
|
|
2207
|
+
const int32_t nblk1 = ((ne01 + nqptg - 1)/nqptg);
|
|
2208
|
+
const int32_t nblk0 = ((ne30 + ncpsg - 1)/ncpsg);
|
|
2209
|
+
|
|
2210
|
+
wsp_ggml_metal_encoder_dispatch_threadgroups(enc, nblk0, nblk1, ne32*ne33, 32, 1, 1);
|
|
2211
|
+
|
|
2212
|
+
need_sync = true;
|
|
2213
|
+
} else {
|
|
2214
|
+
assert(wsp_ggml_metal_op_flash_attn_ext_extra_blk(op) == 0);
|
|
2215
|
+
}
|
|
2216
|
+
|
|
2217
|
+
if (need_sync) {
|
|
2218
|
+
wsp_ggml_metal_op_concurrency_reset(ctx);
|
|
2219
|
+
}
|
|
2220
|
+
|
|
1961
2221
|
const int is_q = wsp_ggml_is_quantized(op->src[1]->type) ? 1 : 0;
|
|
1962
2222
|
|
|
1963
2223
|
// 2*(2*ncpsg)
|
|
@@ -2007,6 +2267,7 @@ int wsp_ggml_metal_op_flash_attn_ext(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
2007
2267
|
/*.nb21 =*/ nb21,
|
|
2008
2268
|
/*.nb22 =*/ nb22,
|
|
2009
2269
|
/*.nb23 =*/ nb23,
|
|
2270
|
+
/*.ne31 =*/ ne31,
|
|
2010
2271
|
/*.ne32 =*/ ne32,
|
|
2011
2272
|
/*.ne33 =*/ ne33,
|
|
2012
2273
|
/*.nb31 =*/ nb31,
|
|
@@ -2023,24 +2284,18 @@ int wsp_ggml_metal_op_flash_attn_ext(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
2023
2284
|
/*.logit_softcap =*/ logit_softcap,
|
|
2024
2285
|
};
|
|
2025
2286
|
|
|
2026
|
-
wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_flash_attn_ext(lib, op, has_mask, has_sinks, has_bias, has_scap, nsg);
|
|
2287
|
+
wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_flash_attn_ext(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg);
|
|
2027
2288
|
|
|
2028
2289
|
wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
2029
2290
|
wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
2030
|
-
wsp_ggml_metal_encoder_set_buffer (enc,
|
|
2031
|
-
wsp_ggml_metal_encoder_set_buffer (enc,
|
|
2032
|
-
wsp_ggml_metal_encoder_set_buffer (enc,
|
|
2033
|
-
|
|
2034
|
-
|
|
2035
|
-
|
|
2036
|
-
|
|
2037
|
-
|
|
2038
|
-
if (op->src[4]) {
|
|
2039
|
-
wsp_ggml_metal_encoder_set_buffer(enc, wsp_ggml_metal_get_buffer_id(op->src[4]), 5);
|
|
2040
|
-
} else {
|
|
2041
|
-
wsp_ggml_metal_encoder_set_buffer(enc, wsp_ggml_metal_get_buffer_id(op->src[0]), 5);
|
|
2042
|
-
}
|
|
2043
|
-
wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op), 6);
|
|
2291
|
+
wsp_ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
|
|
2292
|
+
wsp_ggml_metal_encoder_set_buffer (enc, bid_src1, 2);
|
|
2293
|
+
wsp_ggml_metal_encoder_set_buffer (enc, bid_src2, 3);
|
|
2294
|
+
wsp_ggml_metal_encoder_set_buffer (enc, bid_src3, 4);
|
|
2295
|
+
wsp_ggml_metal_encoder_set_buffer (enc, bid_src4, 5);
|
|
2296
|
+
wsp_ggml_metal_encoder_set_buffer (enc, bid_pad, 6);
|
|
2297
|
+
wsp_ggml_metal_encoder_set_buffer (enc, bid_blk, 7);
|
|
2298
|
+
wsp_ggml_metal_encoder_set_buffer (enc, bid_dst, 8);
|
|
2044
2299
|
|
|
2045
2300
|
wsp_ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
|
2046
2301
|
|
|
@@ -2048,14 +2303,62 @@ int wsp_ggml_metal_op_flash_attn_ext(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
2048
2303
|
#undef FATTN_SMEM
|
|
2049
2304
|
} else {
|
|
2050
2305
|
// half4x4 kernel
|
|
2051
|
-
const
|
|
2052
|
-
const
|
|
2053
|
-
const
|
|
2306
|
+
const int nqptg = OP_FLASH_ATTN_EXT_VEC_NQPTG; // queries per threadgroup
|
|
2307
|
+
const int ncpsg = OP_FLASH_ATTN_EXT_VEC_NCPSG; // cache values per simdgroup !! sync with kernel template arguments !!
|
|
2308
|
+
const int nkpsg = 1*ncpsg;
|
|
2054
2309
|
|
|
2055
2310
|
WSP_GGML_ASSERT(nqptg <= 32);
|
|
2056
2311
|
WSP_GGML_ASSERT(nqptg % 1 == 0);
|
|
2057
2312
|
WSP_GGML_ASSERT(ncpsg % 32 == 0);
|
|
2058
2313
|
|
|
2314
|
+
bool need_sync = false;
|
|
2315
|
+
|
|
2316
|
+
const bool has_kvpad = ne11 % ncpsg != 0;
|
|
2317
|
+
|
|
2318
|
+
if (has_kvpad) {
|
|
2319
|
+
assert(wsp_ggml_metal_op_flash_attn_ext_extra_pad(op) != 0);
|
|
2320
|
+
|
|
2321
|
+
wsp_ggml_metal_kargs_flash_attn_ext_pad args0 = {
|
|
2322
|
+
/*.ne11 =*/ne11,
|
|
2323
|
+
/*.ne_12_2 =*/ne12,
|
|
2324
|
+
/*.ne_12_3 =*/ne13,
|
|
2325
|
+
/*.nb11 =*/nb11,
|
|
2326
|
+
/*.nb12 =*/nb12,
|
|
2327
|
+
/*.nb13 =*/nb13,
|
|
2328
|
+
/*.nb21 =*/nb21,
|
|
2329
|
+
/*.nb22 =*/nb22,
|
|
2330
|
+
/*.nb23 =*/nb23,
|
|
2331
|
+
/*.ne31 =*/ne31,
|
|
2332
|
+
/*.ne32 =*/ne32,
|
|
2333
|
+
/*.ne33 =*/ne33,
|
|
2334
|
+
/*.nb31 =*/nb31,
|
|
2335
|
+
/*.nb32 =*/nb32,
|
|
2336
|
+
/*.nb33 =*/nb33,
|
|
2337
|
+
};
|
|
2338
|
+
|
|
2339
|
+
wsp_ggml_metal_pipeline_t pipeline0 = wsp_ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg);
|
|
2340
|
+
|
|
2341
|
+
wsp_ggml_metal_encoder_set_pipeline(enc, pipeline0);
|
|
2342
|
+
wsp_ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0);
|
|
2343
|
+
wsp_ggml_metal_encoder_set_buffer (enc, bid_src1, 1);
|
|
2344
|
+
wsp_ggml_metal_encoder_set_buffer (enc, bid_src2, 2);
|
|
2345
|
+
wsp_ggml_metal_encoder_set_buffer (enc, bid_src3, 3);
|
|
2346
|
+
wsp_ggml_metal_encoder_set_buffer (enc, bid_pad, 4);
|
|
2347
|
+
|
|
2348
|
+
assert(ne12 == ne22);
|
|
2349
|
+
assert(ne13 == ne23);
|
|
2350
|
+
|
|
2351
|
+
wsp_ggml_metal_encoder_dispatch_threadgroups(enc, ncpsg, std::max(ne12, ne32), std::max(ne13, ne33), 32, 1, 1);
|
|
2352
|
+
|
|
2353
|
+
need_sync = true;
|
|
2354
|
+
} else {
|
|
2355
|
+
assert(wsp_ggml_metal_op_flash_attn_ext_extra_pad(op) == 0);
|
|
2356
|
+
}
|
|
2357
|
+
|
|
2358
|
+
if (need_sync) {
|
|
2359
|
+
wsp_ggml_metal_op_concurrency_reset(ctx);
|
|
2360
|
+
}
|
|
2361
|
+
|
|
2059
2362
|
// ne00 + 2*ncpsg*(nsg)
|
|
2060
2363
|
// for each query, we load it as f16 in shared memory (ne00)
|
|
2061
2364
|
// and store the soft_max values and the mask
|
|
@@ -2120,6 +2423,7 @@ int wsp_ggml_metal_op_flash_attn_ext(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
2120
2423
|
/*.nb21 =*/ nb21,
|
|
2121
2424
|
/*.nb22 =*/ nb22,
|
|
2122
2425
|
/*.nb23 =*/ nb23,
|
|
2426
|
+
/*.ne31 =*/ ne31,
|
|
2123
2427
|
/*.ne32 =*/ ne32,
|
|
2124
2428
|
/*.ne33 =*/ ne33,
|
|
2125
2429
|
/*.nb31 =*/ nb31,
|
|
@@ -2136,25 +2440,17 @@ int wsp_ggml_metal_op_flash_attn_ext(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
2136
2440
|
/*.logit_softcap =*/ logit_softcap,
|
|
2137
2441
|
};
|
|
2138
2442
|
|
|
2139
|
-
wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_flash_attn_ext_vec(lib, op, has_mask, has_sinks, has_bias, has_scap, nsg, nwg);
|
|
2443
|
+
wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_flash_attn_ext_vec(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg, nwg);
|
|
2140
2444
|
|
|
2141
2445
|
WSP_GGML_ASSERT(nsg*32 <= wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
|
2142
2446
|
|
|
2143
2447
|
wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
2144
2448
|
wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
2145
|
-
wsp_ggml_metal_encoder_set_buffer (enc,
|
|
2146
|
-
wsp_ggml_metal_encoder_set_buffer (enc,
|
|
2147
|
-
wsp_ggml_metal_encoder_set_buffer (enc,
|
|
2148
|
-
|
|
2149
|
-
|
|
2150
|
-
} else {
|
|
2151
|
-
wsp_ggml_metal_encoder_set_buffer(enc, wsp_ggml_metal_get_buffer_id(op->src[0]), 4);
|
|
2152
|
-
}
|
|
2153
|
-
if (op->src[4]) {
|
|
2154
|
-
wsp_ggml_metal_encoder_set_buffer(enc, wsp_ggml_metal_get_buffer_id(op->src[4]), 5);
|
|
2155
|
-
} else {
|
|
2156
|
-
wsp_ggml_metal_encoder_set_buffer(enc, wsp_ggml_metal_get_buffer_id(op->src[0]), 5);
|
|
2157
|
-
}
|
|
2449
|
+
wsp_ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
|
|
2450
|
+
wsp_ggml_metal_encoder_set_buffer (enc, bid_src1, 2);
|
|
2451
|
+
wsp_ggml_metal_encoder_set_buffer (enc, bid_src2, 3);
|
|
2452
|
+
wsp_ggml_metal_encoder_set_buffer (enc, bid_src3, 4);
|
|
2453
|
+
wsp_ggml_metal_encoder_set_buffer (enc, bid_src4, 5);
|
|
2158
2454
|
|
|
2159
2455
|
const size_t smem = FATTN_SMEM(nsg);
|
|
2160
2456
|
|
|
@@ -2162,23 +2458,25 @@ int wsp_ggml_metal_op_flash_attn_ext(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
2162
2458
|
WSP_GGML_ASSERT(smem <= props_dev->max_theadgroup_memory_size);
|
|
2163
2459
|
|
|
2164
2460
|
if (nwg == 1) {
|
|
2461
|
+
assert(wsp_ggml_metal_op_flash_attn_ext_extra_tmp(op) == 0);
|
|
2462
|
+
|
|
2165
2463
|
// using 1 workgroup -> write the result directly into dst
|
|
2166
|
-
wsp_ggml_metal_encoder_set_buffer(enc,
|
|
2464
|
+
wsp_ggml_metal_encoder_set_buffer(enc, bid_pad, 6);
|
|
2465
|
+
wsp_ggml_metal_encoder_set_buffer(enc, bid_dst, 7);
|
|
2167
2466
|
|
|
2168
2467
|
wsp_ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
|
2169
2468
|
|
|
2170
2469
|
wsp_ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg, 32, nsg, 1);
|
|
2171
2470
|
} else {
|
|
2172
2471
|
// sanity checks
|
|
2472
|
+
assert(wsp_ggml_metal_op_flash_attn_ext_extra_tmp(op) != 0);
|
|
2473
|
+
|
|
2173
2474
|
WSP_GGML_ASSERT(ne01*ne02*ne03 == ne1*ne2*ne3);
|
|
2174
2475
|
WSP_GGML_ASSERT((uint64_t)ne1*ne2*ne3 <= (1u << 31));
|
|
2175
2476
|
|
|
2176
|
-
wsp_ggml_metal_buffer_id bid_dst = wsp_ggml_metal_get_buffer_id(op);
|
|
2177
|
-
|
|
2178
2477
|
// write the results from each workgroup into a temp buffer
|
|
2179
|
-
|
|
2180
|
-
bid_tmp
|
|
2181
|
-
wsp_ggml_metal_encoder_set_buffer(enc, bid_tmp, 6);
|
|
2478
|
+
wsp_ggml_metal_encoder_set_buffer(enc, bid_pad, 6);
|
|
2479
|
+
wsp_ggml_metal_encoder_set_buffer(enc, bid_tmp, 7);
|
|
2182
2480
|
|
|
2183
2481
|
wsp_ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
|
2184
2482
|
wsp_ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg, 32, nsg, 1);
|
|
@@ -2688,6 +2986,7 @@ int wsp_ggml_metal_op_rope(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
2688
2986
|
/* sect_1 =*/ sect_1,
|
|
2689
2987
|
/* sect_2 =*/ sect_2,
|
|
2690
2988
|
/* sect_3 =*/ sect_3,
|
|
2989
|
+
/* src2 =*/ op->src[2] != nullptr,
|
|
2691
2990
|
};
|
|
2692
2991
|
|
|
2693
2992
|
wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_rope(lib, op);
|
|
@@ -2823,6 +3122,62 @@ int wsp_ggml_metal_op_conv_transpose_1d(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
2823
3122
|
return 1;
|
|
2824
3123
|
}
|
|
2825
3124
|
|
|
3125
|
+
int wsp_ggml_metal_op_conv_transpose_2d(wsp_ggml_metal_op_t ctx, int idx) {
|
|
3126
|
+
wsp_ggml_tensor * op = ctx->node(idx);
|
|
3127
|
+
|
|
3128
|
+
wsp_ggml_metal_library_t lib = ctx->lib;
|
|
3129
|
+
wsp_ggml_metal_encoder_t enc = ctx->enc;
|
|
3130
|
+
|
|
3131
|
+
WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
3132
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
3133
|
+
WSP_GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
|
3134
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
|
3135
|
+
WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
3136
|
+
WSP_GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
|
3137
|
+
|
|
3138
|
+
const int32_t s0 = ((const int32_t *)(op->op_params))[0];
|
|
3139
|
+
|
|
3140
|
+
const int32_t IC = op->src[1]->ne[2];
|
|
3141
|
+
const int32_t IH = op->src[1]->ne[1];
|
|
3142
|
+
const int32_t IW = op->src[1]->ne[0];
|
|
3143
|
+
|
|
3144
|
+
const int32_t KH = op->src[0]->ne[1];
|
|
3145
|
+
const int32_t KW = op->src[0]->ne[0];
|
|
3146
|
+
|
|
3147
|
+
const int32_t OW = op->ne[0];
|
|
3148
|
+
const int32_t OH = op->ne[1];
|
|
3149
|
+
const int32_t OC = op->ne[2];
|
|
3150
|
+
|
|
3151
|
+
wsp_ggml_metal_kargs_conv_transpose_2d args = {
|
|
3152
|
+
/*.IC =*/ IC,
|
|
3153
|
+
/*.IH =*/ IH,
|
|
3154
|
+
/*.IW =*/ IW,
|
|
3155
|
+
/*.KH =*/ KH,
|
|
3156
|
+
/*.KW =*/ KW,
|
|
3157
|
+
/*.OC =*/ OC,
|
|
3158
|
+
/*.s0 =*/ s0,
|
|
3159
|
+
/*.nb0 =*/ nb0,
|
|
3160
|
+
/*.nb1 =*/ nb1,
|
|
3161
|
+
/*.nb2 =*/ nb2,
|
|
3162
|
+
};
|
|
3163
|
+
|
|
3164
|
+
wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_conv_transpose_2d(lib, op);
|
|
3165
|
+
|
|
3166
|
+
wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
3167
|
+
wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
3168
|
+
wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[0]), 1);
|
|
3169
|
+
wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[1]), 2);
|
|
3170
|
+
wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op), 3);
|
|
3171
|
+
|
|
3172
|
+
// Metal requires buffer size to be multiple of 16 bytes
|
|
3173
|
+
const size_t smem = WSP_GGML_PAD(KW * KH * sizeof(float), 16);
|
|
3174
|
+
wsp_ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
|
3175
|
+
|
|
3176
|
+
wsp_ggml_metal_encoder_dispatch_threadgroups(enc, OW, OH, OC, KW, KH, 1);
|
|
3177
|
+
|
|
3178
|
+
return 1;
|
|
3179
|
+
}
|
|
3180
|
+
|
|
2826
3181
|
int wsp_ggml_metal_op_upscale(wsp_ggml_metal_op_t ctx, int idx) {
|
|
2827
3182
|
wsp_ggml_tensor * op = ctx->node(idx);
|
|
2828
3183
|
|
|
@@ -3156,3 +3511,73 @@ int wsp_ggml_metal_op_leaky_relu(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
3156
3511
|
|
|
3157
3512
|
return 1;
|
|
3158
3513
|
}
|
|
3514
|
+
|
|
3515
|
+
int wsp_ggml_metal_op_opt_step_adamw(wsp_ggml_metal_op_t ctx, int idx) {
|
|
3516
|
+
wsp_ggml_tensor * op = ctx->node(idx);
|
|
3517
|
+
|
|
3518
|
+
wsp_ggml_metal_library_t lib = ctx->lib;
|
|
3519
|
+
wsp_ggml_metal_encoder_t enc = ctx->enc;
|
|
3520
|
+
|
|
3521
|
+
WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
3522
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
3523
|
+
WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
3524
|
+
WSP_GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
|
3525
|
+
|
|
3526
|
+
wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_opt_step_adamw(lib, op);
|
|
3527
|
+
|
|
3528
|
+
const int64_t np = wsp_ggml_nelements(op->src[0]);
|
|
3529
|
+
wsp_ggml_metal_kargs_opt_step_adamw args = {
|
|
3530
|
+
/*.np =*/ np,
|
|
3531
|
+
};
|
|
3532
|
+
|
|
3533
|
+
int ida = 0;
|
|
3534
|
+
|
|
3535
|
+
wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
3536
|
+
wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), ida++);
|
|
3537
|
+
wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[0]), ida++);
|
|
3538
|
+
wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[1]), ida++);
|
|
3539
|
+
wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[2]), ida++);
|
|
3540
|
+
wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[3]), ida++);
|
|
3541
|
+
wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[4]), ida++);
|
|
3542
|
+
|
|
3543
|
+
const int nth = std::min(wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0);
|
|
3544
|
+
const int64_t n = (np + nth - 1) / nth;
|
|
3545
|
+
|
|
3546
|
+
wsp_ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, nth, 1, 1);
|
|
3547
|
+
|
|
3548
|
+
return 1;
|
|
3549
|
+
}
|
|
3550
|
+
|
|
3551
|
+
int wsp_ggml_metal_op_opt_step_sgd(wsp_ggml_metal_op_t ctx, int idx) {
|
|
3552
|
+
wsp_ggml_tensor * op = ctx->node(idx);
|
|
3553
|
+
|
|
3554
|
+
wsp_ggml_metal_library_t lib = ctx->lib;
|
|
3555
|
+
wsp_ggml_metal_encoder_t enc = ctx->enc;
|
|
3556
|
+
|
|
3557
|
+
WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
3558
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
3559
|
+
WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
3560
|
+
WSP_GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
|
3561
|
+
|
|
3562
|
+
wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_opt_step_sgd(lib, op);
|
|
3563
|
+
|
|
3564
|
+
const int64_t np = wsp_ggml_nelements(op->src[0]);
|
|
3565
|
+
wsp_ggml_metal_kargs_opt_step_sgd args = {
|
|
3566
|
+
/*.np =*/ np,
|
|
3567
|
+
};
|
|
3568
|
+
|
|
3569
|
+
int ida = 0;
|
|
3570
|
+
|
|
3571
|
+
wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
3572
|
+
wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), ida++);
|
|
3573
|
+
wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[0]), ida++);
|
|
3574
|
+
wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[1]), ida++);
|
|
3575
|
+
wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[2]), ida++);
|
|
3576
|
+
|
|
3577
|
+
const int nth = std::min(wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0);
|
|
3578
|
+
const int64_t n = (np + nth - 1) / nth;
|
|
3579
|
+
|
|
3580
|
+
wsp_ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, nth, 1, 1);
|
|
3581
|
+
|
|
3582
|
+
return 1;
|
|
3583
|
+
}
|