whisper.rn 0.5.1 → 0.5.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/src/main/jni.cpp +12 -3
- package/cpp/ggml-alloc.c +49 -18
- package/cpp/ggml-backend-impl.h +0 -3
- package/cpp/ggml-backend-reg.cpp +8 -0
- package/cpp/ggml-backend.cpp +0 -2
- package/cpp/ggml-backend.h +2 -0
- package/cpp/ggml-cpu/amx/amx.cpp +1 -0
- package/cpp/ggml-cpu/arch/arm/quants.c +428 -26
- package/cpp/ggml-cpu/ggml-cpu-impl.h +4 -2
- package/cpp/ggml-cpu/ggml-cpu.c +67 -24
- package/cpp/ggml-cpu/ops.cpp +489 -364
- package/cpp/ggml-cpu/ops.h +4 -4
- package/cpp/ggml-cpu/repack.cpp +143 -29
- package/cpp/ggml-cpu/simd-mappings.h +25 -25
- package/cpp/ggml-cpu/unary-ops.cpp +151 -0
- package/cpp/ggml-cpu/unary-ops.h +7 -0
- package/cpp/ggml-cpu/vec.cpp +83 -0
- package/cpp/ggml-cpu/vec.h +20 -8
- package/cpp/ggml-impl.h +67 -2
- package/cpp/ggml-metal/ggml-metal-common.cpp +2 -2
- package/cpp/ggml-metal/ggml-metal-context.m +5 -6
- package/cpp/ggml-metal/ggml-metal-device.cpp +300 -14
- package/cpp/ggml-metal/ggml-metal-device.h +26 -1
- package/cpp/ggml-metal/ggml-metal-device.m +243 -28
- package/cpp/ggml-metal/ggml-metal-impl.h +177 -9
- package/cpp/ggml-metal/ggml-metal-ops.cpp +843 -157
- package/cpp/ggml-metal/ggml-metal-ops.h +8 -0
- package/cpp/ggml-metal/ggml-metal.cpp +8 -3
- package/cpp/ggml-metal/ggml-metal.metal +12436 -0
- package/cpp/ggml.c +317 -4
- package/cpp/ggml.h +139 -0
- package/cpp/jsi/RNWhisperJSI.cpp +7 -2
- package/cpp/rn-whisper.h +1 -0
- package/cpp/whisper.cpp +8 -2
- package/ios/RNWhisperContext.mm +3 -1
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +0 -3
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend.h +2 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-impl.h +67 -2
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml.h +139 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/rn-whisper.h +1 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Info.plist +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/ggml-metal.metal +12436 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/rnwhisper +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +0 -3
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +2 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +67 -2
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +139 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-whisper.h +1 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +1 -1
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/ggml-metal.metal +12436 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +0 -3
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend.h +2 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-impl.h +67 -2
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml.h +139 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/rn-whisper.h +1 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Info.plist +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/ggml-metal.metal +12436 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/rnwhisper +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +0 -3
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +2 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +67 -2
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +139 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-whisper.h +1 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +1 -1
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/ggml-metal.metal +12436 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
- package/lib/commonjs/NativeRNWhisper.js.map +1 -1
- package/lib/commonjs/version.json +1 -1
- package/lib/module/NativeRNWhisper.js.map +1 -1
- package/lib/module/version.json +1 -1
- package/lib/typescript/NativeRNWhisper.d.ts +2 -0
- package/lib/typescript/NativeRNWhisper.d.ts.map +1 -1
- package/package.json +1 -1
- package/src/NativeRNWhisper.ts +2 -0
- package/src/version.json +1 -1
- package/whisper-rn.podspec +1 -1
- package/cpp/ggml-metal/ggml-whisper-sim.metallib +0 -0
- package/cpp/ggml-metal/ggml-whisper.metallib +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
|
@@ -10,6 +10,8 @@
|
|
|
10
10
|
|
|
11
11
|
#include <cassert>
|
|
12
12
|
#include <algorithm>
|
|
13
|
+
#include <limits>
|
|
14
|
+
#include <cmath>
|
|
13
15
|
|
|
14
16
|
static wsp_ggml_metal_buffer_id wsp_ggml_metal_get_buffer_id(const wsp_ggml_tensor * t) {
|
|
15
17
|
if (!t) {
|
|
@@ -226,6 +228,10 @@ static int wsp_ggml_metal_op_encode_impl(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
226
228
|
WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, node->src[0], nb);
|
|
227
229
|
WSP_GGML_TENSOR_LOCALS( int64_t, ne1, node->src[1], ne);
|
|
228
230
|
WSP_GGML_TENSOR_LOCALS(uint64_t, nb1, node->src[1], nb);
|
|
231
|
+
WSP_GGML_TENSOR_LOCALS( int64_t, ne2, node->src[2], ne);
|
|
232
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb2, node->src[2], nb);
|
|
233
|
+
WSP_GGML_TENSOR_LOCALS( int64_t, ne3, node->src[3], ne);
|
|
234
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb3, node->src[3], nb);
|
|
229
235
|
WSP_GGML_TENSOR_LOCALS( int64_t, ne, node, ne);
|
|
230
236
|
WSP_GGML_TENSOR_LOCALS(uint64_t, nb, node, nb);
|
|
231
237
|
|
|
@@ -237,6 +243,14 @@ static int wsp_ggml_metal_op_encode_impl(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
237
243
|
WSP_GGML_LOG_DEBUG("%s: src1 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, wsp_ggml_type_name(node->src[1]->type), ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13,
|
|
238
244
|
wsp_ggml_is_contiguous(node->src[1]), node->src[1]->name);
|
|
239
245
|
}
|
|
246
|
+
if (node->src[2]) {
|
|
247
|
+
WSP_GGML_LOG_DEBUG("%s: src2 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, wsp_ggml_type_name(node->src[2]->type), ne20, ne21, ne22, ne23, nb20, nb21, nb22, nb23,
|
|
248
|
+
wsp_ggml_is_contiguous(node->src[2]), node->src[2]->name);
|
|
249
|
+
}
|
|
250
|
+
if (node->src[3]) {
|
|
251
|
+
WSP_GGML_LOG_DEBUG("%s: src3 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, wsp_ggml_type_name(node->src[3]->type), ne30, ne31, ne32, ne33, nb30, nb31, nb32, nb33,
|
|
252
|
+
wsp_ggml_is_contiguous(node->src[3]), node->src[3]->name);
|
|
253
|
+
}
|
|
240
254
|
if (node) {
|
|
241
255
|
WSP_GGML_LOG_DEBUG("%s: node - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], 1, %s\n", __func__, wsp_ggml_type_name(node->type), ne0, ne1, ne2, ne3, nb0, nb1, nb2, nb3,
|
|
242
256
|
node->name);
|
|
@@ -289,11 +303,19 @@ static int wsp_ggml_metal_op_encode_impl(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
289
303
|
{
|
|
290
304
|
n_fuse = wsp_ggml_metal_op_glu(ctx, idx);
|
|
291
305
|
} break;
|
|
306
|
+
case WSP_GGML_OP_SUM:
|
|
307
|
+
{
|
|
308
|
+
n_fuse = wsp_ggml_metal_op_sum(ctx, idx);
|
|
309
|
+
} break;
|
|
292
310
|
case WSP_GGML_OP_SUM_ROWS:
|
|
293
311
|
case WSP_GGML_OP_MEAN:
|
|
294
312
|
{
|
|
295
313
|
n_fuse = wsp_ggml_metal_op_sum_rows(ctx, idx);
|
|
296
314
|
} break;
|
|
315
|
+
case WSP_GGML_OP_CUMSUM:
|
|
316
|
+
{
|
|
317
|
+
n_fuse = wsp_ggml_metal_op_cumsum(ctx, idx);
|
|
318
|
+
} break;
|
|
297
319
|
case WSP_GGML_OP_SOFT_MAX:
|
|
298
320
|
{
|
|
299
321
|
n_fuse = wsp_ggml_metal_op_soft_max(ctx, idx);
|
|
@@ -348,10 +370,18 @@ static int wsp_ggml_metal_op_encode_impl(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
348
370
|
{
|
|
349
371
|
n_fuse = wsp_ggml_metal_op_im2col(ctx, idx);
|
|
350
372
|
} break;
|
|
373
|
+
case WSP_GGML_OP_CONV_2D:
|
|
374
|
+
{
|
|
375
|
+
n_fuse = wsp_ggml_metal_op_conv_2d(ctx, idx);
|
|
376
|
+
} break;
|
|
351
377
|
case WSP_GGML_OP_CONV_TRANSPOSE_1D:
|
|
352
378
|
{
|
|
353
379
|
n_fuse = wsp_ggml_metal_op_conv_transpose_1d(ctx, idx);
|
|
354
380
|
} break;
|
|
381
|
+
case WSP_GGML_OP_CONV_TRANSPOSE_2D:
|
|
382
|
+
{
|
|
383
|
+
n_fuse = wsp_ggml_metal_op_conv_transpose_2d(ctx, idx);
|
|
384
|
+
} break;
|
|
355
385
|
case WSP_GGML_OP_UPSCALE:
|
|
356
386
|
{
|
|
357
387
|
n_fuse = wsp_ggml_metal_op_upscale(ctx, idx);
|
|
@@ -398,6 +428,14 @@ static int wsp_ggml_metal_op_encode_impl(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
398
428
|
{
|
|
399
429
|
n_fuse = wsp_ggml_metal_op_argmax(ctx, idx);
|
|
400
430
|
} break;
|
|
431
|
+
case WSP_GGML_OP_OPT_STEP_ADAMW:
|
|
432
|
+
{
|
|
433
|
+
n_fuse = wsp_ggml_metal_op_opt_step_adamw(ctx, idx);
|
|
434
|
+
} break;
|
|
435
|
+
case WSP_GGML_OP_OPT_STEP_SGD:
|
|
436
|
+
{
|
|
437
|
+
n_fuse = wsp_ggml_metal_op_opt_step_sgd(ctx, idx);
|
|
438
|
+
} break;
|
|
401
439
|
default:
|
|
402
440
|
{
|
|
403
441
|
WSP_GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, wsp_ggml_op_name(node->op));
|
|
@@ -506,7 +544,7 @@ int wsp_ggml_metal_op_repeat(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
506
544
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
507
545
|
WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
508
546
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
509
|
-
WSP_GGML_TENSOR_LOCALS(
|
|
547
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
510
548
|
|
|
511
549
|
wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_repeat(lib, op->type);
|
|
512
550
|
|
|
@@ -552,7 +590,7 @@ int wsp_ggml_metal_op_acc(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
552
590
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
|
553
591
|
WSP_GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
|
554
592
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
555
|
-
WSP_GGML_TENSOR_LOCALS(
|
|
593
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
556
594
|
|
|
557
595
|
WSP_GGML_ASSERT(op->src[0]->type == WSP_GGML_TYPE_F32);
|
|
558
596
|
WSP_GGML_ASSERT(op->src[1]->type == WSP_GGML_TYPE_F32);
|
|
@@ -577,6 +615,7 @@ int wsp_ggml_metal_op_acc(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
577
615
|
wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type);
|
|
578
616
|
|
|
579
617
|
wsp_ggml_metal_kargs_cpy args = {
|
|
618
|
+
/*.nk0 =*/ ne00,
|
|
580
619
|
/*.ne00 =*/ ne00,
|
|
581
620
|
/*.ne01 =*/ ne01,
|
|
582
621
|
/*.ne02 =*/ ne02,
|
|
@@ -660,7 +699,7 @@ int wsp_ggml_metal_op_scale(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
660
699
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
661
700
|
WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
662
701
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
663
|
-
WSP_GGML_TENSOR_LOCALS(
|
|
702
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
664
703
|
|
|
665
704
|
float scale;
|
|
666
705
|
float bias;
|
|
@@ -699,7 +738,7 @@ int wsp_ggml_metal_op_clamp(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
699
738
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
700
739
|
WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
701
740
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
702
|
-
WSP_GGML_TENSOR_LOCALS(
|
|
741
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
703
742
|
|
|
704
743
|
float min;
|
|
705
744
|
float max;
|
|
@@ -738,7 +777,7 @@ int wsp_ggml_metal_op_unary(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
738
777
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
739
778
|
WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
740
779
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
741
|
-
WSP_GGML_TENSOR_LOCALS(
|
|
780
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
742
781
|
|
|
743
782
|
int64_t n = wsp_ggml_nelements(op);
|
|
744
783
|
|
|
@@ -768,7 +807,7 @@ int wsp_ggml_metal_op_glu(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
768
807
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
|
769
808
|
WSP_GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
|
770
809
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
771
|
-
WSP_GGML_TENSOR_LOCALS(
|
|
810
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
772
811
|
|
|
773
812
|
if (op->src[1]) {
|
|
774
813
|
WSP_GGML_ASSERT(wsp_ggml_are_same_shape(op->src[0], op->src[1]));
|
|
@@ -800,18 +839,6 @@ int wsp_ggml_metal_op_glu(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
800
839
|
|
|
801
840
|
const int32_t nth = std::min(wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne00/2);
|
|
802
841
|
|
|
803
|
-
//[encoder setComputePipelineState:pipeline];
|
|
804
|
-
//[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
805
|
-
//if (src1) {
|
|
806
|
-
// [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
|
807
|
-
//} else {
|
|
808
|
-
// [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
|
809
|
-
//}
|
|
810
|
-
//[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
811
|
-
//[encoder setBytes:&args length:sizeof(args) atIndex:3];
|
|
812
|
-
|
|
813
|
-
//[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
814
|
-
|
|
815
842
|
wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
816
843
|
wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
817
844
|
wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[0]), 1);
|
|
@@ -827,6 +854,43 @@ int wsp_ggml_metal_op_glu(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
827
854
|
return 1;
|
|
828
855
|
}
|
|
829
856
|
|
|
857
|
+
int wsp_ggml_metal_op_sum(wsp_ggml_metal_op_t ctx, int idx) {
|
|
858
|
+
wsp_ggml_tensor * op = ctx->node(idx);
|
|
859
|
+
|
|
860
|
+
wsp_ggml_metal_library_t lib = ctx->lib;
|
|
861
|
+
wsp_ggml_metal_encoder_t enc = ctx->enc;
|
|
862
|
+
|
|
863
|
+
const uint64_t n = (uint64_t) wsp_ggml_nelements(op->src[0]);
|
|
864
|
+
|
|
865
|
+
wsp_ggml_metal_kargs_sum args = {
|
|
866
|
+
/*.np =*/ n,
|
|
867
|
+
};
|
|
868
|
+
|
|
869
|
+
wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_sum(lib, op);
|
|
870
|
+
|
|
871
|
+
int nth = 32; // SIMD width
|
|
872
|
+
|
|
873
|
+
while (nth < (int) n && nth < wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
|
|
874
|
+
nth *= 2;
|
|
875
|
+
}
|
|
876
|
+
|
|
877
|
+
nth = std::min(nth, wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
|
878
|
+
nth = std::min(nth, (int) n);
|
|
879
|
+
|
|
880
|
+
const int nsg = (nth + 31) / 32;
|
|
881
|
+
|
|
882
|
+
wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
883
|
+
wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
884
|
+
wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[0]), 1);
|
|
885
|
+
wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op), 2);
|
|
886
|
+
|
|
887
|
+
wsp_ggml_metal_encoder_set_threadgroup_memory_size(enc, nsg * sizeof(float), 0);
|
|
888
|
+
|
|
889
|
+
wsp_ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, nth, 1, 1);
|
|
890
|
+
|
|
891
|
+
return 1;
|
|
892
|
+
}
|
|
893
|
+
|
|
830
894
|
int wsp_ggml_metal_op_sum_rows(wsp_ggml_metal_op_t ctx, int idx) {
|
|
831
895
|
wsp_ggml_tensor * op = ctx->node(idx);
|
|
832
896
|
|
|
@@ -836,7 +900,7 @@ int wsp_ggml_metal_op_sum_rows(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
836
900
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
837
901
|
WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
838
902
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
839
|
-
WSP_GGML_TENSOR_LOCALS(
|
|
903
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
840
904
|
|
|
841
905
|
wsp_ggml_metal_kargs_sum_rows args = {
|
|
842
906
|
/*.ne00 =*/ ne00,
|
|
@@ -870,14 +934,6 @@ int wsp_ggml_metal_op_sum_rows(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
870
934
|
|
|
871
935
|
const size_t smem = wsp_ggml_metal_pipeline_get_smem(pipeline);
|
|
872
936
|
|
|
873
|
-
//[encoder setComputePipelineState:pipeline];
|
|
874
|
-
//[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
|
875
|
-
//[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
|
876
|
-
//[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
877
|
-
//[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
|
878
|
-
|
|
879
|
-
//[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
880
|
-
|
|
881
937
|
wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
882
938
|
wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
883
939
|
wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[0]), 1);
|
|
@@ -890,6 +946,149 @@ int wsp_ggml_metal_op_sum_rows(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
890
946
|
return 1;
|
|
891
947
|
}
|
|
892
948
|
|
|
949
|
+
int wsp_ggml_metal_op_cumsum(wsp_ggml_metal_op_t ctx, int idx) {
|
|
950
|
+
wsp_ggml_tensor * op = ctx->node(idx);
|
|
951
|
+
|
|
952
|
+
wsp_ggml_metal_library_t lib = ctx->lib;
|
|
953
|
+
wsp_ggml_metal_encoder_t enc = ctx->enc;
|
|
954
|
+
|
|
955
|
+
WSP_GGML_ASSERT(wsp_ggml_is_contiguous_rows(op->src[0]));
|
|
956
|
+
|
|
957
|
+
WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
958
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
959
|
+
WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
960
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
961
|
+
|
|
962
|
+
wsp_ggml_metal_pipeline_t pipeline_blk = wsp_ggml_metal_library_get_pipeline_cumsum_blk(lib, op);
|
|
963
|
+
|
|
964
|
+
int nth = 1;
|
|
965
|
+
while (nth < ne00 && 2*nth <= wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline_blk)) {
|
|
966
|
+
nth *= 2;
|
|
967
|
+
}
|
|
968
|
+
|
|
969
|
+
WSP_GGML_ASSERT(ne00 <= nth*nth);
|
|
970
|
+
|
|
971
|
+
const int64_t net0 = (ne00 + nth - 1) / nth;
|
|
972
|
+
const int64_t net1 = ne01;
|
|
973
|
+
const int64_t net2 = ne02;
|
|
974
|
+
const int64_t net3 = ne03;
|
|
975
|
+
|
|
976
|
+
const uint64_t nbt0 = sizeof(float);
|
|
977
|
+
const uint64_t nbt1 = net0*nbt0;
|
|
978
|
+
const uint64_t nbt2 = net1*nbt1;
|
|
979
|
+
const uint64_t nbt3 = net2*nbt2;
|
|
980
|
+
|
|
981
|
+
const size_t smem = WSP_GGML_PAD(32*sizeof(float), 16);
|
|
982
|
+
|
|
983
|
+
wsp_ggml_metal_buffer_id bid_src0 = wsp_ggml_metal_get_buffer_id(op->src[0]);
|
|
984
|
+
wsp_ggml_metal_buffer_id bid_dst = wsp_ggml_metal_get_buffer_id(op);
|
|
985
|
+
|
|
986
|
+
wsp_ggml_metal_buffer_id bid_tmp = bid_dst;
|
|
987
|
+
bid_tmp.offs += wsp_ggml_nbytes(op);
|
|
988
|
+
|
|
989
|
+
{
|
|
990
|
+
wsp_ggml_metal_kargs_cumsum_blk args = {
|
|
991
|
+
/*.ne00 =*/ ne00,
|
|
992
|
+
/*.ne01 =*/ ne01,
|
|
993
|
+
/*.ne02 =*/ ne02,
|
|
994
|
+
/*.ne03 =*/ ne03,
|
|
995
|
+
/*.nb00 =*/ nb00,
|
|
996
|
+
/*.nb01 =*/ nb01,
|
|
997
|
+
/*.nb02 =*/ nb02,
|
|
998
|
+
/*.nb03 =*/ nb03,
|
|
999
|
+
/*.net0 =*/ net0,
|
|
1000
|
+
/*.net1 =*/ net1,
|
|
1001
|
+
/*.net2 =*/ net2,
|
|
1002
|
+
/*.net3 =*/ net3,
|
|
1003
|
+
/*.nbt0 =*/ nbt0,
|
|
1004
|
+
/*.nbt1 =*/ nbt1,
|
|
1005
|
+
/*.nbt2 =*/ nbt2,
|
|
1006
|
+
/*.nbt3 =*/ nbt3,
|
|
1007
|
+
/*.outb =*/ ne00 > nth,
|
|
1008
|
+
};
|
|
1009
|
+
|
|
1010
|
+
wsp_ggml_metal_encoder_set_pipeline(enc, pipeline_blk);
|
|
1011
|
+
wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
1012
|
+
wsp_ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
|
|
1013
|
+
wsp_ggml_metal_encoder_set_buffer (enc, bid_tmp, 2);
|
|
1014
|
+
wsp_ggml_metal_encoder_set_buffer (enc, bid_dst, 3);
|
|
1015
|
+
|
|
1016
|
+
wsp_ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
|
1017
|
+
|
|
1018
|
+
wsp_ggml_metal_encoder_dispatch_threadgroups(enc, net0*ne01, ne02, ne03, nth, 1, 1);
|
|
1019
|
+
}
|
|
1020
|
+
|
|
1021
|
+
if (ne00 > nth) {
|
|
1022
|
+
wsp_ggml_metal_op_concurrency_reset(ctx);
|
|
1023
|
+
|
|
1024
|
+
{
|
|
1025
|
+
wsp_ggml_metal_kargs_cumsum_blk args = {
|
|
1026
|
+
/*.ne00 =*/ net0,
|
|
1027
|
+
/*.ne01 =*/ net1,
|
|
1028
|
+
/*.ne02 =*/ net2,
|
|
1029
|
+
/*.ne03 =*/ net3,
|
|
1030
|
+
/*.nb00 =*/ nbt0,
|
|
1031
|
+
/*.nb01 =*/ nbt1,
|
|
1032
|
+
/*.nb02 =*/ nbt2,
|
|
1033
|
+
/*.nb03 =*/ nbt3,
|
|
1034
|
+
/*.net0 =*/ net0,
|
|
1035
|
+
/*.net1 =*/ net1,
|
|
1036
|
+
/*.net2 =*/ net2,
|
|
1037
|
+
/*.net3 =*/ net3,
|
|
1038
|
+
/*.nbt0 =*/ nbt0,
|
|
1039
|
+
/*.nbt1 =*/ nbt1,
|
|
1040
|
+
/*.nbt2 =*/ nbt2,
|
|
1041
|
+
/*.nbt3 =*/ nbt3,
|
|
1042
|
+
/*.outb =*/ false,
|
|
1043
|
+
};
|
|
1044
|
+
|
|
1045
|
+
wsp_ggml_metal_encoder_set_pipeline(enc, pipeline_blk);
|
|
1046
|
+
wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
1047
|
+
wsp_ggml_metal_encoder_set_buffer (enc, bid_tmp, 1);
|
|
1048
|
+
wsp_ggml_metal_encoder_set_buffer (enc, bid_tmp, 2);
|
|
1049
|
+
wsp_ggml_metal_encoder_set_buffer (enc, bid_tmp, 3);
|
|
1050
|
+
|
|
1051
|
+
wsp_ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
|
1052
|
+
|
|
1053
|
+
wsp_ggml_metal_encoder_dispatch_threadgroups(enc, net1, net2, net3, nth, 1, 1);
|
|
1054
|
+
}
|
|
1055
|
+
|
|
1056
|
+
wsp_ggml_metal_op_concurrency_reset(ctx);
|
|
1057
|
+
|
|
1058
|
+
{
|
|
1059
|
+
wsp_ggml_metal_pipeline_t pipeline_add = wsp_ggml_metal_library_get_pipeline_cumsum_add(lib, op);
|
|
1060
|
+
|
|
1061
|
+
wsp_ggml_metal_kargs_cumsum_add args = {
|
|
1062
|
+
/*.ne00 =*/ ne00,
|
|
1063
|
+
/*.ne01 =*/ ne01,
|
|
1064
|
+
/*.ne02 =*/ ne02,
|
|
1065
|
+
/*.ne03 =*/ ne03,
|
|
1066
|
+
/*.nb00 =*/ nb00,
|
|
1067
|
+
/*.nb01 =*/ nb01,
|
|
1068
|
+
/*.nb02 =*/ nb02,
|
|
1069
|
+
/*.nb03 =*/ nb03,
|
|
1070
|
+
/*.net0 =*/ net0,
|
|
1071
|
+
/*.net1 =*/ net1,
|
|
1072
|
+
/*.net2 =*/ net2,
|
|
1073
|
+
/*.net3 =*/ net3,
|
|
1074
|
+
/*.nbt0 =*/ nbt0,
|
|
1075
|
+
/*.nbt1 =*/ nbt1,
|
|
1076
|
+
/*.nbt2 =*/ nbt2,
|
|
1077
|
+
/*.nbt3 =*/ nbt3,
|
|
1078
|
+
};
|
|
1079
|
+
|
|
1080
|
+
wsp_ggml_metal_encoder_set_pipeline(enc, pipeline_add);
|
|
1081
|
+
wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
1082
|
+
wsp_ggml_metal_encoder_set_buffer (enc, bid_tmp, 1);
|
|
1083
|
+
wsp_ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
|
|
1084
|
+
|
|
1085
|
+
wsp_ggml_metal_encoder_dispatch_threadgroups(enc, net0*ne01, ne02, ne03, nth, 1, 1);
|
|
1086
|
+
}
|
|
1087
|
+
}
|
|
1088
|
+
|
|
1089
|
+
return 1;
|
|
1090
|
+
}
|
|
1091
|
+
|
|
893
1092
|
int wsp_ggml_metal_op_get_rows(wsp_ggml_metal_op_t ctx, int idx) {
|
|
894
1093
|
wsp_ggml_tensor * op = ctx->node(idx);
|
|
895
1094
|
|
|
@@ -901,28 +1100,36 @@ int wsp_ggml_metal_op_get_rows(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
901
1100
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
|
902
1101
|
WSP_GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
|
903
1102
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
904
|
-
WSP_GGML_TENSOR_LOCALS(
|
|
1103
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
905
1104
|
|
|
906
1105
|
wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_get_rows(lib, op->src[0]->type);
|
|
907
1106
|
|
|
908
1107
|
wsp_ggml_metal_kargs_get_rows args = {
|
|
909
|
-
/*.
|
|
910
|
-
/*.
|
|
911
|
-
/*.
|
|
912
|
-
/*.
|
|
913
|
-
/*.
|
|
914
|
-
/*.
|
|
915
|
-
/*.
|
|
916
|
-
/*.
|
|
1108
|
+
/*.ne00t =*/ wsp_ggml_is_quantized(op->src[0]->type) ? ne00/16 : ne00,
|
|
1109
|
+
/*.ne00 =*/ ne00,
|
|
1110
|
+
/*.nb01 =*/ nb01,
|
|
1111
|
+
/*.nb02 =*/ nb02,
|
|
1112
|
+
/*.nb03 =*/ nb03,
|
|
1113
|
+
/*.ne10 =*/ ne10,
|
|
1114
|
+
/*.nb10 =*/ nb10,
|
|
1115
|
+
/*.nb11 =*/ nb11,
|
|
1116
|
+
/*.nb12 =*/ nb12,
|
|
1117
|
+
/*.nb1 =*/ nb1,
|
|
1118
|
+
/*.nb2 =*/ nb2,
|
|
1119
|
+
/*.nb3 =*/ nb3,
|
|
917
1120
|
};
|
|
918
1121
|
|
|
1122
|
+
const int nth = std::min(args.ne00t, wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
|
1123
|
+
|
|
1124
|
+
const int nw0 = (args.ne00t + nth - 1)/nth;
|
|
1125
|
+
|
|
919
1126
|
wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
920
1127
|
wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
921
1128
|
wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[0]), 1);
|
|
922
1129
|
wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[1]), 2);
|
|
923
1130
|
wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op), 3);
|
|
924
1131
|
|
|
925
|
-
wsp_ggml_metal_encoder_dispatch_threadgroups(enc, ne10, ne11, ne12,
|
|
1132
|
+
wsp_ggml_metal_encoder_dispatch_threadgroups(enc, nw0*ne10, ne11, ne12, nth, 1, 1);
|
|
926
1133
|
|
|
927
1134
|
return 1;
|
|
928
1135
|
}
|
|
@@ -938,7 +1145,7 @@ int wsp_ggml_metal_op_set_rows(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
938
1145
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
|
939
1146
|
WSP_GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
|
940
1147
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
941
|
-
WSP_GGML_TENSOR_LOCALS(
|
|
1148
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
942
1149
|
|
|
943
1150
|
wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_set_rows(lib, op->src[1]->type, op->type);
|
|
944
1151
|
|
|
@@ -1002,7 +1209,7 @@ int wsp_ggml_metal_op_soft_max(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
1002
1209
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
|
|
1003
1210
|
WSP_GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
|
|
1004
1211
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
1005
|
-
WSP_GGML_TENSOR_LOCALS(
|
|
1212
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
1006
1213
|
|
|
1007
1214
|
float scale;
|
|
1008
1215
|
float max_bias;
|
|
@@ -1090,7 +1297,7 @@ int wsp_ggml_metal_op_ssm_conv(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
1090
1297
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
|
1091
1298
|
WSP_GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
|
1092
1299
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
1093
|
-
WSP_GGML_TENSOR_LOCALS(
|
|
1300
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
1094
1301
|
|
|
1095
1302
|
wsp_ggml_metal_kargs_ssm_conv args = {
|
|
1096
1303
|
/*.ne00 =*/ ne00,
|
|
@@ -1117,7 +1324,7 @@ int wsp_ggml_metal_op_ssm_conv(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
1117
1324
|
wsp_ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
|
|
1118
1325
|
wsp_ggml_metal_encoder_set_buffer(enc, wsp_ggml_metal_get_buffer_id(op->src[0]), 1);
|
|
1119
1326
|
wsp_ggml_metal_encoder_set_buffer(enc, wsp_ggml_metal_get_buffer_id(op->src[1]), 2);
|
|
1120
|
-
wsp_ggml_metal_encoder_set_buffer(enc, wsp_ggml_metal_get_buffer_id(op),
|
|
1327
|
+
wsp_ggml_metal_encoder_set_buffer(enc, wsp_ggml_metal_get_buffer_id(op), 3);
|
|
1121
1328
|
|
|
1122
1329
|
wsp_ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne1, ne02, 1, 1, 1);
|
|
1123
1330
|
|
|
@@ -1145,7 +1352,7 @@ int wsp_ggml_metal_op_ssm_scan(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
1145
1352
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne6, op->src[6], ne);
|
|
1146
1353
|
WSP_GGML_TENSOR_LOCALS(uint64_t, nb6, op->src[6], nb);
|
|
1147
1354
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
1148
|
-
WSP_GGML_TENSOR_LOCALS(
|
|
1355
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
1149
1356
|
|
|
1150
1357
|
const wsp_ggml_tensor * src3 = op->src[3];
|
|
1151
1358
|
const wsp_ggml_tensor * src4 = op->src[4];
|
|
@@ -1172,25 +1379,36 @@ int wsp_ggml_metal_op_ssm_scan(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
1172
1379
|
/*.n_seq_tokens =*/ n_seq_tokens,
|
|
1173
1380
|
/*.n_seqs =*/ n_seqs,
|
|
1174
1381
|
/*.s_off =*/ wsp_ggml_nelements(op->src[1]) * sizeof(float),
|
|
1382
|
+
/*.nb00 =*/ nb00,
|
|
1175
1383
|
/*.nb01 =*/ nb01,
|
|
1176
1384
|
/*.nb02 =*/ nb02,
|
|
1177
1385
|
/*.nb03 =*/ nb03,
|
|
1386
|
+
/*.nb10 =*/ nb10,
|
|
1178
1387
|
/*.nb11 =*/ nb11,
|
|
1179
1388
|
/*.nb12 =*/ nb12,
|
|
1389
|
+
/*.ns12 =*/ nb12/nb10,
|
|
1180
1390
|
/*.nb13 =*/ nb13,
|
|
1391
|
+
/*.nb20 =*/ nb20,
|
|
1181
1392
|
/*.nb21 =*/ nb21,
|
|
1393
|
+
/*.ns21 =*/ nb21/nb20,
|
|
1182
1394
|
/*.nb22 =*/ nb22,
|
|
1395
|
+
/*.ne30 =*/ ne30,
|
|
1183
1396
|
/*.nb31 =*/ nb31,
|
|
1184
1397
|
/*.nb41 =*/ nb41,
|
|
1185
1398
|
/*.nb42 =*/ nb42,
|
|
1399
|
+
/*.ns42 =*/ nb42/nb40,
|
|
1186
1400
|
/*.nb43 =*/ nb43,
|
|
1187
1401
|
/*.nb51 =*/ nb51,
|
|
1188
1402
|
/*.nb52 =*/ nb52,
|
|
1403
|
+
/*.ns52 =*/ nb52/nb50,
|
|
1189
1404
|
/*.nb53 =*/ nb53,
|
|
1405
|
+
/*.nb0 =*/ nb0,
|
|
1190
1406
|
};
|
|
1191
1407
|
|
|
1192
1408
|
wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_ssm_scan(lib, op);
|
|
1193
1409
|
|
|
1410
|
+
WSP_GGML_ASSERT(d_state <= wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
|
1411
|
+
|
|
1194
1412
|
const size_t sms = wsp_ggml_metal_pipeline_get_smem(pipeline);
|
|
1195
1413
|
|
|
1196
1414
|
wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
@@ -1206,13 +1424,7 @@ int wsp_ggml_metal_op_ssm_scan(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
1206
1424
|
|
|
1207
1425
|
wsp_ggml_metal_encoder_set_threadgroup_memory_size(enc, sms, 0);
|
|
1208
1426
|
|
|
1209
|
-
|
|
1210
|
-
// Mamba-2
|
|
1211
|
-
wsp_ggml_metal_encoder_dispatch_threadgroups(enc, d_inner, n_head, n_seqs, d_state, 1, 1);
|
|
1212
|
-
} else {
|
|
1213
|
-
WSP_GGML_ASSERT(d_inner == 1);
|
|
1214
|
-
wsp_ggml_metal_encoder_dispatch_threadgroups(enc, n_head, n_seqs, 1, d_state, 1, 1);
|
|
1215
|
-
}
|
|
1427
|
+
wsp_ggml_metal_encoder_dispatch_threadgroups(enc, d_inner, n_head, n_seqs, d_state, 1, 1);
|
|
1216
1428
|
|
|
1217
1429
|
return 1;
|
|
1218
1430
|
}
|
|
@@ -1226,7 +1438,7 @@ int wsp_ggml_metal_op_rwkv(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
1226
1438
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
1227
1439
|
WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
1228
1440
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
1229
|
-
WSP_GGML_TENSOR_LOCALS(
|
|
1441
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
1230
1442
|
|
|
1231
1443
|
const int64_t B = op->op == WSP_GGML_OP_RWKV_WKV6 ? op->src[5]->ne[1] : op->src[6]->ne[1];
|
|
1232
1444
|
const int64_t T = op->src[0]->ne[2];
|
|
@@ -1267,32 +1479,29 @@ int wsp_ggml_metal_op_cpy(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
1267
1479
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
1268
1480
|
WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
1269
1481
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
1270
|
-
WSP_GGML_TENSOR_LOCALS(
|
|
1482
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
1271
1483
|
|
|
1272
1484
|
wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type);
|
|
1273
1485
|
|
|
1274
1486
|
WSP_GGML_ASSERT(ne00 % wsp_ggml_blck_size(op->src[0]->type) == 0);
|
|
1275
1487
|
|
|
1276
|
-
|
|
1277
|
-
|
|
1278
|
-
|
|
1279
|
-
|
|
1280
|
-
|
|
1281
|
-
|
|
1282
|
-
while (nth < nk00 && nth < wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
|
|
1283
|
-
nth *= 2;
|
|
1488
|
+
int64_t nk0 = ne00;
|
|
1489
|
+
if (wsp_ggml_is_quantized(op->src[0]->type)) {
|
|
1490
|
+
nk0 = ne00/16;
|
|
1491
|
+
} else if (wsp_ggml_is_quantized(op->type)) {
|
|
1492
|
+
nk0 = ne00/wsp_ggml_blck_size(op->type);
|
|
1284
1493
|
}
|
|
1285
1494
|
|
|
1286
|
-
nth = std::min(
|
|
1495
|
+
int nth = std::min<int>(nk0, wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
|
1287
1496
|
|
|
1288
1497
|
// when rows are small, we can batch them together in a single threadgroup
|
|
1289
1498
|
int nrptg = 1;
|
|
1290
1499
|
|
|
1291
1500
|
// TODO: relax this constraint in the future
|
|
1292
1501
|
if (wsp_ggml_blck_size(op->src[0]->type) == 1 && wsp_ggml_blck_size(op->type) == 1) {
|
|
1293
|
-
if (nth >
|
|
1294
|
-
nrptg = (nth +
|
|
1295
|
-
nth =
|
|
1502
|
+
if (nth > nk0) {
|
|
1503
|
+
nrptg = (nth + nk0 - 1)/nk0;
|
|
1504
|
+
nth = nk0;
|
|
1296
1505
|
|
|
1297
1506
|
if (nrptg*nth > wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
|
|
1298
1507
|
nrptg--;
|
|
@@ -1300,10 +1509,11 @@ int wsp_ggml_metal_op_cpy(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
1300
1509
|
}
|
|
1301
1510
|
}
|
|
1302
1511
|
|
|
1303
|
-
nth = std::min(nth,
|
|
1512
|
+
nth = std::min<int>(nth, nk0);
|
|
1304
1513
|
|
|
1305
1514
|
wsp_ggml_metal_kargs_cpy args = {
|
|
1306
|
-
/*.
|
|
1515
|
+
/*.nk0 =*/ nk0,
|
|
1516
|
+
/*.ne00 =*/ ne00,
|
|
1307
1517
|
/*.ne01 =*/ ne01,
|
|
1308
1518
|
/*.ne02 =*/ ne02,
|
|
1309
1519
|
/*.ne03 =*/ ne03,
|
|
@@ -1321,12 +1531,14 @@ int wsp_ggml_metal_op_cpy(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
1321
1531
|
/*.nb3 =*/ nb3,
|
|
1322
1532
|
};
|
|
1323
1533
|
|
|
1534
|
+
const int nw0 = nrptg == 1 ? (nk0 + nth - 1)/nth : 1;
|
|
1535
|
+
|
|
1324
1536
|
wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
1325
1537
|
wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
1326
1538
|
wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[0]), 1);
|
|
1327
1539
|
wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op), 2);
|
|
1328
1540
|
|
|
1329
|
-
wsp_ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, nrptg, 1);
|
|
1541
|
+
wsp_ggml_metal_encoder_dispatch_threadgroups(enc, nw0*(ne01 + nrptg - 1)/nrptg, ne02, ne03, nth, nrptg, 1);
|
|
1330
1542
|
|
|
1331
1543
|
return 1;
|
|
1332
1544
|
}
|
|
@@ -1340,7 +1552,7 @@ int wsp_ggml_metal_op_pool_2d(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
1340
1552
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
1341
1553
|
WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
1342
1554
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
1343
|
-
WSP_GGML_TENSOR_LOCALS(
|
|
1555
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
1344
1556
|
|
|
1345
1557
|
const int32_t * opts = op->op_params;
|
|
1346
1558
|
wsp_ggml_op_pool op_pool = (wsp_ggml_op_pool) opts[0];
|
|
@@ -1404,7 +1616,7 @@ int wsp_ggml_metal_op_mul_mat(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
1404
1616
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
|
1405
1617
|
WSP_GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
|
1406
1618
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
1407
|
-
WSP_GGML_TENSOR_LOCALS(
|
|
1619
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
1408
1620
|
|
|
1409
1621
|
WSP_GGML_ASSERT(ne00 == ne10);
|
|
1410
1622
|
|
|
@@ -1520,9 +1732,8 @@ int wsp_ggml_metal_op_mul_mat(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
1520
1732
|
!wsp_ggml_is_transposed(op->src[1]) &&
|
|
1521
1733
|
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
|
|
1522
1734
|
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
|
|
1523
|
-
props_dev->has_simdgroup_mm && ne00 >= 64 &&
|
|
1524
|
-
(
|
|
1525
|
-
//printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
|
|
1735
|
+
props_dev->has_simdgroup_mm && ne00 >= 64 && ne11 > ne11_mm_min) {
|
|
1736
|
+
//WSP_GGML_LOG_INFO("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
|
|
1526
1737
|
|
|
1527
1738
|
// some Metal matrix data types require aligned pointers
|
|
1528
1739
|
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
|
|
@@ -1646,7 +1857,7 @@ int wsp_ggml_metal_op_mul_mat_id(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
1646
1857
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
|
|
1647
1858
|
WSP_GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
|
|
1648
1859
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
1649
|
-
WSP_GGML_TENSOR_LOCALS(
|
|
1860
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
1650
1861
|
|
|
1651
1862
|
// src2 = ids
|
|
1652
1863
|
WSP_GGML_ASSERT(op->src[2]->type == WSP_GGML_TYPE_I32);
|
|
@@ -1875,20 +2086,114 @@ bool wsp_ggml_metal_op_flash_attn_ext_use_vec(const wsp_ggml_tensor * op) {
|
|
|
1875
2086
|
return (ne01 < 20) && (ne00 % 32 == 0);
|
|
1876
2087
|
}
|
|
1877
2088
|
|
|
2089
|
+
size_t wsp_ggml_metal_op_flash_attn_ext_extra_pad(const wsp_ggml_tensor * op) {
|
|
2090
|
+
assert(op->op == WSP_GGML_OP_FLASH_ATTN_EXT);
|
|
2091
|
+
|
|
2092
|
+
WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
2093
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
2094
|
+
WSP_GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
|
2095
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
|
2096
|
+
WSP_GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
|
|
2097
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
|
|
2098
|
+
WSP_GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne);
|
|
2099
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb);
|
|
2100
|
+
|
|
2101
|
+
size_t res = 0;
|
|
2102
|
+
|
|
2103
|
+
const bool has_mask = op->src[3] != nullptr;
|
|
2104
|
+
|
|
2105
|
+
if (wsp_ggml_metal_op_flash_attn_ext_use_vec(op)) {
|
|
2106
|
+
// note: always reserve the padding space to avoid graph reallocations
|
|
2107
|
+
//const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_VEC_NCPSG != 0;
|
|
2108
|
+
const bool has_kvpad = true;
|
|
2109
|
+
|
|
2110
|
+
if (has_kvpad) {
|
|
2111
|
+
res += OP_FLASH_ATTN_EXT_VEC_NCPSG*(
|
|
2112
|
+
nb11*ne12*ne13 +
|
|
2113
|
+
nb21*ne22*ne23 +
|
|
2114
|
+
(has_mask ? wsp_ggml_type_size(WSP_GGML_TYPE_F16)*ne31*ne32*ne33 : 0));
|
|
2115
|
+
}
|
|
2116
|
+
} else {
|
|
2117
|
+
//const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_NCPSG != 0;
|
|
2118
|
+
const bool has_kvpad = true;
|
|
2119
|
+
|
|
2120
|
+
if (has_kvpad) {
|
|
2121
|
+
res += OP_FLASH_ATTN_EXT_NCPSG*(
|
|
2122
|
+
nb11*ne12*ne13 +
|
|
2123
|
+
nb21*ne22*ne23 +
|
|
2124
|
+
(has_mask ? wsp_ggml_type_size(WSP_GGML_TYPE_F16)*ne31*ne32*ne33 : 0));
|
|
2125
|
+
}
|
|
2126
|
+
}
|
|
2127
|
+
|
|
2128
|
+
return res;
|
|
2129
|
+
}
|
|
2130
|
+
|
|
2131
|
+
size_t wsp_ggml_metal_op_flash_attn_ext_extra_blk(const wsp_ggml_tensor * op) {
|
|
2132
|
+
assert(op->op == WSP_GGML_OP_FLASH_ATTN_EXT);
|
|
2133
|
+
|
|
2134
|
+
WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
2135
|
+
//WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
2136
|
+
//WSP_GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
|
2137
|
+
//WSP_GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
|
2138
|
+
//WSP_GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
|
|
2139
|
+
//WSP_GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
|
|
2140
|
+
WSP_GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne);
|
|
2141
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb);
|
|
2142
|
+
|
|
2143
|
+
size_t res = 0;
|
|
2144
|
+
|
|
2145
|
+
const bool has_mask = op->src[3] != nullptr;
|
|
2146
|
+
|
|
2147
|
+
if (!has_mask) {
|
|
2148
|
+
return res;
|
|
2149
|
+
}
|
|
2150
|
+
|
|
2151
|
+
const bool is_vec = wsp_ggml_metal_op_flash_attn_ext_use_vec(op);
|
|
2152
|
+
|
|
2153
|
+
// this optimization is not useful for the vector kernels
|
|
2154
|
+
// note: always reserve the blk buffer to avoid graph reallocations
|
|
2155
|
+
//if (is_vec) {
|
|
2156
|
+
// return res;
|
|
2157
|
+
//}
|
|
2158
|
+
|
|
2159
|
+
const int nqptg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NQPTG : OP_FLASH_ATTN_EXT_NQPTG;
|
|
2160
|
+
const int ncpsg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NCPSG : OP_FLASH_ATTN_EXT_NCPSG;
|
|
2161
|
+
|
|
2162
|
+
const int64_t ne1 = (ne01 + nqptg - 1)/nqptg;
|
|
2163
|
+
const int64_t ne0 = (ne30 + ncpsg - 1)/ncpsg;
|
|
2164
|
+
|
|
2165
|
+
res += WSP_GGML_PAD(wsp_ggml_type_size(WSP_GGML_TYPE_I8)*ne0*ne1*ne32*ne33, 32);
|
|
2166
|
+
|
|
2167
|
+
return res;
|
|
2168
|
+
}
|
|
2169
|
+
|
|
1878
2170
|
size_t wsp_ggml_metal_op_flash_attn_ext_extra_tmp(const wsp_ggml_tensor * op) {
|
|
1879
2171
|
assert(op->op == WSP_GGML_OP_FLASH_ATTN_EXT);
|
|
1880
2172
|
|
|
1881
|
-
|
|
2173
|
+
WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
2174
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
2175
|
+
//WSP_GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
|
2176
|
+
//WSP_GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
|
2177
|
+
WSP_GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
|
|
2178
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
|
|
2179
|
+
//WSP_GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne);
|
|
2180
|
+
//WSP_GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb);
|
|
2181
|
+
|
|
2182
|
+
size_t res = 0;
|
|
2183
|
+
|
|
2184
|
+
// note: always reserve the temp buffer to avoid graph reallocations
|
|
2185
|
+
//if (wsp_ggml_metal_op_flash_attn_ext_use_vec(op)) {
|
|
2186
|
+
if (true) {
|
|
2187
|
+
const int64_t nwg = 32;
|
|
2188
|
+
const int64_t ne01_max = std::min(ne01, 32);
|
|
1882
2189
|
|
|
1883
|
-
|
|
1884
|
-
|
|
1885
|
-
|
|
1886
|
-
|
|
2190
|
+
// temp buffer for writing the results from each workgroup
|
|
2191
|
+
// - ne20: the size of the Value head
|
|
2192
|
+
// - + 2: the S and M values for each intermediate result
|
|
2193
|
+
res += wsp_ggml_type_size(WSP_GGML_TYPE_F32)*(ne01_max*ne02*ne03*nwg*(ne20 + 2));
|
|
2194
|
+
}
|
|
1887
2195
|
|
|
1888
|
-
|
|
1889
|
-
// - ne20: the size of the Value head
|
|
1890
|
-
// - + 2: the S and M values for each intermediate result
|
|
1891
|
-
return wsp_ggml_type_size(WSP_GGML_TYPE_F32)*(ne01*ne02*ne03*nwg*(ne20 + 2));
|
|
2196
|
+
return res;
|
|
1892
2197
|
}
|
|
1893
2198
|
|
|
1894
2199
|
int wsp_ggml_metal_op_flash_attn_ext(wsp_ggml_metal_op_t ctx, int idx) {
|
|
@@ -1910,8 +2215,7 @@ int wsp_ggml_metal_op_flash_attn_ext(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
1910
2215
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
1911
2216
|
WSP_GGML_TENSOR_LOCALS( int32_t, nb, op, nb);
|
|
1912
2217
|
|
|
1913
|
-
WSP_GGML_ASSERT(ne00 % 4
|
|
1914
|
-
WSP_GGML_ASSERT(ne11 % 32 == 0);
|
|
2218
|
+
WSP_GGML_ASSERT(ne00 % 4 == 0);
|
|
1915
2219
|
|
|
1916
2220
|
WSP_GGML_ASSERT(op->src[0]->type == WSP_GGML_TYPE_F32);
|
|
1917
2221
|
WSP_GGML_ASSERT(op->src[1]->type == op->src[2]->type);
|
|
@@ -1921,8 +2225,8 @@ int wsp_ggml_metal_op_flash_attn_ext(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
1921
2225
|
WSP_GGML_ASSERT(ne12 == ne22);
|
|
1922
2226
|
|
|
1923
2227
|
WSP_GGML_ASSERT(!op->src[3] || op->src[3]->type == WSP_GGML_TYPE_F16);
|
|
1924
|
-
WSP_GGML_ASSERT(!op->src[3] || op->src[3]->ne[1] >=
|
|
1925
|
-
"the Flash-Attention Metal kernel requires the mask to be
|
|
2228
|
+
WSP_GGML_ASSERT(!op->src[3] || op->src[3]->ne[1] >= op->src[0]->ne[1] &&
|
|
2229
|
+
"the Flash-Attention Metal kernel requires the mask to be at least n_queries big");
|
|
1926
2230
|
|
|
1927
2231
|
float scale;
|
|
1928
2232
|
float max_bias;
|
|
@@ -1949,15 +2253,107 @@ int wsp_ggml_metal_op_flash_attn_ext(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
1949
2253
|
|
|
1950
2254
|
WSP_GGML_ASSERT(ne01 < 65536);
|
|
1951
2255
|
|
|
2256
|
+
wsp_ggml_metal_buffer_id bid_src0 = wsp_ggml_metal_get_buffer_id(op->src[0]);
|
|
2257
|
+
wsp_ggml_metal_buffer_id bid_src1 = wsp_ggml_metal_get_buffer_id(op->src[1]);
|
|
2258
|
+
wsp_ggml_metal_buffer_id bid_src2 = wsp_ggml_metal_get_buffer_id(op->src[2]);
|
|
2259
|
+
wsp_ggml_metal_buffer_id bid_src3 = has_mask ? wsp_ggml_metal_get_buffer_id(op->src[3]) : bid_src0;
|
|
2260
|
+
wsp_ggml_metal_buffer_id bid_src4 = has_sinks ? wsp_ggml_metal_get_buffer_id(op->src[4]) : bid_src0;
|
|
2261
|
+
|
|
2262
|
+
wsp_ggml_metal_buffer_id bid_dst = wsp_ggml_metal_get_buffer_id(op);
|
|
2263
|
+
|
|
2264
|
+
wsp_ggml_metal_buffer_id bid_pad = bid_dst;
|
|
2265
|
+
bid_pad.offs += wsp_ggml_nbytes(op);
|
|
2266
|
+
|
|
2267
|
+
wsp_ggml_metal_buffer_id bid_blk = bid_pad;
|
|
2268
|
+
bid_blk.offs += wsp_ggml_metal_op_flash_attn_ext_extra_pad(op);
|
|
2269
|
+
|
|
2270
|
+
wsp_ggml_metal_buffer_id bid_tmp = bid_blk;
|
|
2271
|
+
bid_tmp.offs += wsp_ggml_metal_op_flash_attn_ext_extra_blk(op);
|
|
2272
|
+
|
|
1952
2273
|
if (!wsp_ggml_metal_op_flash_attn_ext_use_vec(op)) {
|
|
1953
2274
|
// half8x8 kernel
|
|
1954
|
-
const
|
|
1955
|
-
const
|
|
2275
|
+
const int nqptg = OP_FLASH_ATTN_EXT_NQPTG; // queries per threadgroup
|
|
2276
|
+
const int ncpsg = OP_FLASH_ATTN_EXT_NCPSG; // cache values per simdgroup
|
|
1956
2277
|
|
|
1957
2278
|
WSP_GGML_ASSERT(nqptg <= 32);
|
|
1958
2279
|
WSP_GGML_ASSERT(nqptg % 8 == 0);
|
|
1959
2280
|
WSP_GGML_ASSERT(ncpsg % 32 == 0);
|
|
1960
2281
|
|
|
2282
|
+
bool need_sync = false;
|
|
2283
|
+
|
|
2284
|
+
const bool has_kvpad = ne11 % ncpsg != 0;
|
|
2285
|
+
|
|
2286
|
+
if (has_kvpad) {
|
|
2287
|
+
assert(wsp_ggml_metal_op_flash_attn_ext_extra_pad(op) != 0);
|
|
2288
|
+
|
|
2289
|
+
wsp_ggml_metal_kargs_flash_attn_ext_pad args0 = {
|
|
2290
|
+
/*.ne11 =*/ne11,
|
|
2291
|
+
/*.ne_12_2 =*/ne12,
|
|
2292
|
+
/*.ne_12_3 =*/ne13,
|
|
2293
|
+
/*.nb11 =*/nb11,
|
|
2294
|
+
/*.nb12 =*/nb12,
|
|
2295
|
+
/*.nb13 =*/nb13,
|
|
2296
|
+
/*.nb21 =*/nb21,
|
|
2297
|
+
/*.nb22 =*/nb22,
|
|
2298
|
+
/*.nb23 =*/nb23,
|
|
2299
|
+
/*.ne31 =*/ne31,
|
|
2300
|
+
/*.ne32 =*/ne32,
|
|
2301
|
+
/*.ne33 =*/ne33,
|
|
2302
|
+
/*.nb31 =*/nb31,
|
|
2303
|
+
/*.nb32 =*/nb32,
|
|
2304
|
+
/*.nb33 =*/nb33,
|
|
2305
|
+
};
|
|
2306
|
+
|
|
2307
|
+
wsp_ggml_metal_pipeline_t pipeline0 = wsp_ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg);
|
|
2308
|
+
|
|
2309
|
+
wsp_ggml_metal_encoder_set_pipeline(enc, pipeline0);
|
|
2310
|
+
wsp_ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0);
|
|
2311
|
+
wsp_ggml_metal_encoder_set_buffer (enc, bid_src1, 1);
|
|
2312
|
+
wsp_ggml_metal_encoder_set_buffer (enc, bid_src2, 2);
|
|
2313
|
+
wsp_ggml_metal_encoder_set_buffer (enc, bid_src3, 3);
|
|
2314
|
+
wsp_ggml_metal_encoder_set_buffer (enc, bid_pad, 4);
|
|
2315
|
+
|
|
2316
|
+
assert(ne12 == ne22);
|
|
2317
|
+
assert(ne13 == ne23);
|
|
2318
|
+
|
|
2319
|
+
wsp_ggml_metal_encoder_dispatch_threadgroups(enc, ncpsg, std::max(ne12, ne32), std::max(ne13, ne33), 32, 1, 1);
|
|
2320
|
+
|
|
2321
|
+
need_sync = true;
|
|
2322
|
+
}
|
|
2323
|
+
|
|
2324
|
+
if (has_mask) {
|
|
2325
|
+
assert(wsp_ggml_metal_op_flash_attn_ext_extra_blk(op) != 0);
|
|
2326
|
+
|
|
2327
|
+
wsp_ggml_metal_kargs_flash_attn_ext_blk args0 = {
|
|
2328
|
+
/*.ne01 =*/ ne01,
|
|
2329
|
+
/*.ne30 =*/ ne30,
|
|
2330
|
+
/*.ne31 =*/ ne31,
|
|
2331
|
+
/*.ne32 =*/ ne32,
|
|
2332
|
+
/*.ne33 =*/ ne33,
|
|
2333
|
+
/*.nb31 =*/ nb31,
|
|
2334
|
+
/*.nb32 =*/ nb32,
|
|
2335
|
+
/*.nb33 =*/ nb33,
|
|
2336
|
+
};
|
|
2337
|
+
|
|
2338
|
+
wsp_ggml_metal_pipeline_t pipeline0 = wsp_ggml_metal_library_get_pipeline_flash_attn_ext_blk(lib, op, nqptg, ncpsg);
|
|
2339
|
+
|
|
2340
|
+
wsp_ggml_metal_encoder_set_pipeline(enc, pipeline0);
|
|
2341
|
+
wsp_ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0);
|
|
2342
|
+
wsp_ggml_metal_encoder_set_buffer (enc, bid_src3, 1);
|
|
2343
|
+
wsp_ggml_metal_encoder_set_buffer (enc, bid_blk, 2);
|
|
2344
|
+
|
|
2345
|
+
const int32_t nblk1 = ((ne01 + nqptg - 1)/nqptg);
|
|
2346
|
+
const int32_t nblk0 = ((ne30 + ncpsg - 1)/ncpsg);
|
|
2347
|
+
|
|
2348
|
+
wsp_ggml_metal_encoder_dispatch_threadgroups(enc, nblk0, nblk1, ne32*ne33, 32, 1, 1);
|
|
2349
|
+
|
|
2350
|
+
need_sync = true;
|
|
2351
|
+
}
|
|
2352
|
+
|
|
2353
|
+
if (need_sync) {
|
|
2354
|
+
wsp_ggml_metal_op_concurrency_reset(ctx);
|
|
2355
|
+
}
|
|
2356
|
+
|
|
1961
2357
|
const int is_q = wsp_ggml_is_quantized(op->src[1]->type) ? 1 : 0;
|
|
1962
2358
|
|
|
1963
2359
|
// 2*(2*ncpsg)
|
|
@@ -2007,6 +2403,7 @@ int wsp_ggml_metal_op_flash_attn_ext(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
2007
2403
|
/*.nb21 =*/ nb21,
|
|
2008
2404
|
/*.nb22 =*/ nb22,
|
|
2009
2405
|
/*.nb23 =*/ nb23,
|
|
2406
|
+
/*.ne31 =*/ ne31,
|
|
2010
2407
|
/*.ne32 =*/ ne32,
|
|
2011
2408
|
/*.ne33 =*/ ne33,
|
|
2012
2409
|
/*.nb31 =*/ nb31,
|
|
@@ -2023,24 +2420,18 @@ int wsp_ggml_metal_op_flash_attn_ext(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
2023
2420
|
/*.logit_softcap =*/ logit_softcap,
|
|
2024
2421
|
};
|
|
2025
2422
|
|
|
2026
|
-
wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_flash_attn_ext(lib, op, has_mask, has_sinks, has_bias, has_scap, nsg);
|
|
2423
|
+
wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_flash_attn_ext(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg);
|
|
2027
2424
|
|
|
2028
2425
|
wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
2029
2426
|
wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
2030
|
-
wsp_ggml_metal_encoder_set_buffer (enc,
|
|
2031
|
-
wsp_ggml_metal_encoder_set_buffer (enc,
|
|
2032
|
-
wsp_ggml_metal_encoder_set_buffer (enc,
|
|
2033
|
-
|
|
2034
|
-
|
|
2035
|
-
|
|
2036
|
-
|
|
2037
|
-
|
|
2038
|
-
if (op->src[4]) {
|
|
2039
|
-
wsp_ggml_metal_encoder_set_buffer(enc, wsp_ggml_metal_get_buffer_id(op->src[4]), 5);
|
|
2040
|
-
} else {
|
|
2041
|
-
wsp_ggml_metal_encoder_set_buffer(enc, wsp_ggml_metal_get_buffer_id(op->src[0]), 5);
|
|
2042
|
-
}
|
|
2043
|
-
wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op), 6);
|
|
2427
|
+
wsp_ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
|
|
2428
|
+
wsp_ggml_metal_encoder_set_buffer (enc, bid_src1, 2);
|
|
2429
|
+
wsp_ggml_metal_encoder_set_buffer (enc, bid_src2, 3);
|
|
2430
|
+
wsp_ggml_metal_encoder_set_buffer (enc, bid_src3, 4);
|
|
2431
|
+
wsp_ggml_metal_encoder_set_buffer (enc, bid_src4, 5);
|
|
2432
|
+
wsp_ggml_metal_encoder_set_buffer (enc, bid_pad, 6);
|
|
2433
|
+
wsp_ggml_metal_encoder_set_buffer (enc, bid_blk, 7);
|
|
2434
|
+
wsp_ggml_metal_encoder_set_buffer (enc, bid_dst, 8);
|
|
2044
2435
|
|
|
2045
2436
|
wsp_ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
|
2046
2437
|
|
|
@@ -2048,14 +2439,60 @@ int wsp_ggml_metal_op_flash_attn_ext(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
2048
2439
|
#undef FATTN_SMEM
|
|
2049
2440
|
} else {
|
|
2050
2441
|
// half4x4 kernel
|
|
2051
|
-
const
|
|
2052
|
-
const
|
|
2053
|
-
const
|
|
2442
|
+
const int nqptg = OP_FLASH_ATTN_EXT_VEC_NQPTG; // queries per threadgroup
|
|
2443
|
+
const int ncpsg = OP_FLASH_ATTN_EXT_VEC_NCPSG; // cache values per simdgroup !! sync with kernel template arguments !!
|
|
2444
|
+
const int nkpsg = 1*ncpsg;
|
|
2054
2445
|
|
|
2055
2446
|
WSP_GGML_ASSERT(nqptg <= 32);
|
|
2056
2447
|
WSP_GGML_ASSERT(nqptg % 1 == 0);
|
|
2057
2448
|
WSP_GGML_ASSERT(ncpsg % 32 == 0);
|
|
2058
2449
|
|
|
2450
|
+
bool need_sync = false;
|
|
2451
|
+
|
|
2452
|
+
const bool has_kvpad = ne11 % ncpsg != 0;
|
|
2453
|
+
|
|
2454
|
+
if (has_kvpad) {
|
|
2455
|
+
assert(wsp_ggml_metal_op_flash_attn_ext_extra_pad(op) != 0);
|
|
2456
|
+
|
|
2457
|
+
wsp_ggml_metal_kargs_flash_attn_ext_pad args0 = {
|
|
2458
|
+
/*.ne11 =*/ne11,
|
|
2459
|
+
/*.ne_12_2 =*/ne12,
|
|
2460
|
+
/*.ne_12_3 =*/ne13,
|
|
2461
|
+
/*.nb11 =*/nb11,
|
|
2462
|
+
/*.nb12 =*/nb12,
|
|
2463
|
+
/*.nb13 =*/nb13,
|
|
2464
|
+
/*.nb21 =*/nb21,
|
|
2465
|
+
/*.nb22 =*/nb22,
|
|
2466
|
+
/*.nb23 =*/nb23,
|
|
2467
|
+
/*.ne31 =*/ne31,
|
|
2468
|
+
/*.ne32 =*/ne32,
|
|
2469
|
+
/*.ne33 =*/ne33,
|
|
2470
|
+
/*.nb31 =*/nb31,
|
|
2471
|
+
/*.nb32 =*/nb32,
|
|
2472
|
+
/*.nb33 =*/nb33,
|
|
2473
|
+
};
|
|
2474
|
+
|
|
2475
|
+
wsp_ggml_metal_pipeline_t pipeline0 = wsp_ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg);
|
|
2476
|
+
|
|
2477
|
+
wsp_ggml_metal_encoder_set_pipeline(enc, pipeline0);
|
|
2478
|
+
wsp_ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0);
|
|
2479
|
+
wsp_ggml_metal_encoder_set_buffer (enc, bid_src1, 1);
|
|
2480
|
+
wsp_ggml_metal_encoder_set_buffer (enc, bid_src2, 2);
|
|
2481
|
+
wsp_ggml_metal_encoder_set_buffer (enc, bid_src3, 3);
|
|
2482
|
+
wsp_ggml_metal_encoder_set_buffer (enc, bid_pad, 4);
|
|
2483
|
+
|
|
2484
|
+
assert(ne12 == ne22);
|
|
2485
|
+
assert(ne13 == ne23);
|
|
2486
|
+
|
|
2487
|
+
wsp_ggml_metal_encoder_dispatch_threadgroups(enc, ncpsg, std::max(ne12, ne32), std::max(ne13, ne33), 32, 1, 1);
|
|
2488
|
+
|
|
2489
|
+
need_sync = true;
|
|
2490
|
+
}
|
|
2491
|
+
|
|
2492
|
+
if (need_sync) {
|
|
2493
|
+
wsp_ggml_metal_op_concurrency_reset(ctx);
|
|
2494
|
+
}
|
|
2495
|
+
|
|
2059
2496
|
// ne00 + 2*ncpsg*(nsg)
|
|
2060
2497
|
// for each query, we load it as f16 in shared memory (ne00)
|
|
2061
2498
|
// and store the soft_max values and the mask
|
|
@@ -2120,6 +2557,7 @@ int wsp_ggml_metal_op_flash_attn_ext(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
2120
2557
|
/*.nb21 =*/ nb21,
|
|
2121
2558
|
/*.nb22 =*/ nb22,
|
|
2122
2559
|
/*.nb23 =*/ nb23,
|
|
2560
|
+
/*.ne31 =*/ ne31,
|
|
2123
2561
|
/*.ne32 =*/ ne32,
|
|
2124
2562
|
/*.ne33 =*/ ne33,
|
|
2125
2563
|
/*.nb31 =*/ nb31,
|
|
@@ -2136,25 +2574,17 @@ int wsp_ggml_metal_op_flash_attn_ext(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
2136
2574
|
/*.logit_softcap =*/ logit_softcap,
|
|
2137
2575
|
};
|
|
2138
2576
|
|
|
2139
|
-
wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_flash_attn_ext_vec(lib, op, has_mask, has_sinks, has_bias, has_scap, nsg, nwg);
|
|
2577
|
+
wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_flash_attn_ext_vec(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg, nwg);
|
|
2140
2578
|
|
|
2141
2579
|
WSP_GGML_ASSERT(nsg*32 <= wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
|
2142
2580
|
|
|
2143
2581
|
wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
2144
2582
|
wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
2145
|
-
wsp_ggml_metal_encoder_set_buffer (enc,
|
|
2146
|
-
wsp_ggml_metal_encoder_set_buffer (enc,
|
|
2147
|
-
wsp_ggml_metal_encoder_set_buffer (enc,
|
|
2148
|
-
|
|
2149
|
-
|
|
2150
|
-
} else {
|
|
2151
|
-
wsp_ggml_metal_encoder_set_buffer(enc, wsp_ggml_metal_get_buffer_id(op->src[0]), 4);
|
|
2152
|
-
}
|
|
2153
|
-
if (op->src[4]) {
|
|
2154
|
-
wsp_ggml_metal_encoder_set_buffer(enc, wsp_ggml_metal_get_buffer_id(op->src[4]), 5);
|
|
2155
|
-
} else {
|
|
2156
|
-
wsp_ggml_metal_encoder_set_buffer(enc, wsp_ggml_metal_get_buffer_id(op->src[0]), 5);
|
|
2157
|
-
}
|
|
2583
|
+
wsp_ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
|
|
2584
|
+
wsp_ggml_metal_encoder_set_buffer (enc, bid_src1, 2);
|
|
2585
|
+
wsp_ggml_metal_encoder_set_buffer (enc, bid_src2, 3);
|
|
2586
|
+
wsp_ggml_metal_encoder_set_buffer (enc, bid_src3, 4);
|
|
2587
|
+
wsp_ggml_metal_encoder_set_buffer (enc, bid_src4, 5);
|
|
2158
2588
|
|
|
2159
2589
|
const size_t smem = FATTN_SMEM(nsg);
|
|
2160
2590
|
|
|
@@ -2162,23 +2592,25 @@ int wsp_ggml_metal_op_flash_attn_ext(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
2162
2592
|
WSP_GGML_ASSERT(smem <= props_dev->max_theadgroup_memory_size);
|
|
2163
2593
|
|
|
2164
2594
|
if (nwg == 1) {
|
|
2595
|
+
assert(wsp_ggml_metal_op_flash_attn_ext_extra_tmp(op) == 0);
|
|
2596
|
+
|
|
2165
2597
|
// using 1 workgroup -> write the result directly into dst
|
|
2166
|
-
wsp_ggml_metal_encoder_set_buffer(enc,
|
|
2598
|
+
wsp_ggml_metal_encoder_set_buffer(enc, bid_pad, 6);
|
|
2599
|
+
wsp_ggml_metal_encoder_set_buffer(enc, bid_dst, 7);
|
|
2167
2600
|
|
|
2168
2601
|
wsp_ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
|
2169
2602
|
|
|
2170
2603
|
wsp_ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg, 32, nsg, 1);
|
|
2171
2604
|
} else {
|
|
2172
2605
|
// sanity checks
|
|
2606
|
+
assert(wsp_ggml_metal_op_flash_attn_ext_extra_tmp(op) != 0);
|
|
2607
|
+
|
|
2173
2608
|
WSP_GGML_ASSERT(ne01*ne02*ne03 == ne1*ne2*ne3);
|
|
2174
2609
|
WSP_GGML_ASSERT((uint64_t)ne1*ne2*ne3 <= (1u << 31));
|
|
2175
2610
|
|
|
2176
|
-
wsp_ggml_metal_buffer_id bid_dst = wsp_ggml_metal_get_buffer_id(op);
|
|
2177
|
-
|
|
2178
2611
|
// write the results from each workgroup into a temp buffer
|
|
2179
|
-
|
|
2180
|
-
bid_tmp
|
|
2181
|
-
wsp_ggml_metal_encoder_set_buffer(enc, bid_tmp, 6);
|
|
2612
|
+
wsp_ggml_metal_encoder_set_buffer(enc, bid_pad, 6);
|
|
2613
|
+
wsp_ggml_metal_encoder_set_buffer(enc, bid_tmp, 7);
|
|
2182
2614
|
|
|
2183
2615
|
wsp_ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
|
2184
2616
|
wsp_ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg, 32, nsg, 1);
|
|
@@ -2385,7 +2817,7 @@ int wsp_ggml_metal_op_l2_norm(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
2385
2817
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
2386
2818
|
WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
2387
2819
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
2388
|
-
WSP_GGML_TENSOR_LOCALS(
|
|
2820
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
2389
2821
|
|
|
2390
2822
|
float eps;
|
|
2391
2823
|
memcpy(&eps, op->op_params, sizeof(float));
|
|
@@ -2433,7 +2865,7 @@ int wsp_ggml_metal_op_group_norm(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
2433
2865
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
2434
2866
|
WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
2435
2867
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
2436
|
-
WSP_GGML_TENSOR_LOCALS(
|
|
2868
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
2437
2869
|
|
|
2438
2870
|
const int32_t ngrp = ((const int32_t *) op->op_params)[0];
|
|
2439
2871
|
|
|
@@ -2488,7 +2920,7 @@ int wsp_ggml_metal_op_norm(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
2488
2920
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
2489
2921
|
WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
2490
2922
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
2491
|
-
WSP_GGML_TENSOR_LOCALS(
|
|
2923
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
2492
2924
|
|
|
2493
2925
|
float eps;
|
|
2494
2926
|
memcpy(&eps, op->op_params, sizeof(float));
|
|
@@ -2624,7 +3056,7 @@ int wsp_ggml_metal_op_rope(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
2624
3056
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
|
2625
3057
|
WSP_GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
|
2626
3058
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
2627
|
-
WSP_GGML_TENSOR_LOCALS(
|
|
3059
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
2628
3060
|
|
|
2629
3061
|
// make sure we have one or more position id(ne10) per token(ne02)
|
|
2630
3062
|
WSP_GGML_ASSERT(ne10 % ne02 == 0);
|
|
@@ -2688,6 +3120,7 @@ int wsp_ggml_metal_op_rope(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
2688
3120
|
/* sect_1 =*/ sect_1,
|
|
2689
3121
|
/* sect_2 =*/ sect_2,
|
|
2690
3122
|
/* sect_3 =*/ sect_3,
|
|
3123
|
+
/* src2 =*/ op->src[2] != nullptr,
|
|
2691
3124
|
};
|
|
2692
3125
|
|
|
2693
3126
|
wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_rope(lib, op);
|
|
@@ -2717,7 +3150,7 @@ int wsp_ggml_metal_op_im2col(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
2717
3150
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
2718
3151
|
WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
2719
3152
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
2720
|
-
WSP_GGML_TENSOR_LOCALS(
|
|
3153
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
2721
3154
|
|
|
2722
3155
|
const int32_t s0 = ((const int32_t *)(op->op_params))[0];
|
|
2723
3156
|
const int32_t s1 = ((const int32_t *)(op->op_params))[1];
|
|
@@ -2778,6 +3211,84 @@ int wsp_ggml_metal_op_im2col(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
2778
3211
|
return 1;
|
|
2779
3212
|
}
|
|
2780
3213
|
|
|
3214
|
+
int wsp_ggml_metal_op_conv_2d(wsp_ggml_metal_op_t ctx, int idx) {
|
|
3215
|
+
wsp_ggml_tensor * op = ctx->node(idx);
|
|
3216
|
+
|
|
3217
|
+
wsp_ggml_metal_library_t lib = ctx->lib;
|
|
3218
|
+
wsp_ggml_metal_encoder_t enc = ctx->enc;
|
|
3219
|
+
|
|
3220
|
+
WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
3221
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
3222
|
+
WSP_GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
|
3223
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
|
3224
|
+
WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
3225
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
3226
|
+
|
|
3227
|
+
WSP_GGML_ASSERT(wsp_ggml_is_contiguous(op->src[0]));
|
|
3228
|
+
WSP_GGML_ASSERT(op->src[1]->type == WSP_GGML_TYPE_F32);
|
|
3229
|
+
WSP_GGML_ASSERT(op->type == WSP_GGML_TYPE_F32);
|
|
3230
|
+
WSP_GGML_ASSERT(op->src[0]->type == WSP_GGML_TYPE_F16 || op->src[0]->type == WSP_GGML_TYPE_F32);
|
|
3231
|
+
|
|
3232
|
+
const int32_t s0 = ((const int32_t *) op->op_params)[0];
|
|
3233
|
+
const int32_t s1 = ((const int32_t *) op->op_params)[1];
|
|
3234
|
+
const int32_t p0 = ((const int32_t *) op->op_params)[2];
|
|
3235
|
+
const int32_t p1 = ((const int32_t *) op->op_params)[3];
|
|
3236
|
+
const int32_t d0 = ((const int32_t *) op->op_params)[4];
|
|
3237
|
+
const int32_t d1 = ((const int32_t *) op->op_params)[5];
|
|
3238
|
+
|
|
3239
|
+
wsp_ggml_metal_kargs_conv_2d args = {
|
|
3240
|
+
/*.nb00 =*/ nb00,
|
|
3241
|
+
/*.nb01 =*/ nb01,
|
|
3242
|
+
/*.nb02 =*/ nb02,
|
|
3243
|
+
/*.nb03 =*/ nb03,
|
|
3244
|
+
/*.nb10 =*/ nb10,
|
|
3245
|
+
/*.nb11 =*/ nb11,
|
|
3246
|
+
/*.nb12 =*/ nb12,
|
|
3247
|
+
/*.nb13 =*/ nb13,
|
|
3248
|
+
/*.nb0 =*/ nb0,
|
|
3249
|
+
/*.nb1 =*/ nb1,
|
|
3250
|
+
/*.nb2 =*/ nb2,
|
|
3251
|
+
/*.nb3 =*/ nb3,
|
|
3252
|
+
/*.IW =*/ ne10,
|
|
3253
|
+
/*.IH =*/ ne11,
|
|
3254
|
+
/*.KW =*/ ne00,
|
|
3255
|
+
/*.KH =*/ ne01,
|
|
3256
|
+
/*.IC =*/ ne02,
|
|
3257
|
+
/*.OC =*/ ne03,
|
|
3258
|
+
/*.OW =*/ ne0,
|
|
3259
|
+
/*.OH =*/ ne1,
|
|
3260
|
+
/*.N =*/ ne3,
|
|
3261
|
+
/*.s0 =*/ s0,
|
|
3262
|
+
/*.s1 =*/ s1,
|
|
3263
|
+
/*.p0 =*/ p0,
|
|
3264
|
+
/*.p1 =*/ p1,
|
|
3265
|
+
/*.d0 =*/ d0,
|
|
3266
|
+
/*.d1 =*/ d1,
|
|
3267
|
+
};
|
|
3268
|
+
|
|
3269
|
+
wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_conv_2d(lib, op);
|
|
3270
|
+
|
|
3271
|
+
int nth = wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline);
|
|
3272
|
+
nth = std::min(nth, 256);
|
|
3273
|
+
nth = std::max(nth, 1);
|
|
3274
|
+
|
|
3275
|
+
const uint64_t n_out = wsp_ggml_nelements(op);
|
|
3276
|
+
|
|
3277
|
+
uint64_t tg = (n_out + nth - 1)/nth;
|
|
3278
|
+
tg = std::max<uint64_t>(tg, 1);
|
|
3279
|
+
tg = std::min<uint64_t>(tg, (uint64_t) std::numeric_limits<int>::max());
|
|
3280
|
+
|
|
3281
|
+
wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
3282
|
+
wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
3283
|
+
wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[0]), 1);
|
|
3284
|
+
wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[1]), 2);
|
|
3285
|
+
wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op), 3);
|
|
3286
|
+
|
|
3287
|
+
wsp_ggml_metal_encoder_dispatch_threadgroups(enc, tg, 1, 1, nth, 1, 1);
|
|
3288
|
+
|
|
3289
|
+
return 1;
|
|
3290
|
+
}
|
|
3291
|
+
|
|
2781
3292
|
int wsp_ggml_metal_op_conv_transpose_1d(wsp_ggml_metal_op_t ctx, int idx) {
|
|
2782
3293
|
wsp_ggml_tensor * op = ctx->node(idx);
|
|
2783
3294
|
|
|
@@ -2789,7 +3300,7 @@ int wsp_ggml_metal_op_conv_transpose_1d(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
2789
3300
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
|
2790
3301
|
WSP_GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
|
2791
3302
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
2792
|
-
WSP_GGML_TENSOR_LOCALS(
|
|
3303
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
2793
3304
|
|
|
2794
3305
|
const int32_t s0 = ((const int32_t *)(op->op_params))[0];
|
|
2795
3306
|
|
|
@@ -2823,6 +3334,62 @@ int wsp_ggml_metal_op_conv_transpose_1d(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
2823
3334
|
return 1;
|
|
2824
3335
|
}
|
|
2825
3336
|
|
|
3337
|
+
int wsp_ggml_metal_op_conv_transpose_2d(wsp_ggml_metal_op_t ctx, int idx) {
|
|
3338
|
+
wsp_ggml_tensor * op = ctx->node(idx);
|
|
3339
|
+
|
|
3340
|
+
wsp_ggml_metal_library_t lib = ctx->lib;
|
|
3341
|
+
wsp_ggml_metal_encoder_t enc = ctx->enc;
|
|
3342
|
+
|
|
3343
|
+
WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
3344
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
3345
|
+
WSP_GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
|
3346
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
|
3347
|
+
WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
3348
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
3349
|
+
|
|
3350
|
+
const int32_t s0 = ((const int32_t *)(op->op_params))[0];
|
|
3351
|
+
|
|
3352
|
+
const int32_t IC = op->src[1]->ne[2];
|
|
3353
|
+
const int32_t IH = op->src[1]->ne[1];
|
|
3354
|
+
const int32_t IW = op->src[1]->ne[0];
|
|
3355
|
+
|
|
3356
|
+
const int32_t KH = op->src[0]->ne[1];
|
|
3357
|
+
const int32_t KW = op->src[0]->ne[0];
|
|
3358
|
+
|
|
3359
|
+
const int32_t OW = op->ne[0];
|
|
3360
|
+
const int32_t OH = op->ne[1];
|
|
3361
|
+
const int32_t OC = op->ne[2];
|
|
3362
|
+
|
|
3363
|
+
wsp_ggml_metal_kargs_conv_transpose_2d args = {
|
|
3364
|
+
/*.IC =*/ IC,
|
|
3365
|
+
/*.IH =*/ IH,
|
|
3366
|
+
/*.IW =*/ IW,
|
|
3367
|
+
/*.KH =*/ KH,
|
|
3368
|
+
/*.KW =*/ KW,
|
|
3369
|
+
/*.OC =*/ OC,
|
|
3370
|
+
/*.s0 =*/ s0,
|
|
3371
|
+
/*.nb0 =*/ nb0,
|
|
3372
|
+
/*.nb1 =*/ nb1,
|
|
3373
|
+
/*.nb2 =*/ nb2,
|
|
3374
|
+
};
|
|
3375
|
+
|
|
3376
|
+
wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_conv_transpose_2d(lib, op);
|
|
3377
|
+
|
|
3378
|
+
wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
3379
|
+
wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
3380
|
+
wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[0]), 1);
|
|
3381
|
+
wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[1]), 2);
|
|
3382
|
+
wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op), 3);
|
|
3383
|
+
|
|
3384
|
+
// Metal requires buffer size to be multiple of 16 bytes
|
|
3385
|
+
const size_t smem = WSP_GGML_PAD(KW * KH * sizeof(float), 16);
|
|
3386
|
+
wsp_ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
|
3387
|
+
|
|
3388
|
+
wsp_ggml_metal_encoder_dispatch_threadgroups(enc, OW, OH, OC, KW, KH, 1);
|
|
3389
|
+
|
|
3390
|
+
return 1;
|
|
3391
|
+
}
|
|
3392
|
+
|
|
2826
3393
|
int wsp_ggml_metal_op_upscale(wsp_ggml_metal_op_t ctx, int idx) {
|
|
2827
3394
|
wsp_ggml_tensor * op = ctx->node(idx);
|
|
2828
3395
|
|
|
@@ -2832,7 +3399,7 @@ int wsp_ggml_metal_op_upscale(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
2832
3399
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
2833
3400
|
WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
2834
3401
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
2835
|
-
WSP_GGML_TENSOR_LOCALS(
|
|
3402
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
2836
3403
|
|
|
2837
3404
|
const float sf0 = (float)ne0/op->src[0]->ne[0];
|
|
2838
3405
|
const float sf1 = (float)ne1/op->src[0]->ne[1];
|
|
@@ -2885,7 +3452,7 @@ int wsp_ggml_metal_op_pad(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
2885
3452
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
2886
3453
|
WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
2887
3454
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
2888
|
-
WSP_GGML_TENSOR_LOCALS(
|
|
3455
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
2889
3456
|
|
|
2890
3457
|
wsp_ggml_metal_kargs_pad args = {
|
|
2891
3458
|
/*.ne00 =*/ ne00,
|
|
@@ -2929,7 +3496,7 @@ int wsp_ggml_metal_op_pad_reflect_1d(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
2929
3496
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
2930
3497
|
WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
2931
3498
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
2932
|
-
WSP_GGML_TENSOR_LOCALS(
|
|
3499
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
2933
3500
|
|
|
2934
3501
|
wsp_ggml_metal_kargs_pad_reflect_1d args = {
|
|
2935
3502
|
/*.ne00 =*/ ne00,
|
|
@@ -2973,7 +3540,7 @@ int wsp_ggml_metal_op_arange(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
2973
3540
|
wsp_ggml_metal_encoder_t enc = ctx->enc;
|
|
2974
3541
|
|
|
2975
3542
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
2976
|
-
WSP_GGML_TENSOR_LOCALS(
|
|
3543
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
2977
3544
|
|
|
2978
3545
|
float start;
|
|
2979
3546
|
float step;
|
|
@@ -2991,12 +3558,6 @@ int wsp_ggml_metal_op_arange(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
2991
3558
|
|
|
2992
3559
|
wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_arange(lib, op);
|
|
2993
3560
|
|
|
2994
|
-
//[encoder setComputePipelineState:pipeline];
|
|
2995
|
-
//[encoder setBuffer:id_dst offset:offs_dst atIndex:0];
|
|
2996
|
-
//[encoder setBytes:&args length:sizeof(args) atIndex:1];
|
|
2997
|
-
|
|
2998
|
-
//[encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
2999
|
-
|
|
3000
3561
|
wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
3001
3562
|
wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
3002
3563
|
wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op), 1);
|
|
@@ -3015,7 +3576,7 @@ int wsp_ggml_metal_op_timestep_embedding(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
3015
3576
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
3016
3577
|
WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
3017
3578
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
3018
|
-
WSP_GGML_TENSOR_LOCALS(
|
|
3579
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
3019
3580
|
|
|
3020
3581
|
const int dim = op->op_params[0];
|
|
3021
3582
|
const int max_period = op->op_params[1];
|
|
@@ -3049,7 +3610,7 @@ int wsp_ggml_metal_op_argmax(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
3049
3610
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
3050
3611
|
WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
3051
3612
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
3052
|
-
WSP_GGML_TENSOR_LOCALS(
|
|
3613
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
3053
3614
|
|
|
3054
3615
|
wsp_ggml_metal_kargs_argmax args = {
|
|
3055
3616
|
/*.ne00 = */ ne00,
|
|
@@ -3085,38 +3646,93 @@ int wsp_ggml_metal_op_argsort(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
3085
3646
|
wsp_ggml_metal_library_t lib = ctx->lib;
|
|
3086
3647
|
wsp_ggml_metal_encoder_t enc = ctx->enc;
|
|
3087
3648
|
|
|
3649
|
+
WSP_GGML_ASSERT(wsp_ggml_is_contiguous_rows(op->src[0]));
|
|
3650
|
+
|
|
3088
3651
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
3089
3652
|
WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
3090
3653
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
3091
|
-
WSP_GGML_TENSOR_LOCALS(
|
|
3654
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
3655
|
+
|
|
3656
|
+
wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_argsort(lib, op);
|
|
3092
3657
|
|
|
3093
3658
|
// bitonic sort requires the number of elements to be power of 2
|
|
3094
|
-
|
|
3095
|
-
while (
|
|
3096
|
-
|
|
3659
|
+
int nth = 1;
|
|
3660
|
+
while (nth < ne00 && 2*nth <= wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
|
|
3661
|
+
nth *= 2;
|
|
3097
3662
|
}
|
|
3098
3663
|
|
|
3099
|
-
|
|
3100
|
-
|
|
3101
|
-
const int64_t nrows = wsp_ggml_nrows(op->src[0]);
|
|
3664
|
+
const int npr = (ne00 + nth - 1)/nth;
|
|
3102
3665
|
|
|
3103
3666
|
// Metal kernels require the buffer size to be multiple of 16 bytes
|
|
3104
3667
|
// https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/1443142-setthreadgroupmemorylength
|
|
3105
|
-
const size_t smem = WSP_GGML_PAD(
|
|
3668
|
+
const size_t smem = WSP_GGML_PAD(nth*sizeof(int32_t), 16);
|
|
3669
|
+
|
|
3670
|
+
wsp_ggml_metal_buffer_id bid_src0 = wsp_ggml_metal_get_buffer_id(op->src[0]);
|
|
3671
|
+
wsp_ggml_metal_buffer_id bid_dst = wsp_ggml_metal_get_buffer_id(op);
|
|
3672
|
+
|
|
3673
|
+
wsp_ggml_metal_buffer_id bid_tmp = bid_dst;
|
|
3674
|
+
bid_tmp.offs += wsp_ggml_nbytes(op);
|
|
3675
|
+
|
|
3676
|
+
if ((int) ceil(std::log(npr) / std::log(2)) % 2 == 1) {
|
|
3677
|
+
std::swap(bid_dst, bid_tmp);
|
|
3678
|
+
}
|
|
3106
3679
|
|
|
3107
3680
|
wsp_ggml_metal_kargs_argsort args = {
|
|
3108
|
-
/*.
|
|
3109
|
-
/*.
|
|
3681
|
+
/*.ne00 =*/ ne00,
|
|
3682
|
+
/*.ne01 =*/ ne01,
|
|
3683
|
+
/*.ne02 =*/ ne02,
|
|
3684
|
+
/*.ne03 =*/ ne03,
|
|
3685
|
+
/*.nb00 =*/ nb00,
|
|
3686
|
+
/*.nb01 =*/ nb01,
|
|
3687
|
+
/*.nb02 =*/ nb02,
|
|
3688
|
+
/*.nb03 =*/ nb03,
|
|
3110
3689
|
};
|
|
3111
3690
|
|
|
3112
3691
|
wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
3113
3692
|
wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
3114
|
-
wsp_ggml_metal_encoder_set_buffer (enc,
|
|
3115
|
-
wsp_ggml_metal_encoder_set_buffer (enc,
|
|
3693
|
+
wsp_ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
|
|
3694
|
+
wsp_ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
|
|
3116
3695
|
|
|
3117
3696
|
wsp_ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
|
3118
3697
|
|
|
3119
|
-
wsp_ggml_metal_encoder_dispatch_threadgroups(enc,
|
|
3698
|
+
wsp_ggml_metal_encoder_dispatch_threadgroups(enc, npr*ne01, ne02, ne03, nth, 1, 1);
|
|
3699
|
+
|
|
3700
|
+
wsp_ggml_metal_pipeline_t pipeline_merge = wsp_ggml_metal_library_get_pipeline_argsort_merge(lib, op);
|
|
3701
|
+
|
|
3702
|
+
int len = nth;
|
|
3703
|
+
|
|
3704
|
+
while (len < ne00) {
|
|
3705
|
+
wsp_ggml_metal_op_concurrency_reset(ctx);
|
|
3706
|
+
|
|
3707
|
+
wsp_ggml_metal_kargs_argsort_merge args_merge = {
|
|
3708
|
+
.ne00 = ne00,
|
|
3709
|
+
.ne01 = ne01,
|
|
3710
|
+
.ne02 = ne02,
|
|
3711
|
+
.ne03 = ne03,
|
|
3712
|
+
.nb00 = nb00,
|
|
3713
|
+
.nb01 = nb01,
|
|
3714
|
+
.nb02 = nb02,
|
|
3715
|
+
.nb03 = nb03,
|
|
3716
|
+
.len = len,
|
|
3717
|
+
};
|
|
3718
|
+
|
|
3719
|
+
// merges per row
|
|
3720
|
+
const int nm = (ne00 + 2*len - 1) / (2*len);
|
|
3721
|
+
|
|
3722
|
+
const int nth = std::min(512, wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline_merge));
|
|
3723
|
+
|
|
3724
|
+
wsp_ggml_metal_encoder_set_pipeline(enc, pipeline_merge);
|
|
3725
|
+
wsp_ggml_metal_encoder_set_bytes (enc, &args_merge, sizeof(args_merge), 0);
|
|
3726
|
+
wsp_ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
|
|
3727
|
+
wsp_ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
|
|
3728
|
+
wsp_ggml_metal_encoder_set_buffer (enc, bid_tmp, 3);
|
|
3729
|
+
|
|
3730
|
+
wsp_ggml_metal_encoder_dispatch_threadgroups(enc, nm*ne01, ne02, ne03, nth, 1, 1);
|
|
3731
|
+
|
|
3732
|
+
std::swap(bid_dst, bid_tmp);
|
|
3733
|
+
|
|
3734
|
+
len <<= 1;
|
|
3735
|
+
}
|
|
3120
3736
|
|
|
3121
3737
|
return 1;
|
|
3122
3738
|
}
|
|
@@ -3130,7 +3746,7 @@ int wsp_ggml_metal_op_leaky_relu(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
3130
3746
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
3131
3747
|
WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
3132
3748
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
3133
|
-
WSP_GGML_TENSOR_LOCALS(
|
|
3749
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
3134
3750
|
|
|
3135
3751
|
float slope;
|
|
3136
3752
|
memcpy(&slope, op->op_params, sizeof(float));
|
|
@@ -3156,3 +3772,73 @@ int wsp_ggml_metal_op_leaky_relu(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
3156
3772
|
|
|
3157
3773
|
return 1;
|
|
3158
3774
|
}
|
|
3775
|
+
|
|
3776
|
+
int wsp_ggml_metal_op_opt_step_adamw(wsp_ggml_metal_op_t ctx, int idx) {
|
|
3777
|
+
wsp_ggml_tensor * op = ctx->node(idx);
|
|
3778
|
+
|
|
3779
|
+
wsp_ggml_metal_library_t lib = ctx->lib;
|
|
3780
|
+
wsp_ggml_metal_encoder_t enc = ctx->enc;
|
|
3781
|
+
|
|
3782
|
+
WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
3783
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
3784
|
+
WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
3785
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
3786
|
+
|
|
3787
|
+
wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_opt_step_adamw(lib, op);
|
|
3788
|
+
|
|
3789
|
+
const int64_t np = wsp_ggml_nelements(op->src[0]);
|
|
3790
|
+
wsp_ggml_metal_kargs_opt_step_adamw args = {
|
|
3791
|
+
/*.np =*/ np,
|
|
3792
|
+
};
|
|
3793
|
+
|
|
3794
|
+
int ida = 0;
|
|
3795
|
+
|
|
3796
|
+
wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
3797
|
+
wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), ida++);
|
|
3798
|
+
wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[0]), ida++);
|
|
3799
|
+
wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[1]), ida++);
|
|
3800
|
+
wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[2]), ida++);
|
|
3801
|
+
wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[3]), ida++);
|
|
3802
|
+
wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[4]), ida++);
|
|
3803
|
+
|
|
3804
|
+
const int nth = std::min(wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0);
|
|
3805
|
+
const int64_t n = (np + nth - 1) / nth;
|
|
3806
|
+
|
|
3807
|
+
wsp_ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, nth, 1, 1);
|
|
3808
|
+
|
|
3809
|
+
return 1;
|
|
3810
|
+
}
|
|
3811
|
+
|
|
3812
|
+
int wsp_ggml_metal_op_opt_step_sgd(wsp_ggml_metal_op_t ctx, int idx) {
|
|
3813
|
+
wsp_ggml_tensor * op = ctx->node(idx);
|
|
3814
|
+
|
|
3815
|
+
wsp_ggml_metal_library_t lib = ctx->lib;
|
|
3816
|
+
wsp_ggml_metal_encoder_t enc = ctx->enc;
|
|
3817
|
+
|
|
3818
|
+
WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
3819
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
3820
|
+
WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
3821
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
3822
|
+
|
|
3823
|
+
wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_opt_step_sgd(lib, op);
|
|
3824
|
+
|
|
3825
|
+
const int64_t np = wsp_ggml_nelements(op->src[0]);
|
|
3826
|
+
wsp_ggml_metal_kargs_opt_step_sgd args = {
|
|
3827
|
+
/*.np =*/ np,
|
|
3828
|
+
};
|
|
3829
|
+
|
|
3830
|
+
int ida = 0;
|
|
3831
|
+
|
|
3832
|
+
wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
3833
|
+
wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), ida++);
|
|
3834
|
+
wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[0]), ida++);
|
|
3835
|
+
wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[1]), ida++);
|
|
3836
|
+
wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[2]), ida++);
|
|
3837
|
+
|
|
3838
|
+
const int nth = std::min(wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0);
|
|
3839
|
+
const int64_t n = (np + nth - 1) / nth;
|
|
3840
|
+
|
|
3841
|
+
wsp_ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, nth, 1, 1);
|
|
3842
|
+
|
|
3843
|
+
return 1;
|
|
3844
|
+
}
|