whisper.rn 0.5.2 → 0.5.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/cpp/ggml-alloc.c +11 -4
- package/cpp/ggml-backend-reg.cpp +8 -0
- package/cpp/ggml-backend.cpp +0 -2
- package/cpp/ggml-cpu/arch/arm/quants.c +428 -26
- package/cpp/ggml-cpu/ggml-cpu-impl.h +3 -1
- package/cpp/ggml-cpu/ggml-cpu.c +50 -21
- package/cpp/ggml-cpu/ops.cpp +458 -349
- package/cpp/ggml-cpu/ops.h +4 -4
- package/cpp/ggml-cpu/repack.cpp +143 -29
- package/cpp/ggml-cpu/simd-mappings.h +25 -25
- package/cpp/ggml-cpu/unary-ops.cpp +16 -0
- package/cpp/ggml-cpu/unary-ops.h +2 -0
- package/cpp/ggml-cpu/vec.cpp +17 -0
- package/cpp/ggml-cpu/vec.h +10 -0
- package/cpp/ggml-impl.h +17 -1
- package/cpp/ggml-metal/ggml-metal-context.m +5 -6
- package/cpp/ggml-metal/ggml-metal-device.cpp +101 -4
- package/cpp/ggml-metal/ggml-metal-device.h +8 -1
- package/cpp/ggml-metal/ggml-metal-device.m +216 -14
- package/cpp/ggml-metal/ggml-metal-impl.h +90 -2
- package/cpp/ggml-metal/ggml-metal-ops.cpp +346 -85
- package/cpp/ggml-metal/ggml-metal-ops.h +2 -0
- package/cpp/ggml-metal/ggml-metal.cpp +5 -0
- package/cpp/ggml-metal/ggml-metal.metal +12436 -0
- package/cpp/ggml.c +154 -5
- package/cpp/ggml.h +73 -0
- package/cpp/whisper.cpp +5 -1
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-impl.h +17 -1
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml.h +73 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Info.plist +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/ggml-metal.metal +12436 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/rnwhisper +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +17 -1
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +73 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +1 -1
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/ggml-metal.metal +12436 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-impl.h +17 -1
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml.h +73 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Info.plist +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/ggml-metal.metal +12436 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/rnwhisper +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +17 -1
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +73 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +1 -1
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/ggml-metal.metal +12436 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
- package/package.json +1 -1
- package/whisper-rn.podspec +1 -1
- package/cpp/ggml-metal/ggml-whisper-sim.metallib +0 -0
- package/cpp/ggml-metal/ggml-whisper.metallib +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
|
@@ -10,6 +10,8 @@
|
|
|
10
10
|
|
|
11
11
|
#include <cassert>
|
|
12
12
|
#include <algorithm>
|
|
13
|
+
#include <limits>
|
|
14
|
+
#include <cmath>
|
|
13
15
|
|
|
14
16
|
static wsp_ggml_metal_buffer_id wsp_ggml_metal_get_buffer_id(const wsp_ggml_tensor * t) {
|
|
15
17
|
if (!t) {
|
|
@@ -310,6 +312,10 @@ static int wsp_ggml_metal_op_encode_impl(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
310
312
|
{
|
|
311
313
|
n_fuse = wsp_ggml_metal_op_sum_rows(ctx, idx);
|
|
312
314
|
} break;
|
|
315
|
+
case WSP_GGML_OP_CUMSUM:
|
|
316
|
+
{
|
|
317
|
+
n_fuse = wsp_ggml_metal_op_cumsum(ctx, idx);
|
|
318
|
+
} break;
|
|
313
319
|
case WSP_GGML_OP_SOFT_MAX:
|
|
314
320
|
{
|
|
315
321
|
n_fuse = wsp_ggml_metal_op_soft_max(ctx, idx);
|
|
@@ -364,6 +370,10 @@ static int wsp_ggml_metal_op_encode_impl(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
364
370
|
{
|
|
365
371
|
n_fuse = wsp_ggml_metal_op_im2col(ctx, idx);
|
|
366
372
|
} break;
|
|
373
|
+
case WSP_GGML_OP_CONV_2D:
|
|
374
|
+
{
|
|
375
|
+
n_fuse = wsp_ggml_metal_op_conv_2d(ctx, idx);
|
|
376
|
+
} break;
|
|
367
377
|
case WSP_GGML_OP_CONV_TRANSPOSE_1D:
|
|
368
378
|
{
|
|
369
379
|
n_fuse = wsp_ggml_metal_op_conv_transpose_1d(ctx, idx);
|
|
@@ -534,7 +544,7 @@ int wsp_ggml_metal_op_repeat(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
534
544
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
535
545
|
WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
536
546
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
537
|
-
WSP_GGML_TENSOR_LOCALS(
|
|
547
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
538
548
|
|
|
539
549
|
wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_repeat(lib, op->type);
|
|
540
550
|
|
|
@@ -580,7 +590,7 @@ int wsp_ggml_metal_op_acc(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
580
590
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
|
581
591
|
WSP_GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
|
582
592
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
583
|
-
WSP_GGML_TENSOR_LOCALS(
|
|
593
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
584
594
|
|
|
585
595
|
WSP_GGML_ASSERT(op->src[0]->type == WSP_GGML_TYPE_F32);
|
|
586
596
|
WSP_GGML_ASSERT(op->src[1]->type == WSP_GGML_TYPE_F32);
|
|
@@ -689,7 +699,7 @@ int wsp_ggml_metal_op_scale(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
689
699
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
690
700
|
WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
691
701
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
692
|
-
WSP_GGML_TENSOR_LOCALS(
|
|
702
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
693
703
|
|
|
694
704
|
float scale;
|
|
695
705
|
float bias;
|
|
@@ -728,7 +738,7 @@ int wsp_ggml_metal_op_clamp(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
728
738
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
729
739
|
WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
730
740
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
731
|
-
WSP_GGML_TENSOR_LOCALS(
|
|
741
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
732
742
|
|
|
733
743
|
float min;
|
|
734
744
|
float max;
|
|
@@ -767,7 +777,7 @@ int wsp_ggml_metal_op_unary(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
767
777
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
768
778
|
WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
769
779
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
770
|
-
WSP_GGML_TENSOR_LOCALS(
|
|
780
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
771
781
|
|
|
772
782
|
int64_t n = wsp_ggml_nelements(op);
|
|
773
783
|
|
|
@@ -797,7 +807,7 @@ int wsp_ggml_metal_op_glu(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
797
807
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
|
798
808
|
WSP_GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
|
799
809
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
800
|
-
WSP_GGML_TENSOR_LOCALS(
|
|
810
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
801
811
|
|
|
802
812
|
if (op->src[1]) {
|
|
803
813
|
WSP_GGML_ASSERT(wsp_ggml_are_same_shape(op->src[0], op->src[1]));
|
|
@@ -829,18 +839,6 @@ int wsp_ggml_metal_op_glu(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
829
839
|
|
|
830
840
|
const int32_t nth = std::min(wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne00/2);
|
|
831
841
|
|
|
832
|
-
//[encoder setComputePipelineState:pipeline];
|
|
833
|
-
//[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
834
|
-
//if (src1) {
|
|
835
|
-
// [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
|
836
|
-
//} else {
|
|
837
|
-
// [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
|
838
|
-
//}
|
|
839
|
-
//[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
840
|
-
//[encoder setBytes:&args length:sizeof(args) atIndex:3];
|
|
841
|
-
|
|
842
|
-
//[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
843
|
-
|
|
844
842
|
wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
845
843
|
wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
846
844
|
wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[0]), 1);
|
|
@@ -902,7 +900,7 @@ int wsp_ggml_metal_op_sum_rows(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
902
900
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
903
901
|
WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
904
902
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
905
|
-
WSP_GGML_TENSOR_LOCALS(
|
|
903
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
906
904
|
|
|
907
905
|
wsp_ggml_metal_kargs_sum_rows args = {
|
|
908
906
|
/*.ne00 =*/ ne00,
|
|
@@ -936,14 +934,6 @@ int wsp_ggml_metal_op_sum_rows(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
936
934
|
|
|
937
935
|
const size_t smem = wsp_ggml_metal_pipeline_get_smem(pipeline);
|
|
938
936
|
|
|
939
|
-
//[encoder setComputePipelineState:pipeline];
|
|
940
|
-
//[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
|
941
|
-
//[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
|
942
|
-
//[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
943
|
-
//[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
|
944
|
-
|
|
945
|
-
//[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
946
|
-
|
|
947
937
|
wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
948
938
|
wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
949
939
|
wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[0]), 1);
|
|
@@ -956,6 +946,149 @@ int wsp_ggml_metal_op_sum_rows(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
956
946
|
return 1;
|
|
957
947
|
}
|
|
958
948
|
|
|
949
|
+
int wsp_ggml_metal_op_cumsum(wsp_ggml_metal_op_t ctx, int idx) {
|
|
950
|
+
wsp_ggml_tensor * op = ctx->node(idx);
|
|
951
|
+
|
|
952
|
+
wsp_ggml_metal_library_t lib = ctx->lib;
|
|
953
|
+
wsp_ggml_metal_encoder_t enc = ctx->enc;
|
|
954
|
+
|
|
955
|
+
WSP_GGML_ASSERT(wsp_ggml_is_contiguous_rows(op->src[0]));
|
|
956
|
+
|
|
957
|
+
WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
958
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
959
|
+
WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
960
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
961
|
+
|
|
962
|
+
wsp_ggml_metal_pipeline_t pipeline_blk = wsp_ggml_metal_library_get_pipeline_cumsum_blk(lib, op);
|
|
963
|
+
|
|
964
|
+
int nth = 1;
|
|
965
|
+
while (nth < ne00 && 2*nth <= wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline_blk)) {
|
|
966
|
+
nth *= 2;
|
|
967
|
+
}
|
|
968
|
+
|
|
969
|
+
WSP_GGML_ASSERT(ne00 <= nth*nth);
|
|
970
|
+
|
|
971
|
+
const int64_t net0 = (ne00 + nth - 1) / nth;
|
|
972
|
+
const int64_t net1 = ne01;
|
|
973
|
+
const int64_t net2 = ne02;
|
|
974
|
+
const int64_t net3 = ne03;
|
|
975
|
+
|
|
976
|
+
const uint64_t nbt0 = sizeof(float);
|
|
977
|
+
const uint64_t nbt1 = net0*nbt0;
|
|
978
|
+
const uint64_t nbt2 = net1*nbt1;
|
|
979
|
+
const uint64_t nbt3 = net2*nbt2;
|
|
980
|
+
|
|
981
|
+
const size_t smem = WSP_GGML_PAD(32*sizeof(float), 16);
|
|
982
|
+
|
|
983
|
+
wsp_ggml_metal_buffer_id bid_src0 = wsp_ggml_metal_get_buffer_id(op->src[0]);
|
|
984
|
+
wsp_ggml_metal_buffer_id bid_dst = wsp_ggml_metal_get_buffer_id(op);
|
|
985
|
+
|
|
986
|
+
wsp_ggml_metal_buffer_id bid_tmp = bid_dst;
|
|
987
|
+
bid_tmp.offs += wsp_ggml_nbytes(op);
|
|
988
|
+
|
|
989
|
+
{
|
|
990
|
+
wsp_ggml_metal_kargs_cumsum_blk args = {
|
|
991
|
+
/*.ne00 =*/ ne00,
|
|
992
|
+
/*.ne01 =*/ ne01,
|
|
993
|
+
/*.ne02 =*/ ne02,
|
|
994
|
+
/*.ne03 =*/ ne03,
|
|
995
|
+
/*.nb00 =*/ nb00,
|
|
996
|
+
/*.nb01 =*/ nb01,
|
|
997
|
+
/*.nb02 =*/ nb02,
|
|
998
|
+
/*.nb03 =*/ nb03,
|
|
999
|
+
/*.net0 =*/ net0,
|
|
1000
|
+
/*.net1 =*/ net1,
|
|
1001
|
+
/*.net2 =*/ net2,
|
|
1002
|
+
/*.net3 =*/ net3,
|
|
1003
|
+
/*.nbt0 =*/ nbt0,
|
|
1004
|
+
/*.nbt1 =*/ nbt1,
|
|
1005
|
+
/*.nbt2 =*/ nbt2,
|
|
1006
|
+
/*.nbt3 =*/ nbt3,
|
|
1007
|
+
/*.outb =*/ ne00 > nth,
|
|
1008
|
+
};
|
|
1009
|
+
|
|
1010
|
+
wsp_ggml_metal_encoder_set_pipeline(enc, pipeline_blk);
|
|
1011
|
+
wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
1012
|
+
wsp_ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
|
|
1013
|
+
wsp_ggml_metal_encoder_set_buffer (enc, bid_tmp, 2);
|
|
1014
|
+
wsp_ggml_metal_encoder_set_buffer (enc, bid_dst, 3);
|
|
1015
|
+
|
|
1016
|
+
wsp_ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
|
1017
|
+
|
|
1018
|
+
wsp_ggml_metal_encoder_dispatch_threadgroups(enc, net0*ne01, ne02, ne03, nth, 1, 1);
|
|
1019
|
+
}
|
|
1020
|
+
|
|
1021
|
+
if (ne00 > nth) {
|
|
1022
|
+
wsp_ggml_metal_op_concurrency_reset(ctx);
|
|
1023
|
+
|
|
1024
|
+
{
|
|
1025
|
+
wsp_ggml_metal_kargs_cumsum_blk args = {
|
|
1026
|
+
/*.ne00 =*/ net0,
|
|
1027
|
+
/*.ne01 =*/ net1,
|
|
1028
|
+
/*.ne02 =*/ net2,
|
|
1029
|
+
/*.ne03 =*/ net3,
|
|
1030
|
+
/*.nb00 =*/ nbt0,
|
|
1031
|
+
/*.nb01 =*/ nbt1,
|
|
1032
|
+
/*.nb02 =*/ nbt2,
|
|
1033
|
+
/*.nb03 =*/ nbt3,
|
|
1034
|
+
/*.net0 =*/ net0,
|
|
1035
|
+
/*.net1 =*/ net1,
|
|
1036
|
+
/*.net2 =*/ net2,
|
|
1037
|
+
/*.net3 =*/ net3,
|
|
1038
|
+
/*.nbt0 =*/ nbt0,
|
|
1039
|
+
/*.nbt1 =*/ nbt1,
|
|
1040
|
+
/*.nbt2 =*/ nbt2,
|
|
1041
|
+
/*.nbt3 =*/ nbt3,
|
|
1042
|
+
/*.outb =*/ false,
|
|
1043
|
+
};
|
|
1044
|
+
|
|
1045
|
+
wsp_ggml_metal_encoder_set_pipeline(enc, pipeline_blk);
|
|
1046
|
+
wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
1047
|
+
wsp_ggml_metal_encoder_set_buffer (enc, bid_tmp, 1);
|
|
1048
|
+
wsp_ggml_metal_encoder_set_buffer (enc, bid_tmp, 2);
|
|
1049
|
+
wsp_ggml_metal_encoder_set_buffer (enc, bid_tmp, 3);
|
|
1050
|
+
|
|
1051
|
+
wsp_ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
|
1052
|
+
|
|
1053
|
+
wsp_ggml_metal_encoder_dispatch_threadgroups(enc, net1, net2, net3, nth, 1, 1);
|
|
1054
|
+
}
|
|
1055
|
+
|
|
1056
|
+
wsp_ggml_metal_op_concurrency_reset(ctx);
|
|
1057
|
+
|
|
1058
|
+
{
|
|
1059
|
+
wsp_ggml_metal_pipeline_t pipeline_add = wsp_ggml_metal_library_get_pipeline_cumsum_add(lib, op);
|
|
1060
|
+
|
|
1061
|
+
wsp_ggml_metal_kargs_cumsum_add args = {
|
|
1062
|
+
/*.ne00 =*/ ne00,
|
|
1063
|
+
/*.ne01 =*/ ne01,
|
|
1064
|
+
/*.ne02 =*/ ne02,
|
|
1065
|
+
/*.ne03 =*/ ne03,
|
|
1066
|
+
/*.nb00 =*/ nb00,
|
|
1067
|
+
/*.nb01 =*/ nb01,
|
|
1068
|
+
/*.nb02 =*/ nb02,
|
|
1069
|
+
/*.nb03 =*/ nb03,
|
|
1070
|
+
/*.net0 =*/ net0,
|
|
1071
|
+
/*.net1 =*/ net1,
|
|
1072
|
+
/*.net2 =*/ net2,
|
|
1073
|
+
/*.net3 =*/ net3,
|
|
1074
|
+
/*.nbt0 =*/ nbt0,
|
|
1075
|
+
/*.nbt1 =*/ nbt1,
|
|
1076
|
+
/*.nbt2 =*/ nbt2,
|
|
1077
|
+
/*.nbt3 =*/ nbt3,
|
|
1078
|
+
};
|
|
1079
|
+
|
|
1080
|
+
wsp_ggml_metal_encoder_set_pipeline(enc, pipeline_add);
|
|
1081
|
+
wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
1082
|
+
wsp_ggml_metal_encoder_set_buffer (enc, bid_tmp, 1);
|
|
1083
|
+
wsp_ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
|
|
1084
|
+
|
|
1085
|
+
wsp_ggml_metal_encoder_dispatch_threadgroups(enc, net0*ne01, ne02, ne03, nth, 1, 1);
|
|
1086
|
+
}
|
|
1087
|
+
}
|
|
1088
|
+
|
|
1089
|
+
return 1;
|
|
1090
|
+
}
|
|
1091
|
+
|
|
959
1092
|
int wsp_ggml_metal_op_get_rows(wsp_ggml_metal_op_t ctx, int idx) {
|
|
960
1093
|
wsp_ggml_tensor * op = ctx->node(idx);
|
|
961
1094
|
|
|
@@ -967,7 +1100,7 @@ int wsp_ggml_metal_op_get_rows(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
967
1100
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
|
968
1101
|
WSP_GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
|
969
1102
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
970
|
-
WSP_GGML_TENSOR_LOCALS(
|
|
1103
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
971
1104
|
|
|
972
1105
|
wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_get_rows(lib, op->src[0]->type);
|
|
973
1106
|
|
|
@@ -1012,7 +1145,7 @@ int wsp_ggml_metal_op_set_rows(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
1012
1145
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
|
1013
1146
|
WSP_GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
|
1014
1147
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
1015
|
-
WSP_GGML_TENSOR_LOCALS(
|
|
1148
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
1016
1149
|
|
|
1017
1150
|
wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_set_rows(lib, op->src[1]->type, op->type);
|
|
1018
1151
|
|
|
@@ -1076,7 +1209,7 @@ int wsp_ggml_metal_op_soft_max(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
1076
1209
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
|
|
1077
1210
|
WSP_GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
|
|
1078
1211
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
1079
|
-
WSP_GGML_TENSOR_LOCALS(
|
|
1212
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
1080
1213
|
|
|
1081
1214
|
float scale;
|
|
1082
1215
|
float max_bias;
|
|
@@ -1164,7 +1297,7 @@ int wsp_ggml_metal_op_ssm_conv(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
1164
1297
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
|
1165
1298
|
WSP_GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
|
1166
1299
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
1167
|
-
WSP_GGML_TENSOR_LOCALS(
|
|
1300
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
1168
1301
|
|
|
1169
1302
|
wsp_ggml_metal_kargs_ssm_conv args = {
|
|
1170
1303
|
/*.ne00 =*/ ne00,
|
|
@@ -1219,7 +1352,7 @@ int wsp_ggml_metal_op_ssm_scan(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
1219
1352
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne6, op->src[6], ne);
|
|
1220
1353
|
WSP_GGML_TENSOR_LOCALS(uint64_t, nb6, op->src[6], nb);
|
|
1221
1354
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
1222
|
-
WSP_GGML_TENSOR_LOCALS(
|
|
1355
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
1223
1356
|
|
|
1224
1357
|
const wsp_ggml_tensor * src3 = op->src[3];
|
|
1225
1358
|
const wsp_ggml_tensor * src4 = op->src[4];
|
|
@@ -1305,7 +1438,7 @@ int wsp_ggml_metal_op_rwkv(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
1305
1438
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
1306
1439
|
WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
1307
1440
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
1308
|
-
WSP_GGML_TENSOR_LOCALS(
|
|
1441
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
1309
1442
|
|
|
1310
1443
|
const int64_t B = op->op == WSP_GGML_OP_RWKV_WKV6 ? op->src[5]->ne[1] : op->src[6]->ne[1];
|
|
1311
1444
|
const int64_t T = op->src[0]->ne[2];
|
|
@@ -1346,7 +1479,7 @@ int wsp_ggml_metal_op_cpy(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
1346
1479
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
1347
1480
|
WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
1348
1481
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
1349
|
-
WSP_GGML_TENSOR_LOCALS(
|
|
1482
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
1350
1483
|
|
|
1351
1484
|
wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type);
|
|
1352
1485
|
|
|
@@ -1419,7 +1552,7 @@ int wsp_ggml_metal_op_pool_2d(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
1419
1552
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
1420
1553
|
WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
1421
1554
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
1422
|
-
WSP_GGML_TENSOR_LOCALS(
|
|
1555
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
1423
1556
|
|
|
1424
1557
|
const int32_t * opts = op->op_params;
|
|
1425
1558
|
wsp_ggml_op_pool op_pool = (wsp_ggml_op_pool) opts[0];
|
|
@@ -1483,7 +1616,7 @@ int wsp_ggml_metal_op_mul_mat(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
1483
1616
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
|
1484
1617
|
WSP_GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
|
1485
1618
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
1486
|
-
WSP_GGML_TENSOR_LOCALS(
|
|
1619
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
1487
1620
|
|
|
1488
1621
|
WSP_GGML_ASSERT(ne00 == ne10);
|
|
1489
1622
|
|
|
@@ -1724,7 +1857,7 @@ int wsp_ggml_metal_op_mul_mat_id(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
1724
1857
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
|
|
1725
1858
|
WSP_GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
|
|
1726
1859
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
1727
|
-
WSP_GGML_TENSOR_LOCALS(
|
|
1860
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
1728
1861
|
|
|
1729
1862
|
// src2 = ids
|
|
1730
1863
|
WSP_GGML_ASSERT(op->src[2]->type == WSP_GGML_TYPE_I32);
|
|
@@ -1970,7 +2103,9 @@ size_t wsp_ggml_metal_op_flash_attn_ext_extra_pad(const wsp_ggml_tensor * op) {
|
|
|
1970
2103
|
const bool has_mask = op->src[3] != nullptr;
|
|
1971
2104
|
|
|
1972
2105
|
if (wsp_ggml_metal_op_flash_attn_ext_use_vec(op)) {
|
|
1973
|
-
|
|
2106
|
+
// note: always reserve the padding space to avoid graph reallocations
|
|
2107
|
+
//const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_VEC_NCPSG != 0;
|
|
2108
|
+
const bool has_kvpad = true;
|
|
1974
2109
|
|
|
1975
2110
|
if (has_kvpad) {
|
|
1976
2111
|
res += OP_FLASH_ATTN_EXT_VEC_NCPSG*(
|
|
@@ -1979,7 +2114,8 @@ size_t wsp_ggml_metal_op_flash_attn_ext_extra_pad(const wsp_ggml_tensor * op) {
|
|
|
1979
2114
|
(has_mask ? wsp_ggml_type_size(WSP_GGML_TYPE_F16)*ne31*ne32*ne33 : 0));
|
|
1980
2115
|
}
|
|
1981
2116
|
} else {
|
|
1982
|
-
const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_NCPSG != 0;
|
|
2117
|
+
//const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_NCPSG != 0;
|
|
2118
|
+
const bool has_kvpad = true;
|
|
1983
2119
|
|
|
1984
2120
|
if (has_kvpad) {
|
|
1985
2121
|
res += OP_FLASH_ATTN_EXT_NCPSG*(
|
|
@@ -2015,9 +2151,10 @@ size_t wsp_ggml_metal_op_flash_attn_ext_extra_blk(const wsp_ggml_tensor * op) {
|
|
|
2015
2151
|
const bool is_vec = wsp_ggml_metal_op_flash_attn_ext_use_vec(op);
|
|
2016
2152
|
|
|
2017
2153
|
// this optimization is not useful for the vector kernels
|
|
2018
|
-
|
|
2019
|
-
|
|
2020
|
-
|
|
2154
|
+
// note: always reserve the blk buffer to avoid graph reallocations
|
|
2155
|
+
//if (is_vec) {
|
|
2156
|
+
// return res;
|
|
2157
|
+
//}
|
|
2021
2158
|
|
|
2022
2159
|
const int nqptg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NQPTG : OP_FLASH_ATTN_EXT_NQPTG;
|
|
2023
2160
|
const int ncpsg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NCPSG : OP_FLASH_ATTN_EXT_NCPSG;
|
|
@@ -2044,13 +2181,16 @@ size_t wsp_ggml_metal_op_flash_attn_ext_extra_tmp(const wsp_ggml_tensor * op) {
|
|
|
2044
2181
|
|
|
2045
2182
|
size_t res = 0;
|
|
2046
2183
|
|
|
2047
|
-
|
|
2184
|
+
// note: always reserve the temp buffer to avoid graph reallocations
|
|
2185
|
+
//if (wsp_ggml_metal_op_flash_attn_ext_use_vec(op)) {
|
|
2186
|
+
if (true) {
|
|
2048
2187
|
const int64_t nwg = 32;
|
|
2188
|
+
const int64_t ne01_max = std::min(ne01, 32);
|
|
2049
2189
|
|
|
2050
2190
|
// temp buffer for writing the results from each workgroup
|
|
2051
2191
|
// - ne20: the size of the Value head
|
|
2052
2192
|
// - + 2: the S and M values for each intermediate result
|
|
2053
|
-
res += wsp_ggml_type_size(WSP_GGML_TYPE_F32)*(
|
|
2193
|
+
res += wsp_ggml_type_size(WSP_GGML_TYPE_F32)*(ne01_max*ne02*ne03*nwg*(ne20 + 2));
|
|
2054
2194
|
}
|
|
2055
2195
|
|
|
2056
2196
|
return res;
|
|
@@ -2179,8 +2319,6 @@ int wsp_ggml_metal_op_flash_attn_ext(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
2179
2319
|
wsp_ggml_metal_encoder_dispatch_threadgroups(enc, ncpsg, std::max(ne12, ne32), std::max(ne13, ne33), 32, 1, 1);
|
|
2180
2320
|
|
|
2181
2321
|
need_sync = true;
|
|
2182
|
-
} else {
|
|
2183
|
-
assert(wsp_ggml_metal_op_flash_attn_ext_extra_pad(op) == 0);
|
|
2184
2322
|
}
|
|
2185
2323
|
|
|
2186
2324
|
if (has_mask) {
|
|
@@ -2210,8 +2348,6 @@ int wsp_ggml_metal_op_flash_attn_ext(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
2210
2348
|
wsp_ggml_metal_encoder_dispatch_threadgroups(enc, nblk0, nblk1, ne32*ne33, 32, 1, 1);
|
|
2211
2349
|
|
|
2212
2350
|
need_sync = true;
|
|
2213
|
-
} else {
|
|
2214
|
-
assert(wsp_ggml_metal_op_flash_attn_ext_extra_blk(op) == 0);
|
|
2215
2351
|
}
|
|
2216
2352
|
|
|
2217
2353
|
if (need_sync) {
|
|
@@ -2351,8 +2487,6 @@ int wsp_ggml_metal_op_flash_attn_ext(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
2351
2487
|
wsp_ggml_metal_encoder_dispatch_threadgroups(enc, ncpsg, std::max(ne12, ne32), std::max(ne13, ne33), 32, 1, 1);
|
|
2352
2488
|
|
|
2353
2489
|
need_sync = true;
|
|
2354
|
-
} else {
|
|
2355
|
-
assert(wsp_ggml_metal_op_flash_attn_ext_extra_pad(op) == 0);
|
|
2356
2490
|
}
|
|
2357
2491
|
|
|
2358
2492
|
if (need_sync) {
|
|
@@ -2683,7 +2817,7 @@ int wsp_ggml_metal_op_l2_norm(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
2683
2817
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
2684
2818
|
WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
2685
2819
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
2686
|
-
WSP_GGML_TENSOR_LOCALS(
|
|
2820
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
2687
2821
|
|
|
2688
2822
|
float eps;
|
|
2689
2823
|
memcpy(&eps, op->op_params, sizeof(float));
|
|
@@ -2731,7 +2865,7 @@ int wsp_ggml_metal_op_group_norm(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
2731
2865
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
2732
2866
|
WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
2733
2867
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
2734
|
-
WSP_GGML_TENSOR_LOCALS(
|
|
2868
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
2735
2869
|
|
|
2736
2870
|
const int32_t ngrp = ((const int32_t *) op->op_params)[0];
|
|
2737
2871
|
|
|
@@ -2786,7 +2920,7 @@ int wsp_ggml_metal_op_norm(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
2786
2920
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
2787
2921
|
WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
2788
2922
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
2789
|
-
WSP_GGML_TENSOR_LOCALS(
|
|
2923
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
2790
2924
|
|
|
2791
2925
|
float eps;
|
|
2792
2926
|
memcpy(&eps, op->op_params, sizeof(float));
|
|
@@ -2922,7 +3056,7 @@ int wsp_ggml_metal_op_rope(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
2922
3056
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
|
2923
3057
|
WSP_GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
|
2924
3058
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
2925
|
-
WSP_GGML_TENSOR_LOCALS(
|
|
3059
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
2926
3060
|
|
|
2927
3061
|
// make sure we have one or more position id(ne10) per token(ne02)
|
|
2928
3062
|
WSP_GGML_ASSERT(ne10 % ne02 == 0);
|
|
@@ -3016,7 +3150,7 @@ int wsp_ggml_metal_op_im2col(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
3016
3150
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
3017
3151
|
WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
3018
3152
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
3019
|
-
WSP_GGML_TENSOR_LOCALS(
|
|
3153
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
3020
3154
|
|
|
3021
3155
|
const int32_t s0 = ((const int32_t *)(op->op_params))[0];
|
|
3022
3156
|
const int32_t s1 = ((const int32_t *)(op->op_params))[1];
|
|
@@ -3077,6 +3211,84 @@ int wsp_ggml_metal_op_im2col(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
3077
3211
|
return 1;
|
|
3078
3212
|
}
|
|
3079
3213
|
|
|
3214
|
+
int wsp_ggml_metal_op_conv_2d(wsp_ggml_metal_op_t ctx, int idx) {
|
|
3215
|
+
wsp_ggml_tensor * op = ctx->node(idx);
|
|
3216
|
+
|
|
3217
|
+
wsp_ggml_metal_library_t lib = ctx->lib;
|
|
3218
|
+
wsp_ggml_metal_encoder_t enc = ctx->enc;
|
|
3219
|
+
|
|
3220
|
+
WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
3221
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
3222
|
+
WSP_GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
|
3223
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
|
3224
|
+
WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
3225
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
3226
|
+
|
|
3227
|
+
WSP_GGML_ASSERT(wsp_ggml_is_contiguous(op->src[0]));
|
|
3228
|
+
WSP_GGML_ASSERT(op->src[1]->type == WSP_GGML_TYPE_F32);
|
|
3229
|
+
WSP_GGML_ASSERT(op->type == WSP_GGML_TYPE_F32);
|
|
3230
|
+
WSP_GGML_ASSERT(op->src[0]->type == WSP_GGML_TYPE_F16 || op->src[0]->type == WSP_GGML_TYPE_F32);
|
|
3231
|
+
|
|
3232
|
+
const int32_t s0 = ((const int32_t *) op->op_params)[0];
|
|
3233
|
+
const int32_t s1 = ((const int32_t *) op->op_params)[1];
|
|
3234
|
+
const int32_t p0 = ((const int32_t *) op->op_params)[2];
|
|
3235
|
+
const int32_t p1 = ((const int32_t *) op->op_params)[3];
|
|
3236
|
+
const int32_t d0 = ((const int32_t *) op->op_params)[4];
|
|
3237
|
+
const int32_t d1 = ((const int32_t *) op->op_params)[5];
|
|
3238
|
+
|
|
3239
|
+
wsp_ggml_metal_kargs_conv_2d args = {
|
|
3240
|
+
/*.nb00 =*/ nb00,
|
|
3241
|
+
/*.nb01 =*/ nb01,
|
|
3242
|
+
/*.nb02 =*/ nb02,
|
|
3243
|
+
/*.nb03 =*/ nb03,
|
|
3244
|
+
/*.nb10 =*/ nb10,
|
|
3245
|
+
/*.nb11 =*/ nb11,
|
|
3246
|
+
/*.nb12 =*/ nb12,
|
|
3247
|
+
/*.nb13 =*/ nb13,
|
|
3248
|
+
/*.nb0 =*/ nb0,
|
|
3249
|
+
/*.nb1 =*/ nb1,
|
|
3250
|
+
/*.nb2 =*/ nb2,
|
|
3251
|
+
/*.nb3 =*/ nb3,
|
|
3252
|
+
/*.IW =*/ ne10,
|
|
3253
|
+
/*.IH =*/ ne11,
|
|
3254
|
+
/*.KW =*/ ne00,
|
|
3255
|
+
/*.KH =*/ ne01,
|
|
3256
|
+
/*.IC =*/ ne02,
|
|
3257
|
+
/*.OC =*/ ne03,
|
|
3258
|
+
/*.OW =*/ ne0,
|
|
3259
|
+
/*.OH =*/ ne1,
|
|
3260
|
+
/*.N =*/ ne3,
|
|
3261
|
+
/*.s0 =*/ s0,
|
|
3262
|
+
/*.s1 =*/ s1,
|
|
3263
|
+
/*.p0 =*/ p0,
|
|
3264
|
+
/*.p1 =*/ p1,
|
|
3265
|
+
/*.d0 =*/ d0,
|
|
3266
|
+
/*.d1 =*/ d1,
|
|
3267
|
+
};
|
|
3268
|
+
|
|
3269
|
+
wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_conv_2d(lib, op);
|
|
3270
|
+
|
|
3271
|
+
int nth = wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline);
|
|
3272
|
+
nth = std::min(nth, 256);
|
|
3273
|
+
nth = std::max(nth, 1);
|
|
3274
|
+
|
|
3275
|
+
const uint64_t n_out = wsp_ggml_nelements(op);
|
|
3276
|
+
|
|
3277
|
+
uint64_t tg = (n_out + nth - 1)/nth;
|
|
3278
|
+
tg = std::max<uint64_t>(tg, 1);
|
|
3279
|
+
tg = std::min<uint64_t>(tg, (uint64_t) std::numeric_limits<int>::max());
|
|
3280
|
+
|
|
3281
|
+
wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
3282
|
+
wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
3283
|
+
wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[0]), 1);
|
|
3284
|
+
wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op->src[1]), 2);
|
|
3285
|
+
wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op), 3);
|
|
3286
|
+
|
|
3287
|
+
wsp_ggml_metal_encoder_dispatch_threadgroups(enc, tg, 1, 1, nth, 1, 1);
|
|
3288
|
+
|
|
3289
|
+
return 1;
|
|
3290
|
+
}
|
|
3291
|
+
|
|
3080
3292
|
int wsp_ggml_metal_op_conv_transpose_1d(wsp_ggml_metal_op_t ctx, int idx) {
|
|
3081
3293
|
wsp_ggml_tensor * op = ctx->node(idx);
|
|
3082
3294
|
|
|
@@ -3088,7 +3300,7 @@ int wsp_ggml_metal_op_conv_transpose_1d(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
3088
3300
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
|
3089
3301
|
WSP_GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
|
3090
3302
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
3091
|
-
WSP_GGML_TENSOR_LOCALS(
|
|
3303
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
3092
3304
|
|
|
3093
3305
|
const int32_t s0 = ((const int32_t *)(op->op_params))[0];
|
|
3094
3306
|
|
|
@@ -3133,7 +3345,7 @@ int wsp_ggml_metal_op_conv_transpose_2d(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
3133
3345
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
|
3134
3346
|
WSP_GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
|
3135
3347
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
3136
|
-
WSP_GGML_TENSOR_LOCALS(
|
|
3348
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
3137
3349
|
|
|
3138
3350
|
const int32_t s0 = ((const int32_t *)(op->op_params))[0];
|
|
3139
3351
|
|
|
@@ -3187,7 +3399,7 @@ int wsp_ggml_metal_op_upscale(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
3187
3399
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
3188
3400
|
WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
3189
3401
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
3190
|
-
WSP_GGML_TENSOR_LOCALS(
|
|
3402
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
3191
3403
|
|
|
3192
3404
|
const float sf0 = (float)ne0/op->src[0]->ne[0];
|
|
3193
3405
|
const float sf1 = (float)ne1/op->src[0]->ne[1];
|
|
@@ -3240,7 +3452,7 @@ int wsp_ggml_metal_op_pad(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
3240
3452
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
3241
3453
|
WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
3242
3454
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
3243
|
-
WSP_GGML_TENSOR_LOCALS(
|
|
3455
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
3244
3456
|
|
|
3245
3457
|
wsp_ggml_metal_kargs_pad args = {
|
|
3246
3458
|
/*.ne00 =*/ ne00,
|
|
@@ -3284,7 +3496,7 @@ int wsp_ggml_metal_op_pad_reflect_1d(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
3284
3496
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
3285
3497
|
WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
3286
3498
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
3287
|
-
WSP_GGML_TENSOR_LOCALS(
|
|
3499
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
3288
3500
|
|
|
3289
3501
|
wsp_ggml_metal_kargs_pad_reflect_1d args = {
|
|
3290
3502
|
/*.ne00 =*/ ne00,
|
|
@@ -3328,7 +3540,7 @@ int wsp_ggml_metal_op_arange(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
3328
3540
|
wsp_ggml_metal_encoder_t enc = ctx->enc;
|
|
3329
3541
|
|
|
3330
3542
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
3331
|
-
WSP_GGML_TENSOR_LOCALS(
|
|
3543
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
3332
3544
|
|
|
3333
3545
|
float start;
|
|
3334
3546
|
float step;
|
|
@@ -3346,12 +3558,6 @@ int wsp_ggml_metal_op_arange(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
3346
3558
|
|
|
3347
3559
|
wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_arange(lib, op);
|
|
3348
3560
|
|
|
3349
|
-
//[encoder setComputePipelineState:pipeline];
|
|
3350
|
-
//[encoder setBuffer:id_dst offset:offs_dst atIndex:0];
|
|
3351
|
-
//[encoder setBytes:&args length:sizeof(args) atIndex:1];
|
|
3352
|
-
|
|
3353
|
-
//[encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
3354
|
-
|
|
3355
3561
|
wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
3356
3562
|
wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
3357
3563
|
wsp_ggml_metal_encoder_set_buffer (enc, wsp_ggml_metal_get_buffer_id(op), 1);
|
|
@@ -3370,7 +3576,7 @@ int wsp_ggml_metal_op_timestep_embedding(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
3370
3576
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
3371
3577
|
WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
3372
3578
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
3373
|
-
WSP_GGML_TENSOR_LOCALS(
|
|
3579
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
3374
3580
|
|
|
3375
3581
|
const int dim = op->op_params[0];
|
|
3376
3582
|
const int max_period = op->op_params[1];
|
|
@@ -3404,7 +3610,7 @@ int wsp_ggml_metal_op_argmax(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
3404
3610
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
3405
3611
|
WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
3406
3612
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
3407
|
-
WSP_GGML_TENSOR_LOCALS(
|
|
3613
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
3408
3614
|
|
|
3409
3615
|
wsp_ggml_metal_kargs_argmax args = {
|
|
3410
3616
|
/*.ne00 = */ ne00,
|
|
@@ -3440,38 +3646,93 @@ int wsp_ggml_metal_op_argsort(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
3440
3646
|
wsp_ggml_metal_library_t lib = ctx->lib;
|
|
3441
3647
|
wsp_ggml_metal_encoder_t enc = ctx->enc;
|
|
3442
3648
|
|
|
3649
|
+
WSP_GGML_ASSERT(wsp_ggml_is_contiguous_rows(op->src[0]));
|
|
3650
|
+
|
|
3443
3651
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
3444
3652
|
WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
3445
3653
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
3446
|
-
WSP_GGML_TENSOR_LOCALS(
|
|
3654
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
3655
|
+
|
|
3656
|
+
wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_argsort(lib, op);
|
|
3447
3657
|
|
|
3448
3658
|
// bitonic sort requires the number of elements to be power of 2
|
|
3449
|
-
|
|
3450
|
-
while (
|
|
3451
|
-
|
|
3659
|
+
int nth = 1;
|
|
3660
|
+
while (nth < ne00 && 2*nth <= wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
|
|
3661
|
+
nth *= 2;
|
|
3452
3662
|
}
|
|
3453
3663
|
|
|
3454
|
-
|
|
3455
|
-
|
|
3456
|
-
const int64_t nrows = wsp_ggml_nrows(op->src[0]);
|
|
3664
|
+
const int npr = (ne00 + nth - 1)/nth;
|
|
3457
3665
|
|
|
3458
3666
|
// Metal kernels require the buffer size to be multiple of 16 bytes
|
|
3459
3667
|
// https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/1443142-setthreadgroupmemorylength
|
|
3460
|
-
const size_t smem = WSP_GGML_PAD(
|
|
3668
|
+
const size_t smem = WSP_GGML_PAD(nth*sizeof(int32_t), 16);
|
|
3669
|
+
|
|
3670
|
+
wsp_ggml_metal_buffer_id bid_src0 = wsp_ggml_metal_get_buffer_id(op->src[0]);
|
|
3671
|
+
wsp_ggml_metal_buffer_id bid_dst = wsp_ggml_metal_get_buffer_id(op);
|
|
3672
|
+
|
|
3673
|
+
wsp_ggml_metal_buffer_id bid_tmp = bid_dst;
|
|
3674
|
+
bid_tmp.offs += wsp_ggml_nbytes(op);
|
|
3675
|
+
|
|
3676
|
+
if ((int) ceil(std::log(npr) / std::log(2)) % 2 == 1) {
|
|
3677
|
+
std::swap(bid_dst, bid_tmp);
|
|
3678
|
+
}
|
|
3461
3679
|
|
|
3462
3680
|
wsp_ggml_metal_kargs_argsort args = {
|
|
3463
|
-
/*.
|
|
3464
|
-
/*.
|
|
3681
|
+
/*.ne00 =*/ ne00,
|
|
3682
|
+
/*.ne01 =*/ ne01,
|
|
3683
|
+
/*.ne02 =*/ ne02,
|
|
3684
|
+
/*.ne03 =*/ ne03,
|
|
3685
|
+
/*.nb00 =*/ nb00,
|
|
3686
|
+
/*.nb01 =*/ nb01,
|
|
3687
|
+
/*.nb02 =*/ nb02,
|
|
3688
|
+
/*.nb03 =*/ nb03,
|
|
3465
3689
|
};
|
|
3466
3690
|
|
|
3467
3691
|
wsp_ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
3468
3692
|
wsp_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
3469
|
-
wsp_ggml_metal_encoder_set_buffer (enc,
|
|
3470
|
-
wsp_ggml_metal_encoder_set_buffer (enc,
|
|
3693
|
+
wsp_ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
|
|
3694
|
+
wsp_ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
|
|
3471
3695
|
|
|
3472
3696
|
wsp_ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
|
3473
3697
|
|
|
3474
|
-
wsp_ggml_metal_encoder_dispatch_threadgroups(enc,
|
|
3698
|
+
wsp_ggml_metal_encoder_dispatch_threadgroups(enc, npr*ne01, ne02, ne03, nth, 1, 1);
|
|
3699
|
+
|
|
3700
|
+
wsp_ggml_metal_pipeline_t pipeline_merge = wsp_ggml_metal_library_get_pipeline_argsort_merge(lib, op);
|
|
3701
|
+
|
|
3702
|
+
int len = nth;
|
|
3703
|
+
|
|
3704
|
+
while (len < ne00) {
|
|
3705
|
+
wsp_ggml_metal_op_concurrency_reset(ctx);
|
|
3706
|
+
|
|
3707
|
+
wsp_ggml_metal_kargs_argsort_merge args_merge = {
|
|
3708
|
+
.ne00 = ne00,
|
|
3709
|
+
.ne01 = ne01,
|
|
3710
|
+
.ne02 = ne02,
|
|
3711
|
+
.ne03 = ne03,
|
|
3712
|
+
.nb00 = nb00,
|
|
3713
|
+
.nb01 = nb01,
|
|
3714
|
+
.nb02 = nb02,
|
|
3715
|
+
.nb03 = nb03,
|
|
3716
|
+
.len = len,
|
|
3717
|
+
};
|
|
3718
|
+
|
|
3719
|
+
// merges per row
|
|
3720
|
+
const int nm = (ne00 + 2*len - 1) / (2*len);
|
|
3721
|
+
|
|
3722
|
+
const int nth = std::min(512, wsp_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline_merge));
|
|
3723
|
+
|
|
3724
|
+
wsp_ggml_metal_encoder_set_pipeline(enc, pipeline_merge);
|
|
3725
|
+
wsp_ggml_metal_encoder_set_bytes (enc, &args_merge, sizeof(args_merge), 0);
|
|
3726
|
+
wsp_ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
|
|
3727
|
+
wsp_ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
|
|
3728
|
+
wsp_ggml_metal_encoder_set_buffer (enc, bid_tmp, 3);
|
|
3729
|
+
|
|
3730
|
+
wsp_ggml_metal_encoder_dispatch_threadgroups(enc, nm*ne01, ne02, ne03, nth, 1, 1);
|
|
3731
|
+
|
|
3732
|
+
std::swap(bid_dst, bid_tmp);
|
|
3733
|
+
|
|
3734
|
+
len <<= 1;
|
|
3735
|
+
}
|
|
3475
3736
|
|
|
3476
3737
|
return 1;
|
|
3477
3738
|
}
|
|
@@ -3485,7 +3746,7 @@ int wsp_ggml_metal_op_leaky_relu(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
3485
3746
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
3486
3747
|
WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
3487
3748
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
3488
|
-
WSP_GGML_TENSOR_LOCALS(
|
|
3749
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
3489
3750
|
|
|
3490
3751
|
float slope;
|
|
3491
3752
|
memcpy(&slope, op->op_params, sizeof(float));
|
|
@@ -3521,7 +3782,7 @@ int wsp_ggml_metal_op_opt_step_adamw(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
3521
3782
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
3522
3783
|
WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
3523
3784
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
3524
|
-
WSP_GGML_TENSOR_LOCALS(
|
|
3785
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
3525
3786
|
|
|
3526
3787
|
wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_opt_step_adamw(lib, op);
|
|
3527
3788
|
|
|
@@ -3557,7 +3818,7 @@ int wsp_ggml_metal_op_opt_step_sgd(wsp_ggml_metal_op_t ctx, int idx) {
|
|
|
3557
3818
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
3558
3819
|
WSP_GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
3559
3820
|
WSP_GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
3560
|
-
WSP_GGML_TENSOR_LOCALS(
|
|
3821
|
+
WSP_GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
3561
3822
|
|
|
3562
3823
|
wsp_ggml_metal_pipeline_t pipeline = wsp_ggml_metal_library_get_pipeline_opt_step_sgd(lib, op);
|
|
3563
3824
|
|