whisper.rn 0.5.4 → 0.5.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/src/main/java/com/rnwhisper/WhisperContext.java +5 -0
- package/android/src/main/jni.cpp +13 -0
- package/cpp/ggml-alloc.c +78 -26
- package/cpp/ggml-alloc.h +9 -0
- package/cpp/ggml-backend-impl.h +1 -1
- package/cpp/ggml-backend-reg.cpp +19 -3
- package/cpp/ggml-backend.cpp +72 -20
- package/cpp/ggml-backend.h +2 -1
- package/cpp/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
- package/cpp/ggml-cpu/arch/arm/repack.cpp +1004 -0
- package/cpp/ggml-cpu/arch/x86/repack.cpp +6 -6
- package/cpp/ggml-cpu/arch-fallback.h +50 -2
- package/cpp/ggml-cpu/ggml-cpu-impl.h +1 -1
- package/cpp/ggml-cpu/ggml-cpu.c +139 -58
- package/cpp/ggml-cpu/ggml-cpu.cpp +4 -0
- package/cpp/ggml-cpu/ops.cpp +170 -18
- package/cpp/ggml-cpu/ops.h +1 -0
- package/cpp/ggml-cpu/repack.cpp +531 -5
- package/cpp/ggml-cpu/repack.h +14 -0
- package/cpp/ggml-cpu/simd-mappings.h +16 -18
- package/cpp/ggml-cpu/vec.cpp +41 -1
- package/cpp/ggml-cpu/vec.h +241 -138
- package/cpp/ggml-cpu.h +1 -0
- package/cpp/ggml-impl.h +0 -4
- package/cpp/ggml-metal/ggml-metal-context.m +26 -16
- package/cpp/ggml-metal/ggml-metal-device.cpp +452 -371
- package/cpp/ggml-metal/ggml-metal-device.h +87 -65
- package/cpp/ggml-metal/ggml-metal-device.m +263 -104
- package/cpp/ggml-metal/ggml-metal-impl.h +58 -4
- package/cpp/ggml-metal/ggml-metal-ops.cpp +415 -98
- package/cpp/ggml-metal/ggml-metal-ops.h +4 -0
- package/cpp/ggml-metal/ggml-metal.cpp +6 -5
- package/cpp/ggml-metal/ggml-metal.metal +404 -34
- package/cpp/ggml.c +110 -31
- package/cpp/ggml.h +51 -12
- package/cpp/jsi/RNWhisperJSI.cpp +1 -0
- package/cpp/whisper.cpp +16 -3
- package/ios/CMakeLists.txt +21 -1
- package/ios/RNWhisperContext.mm +5 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-alloc.h +9 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +1 -1
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend.h +2 -1
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-cpu.h +1 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-impl.h +0 -4
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml.h +51 -12
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Info.plist +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/ggml-metal.metal +404 -34
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/rnwhisper +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-alloc.h +9 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +1 -1
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +2 -1
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpu.h +1 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +0 -4
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +51 -12
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +1 -1
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/ggml-metal.metal +404 -34
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-alloc.h +9 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +1 -1
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend.h +2 -1
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-cpu.h +1 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-impl.h +0 -4
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml.h +51 -12
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Info.plist +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/ggml-metal.metal +404 -34
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/rnwhisper +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-alloc.h +9 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +1 -1
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +2 -1
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpu.h +1 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +0 -4
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +51 -12
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +1 -1
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/ggml-metal.metal +404 -34
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
- package/lib/commonjs/NativeRNWhisper.js.map +1 -1
- package/lib/commonjs/jest-mock.js +2 -0
- package/lib/commonjs/jest-mock.js.map +1 -1
- package/lib/commonjs/version.json +1 -1
- package/lib/module/NativeRNWhisper.js.map +1 -1
- package/lib/module/jest-mock.js +2 -0
- package/lib/module/jest-mock.js.map +1 -1
- package/lib/module/version.json +1 -1
- package/lib/typescript/NativeRNWhisper.d.ts +1 -0
- package/lib/typescript/NativeRNWhisper.d.ts.map +1 -1
- package/package.json +1 -1
- package/src/NativeRNWhisper.ts +1 -0
- package/src/jest-mock.ts +2 -0
- package/src/version.json +1 -1
|
@@ -221,7 +221,7 @@ static int wsp_ggml_metal_op_encode_impl(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
221
221
|
}
|
|
222
222
|
|
|
223
223
|
if (ctx->debug_graph > 0) {
|
|
224
|
-
WSP_GGML_LOG_DEBUG("%s: node[%5d] - %-12s %s\n", __func__, idx, wsp_ggml_op_name(node->op), is_concurrent ? "(concurrent)" : "");
|
|
224
|
+
WSP_GGML_LOG_DEBUG("%s: node[%5d] - %-12s %-12s %s\n", __func__, idx, wsp_ggml_op_name(node->op), wsp_ggml_get_name(node), is_concurrent ? "(concurrent)" : "");
|
|
225
225
|
}
|
|
226
226
|
if (ctx->debug_graph > 1) {
|
|
227
227
|
WSP_GGML_TENSOR_LOCALS( int64_t, ne0, node->src[0], ne);
|
|
@@ -286,6 +286,10 @@ static int wsp_ggml_metal_op_encode_impl(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
286
286
|
{
|
|
287
287
|
n_fuse = wsp_ggml_metal_op_scale(ctx, idx);
|
|
288
288
|
} break;
|
|
289
|
+
case WSP_GGML_OP_FILL:
|
|
290
|
+
{
|
|
291
|
+
n_fuse = wsp_ggml_metal_op_fill(ctx, idx);
|
|
292
|
+
} break;
|
|
289
293
|
case WSP_GGML_OP_CLAMP:
|
|
290
294
|
{
|
|
291
295
|
n_fuse = wsp_ggml_metal_op_clamp(ctx, idx);
|
|
@@ -406,10 +410,18 @@ static int wsp_ggml_metal_op_encode_impl(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
406
410
|
{
|
|
407
411
|
n_fuse = wsp_ggml_metal_op_argsort(ctx, idx);
|
|
408
412
|
} break;
|
|
413
|
+
case WSP_GGML_OP_TOP_K:
|
|
414
|
+
{
|
|
415
|
+
n_fuse = wsp_ggml_metal_op_top_k(ctx, idx);
|
|
416
|
+
} break;
|
|
409
417
|
case WSP_GGML_OP_LEAKY_RELU:
|
|
410
418
|
{
|
|
411
419
|
n_fuse = wsp_ggml_metal_op_leaky_relu(ctx, idx);
|
|
412
420
|
} break;
|
|
421
|
+
case WSP_GGML_OP_TRI:
|
|
422
|
+
{
|
|
423
|
+
n_fuse = wsp_ggml_metal_op_tri(ctx, idx);
|
|
424
|
+
} break;
|
|
413
425
|
case WSP_GGML_OP_FLASH_ATTN_EXT:
|
|
414
426
|
{
|
|
415
427
|
n_fuse = wsp_ggml_metal_op_flash_attn_ext(ctx, idx);
|
|
@@ -436,7 +448,11 @@ static int wsp_ggml_metal_op_encode_impl(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
436
448
|
{
|
|
437
449
|
n_fuse = wsp_ggml_metal_op_opt_step_sgd(ctx, idx);
|
|
438
450
|
} break;
|
|
439
|
-
|
|
451
|
+
case WSP_GGML_OP_COUNT_EQUAL:
|
|
452
|
+
{
|
|
453
|
+
n_fuse = wsp_ggml_metal_op_count_equal(ctx, idx);
|
|
454
|
+
} break;
|
|
455
|
+
default:
|
|
440
456
|
{
|
|
441
457
|
WSP_GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, wsp_ggml_op_name(node->op));
|
|
442
458
|
WSP_GGML_ABORT("fatal error");
|
|
@@ -520,7 +536,7 @@ int wsp_ggml_metal_op_concat(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
520
536
|
/*.dim =*/ dim,
|
|
521
537
|
};
|
|
522
538
|
|
|
523
|
-
|
|
539
|
+
auto pipeline = wsp_ggml_metal_library_get_pipeline_base(lib, WSP_GGML_OP_CONCAT);
|
|
524
540
|
|
|
525
541
|
wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
526
542
|
wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
@@ -546,7 +562,7 @@ int wsp_ggml_metal_op_repeat(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
546
562
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
547
563
|
WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
548
564
|
|
|
549
|
-
|
|
565
|
+
auto pipeline = wsp_ggml_metal_library_get_pipeline_repeat(lib, op->type);
|
|
550
566
|
|
|
551
567
|
wsp_ggml_metal_kargs_repeat args = {
|
|
552
568
|
/*.ne00 =*/ ne00,
|
|
@@ -612,7 +628,7 @@ int wsp_ggml_metal_op_acc(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
612
628
|
// TODO: make a simpler cpy_bytes kernel
|
|
613
629
|
|
|
614
630
|
//const id<MTLComputePipelineState> pipeline = ctx->pipelines[WSP_GGML_METAL_PIPELINE_TYPE_CPY_F32_F32].obj;
|
|
615
|
-
|
|
631
|
+
auto pipeline = wsp_ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type);
|
|
616
632
|
|
|
617
633
|
wsp_ggml_metal_kargs_cpy args = {
|
|
618
634
|
/*.nk0 =*/ ne00,
|
|
@@ -675,7 +691,7 @@ int wsp_ggml_metal_op_acc(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
675
691
|
/*.o1 =*/ { 0 },
|
|
676
692
|
};
|
|
677
693
|
|
|
678
|
-
|
|
694
|
+
auto pipeline = wsp_ggml_metal_library_get_pipeline_bin(lib, WSP_GGML_OP_ADD, 1, false);
|
|
679
695
|
|
|
680
696
|
wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
681
697
|
wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
@@ -717,7 +733,42 @@ int wsp_ggml_metal_op_scale(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
717
733
|
n /= 4;
|
|
718
734
|
}
|
|
719
735
|
|
|
720
|
-
|
|
736
|
+
auto pipeline = wsp_ggml_metal_library_get_pipeline_unary(lib, op);
|
|
737
|
+
|
|
738
|
+
wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
739
|
+
wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
740
|
+
wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[0]), 1);
|
|
741
|
+
wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op), 2);
|
|
742
|
+
|
|
743
|
+
wsp_ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
|
|
744
|
+
|
|
745
|
+
return 1;
|
|
746
|
+
}
|
|
747
|
+
|
|
748
|
+
int wsp_ggml_metal_op_fill(wsp_ggml_metal_op_t ctx, int idx) {
|
|
749
|
+
wsp_ggml_tensor * op = ctx->node(idx);
|
|
750
|
+
|
|
751
|
+
wsp_ggml_metal_library_t lib = ctx->lib;
|
|
752
|
+
wsp_ggml_metal_encoder_t enc = ctx->enc;
|
|
753
|
+
|
|
754
|
+
WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
755
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
756
|
+
WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
757
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
758
|
+
|
|
759
|
+
const float val = wsp_ggml_get_op_params_f32(op, 0);
|
|
760
|
+
|
|
761
|
+
wsp_ggml_metal_kargs_fill args = {
|
|
762
|
+
/*.val =*/ val
|
|
763
|
+
};
|
|
764
|
+
|
|
765
|
+
int64_t n = wsp_ggml_nelements(op);
|
|
766
|
+
|
|
767
|
+
if (n % 4 == 0) {
|
|
768
|
+
n /= 4;
|
|
769
|
+
}
|
|
770
|
+
|
|
771
|
+
auto pipeline = wsp_ggml_metal_library_get_pipeline_unary(lib, op);
|
|
721
772
|
|
|
722
773
|
wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
723
774
|
wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
@@ -756,7 +807,7 @@ int wsp_ggml_metal_op_clamp(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
756
807
|
n /= 4;
|
|
757
808
|
}
|
|
758
809
|
|
|
759
|
-
|
|
810
|
+
auto pipeline = wsp_ggml_metal_library_get_pipeline_unary(lib, op);
|
|
760
811
|
|
|
761
812
|
wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
762
813
|
wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
@@ -785,7 +836,7 @@ int wsp_ggml_metal_op_unary(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
785
836
|
n /= 4;
|
|
786
837
|
}
|
|
787
838
|
|
|
788
|
-
|
|
839
|
+
auto pipeline = wsp_ggml_metal_library_get_pipeline_unary(lib, op);
|
|
789
840
|
|
|
790
841
|
wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
791
842
|
wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[0]), 0);
|
|
@@ -813,7 +864,7 @@ int wsp_ggml_metal_op_glu(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
813
864
|
WSP_GGML_ASSERT(wsp_ggml_are_same_shape(op->src[0], op->src[1]));
|
|
814
865
|
}
|
|
815
866
|
|
|
816
|
-
|
|
867
|
+
auto pipeline = wsp_ggml_metal_library_get_pipeline_glu(lib, op);
|
|
817
868
|
|
|
818
869
|
const int32_t swp = wsp_ggml_get_op_params_i32(op, 1);
|
|
819
870
|
const float alpha = wsp_ggml_get_op_params_f32(op, 2);
|
|
@@ -866,7 +917,7 @@ int wsp_ggml_metal_op_sum(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
866
917
|
/*.np =*/ n,
|
|
867
918
|
};
|
|
868
919
|
|
|
869
|
-
|
|
920
|
+
auto pipeline = wsp_ggml_metal_library_get_pipeline_sum(lib, op);
|
|
870
921
|
|
|
871
922
|
int nth = 32; // SIMD width
|
|
872
923
|
|
|
@@ -921,7 +972,7 @@ int wsp_ggml_metal_op_sum_rows(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
921
972
|
/*.nb3 =*/ nb3,
|
|
922
973
|
};
|
|
923
974
|
|
|
924
|
-
|
|
975
|
+
auto pipeline = wsp_ggml_metal_library_get_pipeline_sum_rows(lib, op);
|
|
925
976
|
|
|
926
977
|
int nth = 32; // SIMD width
|
|
927
978
|
|
|
@@ -932,7 +983,7 @@ int wsp_ggml_metal_op_sum_rows(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
932
983
|
nth = std::min(nth, wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
|
933
984
|
nth = std::min(nth, ne00);
|
|
934
985
|
|
|
935
|
-
const size_t smem =
|
|
986
|
+
const size_t smem = pipeline.smem;
|
|
936
987
|
|
|
937
988
|
wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
938
989
|
wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
@@ -959,7 +1010,7 @@ int wsp_ggml_metal_op_cumsum(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
959
1010
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
960
1011
|
WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
961
1012
|
|
|
962
|
-
|
|
1013
|
+
auto pipeline_blk = wsp_ggml_metal_library_get_pipeline_cumsum_blk(lib, op);
|
|
963
1014
|
|
|
964
1015
|
int nth = 1;
|
|
965
1016
|
while (nth < ne00 && 2*nth <= wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline_blk)) {
|
|
@@ -1056,7 +1107,7 @@ int wsp_ggml_metal_op_cumsum(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
1056
1107
|
wsp_ggml_metal_op_concurrency_reset(ctx);
|
|
1057
1108
|
|
|
1058
1109
|
{
|
|
1059
|
-
|
|
1110
|
+
auto pipeline_add = wsp_ggml_metal_library_get_pipeline_cumsum_add(lib, op);
|
|
1060
1111
|
|
|
1061
1112
|
wsp_ggml_metal_kargs_cumsum_add args = {
|
|
1062
1113
|
/*.ne00 =*/ ne00,
|
|
@@ -1102,7 +1153,7 @@ int wsp_ggml_metal_op_get_rows(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
1102
1153
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
1103
1154
|
WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
1104
1155
|
|
|
1105
|
-
|
|
1156
|
+
auto pipeline = wsp_ggml_metal_library_get_pipeline_get_rows(lib, op->src[0]->type);
|
|
1106
1157
|
|
|
1107
1158
|
wsp_ggml_metal_kargs_get_rows args = {
|
|
1108
1159
|
/*.ne00t =*/ wsp_ggml_is_quantized(op->src[0]->type) ? ne00/16 : ne00,
|
|
@@ -1147,7 +1198,7 @@ int wsp_ggml_metal_op_set_rows(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
1147
1198
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
1148
1199
|
WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
1149
1200
|
|
|
1150
|
-
|
|
1201
|
+
auto pipeline = wsp_ggml_metal_library_get_pipeline_set_rows(lib, op->src[1]->type, op->type);
|
|
1151
1202
|
|
|
1152
1203
|
const int32_t nk0 = ne0/wsp_ggml_blck_size(op->type);
|
|
1153
1204
|
|
|
@@ -1248,7 +1299,7 @@ int wsp_ggml_metal_op_soft_max(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
1248
1299
|
/*.n_head_log2 =*/ n_head_log2,
|
|
1249
1300
|
};
|
|
1250
1301
|
|
|
1251
|
-
|
|
1302
|
+
auto pipeline = wsp_ggml_metal_library_get_pipeline_soft_max(lib, op);
|
|
1252
1303
|
|
|
1253
1304
|
int nth = 32; // SIMD width
|
|
1254
1305
|
|
|
@@ -1262,7 +1313,7 @@ int wsp_ggml_metal_op_soft_max(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
1262
1313
|
}
|
|
1263
1314
|
}
|
|
1264
1315
|
|
|
1265
|
-
const size_t smem =
|
|
1316
|
+
const size_t smem = pipeline.smem;
|
|
1266
1317
|
|
|
1267
1318
|
wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
1268
1319
|
wsp_ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
|
|
@@ -1318,15 +1369,43 @@ int wsp_ggml_metal_op_ssm_conv(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
1318
1369
|
/*.nb2 =*/ nb2,
|
|
1319
1370
|
};
|
|
1320
1371
|
|
|
1321
|
-
|
|
1372
|
+
// Use batched kernel for prefill (ne1 > 1) to reduce threadgroup dispatch overhead
|
|
1373
|
+
const bool use_batched = (ne1 > 1);
|
|
1322
1374
|
|
|
1323
|
-
|
|
1324
|
-
|
|
1325
|
-
|
|
1326
|
-
|
|
1327
|
-
|
|
1375
|
+
if (use_batched) {
|
|
1376
|
+
// Determine the smallest power of 2 that's >= ne1, but <= 256
|
|
1377
|
+
int BATCH_SIZE;
|
|
1378
|
+
if (ne1 > 128) BATCH_SIZE = 256;
|
|
1379
|
+
else if (ne1 > 64 ) BATCH_SIZE = 128;
|
|
1380
|
+
else if (ne1 > 32 ) BATCH_SIZE = 64;
|
|
1381
|
+
else if (ne1 > 16 ) BATCH_SIZE = 32;
|
|
1382
|
+
else if (ne1 > 8 ) BATCH_SIZE = 16;
|
|
1383
|
+
else if (ne1 > 4 ) BATCH_SIZE = 8;
|
|
1384
|
+
else BATCH_SIZE = 2;
|
|
1385
|
+
|
|
1386
|
+
auto pipeline = wsp_ggml_metal_library_get_pipeline_ssm_conv_batched(lib, op, BATCH_SIZE);
|
|
1387
|
+
|
|
1388
|
+
wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
1389
|
+
wsp_ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
|
|
1390
|
+
wsp_ggml_metal_encoder_set_buffer(enc, wsp_ggml_metal_get_buffer_id(op->src[0]), 1);
|
|
1391
|
+
wsp_ggml_metal_encoder_set_buffer(enc, wsp_ggml_metal_get_buffer_id(op->src[1]), 2);
|
|
1392
|
+
wsp_ggml_metal_encoder_set_buffer(enc, wsp_ggml_metal_get_buffer_id(op), 3);
|
|
1328
1393
|
|
|
1329
|
-
|
|
1394
|
+
// Dispatch: ne01 rows, ceil(ne1/BATCH_SIZE) token batches, ne02 sequences
|
|
1395
|
+
// Each threadgroup has BATCH_SIZE threads, each handling one token
|
|
1396
|
+
const int n_token_batches = (ne1 + BATCH_SIZE - 1) / BATCH_SIZE;
|
|
1397
|
+
wsp_ggml_metal_encoder_dispatch_threadgroups(enc, ne01, n_token_batches, ne02, BATCH_SIZE, 1, 1);
|
|
1398
|
+
} else {
|
|
1399
|
+
auto pipeline = wsp_ggml_metal_library_get_pipeline_ssm_conv(lib, op);
|
|
1400
|
+
|
|
1401
|
+
wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
1402
|
+
wsp_ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
|
|
1403
|
+
wsp_ggml_metal_encoder_set_buffer(enc, wsp_ggml_metal_get_buffer_id(op->src[0]), 1);
|
|
1404
|
+
wsp_ggml_metal_encoder_set_buffer(enc, wsp_ggml_metal_get_buffer_id(op->src[1]), 2);
|
|
1405
|
+
wsp_ggml_metal_encoder_set_buffer(enc, wsp_ggml_metal_get_buffer_id(op), 3);
|
|
1406
|
+
|
|
1407
|
+
wsp_ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne1, ne02, 1, 1, 1);
|
|
1408
|
+
}
|
|
1330
1409
|
|
|
1331
1410
|
return 1;
|
|
1332
1411
|
}
|
|
@@ -1405,11 +1484,11 @@ int wsp_ggml_metal_op_ssm_scan(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
1405
1484
|
/*.nb0 =*/ nb0,
|
|
1406
1485
|
};
|
|
1407
1486
|
|
|
1408
|
-
|
|
1487
|
+
auto pipeline = wsp_ggml_metal_library_get_pipeline_ssm_scan(lib, op);
|
|
1409
1488
|
|
|
1410
1489
|
WSP_GGML_ASSERT(d_state <= wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
|
1411
1490
|
|
|
1412
|
-
const size_t
|
|
1491
|
+
const size_t smem = pipeline.smem;
|
|
1413
1492
|
|
|
1414
1493
|
wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
1415
1494
|
wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
@@ -1422,7 +1501,7 @@ int wsp_ggml_metal_op_ssm_scan(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
1422
1501
|
wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[6]), 7);
|
|
1423
1502
|
wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op), 8);
|
|
1424
1503
|
|
|
1425
|
-
wsp_ggml_metal_encoder_set_threadgroup_memory_size(enc,
|
|
1504
|
+
wsp_ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
|
1426
1505
|
|
|
1427
1506
|
wsp_ggml_metal_encoder_dispatch_threadgroups(enc, d_inner, n_head, n_seqs, d_state, 1, 1);
|
|
1428
1507
|
|
|
@@ -1445,7 +1524,7 @@ int wsp_ggml_metal_op_rwkv(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
1445
1524
|
const int64_t C = op->ne[0];
|
|
1446
1525
|
const int64_t H = op->src[0]->ne[1];
|
|
1447
1526
|
|
|
1448
|
-
|
|
1527
|
+
auto pipeline = wsp_ggml_metal_library_get_pipeline_rwkv(lib, op);
|
|
1449
1528
|
|
|
1450
1529
|
int ida = 0;
|
|
1451
1530
|
|
|
@@ -1481,7 +1560,7 @@ int wsp_ggml_metal_op_cpy(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
1481
1560
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
1482
1561
|
WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
1483
1562
|
|
|
1484
|
-
|
|
1563
|
+
auto pipeline = wsp_ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type);
|
|
1485
1564
|
|
|
1486
1565
|
WSP_GGML_ASSERT(ne00 % wsp_ggml_blck_size(op->src[0]->type) == 0);
|
|
1487
1566
|
|
|
@@ -1588,7 +1667,7 @@ int wsp_ggml_metal_op_pool_2d(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
1588
1667
|
/* .np = */ np
|
|
1589
1668
|
};
|
|
1590
1669
|
|
|
1591
|
-
|
|
1670
|
+
auto pipeline = wsp_ggml_metal_library_get_pipeline_pool_2d(lib, op, op_pool);
|
|
1592
1671
|
|
|
1593
1672
|
const int nth = std::min(wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), (int) np);
|
|
1594
1673
|
const int ntg = (np + nth - 1) / nth;
|
|
@@ -1697,7 +1776,7 @@ int wsp_ggml_metal_op_mul_mat(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
1697
1776
|
WSP_GGML_ABORT("unsupported ne11");
|
|
1698
1777
|
};
|
|
1699
1778
|
|
|
1700
|
-
|
|
1779
|
+
auto pipeline = wsp_ggml_metal_library_get_pipeline_mul_mv_ext(lib, op->src[0]->type, op->src[1]->type, nsg, nxpsg, r1ptg);
|
|
1701
1780
|
|
|
1702
1781
|
wsp_ggml_metal_kargs_mul_mv_ext args = {
|
|
1703
1782
|
/*.ne00 =*/ ne00,
|
|
@@ -1744,7 +1823,7 @@ int wsp_ggml_metal_op_mul_mat(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
1744
1823
|
// default: break;
|
|
1745
1824
|
//}
|
|
1746
1825
|
|
|
1747
|
-
|
|
1826
|
+
auto pipeline = wsp_ggml_metal_library_get_pipeline_mul_mm(lib, op);
|
|
1748
1827
|
|
|
1749
1828
|
wsp_ggml_metal_kargs_mul_mm args = {
|
|
1750
1829
|
/*.ne00 =*/ ne00,
|
|
@@ -1769,18 +1848,18 @@ int wsp_ggml_metal_op_mul_mat(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
1769
1848
|
wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[1]), 2);
|
|
1770
1849
|
wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op), 3);
|
|
1771
1850
|
|
|
1772
|
-
const size_t smem =
|
|
1851
|
+
const size_t smem = pipeline.smem;
|
|
1773
1852
|
|
|
1774
1853
|
wsp_ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
|
1775
1854
|
wsp_ggml_metal_encoder_dispatch_threadgroups(enc, ((ne11 + 31)/32), ((ne01 + 63)/64), ne12*ne13, 128, 1, 1);
|
|
1776
1855
|
} else {
|
|
1777
|
-
|
|
1856
|
+
auto pipeline = wsp_ggml_metal_library_get_pipeline_mul_mv(lib, op);
|
|
1778
1857
|
|
|
1779
|
-
const int nr0 =
|
|
1780
|
-
const int nr1 =
|
|
1781
|
-
const int nsg =
|
|
1858
|
+
const int nr0 = pipeline.nr0;
|
|
1859
|
+
const int nr1 = pipeline.nr1;
|
|
1860
|
+
const int nsg = pipeline.nsg;
|
|
1782
1861
|
|
|
1783
|
-
const size_t smem =
|
|
1862
|
+
const size_t smem = pipeline.smem;
|
|
1784
1863
|
|
|
1785
1864
|
wsp_ggml_metal_kargs_mul_mv args = {
|
|
1786
1865
|
/*.ne00 =*/ ne00,
|
|
@@ -1911,9 +1990,9 @@ int wsp_ggml_metal_op_mul_mat_id(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
1911
1990
|
nb21,
|
|
1912
1991
|
};
|
|
1913
1992
|
|
|
1914
|
-
|
|
1993
|
+
auto pipeline = wsp_ggml_metal_library_get_pipeline_mul_mm_id_map0(lib, ne02, ne20);
|
|
1915
1994
|
|
|
1916
|
-
const size_t smem =
|
|
1995
|
+
const size_t smem = pipeline.smem;
|
|
1917
1996
|
|
|
1918
1997
|
WSP_GGML_ASSERT(ne02 <= wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
|
1919
1998
|
|
|
@@ -1934,7 +2013,7 @@ int wsp_ggml_metal_op_mul_mat_id(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
1934
2013
|
wsp_ggml_metal_op_concurrency_reset(ctx);
|
|
1935
2014
|
|
|
1936
2015
|
{
|
|
1937
|
-
|
|
2016
|
+
auto pipeline = wsp_ggml_metal_library_get_pipeline_mul_mm_id(lib, op);
|
|
1938
2017
|
|
|
1939
2018
|
wsp_ggml_metal_kargs_mul_mm_id args = {
|
|
1940
2019
|
/*.ne00 =*/ ne00,
|
|
@@ -1963,20 +2042,20 @@ int wsp_ggml_metal_op_mul_mat_id(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
1963
2042
|
wsp_ggml_metal_encoder_set_buffer (enc, bid_ids, 4);
|
|
1964
2043
|
wsp_ggml_metal_encoder_set_buffer (enc, bid_dst, 5);
|
|
1965
2044
|
|
|
1966
|
-
const size_t smem =
|
|
2045
|
+
const size_t smem = pipeline.smem;
|
|
1967
2046
|
|
|
1968
2047
|
wsp_ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
|
1969
2048
|
|
|
1970
2049
|
wsp_ggml_metal_encoder_dispatch_threadgroups(enc, (ne21 + 31)/32, (ne01 + 63)/64, ne02, 128, 1, 1);
|
|
1971
2050
|
}
|
|
1972
2051
|
} else {
|
|
1973
|
-
|
|
2052
|
+
auto pipeline = wsp_ggml_metal_library_get_pipeline_mul_mv_id(lib, op);
|
|
1974
2053
|
|
|
1975
|
-
const int nr0 =
|
|
1976
|
-
const int nr1 =
|
|
1977
|
-
const int nsg =
|
|
2054
|
+
const int nr0 = pipeline.nr0;
|
|
2055
|
+
const int nr1 = pipeline.nr1;
|
|
2056
|
+
const int nsg = pipeline.nsg;
|
|
1978
2057
|
|
|
1979
|
-
const size_t smem =
|
|
2058
|
+
const size_t smem = pipeline.smem;
|
|
1980
2059
|
|
|
1981
2060
|
wsp_ggml_metal_kargs_mul_mv_id args = {
|
|
1982
2061
|
/*.nei0 =*/ ne20,
|
|
@@ -2060,7 +2139,7 @@ int wsp_ggml_metal_op_add_id(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
2060
2139
|
/*.nb21 =*/ nb21,
|
|
2061
2140
|
};
|
|
2062
2141
|
|
|
2063
|
-
|
|
2142
|
+
auto pipeline = wsp_ggml_metal_library_get_pipeline_base(lib, WSP_GGML_OP_ADD_ID);
|
|
2064
2143
|
|
|
2065
2144
|
wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
2066
2145
|
wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
@@ -2102,7 +2181,11 @@ size_t wsp_ggml_metal_op_flash_attn_ext_extra_pad(const wsp_ggml_tensor * op) {
|
|
|
2102
2181
|
|
|
2103
2182
|
const bool has_mask = op->src[3] != nullptr;
|
|
2104
2183
|
|
|
2105
|
-
|
|
2184
|
+
// note: the non-vec kernel requires more extra memory, so always reserve for it
|
|
2185
|
+
WSP_GGML_ASSERT(OP_FLASH_ATTN_EXT_NCPSG >= OP_FLASH_ATTN_EXT_VEC_NCPSG);
|
|
2186
|
+
|
|
2187
|
+
//if (wsp_ggml_metal_op_flash_attn_ext_use_vec(op)) {
|
|
2188
|
+
if (false) {
|
|
2106
2189
|
// note: always reserve the padding space to avoid graph reallocations
|
|
2107
2190
|
//const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_VEC_NCPSG != 0;
|
|
2108
2191
|
const bool has_kvpad = true;
|
|
@@ -2304,7 +2387,7 @@ int wsp_ggml_metal_op_flash_attn_ext(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
2304
2387
|
/*.nb33 =*/nb33,
|
|
2305
2388
|
};
|
|
2306
2389
|
|
|
2307
|
-
|
|
2390
|
+
auto pipeline0 = wsp_ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg);
|
|
2308
2391
|
|
|
2309
2392
|
wsp_ggml_metal_encoder_set_pipeline(enc, pipeline0);
|
|
2310
2393
|
wsp_ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0);
|
|
@@ -2335,7 +2418,7 @@ int wsp_ggml_metal_op_flash_attn_ext(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
2335
2418
|
/*.nb33 =*/ nb33,
|
|
2336
2419
|
};
|
|
2337
2420
|
|
|
2338
|
-
|
|
2421
|
+
auto pipeline0 = wsp_ggml_metal_library_get_pipeline_flash_attn_ext_blk(lib, op, nqptg, ncpsg);
|
|
2339
2422
|
|
|
2340
2423
|
wsp_ggml_metal_encoder_set_pipeline(enc, pipeline0);
|
|
2341
2424
|
wsp_ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0);
|
|
@@ -2420,7 +2503,7 @@ int wsp_ggml_metal_op_flash_attn_ext(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
2420
2503
|
/*.logit_softcap =*/ logit_softcap,
|
|
2421
2504
|
};
|
|
2422
2505
|
|
|
2423
|
-
|
|
2506
|
+
auto pipeline = wsp_ggml_metal_library_get_pipeline_flash_attn_ext(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg);
|
|
2424
2507
|
|
|
2425
2508
|
wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
2426
2509
|
wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
@@ -2472,7 +2555,7 @@ int wsp_ggml_metal_op_flash_attn_ext(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
2472
2555
|
/*.nb33 =*/nb33,
|
|
2473
2556
|
};
|
|
2474
2557
|
|
|
2475
|
-
|
|
2558
|
+
auto pipeline0 = wsp_ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg);
|
|
2476
2559
|
|
|
2477
2560
|
wsp_ggml_metal_encoder_set_pipeline(enc, pipeline0);
|
|
2478
2561
|
wsp_ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0);
|
|
@@ -2574,7 +2657,7 @@ int wsp_ggml_metal_op_flash_attn_ext(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
2574
2657
|
/*.logit_softcap =*/ logit_softcap,
|
|
2575
2658
|
};
|
|
2576
2659
|
|
|
2577
|
-
|
|
2660
|
+
auto pipeline = wsp_ggml_metal_library_get_pipeline_flash_attn_ext_vec(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg, nwg);
|
|
2578
2661
|
|
|
2579
2662
|
WSP_GGML_ASSERT(nsg*32 <= wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
|
2580
2663
|
|
|
@@ -2626,7 +2709,7 @@ int wsp_ggml_metal_op_flash_attn_ext(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
2626
2709
|
nrows,
|
|
2627
2710
|
};
|
|
2628
2711
|
|
|
2629
|
-
|
|
2712
|
+
auto pipeline0 = wsp_ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce(lib, op, ne20, nwg);
|
|
2630
2713
|
|
|
2631
2714
|
wsp_ggml_metal_encoder_set_pipeline(enc, pipeline0);
|
|
2632
2715
|
wsp_ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0);
|
|
@@ -2758,7 +2841,7 @@ int wsp_ggml_metal_op_bin(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
2758
2841
|
// the offsets of src1 and all fused buffers are relative to the start of the src1 buffer
|
|
2759
2842
|
bid_src1.offs = 0;
|
|
2760
2843
|
|
|
2761
|
-
|
|
2844
|
+
struct wsp_ggml_metal_pipeline_with_params pipeline;
|
|
2762
2845
|
|
|
2763
2846
|
if (wsp_ggml_nelements(op->src[1]) == ne10 && wsp_ggml_is_contiguous(op->src[1]) && ne00 % 4 == 0 && ne10 % 4 == 0) {
|
|
2764
2847
|
WSP_GGML_ASSERT(wsp_ggml_is_contiguous(op->src[0]));
|
|
@@ -2831,7 +2914,7 @@ int wsp_ggml_metal_op_l2_norm(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
2831
2914
|
/*.eps =*/ eps,
|
|
2832
2915
|
};
|
|
2833
2916
|
|
|
2834
|
-
|
|
2917
|
+
auto pipeline = wsp_ggml_metal_library_get_pipeline_l2_norm(lib, op);
|
|
2835
2918
|
|
|
2836
2919
|
while (nth < ne00/4 && nth < wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
|
|
2837
2920
|
nth *= 2;
|
|
@@ -2840,7 +2923,7 @@ int wsp_ggml_metal_op_l2_norm(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
2840
2923
|
nth = std::min(nth, wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
|
2841
2924
|
nth = std::min(nth, ne00/4);
|
|
2842
2925
|
|
|
2843
|
-
const size_t smem =
|
|
2926
|
+
const size_t smem = pipeline.smem;
|
|
2844
2927
|
|
|
2845
2928
|
const int64_t nrows = wsp_ggml_nrows(op->src[0]);
|
|
2846
2929
|
|
|
@@ -2883,7 +2966,7 @@ int wsp_ggml_metal_op_group_norm(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
2883
2966
|
/*.eps =*/ eps,
|
|
2884
2967
|
};
|
|
2885
2968
|
|
|
2886
|
-
|
|
2969
|
+
auto pipeline = wsp_ggml_metal_library_get_pipeline_group_norm(lib, op);
|
|
2887
2970
|
|
|
2888
2971
|
int nth = 32; // SIMD width
|
|
2889
2972
|
//while (nth < ne00/4 && nth < wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
|
|
@@ -2893,7 +2976,7 @@ int wsp_ggml_metal_op_group_norm(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
2893
2976
|
//nth = std::min(nth, wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
|
2894
2977
|
//nth = std::min(nth, ne00/4);
|
|
2895
2978
|
|
|
2896
|
-
const size_t smem =
|
|
2979
|
+
const size_t smem = pipeline.smem;
|
|
2897
2980
|
|
|
2898
2981
|
wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
2899
2982
|
wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
@@ -3018,7 +3101,7 @@ int wsp_ggml_metal_op_norm(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
3018
3101
|
}
|
|
3019
3102
|
}
|
|
3020
3103
|
|
|
3021
|
-
|
|
3104
|
+
auto pipeline = wsp_ggml_metal_library_get_pipeline_norm(lib, op, n_fuse);
|
|
3022
3105
|
|
|
3023
3106
|
int nth = 32; // SIMD width
|
|
3024
3107
|
|
|
@@ -3029,7 +3112,7 @@ int wsp_ggml_metal_op_norm(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
3029
3112
|
nth = std::min(nth, wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
|
3030
3113
|
nth = std::min(nth, args.ne00_t);
|
|
3031
3114
|
|
|
3032
|
-
const size_t smem =
|
|
3115
|
+
const size_t smem = pipeline.smem;
|
|
3033
3116
|
|
|
3034
3117
|
wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
3035
3118
|
wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
@@ -3123,7 +3206,7 @@ int wsp_ggml_metal_op_rope(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
3123
3206
|
/* src2 =*/ op->src[2] != nullptr,
|
|
3124
3207
|
};
|
|
3125
3208
|
|
|
3126
|
-
|
|
3209
|
+
auto pipeline = wsp_ggml_metal_library_get_pipeline_rope(lib, op);
|
|
3127
3210
|
|
|
3128
3211
|
wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
3129
3212
|
wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
@@ -3195,7 +3278,7 @@ int wsp_ggml_metal_op_im2col(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
3195
3278
|
/*.KHW =*/ KH * KW,
|
|
3196
3279
|
};
|
|
3197
3280
|
|
|
3198
|
-
|
|
3281
|
+
auto pipeline = wsp_ggml_metal_library_get_pipeline_im2col(lib, op);
|
|
3199
3282
|
|
|
3200
3283
|
WSP_GGML_ASSERT(KH*KW <= wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
|
3201
3284
|
|
|
@@ -3266,7 +3349,7 @@ int wsp_ggml_metal_op_conv_2d(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
3266
3349
|
/*.d1 =*/ d1,
|
|
3267
3350
|
};
|
|
3268
3351
|
|
|
3269
|
-
|
|
3352
|
+
auto pipeline = wsp_ggml_metal_library_get_pipeline_conv_2d(lib, op);
|
|
3270
3353
|
|
|
3271
3354
|
int nth = wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline);
|
|
3272
3355
|
nth = std::min(nth, 256);
|
|
@@ -3321,7 +3404,7 @@ int wsp_ggml_metal_op_conv_transpose_1d(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
3321
3404
|
/*.nb1 =*/ nb1,
|
|
3322
3405
|
};
|
|
3323
3406
|
|
|
3324
|
-
|
|
3407
|
+
auto pipeline = wsp_ggml_metal_library_get_pipeline_conv_transpose_1d(lib, op);
|
|
3325
3408
|
|
|
3326
3409
|
wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
3327
3410
|
wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
@@ -3373,7 +3456,7 @@ int wsp_ggml_metal_op_conv_transpose_2d(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
3373
3456
|
/*.nb2 =*/ nb2,
|
|
3374
3457
|
};
|
|
3375
3458
|
|
|
3376
|
-
|
|
3459
|
+
auto pipeline = wsp_ggml_metal_library_get_pipeline_conv_transpose_2d(lib, op);
|
|
3377
3460
|
|
|
3378
3461
|
wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
3379
3462
|
wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
@@ -3429,7 +3512,7 @@ int wsp_ggml_metal_op_upscale(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
3429
3512
|
/*.sf3 =*/ sf3
|
|
3430
3513
|
};
|
|
3431
3514
|
|
|
3432
|
-
|
|
3515
|
+
auto pipeline = wsp_ggml_metal_library_get_pipeline_upscale(lib, op);
|
|
3433
3516
|
|
|
3434
3517
|
const int nth = std::min(wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0);
|
|
3435
3518
|
|
|
@@ -3473,7 +3556,7 @@ int wsp_ggml_metal_op_pad(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
3473
3556
|
/*.nb3 =*/ nb3
|
|
3474
3557
|
};
|
|
3475
3558
|
|
|
3476
|
-
|
|
3559
|
+
auto pipeline = wsp_ggml_metal_library_get_pipeline_pad(lib, op);
|
|
3477
3560
|
|
|
3478
3561
|
const int nth = std::min(1024, ne0);
|
|
3479
3562
|
|
|
@@ -3519,7 +3602,7 @@ int wsp_ggml_metal_op_pad_reflect_1d(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
3519
3602
|
/*.p1 =*/ ((const int32_t *)(op->op_params))[1]
|
|
3520
3603
|
};
|
|
3521
3604
|
|
|
3522
|
-
|
|
3605
|
+
auto pipeline = wsp_ggml_metal_library_get_pipeline_pad_reflect_1d(lib, op);
|
|
3523
3606
|
|
|
3524
3607
|
const int nth = std::min(1024, ne0);
|
|
3525
3608
|
|
|
@@ -3556,7 +3639,7 @@ int wsp_ggml_metal_op_arange(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
3556
3639
|
|
|
3557
3640
|
const int nth = std::min(1024, ne0);
|
|
3558
3641
|
|
|
3559
|
-
|
|
3642
|
+
auto pipeline = wsp_ggml_metal_library_get_pipeline_arange(lib, op);
|
|
3560
3643
|
|
|
3561
3644
|
wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
3562
3645
|
wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
@@ -3587,7 +3670,7 @@ int wsp_ggml_metal_op_timestep_embedding(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
3587
3670
|
/*.max_period =*/ max_period,
|
|
3588
3671
|
};
|
|
3589
3672
|
|
|
3590
|
-
|
|
3673
|
+
auto pipeline = wsp_ggml_metal_library_get_pipeline_timestep_embedding(lib, op);
|
|
3591
3674
|
|
|
3592
3675
|
const int nth = std::max(1, std::min(1024, dim/2));
|
|
3593
3676
|
|
|
@@ -3617,7 +3700,7 @@ int wsp_ggml_metal_op_argmax(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
3617
3700
|
/*.nb01 = */ nb01,
|
|
3618
3701
|
};
|
|
3619
3702
|
|
|
3620
|
-
|
|
3703
|
+
auto pipeline = wsp_ggml_metal_library_get_pipeline_argmax(lib, op);
|
|
3621
3704
|
|
|
3622
3705
|
const int64_t nrows = wsp_ggml_nrows(op->src[0]);
|
|
3623
3706
|
|
|
@@ -3626,7 +3709,7 @@ int wsp_ggml_metal_op_argmax(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
3626
3709
|
nth *= 2;
|
|
3627
3710
|
}
|
|
3628
3711
|
|
|
3629
|
-
const size_t smem =
|
|
3712
|
+
const size_t smem = pipeline.smem;
|
|
3630
3713
|
|
|
3631
3714
|
wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
3632
3715
|
wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
@@ -3653,7 +3736,7 @@ int wsp_ggml_metal_op_argsort(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
3653
3736
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
3654
3737
|
WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
3655
3738
|
|
|
3656
|
-
|
|
3739
|
+
auto pipeline = wsp_ggml_metal_library_get_pipeline_argsort(lib, op);
|
|
3657
3740
|
|
|
3658
3741
|
// bitonic sort requires the number of elements to be power of 2
|
|
3659
3742
|
int nth = 1;
|
|
@@ -3678,14 +3761,19 @@ int wsp_ggml_metal_op_argsort(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
3678
3761
|
}
|
|
3679
3762
|
|
|
3680
3763
|
wsp_ggml_metal_kargs_argsort args = {
|
|
3681
|
-
/*.ne00
|
|
3682
|
-
/*.ne01
|
|
3683
|
-
/*.ne02
|
|
3684
|
-
/*.ne03
|
|
3685
|
-
/*.nb00
|
|
3686
|
-
/*.nb01
|
|
3687
|
-
/*.nb02
|
|
3688
|
-
/*.nb03
|
|
3764
|
+
/*.ne00 =*/ ne00,
|
|
3765
|
+
/*.ne01 =*/ ne01,
|
|
3766
|
+
/*.ne02 =*/ ne02,
|
|
3767
|
+
/*.ne03 =*/ ne03,
|
|
3768
|
+
/*.nb00 =*/ nb00,
|
|
3769
|
+
/*.nb01 =*/ nb01,
|
|
3770
|
+
/*.nb02 =*/ nb02,
|
|
3771
|
+
/*.nb03 =*/ nb03,
|
|
3772
|
+
/*.ne0 =*/ ne0,
|
|
3773
|
+
/*.ne1 =*/ ne1,
|
|
3774
|
+
/*.ne2 =*/ ne2,
|
|
3775
|
+
/*.ne3 =*/ ne3,
|
|
3776
|
+
/*.top_k =*/ nth,
|
|
3689
3777
|
};
|
|
3690
3778
|
|
|
3691
3779
|
wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
@@ -3697,7 +3785,7 @@ int wsp_ggml_metal_op_argsort(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
3697
3785
|
|
|
3698
3786
|
wsp_ggml_metal_encoder_dispatch_threadgroups(enc, npr*ne01, ne02, ne03, nth, 1, 1);
|
|
3699
3787
|
|
|
3700
|
-
|
|
3788
|
+
auto pipeline_merge = wsp_ggml_metal_library_get_pipeline_argsort_merge(lib, op);
|
|
3701
3789
|
|
|
3702
3790
|
int len = nth;
|
|
3703
3791
|
|
|
@@ -3705,15 +3793,20 @@ int wsp_ggml_metal_op_argsort(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
3705
3793
|
wsp_ggml_metal_op_concurrency_reset(ctx);
|
|
3706
3794
|
|
|
3707
3795
|
wsp_ggml_metal_kargs_argsort_merge args_merge = {
|
|
3708
|
-
|
|
3709
|
-
|
|
3710
|
-
|
|
3711
|
-
|
|
3712
|
-
|
|
3713
|
-
|
|
3714
|
-
|
|
3715
|
-
|
|
3716
|
-
|
|
3796
|
+
/*.ne00 =*/ ne00,
|
|
3797
|
+
/*.ne01 =*/ ne01,
|
|
3798
|
+
/*.ne02 =*/ ne02,
|
|
3799
|
+
/*.ne03 =*/ ne03,
|
|
3800
|
+
/*.nb00 =*/ nb00,
|
|
3801
|
+
/*.nb01 =*/ nb01,
|
|
3802
|
+
/*.nb02 =*/ nb02,
|
|
3803
|
+
/*.nb03 =*/ nb03,
|
|
3804
|
+
/*.ne0 =*/ ne0,
|
|
3805
|
+
/*.ne1 =*/ ne1,
|
|
3806
|
+
/*.ne2 =*/ ne2,
|
|
3807
|
+
/*.ne3 =*/ ne3,
|
|
3808
|
+
/*.top_k =*/ ne00,
|
|
3809
|
+
/*.len =*/ len,
|
|
3717
3810
|
};
|
|
3718
3811
|
|
|
3719
3812
|
// merges per row
|
|
@@ -3737,6 +3830,118 @@ int wsp_ggml_metal_op_argsort(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
3737
3830
|
return 1;
|
|
3738
3831
|
}
|
|
3739
3832
|
|
|
3833
|
+
int wsp_ggml_metal_op_top_k(wsp_ggml_metal_op_t ctx, int idx) {
|
|
3834
|
+
wsp_ggml_tensor * op = ctx->node(idx);
|
|
3835
|
+
|
|
3836
|
+
wsp_ggml_metal_library_t lib = ctx->lib;
|
|
3837
|
+
wsp_ggml_metal_encoder_t enc = ctx->enc;
|
|
3838
|
+
|
|
3839
|
+
WSP_GGML_ASSERT(wsp_ggml_is_contiguous_rows(op->src[0]));
|
|
3840
|
+
|
|
3841
|
+
WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
3842
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
3843
|
+
WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
3844
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
3845
|
+
|
|
3846
|
+
auto pipeline = wsp_ggml_metal_library_get_pipeline_top_k(lib, op);
|
|
3847
|
+
|
|
3848
|
+
// bitonic sort requires the number of elements to be power of 2
|
|
3849
|
+
int nth = 1;
|
|
3850
|
+
while (nth < ne00 && 2*nth <= wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
|
|
3851
|
+
nth *= 2;
|
|
3852
|
+
}
|
|
3853
|
+
|
|
3854
|
+
// blocks per row
|
|
3855
|
+
const int npr = (ne00 + nth - 1)/nth;
|
|
3856
|
+
|
|
3857
|
+
const size_t smem = WSP_GGML_PAD(nth*sizeof(int32_t), 16);
|
|
3858
|
+
|
|
3859
|
+
wsp_ggml_metal_buffer_id bid_src0 = wsp_ggml_metal_get_buffer_id(op->src[0]);
|
|
3860
|
+
wsp_ggml_metal_buffer_id bid_dst = wsp_ggml_metal_get_buffer_id(op);
|
|
3861
|
+
|
|
3862
|
+
wsp_ggml_metal_buffer_id bid_tmp = bid_dst;
|
|
3863
|
+
bid_tmp.offs += sizeof(int32_t)*wsp_ggml_nelements(op->src[0]);
|
|
3864
|
+
|
|
3865
|
+
if ((int) ceil(std::log(npr) / std::log(2)) % 2 == 1) {
|
|
3866
|
+
std::swap(bid_dst, bid_tmp);
|
|
3867
|
+
}
|
|
3868
|
+
|
|
3869
|
+
const int top_k = ne0;
|
|
3870
|
+
|
|
3871
|
+
wsp_ggml_metal_kargs_argsort args = {
|
|
3872
|
+
/*.ne00 =*/ ne00,
|
|
3873
|
+
/*.ne01 =*/ ne01,
|
|
3874
|
+
/*.ne02 =*/ ne02,
|
|
3875
|
+
/*.ne03 =*/ ne03,
|
|
3876
|
+
/*.nb00 =*/ nb00,
|
|
3877
|
+
/*.nb01 =*/ nb01,
|
|
3878
|
+
/*.nb02 =*/ nb02,
|
|
3879
|
+
/*.nb03 =*/ nb03,
|
|
3880
|
+
/*.ne0 =*/ ne0,
|
|
3881
|
+
/*.ne1 =*/ ne1,
|
|
3882
|
+
/*.ne2 =*/ ne2,
|
|
3883
|
+
/*.ne3 =*/ ne3,
|
|
3884
|
+
/*.top_k =*/ std::min(nth, top_k), // for each block, keep just the top_k indices
|
|
3885
|
+
};
|
|
3886
|
+
|
|
3887
|
+
if (npr > 1) {
|
|
3888
|
+
args.ne0 = (npr - 1)*args.top_k + std::min(ne00 - (npr - 1)*nth, args.top_k);
|
|
3889
|
+
}
|
|
3890
|
+
|
|
3891
|
+
wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
3892
|
+
wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
3893
|
+
wsp_ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
|
|
3894
|
+
wsp_ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
|
|
3895
|
+
|
|
3896
|
+
wsp_ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
|
3897
|
+
|
|
3898
|
+
wsp_ggml_metal_encoder_dispatch_threadgroups(enc, npr*ne01, ne02, ne03, nth, 1, 1);
|
|
3899
|
+
|
|
3900
|
+
auto pipeline_merge = wsp_ggml_metal_library_get_pipeline_top_k_merge(lib, op);
|
|
3901
|
+
|
|
3902
|
+
int len = args.top_k;
|
|
3903
|
+
|
|
3904
|
+
while (len < args.ne0) {
|
|
3905
|
+
wsp_ggml_metal_op_concurrency_reset(ctx);
|
|
3906
|
+
|
|
3907
|
+
// merges per row
|
|
3908
|
+
const int nm = (args.ne0 + 2*len - 1) / (2*len);
|
|
3909
|
+
|
|
3910
|
+
const int nth = std::min(512, std::min(len, wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline_merge)));
|
|
3911
|
+
|
|
3912
|
+
wsp_ggml_metal_kargs_argsort_merge args_merge = {
|
|
3913
|
+
/*.ne00 =*/ ne00,
|
|
3914
|
+
/*.ne01 =*/ ne01,
|
|
3915
|
+
/*.ne02 =*/ ne02,
|
|
3916
|
+
/*.ne03 =*/ ne03,
|
|
3917
|
+
/*.nb00 =*/ nb00,
|
|
3918
|
+
/*.nb01 =*/ nb01,
|
|
3919
|
+
/*.nb02 =*/ nb02,
|
|
3920
|
+
/*.nb03 =*/ nb03,
|
|
3921
|
+
/*.ne0 =*/ args.ne0,
|
|
3922
|
+
/*.ne1 =*/ ne1,
|
|
3923
|
+
/*.ne2 =*/ ne2,
|
|
3924
|
+
/*.ne3 =*/ ne3,
|
|
3925
|
+
/*.top_k =*/ nm == 1 ? top_k : args.ne0, // the final merge outputs top_k elements
|
|
3926
|
+
/*.len =*/ len,
|
|
3927
|
+
};
|
|
3928
|
+
|
|
3929
|
+
wsp_ggml_metal_encoder_set_pipeline(enc, pipeline_merge);
|
|
3930
|
+
wsp_ggml_metal_encoder_set_bytes (enc, &args_merge, sizeof(args_merge), 0);
|
|
3931
|
+
wsp_ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
|
|
3932
|
+
wsp_ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
|
|
3933
|
+
wsp_ggml_metal_encoder_set_buffer (enc, bid_tmp, 3);
|
|
3934
|
+
|
|
3935
|
+
wsp_ggml_metal_encoder_dispatch_threadgroups(enc, nm*ne01, ne02, ne03, nth, 1, 1);
|
|
3936
|
+
|
|
3937
|
+
std::swap(bid_dst, bid_tmp);
|
|
3938
|
+
|
|
3939
|
+
len <<= 1;
|
|
3940
|
+
}
|
|
3941
|
+
|
|
3942
|
+
return 1;
|
|
3943
|
+
}
|
|
3944
|
+
|
|
3740
3945
|
int wsp_ggml_metal_op_leaky_relu(wsp_ggml_metal_op_t ctx, int idx) {
|
|
3741
3946
|
wsp_ggml_tensor * op = ctx->node(idx);
|
|
3742
3947
|
|
|
@@ -3755,7 +3960,7 @@ int wsp_ggml_metal_op_leaky_relu(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
3755
3960
|
/*.slope =*/ slope
|
|
3756
3961
|
};
|
|
3757
3962
|
|
|
3758
|
-
|
|
3963
|
+
auto pipeline = wsp_ggml_metal_library_get_pipeline_unary(lib, op);
|
|
3759
3964
|
|
|
3760
3965
|
int64_t n = wsp_ggml_nelements(op);
|
|
3761
3966
|
|
|
@@ -3773,6 +3978,57 @@ int wsp_ggml_metal_op_leaky_relu(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
3773
3978
|
return 1;
|
|
3774
3979
|
}
|
|
3775
3980
|
|
|
3981
|
+
int wsp_ggml_metal_op_tri(wsp_ggml_metal_op_t ctx, int idx) {
|
|
3982
|
+
wsp_ggml_tensor * op = ctx->node(idx);
|
|
3983
|
+
|
|
3984
|
+
wsp_ggml_metal_library_t lib = ctx->lib;
|
|
3985
|
+
wsp_ggml_metal_encoder_t enc = ctx->enc;
|
|
3986
|
+
|
|
3987
|
+
WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
3988
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
3989
|
+
WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
3990
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
3991
|
+
|
|
3992
|
+
wsp_ggml_metal_kargs_tri args = {
|
|
3993
|
+
/*.ne00 =*/ ne00,
|
|
3994
|
+
/*.ne01 =*/ ne01,
|
|
3995
|
+
/*.ne02 =*/ ne02,
|
|
3996
|
+
/*.ne03 =*/ ne03,
|
|
3997
|
+
/*.nb00 =*/ nb00,
|
|
3998
|
+
/*.nb01 =*/ nb01,
|
|
3999
|
+
/*.nb02 =*/ nb02,
|
|
4000
|
+
/*.nb03 =*/ nb03,
|
|
4001
|
+
/*.ne0 =*/ ne0,
|
|
4002
|
+
/*.ne1 =*/ ne1,
|
|
4003
|
+
/*.ne2 =*/ ne2,
|
|
4004
|
+
/*.ne3 =*/ ne3,
|
|
4005
|
+
/*.nb0 =*/ nb0,
|
|
4006
|
+
/*.nb1 =*/ nb1,
|
|
4007
|
+
/*.nb2 =*/ nb2,
|
|
4008
|
+
/*.nb3 =*/ nb3,
|
|
4009
|
+
};
|
|
4010
|
+
|
|
4011
|
+
auto pipeline = wsp_ggml_metal_library_get_pipeline_tri(lib, op);
|
|
4012
|
+
|
|
4013
|
+
int nth = 32; // SIMD width
|
|
4014
|
+
|
|
4015
|
+
while (nth < ne00 && nth < wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
|
|
4016
|
+
nth *= 2;
|
|
4017
|
+
}
|
|
4018
|
+
|
|
4019
|
+
nth = std::min(nth, wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
|
4020
|
+
nth = std::min(nth, ne00);
|
|
4021
|
+
|
|
4022
|
+
wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
4023
|
+
wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
4024
|
+
wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[0]), 1);
|
|
4025
|
+
wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op), 2);
|
|
4026
|
+
|
|
4027
|
+
wsp_ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
|
|
4028
|
+
|
|
4029
|
+
return 1;
|
|
4030
|
+
}
|
|
4031
|
+
|
|
3776
4032
|
int wsp_ggml_metal_op_opt_step_adamw(wsp_ggml_metal_op_t ctx, int idx) {
|
|
3777
4033
|
wsp_ggml_tensor * op = ctx->node(idx);
|
|
3778
4034
|
|
|
@@ -3784,7 +4040,7 @@ int wsp_ggml_metal_op_opt_step_adamw(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
3784
4040
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
3785
4041
|
WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
3786
4042
|
|
|
3787
|
-
|
|
4043
|
+
auto pipeline = wsp_ggml_metal_library_get_pipeline_opt_step_adamw(lib, op);
|
|
3788
4044
|
|
|
3789
4045
|
const int64_t np = wsp_ggml_nelements(op->src[0]);
|
|
3790
4046
|
wsp_ggml_metal_kargs_opt_step_adamw args = {
|
|
@@ -3820,7 +4076,7 @@ int wsp_ggml_metal_op_opt_step_sgd(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
3820
4076
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
3821
4077
|
WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
3822
4078
|
|
|
3823
|
-
|
|
4079
|
+
auto pipeline = wsp_ggml_metal_library_get_pipeline_opt_step_sgd(lib, op);
|
|
3824
4080
|
|
|
3825
4081
|
const int64_t np = wsp_ggml_nelements(op->src[0]);
|
|
3826
4082
|
wsp_ggml_metal_kargs_opt_step_sgd args = {
|
|
@@ -3842,3 +4098,64 @@ int wsp_ggml_metal_op_opt_step_sgd(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
3842
4098
|
|
|
3843
4099
|
return 1;
|
|
3844
4100
|
}
|
|
4101
|
+
|
|
4102
|
+
int wsp_ggml_metal_op_count_equal(wsp_ggml_metal_op_t ctx, int idx) {
|
|
4103
|
+
wsp_ggml_tensor * op = ctx->node(idx);
|
|
4104
|
+
|
|
4105
|
+
wsp_ggml_metal_library_t lib = ctx->lib;
|
|
4106
|
+
wsp_ggml_metal_encoder_t enc = ctx->enc;
|
|
4107
|
+
|
|
4108
|
+
WSP_GGML_TENSOR_LOCALS(int32_t, ne0, op->src[0], ne);
|
|
4109
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
4110
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
|
4111
|
+
|
|
4112
|
+
{
|
|
4113
|
+
wsp_ggml_metal_kargs_memset args = { /*.val =*/ 0 };
|
|
4114
|
+
|
|
4115
|
+
auto pipeline = wsp_ggml_metal_library_get_pipeline_memset(lib, op);
|
|
4116
|
+
|
|
4117
|
+
wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
4118
|
+
wsp_ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
|
|
4119
|
+
wsp_ggml_metal_encoder_set_buffer(enc, wsp_ggml_metal_get_buffer_id(op), 1);
|
|
4120
|
+
|
|
4121
|
+
wsp_ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, 1, 1, 1);
|
|
4122
|
+
}
|
|
4123
|
+
|
|
4124
|
+
wsp_ggml_metal_op_concurrency_reset(ctx);
|
|
4125
|
+
|
|
4126
|
+
{
|
|
4127
|
+
wsp_ggml_metal_kargs_count_equal args = {
|
|
4128
|
+
/*.ne00 =*/ ne00,
|
|
4129
|
+
/*.ne01 =*/ ne01,
|
|
4130
|
+
/*.ne02 =*/ ne02,
|
|
4131
|
+
/*.ne03 =*/ ne03,
|
|
4132
|
+
/*.nb00 =*/ nb00,
|
|
4133
|
+
/*.nb01 =*/ nb01,
|
|
4134
|
+
/*.nb02 =*/ nb02,
|
|
4135
|
+
/*.nb03 =*/ nb03,
|
|
4136
|
+
/*.nb10 =*/ nb10,
|
|
4137
|
+
/*.nb11 =*/ nb11,
|
|
4138
|
+
/*.nb12 =*/ nb12,
|
|
4139
|
+
/*.nb13 =*/ nb13,
|
|
4140
|
+
};
|
|
4141
|
+
|
|
4142
|
+
auto pipeline = wsp_ggml_metal_library_get_pipeline_count_equal(lib, op);
|
|
4143
|
+
|
|
4144
|
+
const size_t smem = pipeline.smem;
|
|
4145
|
+
|
|
4146
|
+
const int nth = 32*pipeline.nsg;
|
|
4147
|
+
|
|
4148
|
+
WSP_GGML_ASSERT(nth <= wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
|
4149
|
+
|
|
4150
|
+
wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
4151
|
+
wsp_ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
|
|
4152
|
+
wsp_ggml_metal_encoder_set_buffer(enc, wsp_ggml_metal_get_buffer_id(op->src[0]), 1);
|
|
4153
|
+
wsp_ggml_metal_encoder_set_buffer(enc, wsp_ggml_metal_get_buffer_id(op->src[1]), 2);
|
|
4154
|
+
wsp_ggml_metal_encoder_set_buffer(enc, wsp_ggml_metal_get_buffer_id(op), 3);
|
|
4155
|
+
|
|
4156
|
+
wsp_ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
|
4157
|
+
wsp_ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
|
|
4158
|
+
}
|
|
4159
|
+
|
|
4160
|
+
return 1;
|
|
4161
|
+
}
|