whisper.rn 0.5.3 → 0.5.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +1 -1
- package/android/src/main/java/com/rnwhisper/WhisperContext.java +5 -0
- package/android/src/main/jni.cpp +13 -0
- package/cpp/ggml-alloc.c +78 -26
- package/cpp/ggml-alloc.h +9 -0
- package/cpp/ggml-backend-impl.h +1 -1
- package/cpp/ggml-backend-reg.cpp +19 -3
- package/cpp/ggml-backend.cpp +72 -20
- package/cpp/ggml-backend.h +2 -1
- package/cpp/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
- package/cpp/ggml-cpu/arch/arm/repack.cpp +1004 -0
- package/cpp/ggml-cpu/arch/x86/repack.cpp +6 -6
- package/cpp/ggml-cpu/arch-fallback.h +50 -2
- package/cpp/ggml-cpu/ggml-cpu-impl.h +1 -1
- package/cpp/ggml-cpu/ggml-cpu.c +139 -58
- package/cpp/ggml-cpu/ggml-cpu.cpp +4 -0
- package/cpp/ggml-cpu/ops.cpp +170 -18
- package/cpp/ggml-cpu/ops.h +1 -0
- package/cpp/ggml-cpu/repack.cpp +531 -5
- package/cpp/ggml-cpu/repack.h +14 -0
- package/cpp/ggml-cpu/simd-mappings.h +16 -18
- package/cpp/ggml-cpu/vec.cpp +41 -1
- package/cpp/ggml-cpu/vec.h +241 -138
- package/cpp/ggml-cpu.h +1 -0
- package/cpp/ggml-impl.h +0 -4
- package/cpp/ggml-metal/ggml-metal-context.m +26 -16
- package/cpp/ggml-metal/ggml-metal-device.cpp +452 -371
- package/cpp/ggml-metal/ggml-metal-device.h +87 -65
- package/cpp/ggml-metal/ggml-metal-device.m +263 -104
- package/cpp/ggml-metal/ggml-metal-impl.h +58 -4
- package/cpp/ggml-metal/ggml-metal-ops.cpp +415 -98
- package/cpp/ggml-metal/ggml-metal-ops.h +4 -0
- package/cpp/ggml-metal/ggml-metal.cpp +6 -5
- package/cpp/ggml-metal/ggml-metal.metal +404 -34
- package/cpp/ggml.c +110 -31
- package/cpp/ggml.h +51 -12
- package/cpp/jsi/RNWhisperJSI.cpp +1 -0
- package/cpp/whisper.cpp +17 -4
- package/ios/CMakeLists.txt +21 -1
- package/ios/RNWhisperContext.mm +5 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-alloc.h +9 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +1 -1
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend.h +2 -1
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-cpu.h +1 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-impl.h +0 -4
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml.h +51 -12
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Info.plist +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/ggml-metal.metal +404 -34
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/rnwhisper +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-alloc.h +9 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +1 -1
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +2 -1
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpu.h +1 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +0 -4
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +51 -12
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +1 -1
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/ggml-metal.metal +404 -34
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-alloc.h +9 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +1 -1
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend.h +2 -1
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-cpu.h +1 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-impl.h +0 -4
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml.h +51 -12
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Info.plist +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/ggml-metal.metal +404 -34
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/rnwhisper +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-alloc.h +9 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +1 -1
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +2 -1
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpu.h +1 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +0 -4
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +51 -12
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +1 -1
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/ggml-metal.metal +404 -34
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
- package/lib/commonjs/NativeRNWhisper.js.map +1 -1
- package/lib/commonjs/jest-mock.js +2 -0
- package/lib/commonjs/jest-mock.js.map +1 -1
- package/lib/commonjs/realtime-transcription/RealtimeTranscriber.js +156 -12
- package/lib/commonjs/realtime-transcription/RealtimeTranscriber.js.map +1 -1
- package/lib/commonjs/version.json +1 -1
- package/lib/module/NativeRNWhisper.js.map +1 -1
- package/lib/module/jest-mock.js +2 -0
- package/lib/module/jest-mock.js.map +1 -1
- package/lib/module/realtime-transcription/RealtimeTranscriber.js +155 -12
- package/lib/module/realtime-transcription/RealtimeTranscriber.js.map +1 -1
- package/lib/module/version.json +1 -1
- package/lib/typescript/NativeRNWhisper.d.ts +1 -0
- package/lib/typescript/NativeRNWhisper.d.ts.map +1 -1
- package/lib/typescript/realtime-transcription/RealtimeTranscriber.d.ts +29 -0
- package/lib/typescript/realtime-transcription/RealtimeTranscriber.d.ts.map +1 -1
- package/lib/typescript/realtime-transcription/types.d.ts +7 -0
- package/lib/typescript/realtime-transcription/types.d.ts.map +1 -1
- package/package.json +1 -1
- package/src/NativeRNWhisper.ts +1 -0
- package/src/jest-mock.ts +2 -0
- package/src/realtime-transcription/RealtimeTranscriber.ts +179 -9
- package/src/realtime-transcription/types.ts +9 -0
- package/src/version.json +1 -1
|
@@ -1957,6 +1957,8 @@ WSP_GGML_TABLE_END()
|
|
|
1957
1957
|
#define FC_MUL_MV 600
|
|
1958
1958
|
#define FC_MUL_MM 700
|
|
1959
1959
|
#define FC_ROPE 800
|
|
1960
|
+
#define FC_SSM_CONV 900
|
|
1961
|
+
#define FC_COUNT_EQUAL 1000
|
|
1960
1962
|
|
|
1961
1963
|
// op-specific constants
|
|
1962
1964
|
#define OP_FLASH_ATTN_EXT_NQPTG 8
|
|
@@ -2062,6 +2064,10 @@ typedef struct {
|
|
|
2062
2064
|
float bias;
|
|
2063
2065
|
} wsp_ggml_metal_kargs_scale;
|
|
2064
2066
|
|
|
2067
|
+
typedef struct {
|
|
2068
|
+
float val;
|
|
2069
|
+
} wsp_ggml_metal_kargs_fill;
|
|
2070
|
+
|
|
2065
2071
|
typedef struct {
|
|
2066
2072
|
float min;
|
|
2067
2073
|
float max;
|
|
@@ -2712,14 +2718,38 @@ typedef struct {
|
|
|
2712
2718
|
} wsp_ggml_metal_kargs_leaky_relu;
|
|
2713
2719
|
|
|
2714
2720
|
typedef struct {
|
|
2715
|
-
|
|
2716
|
-
|
|
2717
|
-
|
|
2718
|
-
|
|
2721
|
+
int32_t ne00;
|
|
2722
|
+
int32_t ne01;
|
|
2723
|
+
int32_t ne02;
|
|
2724
|
+
int32_t ne03;
|
|
2719
2725
|
uint64_t nb00;
|
|
2720
2726
|
uint64_t nb01;
|
|
2721
2727
|
uint64_t nb02;
|
|
2722
2728
|
uint64_t nb03;
|
|
2729
|
+
int32_t ne0;
|
|
2730
|
+
int32_t ne1;
|
|
2731
|
+
int32_t ne2;
|
|
2732
|
+
int32_t ne3;
|
|
2733
|
+
uint64_t nb0;
|
|
2734
|
+
uint64_t nb1;
|
|
2735
|
+
uint64_t nb2;
|
|
2736
|
+
uint64_t nb3;
|
|
2737
|
+
} wsp_ggml_metal_kargs_tri;
|
|
2738
|
+
|
|
2739
|
+
typedef struct {
|
|
2740
|
+
int32_t ne00;
|
|
2741
|
+
int32_t ne01;
|
|
2742
|
+
int32_t ne02;
|
|
2743
|
+
int32_t ne03;
|
|
2744
|
+
uint64_t nb00;
|
|
2745
|
+
uint64_t nb01;
|
|
2746
|
+
uint64_t nb02;
|
|
2747
|
+
uint64_t nb03;
|
|
2748
|
+
int32_t ne0;
|
|
2749
|
+
int32_t ne1;
|
|
2750
|
+
int32_t ne2;
|
|
2751
|
+
int32_t ne3;
|
|
2752
|
+
int32_t top_k;
|
|
2723
2753
|
} wsp_ggml_metal_kargs_argsort;
|
|
2724
2754
|
|
|
2725
2755
|
typedef struct {
|
|
@@ -2731,6 +2761,11 @@ typedef struct {
|
|
|
2731
2761
|
uint64_t nb01;
|
|
2732
2762
|
uint64_t nb02;
|
|
2733
2763
|
uint64_t nb03;
|
|
2764
|
+
int32_t ne0;
|
|
2765
|
+
int32_t ne1;
|
|
2766
|
+
int32_t ne2;
|
|
2767
|
+
int32_t ne3;
|
|
2768
|
+
int32_t top_k;
|
|
2734
2769
|
int32_t len;
|
|
2735
2770
|
} wsp_ggml_metal_kargs_argsort_merge;
|
|
2736
2771
|
|
|
@@ -2740,6 +2775,25 @@ typedef struct {
|
|
|
2740
2775
|
float step;
|
|
2741
2776
|
} wsp_ggml_metal_kargs_arange;
|
|
2742
2777
|
|
|
2778
|
+
typedef struct {
|
|
2779
|
+
int64_t val;
|
|
2780
|
+
} wsp_ggml_metal_kargs_memset;
|
|
2781
|
+
|
|
2782
|
+
typedef struct {
|
|
2783
|
+
int32_t ne00;
|
|
2784
|
+
int32_t ne01;
|
|
2785
|
+
int32_t ne02;
|
|
2786
|
+
int32_t ne03;
|
|
2787
|
+
uint64_t nb00;
|
|
2788
|
+
uint64_t nb01;
|
|
2789
|
+
uint64_t nb02;
|
|
2790
|
+
uint64_t nb03;
|
|
2791
|
+
uint64_t nb10;
|
|
2792
|
+
uint64_t nb11;
|
|
2793
|
+
uint64_t nb12;
|
|
2794
|
+
uint64_t nb13;
|
|
2795
|
+
} wsp_ggml_metal_kargs_count_equal;
|
|
2796
|
+
|
|
2743
2797
|
typedef struct {
|
|
2744
2798
|
int32_t k0;
|
|
2745
2799
|
int32_t k1;
|
|
@@ -4011,6 +4065,22 @@ kernel void kernel_scale_f32_4(
|
|
|
4011
4065
|
dst[tpig] = src0[tpig] * args.scale + args.bias;
|
|
4012
4066
|
}
|
|
4013
4067
|
|
|
4068
|
+
kernel void kernel_fill_f32(
|
|
4069
|
+
constant wsp_ggml_metal_kargs_fill & args,
|
|
4070
|
+
device const float * src0,
|
|
4071
|
+
device float * dst,
|
|
4072
|
+
uint tpig[[thread_position_in_grid]]) {
|
|
4073
|
+
dst[tpig] = args.val;
|
|
4074
|
+
}
|
|
4075
|
+
|
|
4076
|
+
kernel void kernel_fill_f32_4(
|
|
4077
|
+
constant wsp_ggml_metal_kargs_fill & args,
|
|
4078
|
+
device const float4 * src0,
|
|
4079
|
+
device float4 * dst,
|
|
4080
|
+
uint tpig[[thread_position_in_grid]]) {
|
|
4081
|
+
dst[tpig] = args.val;
|
|
4082
|
+
}
|
|
4083
|
+
|
|
4014
4084
|
kernel void kernel_clamp_f32(
|
|
4015
4085
|
constant wsp_ggml_metal_kargs_clamp & args,
|
|
4016
4086
|
device const float * src0,
|
|
@@ -4357,6 +4427,36 @@ kernel void kernel_exp_f32_4(
|
|
|
4357
4427
|
dst[tpig] = exp(src0[tpig]);
|
|
4358
4428
|
}
|
|
4359
4429
|
|
|
4430
|
+
kernel void kernel_softplus_f32(
|
|
4431
|
+
device const float * src0,
|
|
4432
|
+
device float * dst,
|
|
4433
|
+
uint tpig[[thread_position_in_grid]]) {
|
|
4434
|
+
device const float & x = src0[tpig];
|
|
4435
|
+
dst[tpig] = select(log(1.0f + exp(x)), x, x > 20.0f);
|
|
4436
|
+
}
|
|
4437
|
+
|
|
4438
|
+
kernel void kernel_softplus_f32_4(
|
|
4439
|
+
device const float4 * src0,
|
|
4440
|
+
device float4 * dst,
|
|
4441
|
+
uint tpig[[thread_position_in_grid]]) {
|
|
4442
|
+
device const float4 & x = src0[tpig];
|
|
4443
|
+
dst[tpig] = select(log(1.0f + exp(x)), x, x > 20.0f);
|
|
4444
|
+
}
|
|
4445
|
+
|
|
4446
|
+
kernel void kernel_expm1_f32(
|
|
4447
|
+
device const float * src0,
|
|
4448
|
+
device float * dst,
|
|
4449
|
+
uint tpig[[thread_position_in_grid]]) {
|
|
4450
|
+
dst[tpig] = exp(src0[tpig]) - 1.0f;
|
|
4451
|
+
}
|
|
4452
|
+
|
|
4453
|
+
kernel void kernel_expm1_f32_4(
|
|
4454
|
+
device const float4 * src0,
|
|
4455
|
+
device float4 * dst,
|
|
4456
|
+
uint tpig[[thread_position_in_grid]]) {
|
|
4457
|
+
dst[tpig] = exp(src0[tpig]) - 1.0f;
|
|
4458
|
+
}
|
|
4459
|
+
|
|
4360
4460
|
kernel void kernel_reglu_f32(
|
|
4361
4461
|
constant wsp_ggml_metal_kargs_glu & args,
|
|
4362
4462
|
device const char * src0,
|
|
@@ -4506,6 +4606,7 @@ kernel void kernel_op_sum_f32(
|
|
|
4506
4606
|
return;
|
|
4507
4607
|
}
|
|
4508
4608
|
|
|
4609
|
+
// TODO: become function constant
|
|
4509
4610
|
const uint nsg = (ntg.x + 31) / 32;
|
|
4510
4611
|
|
|
4511
4612
|
float sumf = 0;
|
|
@@ -4705,6 +4806,75 @@ typedef decltype(kernel_cumsum_add<float>) kernel_cumsum_add_t;
|
|
|
4705
4806
|
|
|
4706
4807
|
template [[host_name("kernel_cumsum_add_f32")]] kernel kernel_cumsum_add_t kernel_cumsum_add<float>;
|
|
4707
4808
|
|
|
4809
|
+
|
|
4810
|
+
template<uint32_t ttype>
|
|
4811
|
+
bool _wsp_ggml_vec_tri_cmp(const int i, const int r);
|
|
4812
|
+
|
|
4813
|
+
template<>
|
|
4814
|
+
bool _wsp_ggml_vec_tri_cmp</* WSP_GGML_TRI_TYPE_LOWER */ 3>(const int i, const int r) {
|
|
4815
|
+
return i < r;
|
|
4816
|
+
}
|
|
4817
|
+
|
|
4818
|
+
template<>
|
|
4819
|
+
bool _wsp_ggml_vec_tri_cmp</* WSP_GGML_TRI_TYPE_LOWER_DIAG */ 2>(const int i, const int r) {
|
|
4820
|
+
return i <= r;
|
|
4821
|
+
}
|
|
4822
|
+
|
|
4823
|
+
template<>
|
|
4824
|
+
bool _wsp_ggml_vec_tri_cmp</* WSP_GGML_TRI_TYPE_UPPER */ 1>(const int i, const int r) {
|
|
4825
|
+
return i > r;
|
|
4826
|
+
}
|
|
4827
|
+
|
|
4828
|
+
template<>
|
|
4829
|
+
bool _wsp_ggml_vec_tri_cmp</* WSP_GGML_TRI_TYPE_UPPER_DIAG */ 0>(const int i, const int r) {
|
|
4830
|
+
return i >= r;
|
|
4831
|
+
}
|
|
4832
|
+
|
|
4833
|
+
template<typename T, int ttype>
|
|
4834
|
+
kernel void kernel_tri(
|
|
4835
|
+
constant wsp_ggml_metal_kargs_tri & args,
|
|
4836
|
+
device const char * src0,
|
|
4837
|
+
device const char * dst,
|
|
4838
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
4839
|
+
ushort3 tpitg[[thread_position_in_threadgroup]],
|
|
4840
|
+
ushort3 ntg[[threads_per_threadgroup]]) {
|
|
4841
|
+
const int i3 = tgpig.z;
|
|
4842
|
+
const int i2 = tgpig.y;
|
|
4843
|
+
const int i1 = tgpig.x;
|
|
4844
|
+
|
|
4845
|
+
if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {
|
|
4846
|
+
return;
|
|
4847
|
+
}
|
|
4848
|
+
|
|
4849
|
+
device const T * src_row = (device const T *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
|
|
4850
|
+
device T * dst_row = (device T *) ((device char *) dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3);
|
|
4851
|
+
|
|
4852
|
+
// Each thread is a single element of the row if ne00 < max threads per
|
|
4853
|
+
// threadgroup, so this will loop once for each index that this thread is
|
|
4854
|
+
// responsible for
|
|
4855
|
+
for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
|
|
4856
|
+
// Use the comparison as a mask for branchless
|
|
4857
|
+
dst_row[i0] = static_cast<T>(_wsp_ggml_vec_tri_cmp<ttype>(i0, i1)) * src_row[i0];
|
|
4858
|
+
}
|
|
4859
|
+
}
|
|
4860
|
+
|
|
4861
|
+
typedef decltype(kernel_tri<float, 0>) kernel_tri_t;
|
|
4862
|
+
|
|
4863
|
+
template [[host_name("kernel_tri_f32_0")]] kernel kernel_tri_t kernel_tri<float, 0>;
|
|
4864
|
+
template [[host_name("kernel_tri_f32_1")]] kernel kernel_tri_t kernel_tri<float, 1>;
|
|
4865
|
+
template [[host_name("kernel_tri_f32_2")]] kernel kernel_tri_t kernel_tri<float, 2>;
|
|
4866
|
+
template [[host_name("kernel_tri_f32_3")]] kernel kernel_tri_t kernel_tri<float, 3>;
|
|
4867
|
+
template [[host_name("kernel_tri_f16_0")]] kernel kernel_tri_t kernel_tri<half, 0>;
|
|
4868
|
+
template [[host_name("kernel_tri_f16_1")]] kernel kernel_tri_t kernel_tri<half, 1>;
|
|
4869
|
+
template [[host_name("kernel_tri_f16_2")]] kernel kernel_tri_t kernel_tri<half, 2>;
|
|
4870
|
+
template [[host_name("kernel_tri_f16_3")]] kernel kernel_tri_t kernel_tri<half, 3>;
|
|
4871
|
+
#if defined(WSP_GGML_METAL_HAS_BF16)
|
|
4872
|
+
template [[host_name("kernel_tri_bf16_0")]] kernel kernel_tri_t kernel_tri<bfloat, 0>;
|
|
4873
|
+
template [[host_name("kernel_tri_bf16_1")]] kernel kernel_tri_t kernel_tri<bfloat, 1>;
|
|
4874
|
+
template [[host_name("kernel_tri_bf16_2")]] kernel kernel_tri_t kernel_tri<bfloat, 2>;
|
|
4875
|
+
template [[host_name("kernel_tri_bf16_3")]] kernel kernel_tri_t kernel_tri<bfloat, 3>;
|
|
4876
|
+
#endif
|
|
4877
|
+
|
|
4708
4878
|
template<typename T>
|
|
4709
4879
|
kernel void kernel_soft_max(
|
|
4710
4880
|
constant wsp_ggml_metal_kargs_soft_max & args,
|
|
@@ -4990,7 +5160,102 @@ kernel void kernel_ssm_conv_f32_f32_4(
|
|
|
4990
5160
|
x[0] = sumf;
|
|
4991
5161
|
}
|
|
4992
5162
|
|
|
5163
|
+
constant short FC_ssm_conv_bs [[function_constant(FC_SSM_CONV + 0)]];
|
|
5164
|
+
|
|
5165
|
+
// Batched version: each threadgroup processes multiple tokens for better efficiency
|
|
5166
|
+
// Thread layout: each thread handles one token, threadgroup covers BATCH_SIZE tokens
|
|
5167
|
+
kernel void kernel_ssm_conv_f32_f32_batched(
|
|
5168
|
+
constant wsp_ggml_metal_kargs_ssm_conv & args,
|
|
5169
|
+
device const void * src0,
|
|
5170
|
+
device const void * src1,
|
|
5171
|
+
device float * dst,
|
|
5172
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
5173
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
5174
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
|
5175
|
+
// tgpig.x = row index (ir)
|
|
5176
|
+
// tgpig.y = batch of tokens (i2_base / BATCH_SIZE)
|
|
5177
|
+
// tgpig.z = sequence index (i3)
|
|
5178
|
+
// tpitg.x = thread within batch (0..BATCH_SIZE-1)
|
|
5179
|
+
const short BATCH_SIZE = FC_ssm_conv_bs;
|
|
5180
|
+
|
|
5181
|
+
const int64_t ir = tgpig.x;
|
|
5182
|
+
const int64_t i2_base = tgpig.y * BATCH_SIZE;
|
|
5183
|
+
const int64_t i3 = tgpig.z;
|
|
5184
|
+
const int64_t i2_off = tpitg.x;
|
|
5185
|
+
const int64_t i2 = i2_base + i2_off;
|
|
5186
|
+
|
|
5187
|
+
const int64_t nc = args.ne10; // conv kernel size (typically 4)
|
|
5188
|
+
const int64_t n_t = args.ne1; // number of tokens
|
|
5189
|
+
|
|
5190
|
+
// Bounds check for partial batches at the end
|
|
5191
|
+
if (i2 >= n_t) {
|
|
5192
|
+
return;
|
|
5193
|
+
}
|
|
5194
|
+
|
|
5195
|
+
// Load conv weights (shared across all tokens for this row)
|
|
5196
|
+
device const float * c = (device const float *) ((device const char *) src1 + ir*args.nb11);
|
|
5197
|
+
|
|
5198
|
+
// Load source for this specific token
|
|
5199
|
+
device const float * s = (device const float *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02);
|
|
5200
|
+
|
|
5201
|
+
// Output location for this token
|
|
5202
|
+
device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2);
|
|
5203
|
+
|
|
5204
|
+
float sumf = 0.0f;
|
|
5205
|
+
for (int64_t i0 = 0; i0 < nc; ++i0) {
|
|
5206
|
+
sumf += s[i0] * c[i0];
|
|
5207
|
+
}
|
|
5208
|
+
|
|
5209
|
+
x[0] = sumf;
|
|
5210
|
+
}
|
|
5211
|
+
|
|
5212
|
+
kernel void kernel_ssm_conv_f32_f32_batched_4(
|
|
5213
|
+
constant wsp_ggml_metal_kargs_ssm_conv & args,
|
|
5214
|
+
device const void * src0,
|
|
5215
|
+
device const void * src1,
|
|
5216
|
+
device float * dst,
|
|
5217
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
5218
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
5219
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
|
5220
|
+
// tgpig.x = row index (ir)
|
|
5221
|
+
// tgpig.y = batch of tokens (i2_base / BATCH_SIZE)
|
|
5222
|
+
// tgpig.z = sequence index (i3)
|
|
5223
|
+
// tpitg.x = thread within batch (0..BATCH_SIZE-1)
|
|
5224
|
+
const short BATCH_SIZE = FC_ssm_conv_bs;
|
|
5225
|
+
|
|
5226
|
+
const int64_t ir = tgpig.x;
|
|
5227
|
+
const int64_t i2_base = tgpig.y * BATCH_SIZE;
|
|
5228
|
+
const int64_t i3 = tgpig.z;
|
|
5229
|
+
const int64_t i2_off = tpitg.x;
|
|
5230
|
+
const int64_t i2 = i2_base + i2_off;
|
|
5231
|
+
|
|
5232
|
+
const int64_t nc = args.ne10; // conv kernel size (typically 4)
|
|
5233
|
+
const int64_t n_t = args.ne1; // number of tokens
|
|
5234
|
+
|
|
5235
|
+
// Bounds check for partial batches at the end
|
|
5236
|
+
if (i2 >= n_t) {
|
|
5237
|
+
return;
|
|
5238
|
+
}
|
|
5239
|
+
|
|
5240
|
+
// Load conv weights (shared across all tokens for this row)
|
|
5241
|
+
device const float4 * c = (device const float4 *) ((device const char *) src1 + ir*args.nb11);
|
|
5242
|
+
|
|
5243
|
+
// Load source for this specific token
|
|
5244
|
+
device const float4 * s = (device const float4 *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02);
|
|
5245
|
+
|
|
5246
|
+
// Output location for this token
|
|
5247
|
+
device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2);
|
|
5248
|
+
|
|
5249
|
+
float sumf = 0.0f;
|
|
5250
|
+
for (int64_t i0 = 0; i0 < nc/4; ++i0) {
|
|
5251
|
+
sumf += dot(s[i0], c[i0]);
|
|
5252
|
+
}
|
|
5253
|
+
|
|
5254
|
+
x[0] = sumf;
|
|
5255
|
+
}
|
|
5256
|
+
|
|
4993
5257
|
// ref: ggml.c:wsp_ggml_compute_forward_ssm_scan_f32, Mamba-2 part
|
|
5258
|
+
// Optimized version: reduces redundant memory loads by having one thread load shared values
|
|
4994
5259
|
kernel void kernel_ssm_scan_f32(
|
|
4995
5260
|
constant wsp_ggml_metal_kargs_ssm_scan & args,
|
|
4996
5261
|
device const void * src0,
|
|
@@ -5010,7 +5275,15 @@ kernel void kernel_ssm_scan_f32(
|
|
|
5010
5275
|
uint3 tgpg[[threadgroups_per_grid]]) {
|
|
5011
5276
|
constexpr short NW = N_SIMDWIDTH;
|
|
5012
5277
|
|
|
5013
|
-
|
|
5278
|
+
// Shared memory layout:
|
|
5279
|
+
// [0..sgptg*NW-1]: partial sums for reduction (existing)
|
|
5280
|
+
// [sgptg*NW..sgptg*NW+sgptg-1]: pre-computed x_dt values for each token in batch
|
|
5281
|
+
// [sgptg*NW+sgptg..sgptg*NW+2*sgptg-1]: pre-computed dA values for each token in batch
|
|
5282
|
+
threadgroup float * shared_sums = shared;
|
|
5283
|
+
threadgroup float * shared_x_dt = shared + sgptg * NW;
|
|
5284
|
+
threadgroup float * shared_dA = shared + sgptg * NW + sgptg;
|
|
5285
|
+
|
|
5286
|
+
shared_sums[tpitg.x] = 0.0f;
|
|
5014
5287
|
|
|
5015
5288
|
const int32_t i0 = tpitg.x;
|
|
5016
5289
|
const int32_t i1 = tgpig.x;
|
|
@@ -5050,32 +5323,47 @@ kernel void kernel_ssm_scan_f32(
|
|
|
5050
5323
|
for (int i2 = 0; i2 < n_t; i2 += sgptg) {
|
|
5051
5324
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
5052
5325
|
|
|
5053
|
-
|
|
5054
|
-
|
|
5326
|
+
// Pre-compute x_dt and dA for this batch of tokens
|
|
5327
|
+
// Only first sgptg threads do the loads and expensive math
|
|
5328
|
+
if (i0 < sgptg && i2 + i0 < n_t) {
|
|
5329
|
+
// ns12 and ns21 are element strides (nb12/nb10, nb21/nb20)
|
|
5330
|
+
device const float * x_t = x + i0 * args.ns12;
|
|
5331
|
+
device const float * dt_t = dt + i0 * args.ns21;
|
|
5332
|
+
|
|
5333
|
+
const float dt0 = dt_t[0];
|
|
5055
5334
|
const float dtsp = dt0 <= 20.0f ? log(1.0f + exp(dt0)) : dt0;
|
|
5056
|
-
|
|
5057
|
-
|
|
5335
|
+
shared_x_dt[i0] = x_t[0] * dtsp;
|
|
5336
|
+
shared_dA[i0] = dtsp; // Store dtsp, compute exp(dtsp * A0) per-thread since A0 varies
|
|
5337
|
+
}
|
|
5338
|
+
|
|
5339
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
5340
|
+
|
|
5341
|
+
for (int t = 0; t < sgptg && i2 + t < n_t; t++) {
|
|
5342
|
+
const float x_dt = shared_x_dt[t];
|
|
5343
|
+
const float dA = exp(shared_dA[t] * A0);
|
|
5058
5344
|
|
|
5059
5345
|
s = (s0 * dA) + (B[i0] * x_dt);
|
|
5060
5346
|
|
|
5061
5347
|
const float sumf = simd_sum(s * C[i0]);
|
|
5062
5348
|
|
|
5063
5349
|
if (tiisg == 0) {
|
|
5064
|
-
|
|
5350
|
+
shared_sums[t*NW + sgitg] = sumf;
|
|
5065
5351
|
}
|
|
5066
5352
|
|
|
5067
5353
|
// recurse
|
|
5068
5354
|
s0 = s;
|
|
5069
5355
|
|
|
5070
|
-
x += args.ns12;
|
|
5071
|
-
dt += args.ns21;
|
|
5072
5356
|
B += args.ns42;
|
|
5073
5357
|
C += args.ns52;
|
|
5074
5358
|
}
|
|
5075
5359
|
|
|
5360
|
+
// Advance pointers for next batch
|
|
5361
|
+
x += sgptg * args.ns12;
|
|
5362
|
+
dt += sgptg * args.ns21;
|
|
5363
|
+
|
|
5076
5364
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
5077
5365
|
|
|
5078
|
-
const float sumf = simd_sum(
|
|
5366
|
+
const float sumf = simd_sum(shared_sums[sgitg*NW + tiisg]);
|
|
5079
5367
|
|
|
5080
5368
|
if (tiisg == 0 && i2 + sgitg < n_t) {
|
|
5081
5369
|
y[sgitg*nh*nr] = sumf;
|
|
@@ -7432,11 +7720,12 @@ kernel void kernel_argsort_f32_i32(
|
|
|
7432
7720
|
ushort3 ntg[[threads_per_threadgroup]]) {
|
|
7433
7721
|
// bitonic sort
|
|
7434
7722
|
const int col = tpitg[0];
|
|
7723
|
+
const int ib = tgpig[0] / args.ne01;
|
|
7435
7724
|
|
|
7436
|
-
const int i00 =
|
|
7437
|
-
const int i01 =
|
|
7438
|
-
const int i02 =
|
|
7439
|
-
const int i03 =
|
|
7725
|
+
const int i00 = ib*ntg.x;
|
|
7726
|
+
const int i01 = tgpig[0] % args.ne01;
|
|
7727
|
+
const int i02 = tgpig[1];
|
|
7728
|
+
const int i03 = tgpig[2];
|
|
7440
7729
|
|
|
7441
7730
|
device const float * src0_row = (device const float *) (src0 + args.nb01*i01 + args.nb02*i02 + args.nb03*i03);
|
|
7442
7731
|
|
|
@@ -7472,9 +7761,11 @@ kernel void kernel_argsort_f32_i32(
|
|
|
7472
7761
|
}
|
|
7473
7762
|
}
|
|
7474
7763
|
|
|
7764
|
+
const int64_t i0 = ib*args.top_k;
|
|
7765
|
+
|
|
7475
7766
|
// copy the result to dst without the padding
|
|
7476
|
-
if (
|
|
7477
|
-
dst +=
|
|
7767
|
+
if (i0 + col < args.ne0 && col < args.top_k) {
|
|
7768
|
+
dst += i0 + args.ne0*i01 + args.ne0*args.ne1*i02 + args.ne0*args.ne1*args.ne2*i03;
|
|
7478
7769
|
|
|
7479
7770
|
dst[col] = shmem_i32[col];
|
|
7480
7771
|
}
|
|
@@ -7509,22 +7800,22 @@ kernel void kernel_argsort_merge_f32_i32(
|
|
|
7509
7800
|
|
|
7510
7801
|
const int start = im * (2 * args.len);
|
|
7511
7802
|
|
|
7512
|
-
const int len0 = MIN(args.len, MAX(0, args.
|
|
7513
|
-
const int len1 = MIN(args.len, MAX(0, args.
|
|
7803
|
+
const int len0 = MIN(args.len, MAX(0, args.ne0 - (int)(start)));
|
|
7804
|
+
const int len1 = MIN(args.len, MAX(0, args.ne0 - (int)(start + args.len)));
|
|
7514
7805
|
|
|
7515
7806
|
const int total = len0 + len1;
|
|
7516
7807
|
|
|
7517
7808
|
device const int32_t * tmp0 = tmp + start
|
|
7518
|
-
+ i01*args.
|
|
7519
|
-
+ i02*args.
|
|
7520
|
-
+ i03*args.
|
|
7809
|
+
+ i01*args.ne0
|
|
7810
|
+
+ i02*args.ne0*args.ne01
|
|
7811
|
+
+ i03*args.ne0*args.ne01*args.ne02;
|
|
7521
7812
|
|
|
7522
7813
|
device const int32_t * tmp1 = tmp0 + args.len;
|
|
7523
7814
|
|
|
7524
7815
|
dst += start
|
|
7525
|
-
+ i01*args.
|
|
7526
|
-
+ i02*args.
|
|
7527
|
-
+ i03*args.
|
|
7816
|
+
+ i01*args.top_k
|
|
7817
|
+
+ i02*args.top_k*args.ne01
|
|
7818
|
+
+ i03*args.top_k*args.ne01*args.ne02;
|
|
7528
7819
|
|
|
7529
7820
|
device const float * src0_row = (device const float *)(src0
|
|
7530
7821
|
+ args.nb01*i01
|
|
@@ -7538,7 +7829,11 @@ kernel void kernel_argsort_merge_f32_i32(
|
|
|
7538
7829
|
const int chunk = (total + ntg.x - 1) / ntg.x;
|
|
7539
7830
|
|
|
7540
7831
|
const int k0 = tpitg.x * chunk;
|
|
7541
|
-
const int k1 =
|
|
7832
|
+
const int k1 = MIN(MIN(k0 + chunk, total), args.top_k);
|
|
7833
|
+
|
|
7834
|
+
if (k0 >= args.top_k) {
|
|
7835
|
+
return;
|
|
7836
|
+
}
|
|
7542
7837
|
|
|
7543
7838
|
if (k0 >= total) {
|
|
7544
7839
|
return;
|
|
@@ -8512,6 +8807,7 @@ typedef decltype(kernel_flash_attn_ext<FA_TYPES, half4x4, 1, wsp_dewsp_quantize_
|
|
|
8512
8807
|
|
|
8513
8808
|
template [[host_name("kernel_flash_attn_ext_f32_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, wsp_dewsp_quantize_f32, float4x4, 1, wsp_dewsp_quantize_f32, 32, 32>;
|
|
8514
8809
|
template [[host_name("kernel_flash_attn_ext_f32_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, wsp_dewsp_quantize_f32, float4x4, 1, wsp_dewsp_quantize_f32, 40, 40>;
|
|
8810
|
+
template [[host_name("kernel_flash_attn_ext_f32_dk48_dv48" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, wsp_dewsp_quantize_f32, float4x4, 1, wsp_dewsp_quantize_f32, 48, 48>;
|
|
8515
8811
|
template [[host_name("kernel_flash_attn_ext_f32_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, wsp_dewsp_quantize_f32, float4x4, 1, wsp_dewsp_quantize_f32, 64, 64>;
|
|
8516
8812
|
template [[host_name("kernel_flash_attn_ext_f32_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, wsp_dewsp_quantize_f32, float4x4, 1, wsp_dewsp_quantize_f32, 72, 72>;
|
|
8517
8813
|
template [[host_name("kernel_flash_attn_ext_f32_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, wsp_dewsp_quantize_f32, float4x4, 1, wsp_dewsp_quantize_f32, 80, 80>;
|
|
@@ -8525,6 +8821,7 @@ template [[host_name("kernel_flash_attn_ext_f32_dk576_dv512")]] kernel flash_at
|
|
|
8525
8821
|
|
|
8526
8822
|
template [[host_name("kernel_flash_attn_ext_f16_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, wsp_dewsp_quantize_f16, half4x4, 1, wsp_dewsp_quantize_f16, 32, 32>;
|
|
8527
8823
|
template [[host_name("kernel_flash_attn_ext_f16_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, wsp_dewsp_quantize_f16, half4x4, 1, wsp_dewsp_quantize_f16, 40, 40>;
|
|
8824
|
+
template [[host_name("kernel_flash_attn_ext_f16_dk48_dv48" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, wsp_dewsp_quantize_f16, half4x4, 1, wsp_dewsp_quantize_f16, 48, 48>;
|
|
8528
8825
|
template [[host_name("kernel_flash_attn_ext_f16_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, wsp_dewsp_quantize_f16, half4x4, 1, wsp_dewsp_quantize_f16, 64, 64>;
|
|
8529
8826
|
template [[host_name("kernel_flash_attn_ext_f16_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, wsp_dewsp_quantize_f16, half4x4, 1, wsp_dewsp_quantize_f16, 72, 72>;
|
|
8530
8827
|
template [[host_name("kernel_flash_attn_ext_f16_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, wsp_dewsp_quantize_f16, half4x4, 1, wsp_dewsp_quantize_f16, 80, 80>;
|
|
@@ -8539,6 +8836,7 @@ template [[host_name("kernel_flash_attn_ext_f16_dk576_dv512")]] kernel flash_at
|
|
|
8539
8836
|
#if defined(WSP_GGML_METAL_HAS_BF16)
|
|
8540
8837
|
template [[host_name("kernel_flash_attn_ext_bf16_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, wsp_dewsp_quantize_bf16, bfloat4x4, 1, wsp_dewsp_quantize_bf16, 32, 32>;
|
|
8541
8838
|
template [[host_name("kernel_flash_attn_ext_bf16_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, wsp_dewsp_quantize_bf16, bfloat4x4, 1, wsp_dewsp_quantize_bf16, 40, 40>;
|
|
8839
|
+
template [[host_name("kernel_flash_attn_ext_bf16_dk48_dv48" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, wsp_dewsp_quantize_bf16, bfloat4x4, 1, wsp_dewsp_quantize_bf16, 48, 48>;
|
|
8542
8840
|
template [[host_name("kernel_flash_attn_ext_bf16_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, wsp_dewsp_quantize_bf16, bfloat4x4, 1, wsp_dewsp_quantize_bf16, 64, 64>;
|
|
8543
8841
|
template [[host_name("kernel_flash_attn_ext_bf16_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, wsp_dewsp_quantize_bf16, bfloat4x4, 1, wsp_dewsp_quantize_bf16, 72, 72>;
|
|
8544
8842
|
template [[host_name("kernel_flash_attn_ext_bf16_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, wsp_dewsp_quantize_bf16, bfloat4x4, 1, wsp_dewsp_quantize_bf16, 80, 80>;
|
|
@@ -8553,6 +8851,7 @@ template [[host_name("kernel_flash_attn_ext_bf16_dk576_dv512")]] kernel flash_at
|
|
|
8553
8851
|
|
|
8554
8852
|
template [[host_name("kernel_flash_attn_ext_q4_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, wsp_dewsp_quantize_q4_0, block_q4_0, 2, wsp_dewsp_quantize_q4_0, 32, 32>;
|
|
8555
8853
|
template [[host_name("kernel_flash_attn_ext_q4_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, wsp_dewsp_quantize_q4_0, block_q4_0, 2, wsp_dewsp_quantize_q4_0, 40, 40>;
|
|
8854
|
+
template [[host_name("kernel_flash_attn_ext_q4_0_dk48_dv48" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, wsp_dewsp_quantize_q4_0, block_q4_0, 2, wsp_dewsp_quantize_q4_0, 48, 48>;
|
|
8556
8855
|
template [[host_name("kernel_flash_attn_ext_q4_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, wsp_dewsp_quantize_q4_0, block_q4_0, 2, wsp_dewsp_quantize_q4_0, 64, 64>;
|
|
8557
8856
|
template [[host_name("kernel_flash_attn_ext_q4_0_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, wsp_dewsp_quantize_q4_0, block_q4_0, 2, wsp_dewsp_quantize_q4_0, 72, 72>;
|
|
8558
8857
|
template [[host_name("kernel_flash_attn_ext_q4_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, wsp_dewsp_quantize_q4_0, block_q4_0, 2, wsp_dewsp_quantize_q4_0, 80, 80>;
|
|
@@ -8566,6 +8865,7 @@ template [[host_name("kernel_flash_attn_ext_q4_0_dk576_dv512")]] kernel flash_at
|
|
|
8566
8865
|
|
|
8567
8866
|
template [[host_name("kernel_flash_attn_ext_q4_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, wsp_dewsp_quantize_q4_1, block_q4_1, 2, wsp_dewsp_quantize_q4_1, 32, 32>;
|
|
8568
8867
|
template [[host_name("kernel_flash_attn_ext_q4_1_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, wsp_dewsp_quantize_q4_1, block_q4_1, 2, wsp_dewsp_quantize_q4_1, 40, 40>;
|
|
8868
|
+
template [[host_name("kernel_flash_attn_ext_q4_1_dk48_dv48" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, wsp_dewsp_quantize_q4_1, block_q4_1, 2, wsp_dewsp_quantize_q4_1, 48, 48>;
|
|
8569
8869
|
template [[host_name("kernel_flash_attn_ext_q4_1_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, wsp_dewsp_quantize_q4_1, block_q4_1, 2, wsp_dewsp_quantize_q4_1, 64, 64>;
|
|
8570
8870
|
template [[host_name("kernel_flash_attn_ext_q4_1_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, wsp_dewsp_quantize_q4_1, block_q4_1, 2, wsp_dewsp_quantize_q4_1, 72, 72>;
|
|
8571
8871
|
template [[host_name("kernel_flash_attn_ext_q4_1_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, wsp_dewsp_quantize_q4_1, block_q4_1, 2, wsp_dewsp_quantize_q4_1, 80, 80>;
|
|
@@ -8579,6 +8879,7 @@ template [[host_name("kernel_flash_attn_ext_q4_1_dk576_dv512")]] kernel flash_at
|
|
|
8579
8879
|
|
|
8580
8880
|
template [[host_name("kernel_flash_attn_ext_q5_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, wsp_dewsp_quantize_q5_0, block_q5_0, 2, wsp_dewsp_quantize_q5_0, 32, 32>;
|
|
8581
8881
|
template [[host_name("kernel_flash_attn_ext_q5_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, wsp_dewsp_quantize_q5_0, block_q5_0, 2, wsp_dewsp_quantize_q5_0, 40, 40>;
|
|
8882
|
+
template [[host_name("kernel_flash_attn_ext_q5_0_dk48_dv48" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, wsp_dewsp_quantize_q5_0, block_q5_0, 2, wsp_dewsp_quantize_q5_0, 48, 48>;
|
|
8582
8883
|
template [[host_name("kernel_flash_attn_ext_q5_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, wsp_dewsp_quantize_q5_0, block_q5_0, 2, wsp_dewsp_quantize_q5_0, 64, 64>;
|
|
8583
8884
|
template [[host_name("kernel_flash_attn_ext_q5_0_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, wsp_dewsp_quantize_q5_0, block_q5_0, 2, wsp_dewsp_quantize_q5_0, 72, 72>;
|
|
8584
8885
|
template [[host_name("kernel_flash_attn_ext_q5_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, wsp_dewsp_quantize_q5_0, block_q5_0, 2, wsp_dewsp_quantize_q5_0, 80, 80>;
|
|
@@ -8592,6 +8893,7 @@ template [[host_name("kernel_flash_attn_ext_q5_0_dk576_dv512")]] kernel flash_at
|
|
|
8592
8893
|
|
|
8593
8894
|
template [[host_name("kernel_flash_attn_ext_q5_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, wsp_dewsp_quantize_q5_1, block_q5_1, 2, wsp_dewsp_quantize_q5_1, 32, 32>;
|
|
8594
8895
|
template [[host_name("kernel_flash_attn_ext_q5_1_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, wsp_dewsp_quantize_q5_1, block_q5_1, 2, wsp_dewsp_quantize_q5_1, 40, 40>;
|
|
8896
|
+
template [[host_name("kernel_flash_attn_ext_q5_1_dk48_dv48" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, wsp_dewsp_quantize_q5_1, block_q5_1, 2, wsp_dewsp_quantize_q5_1, 48, 48>;
|
|
8595
8897
|
template [[host_name("kernel_flash_attn_ext_q5_1_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, wsp_dewsp_quantize_q5_1, block_q5_1, 2, wsp_dewsp_quantize_q5_1, 64, 64>;
|
|
8596
8898
|
template [[host_name("kernel_flash_attn_ext_q5_1_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, wsp_dewsp_quantize_q5_1, block_q5_1, 2, wsp_dewsp_quantize_q5_1, 72, 72>;
|
|
8597
8899
|
template [[host_name("kernel_flash_attn_ext_q5_1_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, wsp_dewsp_quantize_q5_1, block_q5_1, 2, wsp_dewsp_quantize_q5_1, 80, 80>;
|
|
@@ -8605,6 +8907,7 @@ template [[host_name("kernel_flash_attn_ext_q5_1_dk576_dv512")]] kernel flash_at
|
|
|
8605
8907
|
|
|
8606
8908
|
template [[host_name("kernel_flash_attn_ext_q8_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, wsp_dewsp_quantize_q8_0, block_q8_0, 2, wsp_dewsp_quantize_q8_0, 32, 32>;
|
|
8607
8909
|
template [[host_name("kernel_flash_attn_ext_q8_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, wsp_dewsp_quantize_q8_0, block_q8_0, 2, wsp_dewsp_quantize_q8_0, 40, 40>;
|
|
8910
|
+
template [[host_name("kernel_flash_attn_ext_q8_0_dk48_dv48" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, wsp_dewsp_quantize_q8_0, block_q8_0, 2, wsp_dewsp_quantize_q8_0, 48, 48>;
|
|
8608
8911
|
template [[host_name("kernel_flash_attn_ext_q8_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, wsp_dewsp_quantize_q8_0, block_q8_0, 2, wsp_dewsp_quantize_q8_0, 64, 64>;
|
|
8609
8912
|
template [[host_name("kernel_flash_attn_ext_q8_0_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, wsp_dewsp_quantize_q8_0, block_q8_0, 2, wsp_dewsp_quantize_q8_0, 72, 72>;
|
|
8610
8913
|
template [[host_name("kernel_flash_attn_ext_q8_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, wsp_dewsp_quantize_q8_0, block_q8_0, 2, wsp_dewsp_quantize_q8_0, 80, 80>;
|
|
@@ -11661,6 +11964,7 @@ typedef decltype(kernel_mul_mm_id_map0<1>) kernel_mul_mm_id_map0_t;
|
|
|
11661
11964
|
template [[host_name("kernel_mul_mm_id_map0_ne20_1" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<1>;
|
|
11662
11965
|
template [[host_name("kernel_mul_mm_id_map0_ne20_2" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<2>;
|
|
11663
11966
|
template [[host_name("kernel_mul_mm_id_map0_ne20_4" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<4>;
|
|
11967
|
+
template [[host_name("kernel_mul_mm_id_map0_ne20_5" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<5>;
|
|
11664
11968
|
template [[host_name("kernel_mul_mm_id_map0_ne20_6" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<6>;
|
|
11665
11969
|
template [[host_name("kernel_mul_mm_id_map0_ne20_8" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<8>;
|
|
11666
11970
|
template [[host_name("kernel_mul_mm_id_map0_ne20_10")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<10>;
|
|
@@ -12071,9 +12375,6 @@ template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mul_mm_t kernel_mul_m
|
|
|
12071
12375
|
|
|
12072
12376
|
template [[host_name("kernel_mul_mm_f32_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, wsp_dewsp_quantize_f32, float, float4x4, half, half2x4>;
|
|
12073
12377
|
template [[host_name("kernel_mul_mm_f16_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, half4x4, 1, wsp_dewsp_quantize_f16, half, half4x4, half, half2x4>;
|
|
12074
|
-
#if defined(WSP_GGML_METAL_HAS_BF16)
|
|
12075
|
-
template [[host_name("kernel_mul_mm_bf16_f16")]] kernel mul_mm_t kernel_mul_mm<bfloat, bfloat4x4, simdgroup_bfloat8x8, half, half2x4, simdgroup_half8x8, bfloat4x4, 1, wsp_dewsp_quantize_bf16, bfloat, bfloat4x4, half, half2x4>;
|
|
12076
|
-
#endif
|
|
12077
12378
|
template [[host_name("kernel_mul_mm_q4_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, wsp_dewsp_quantize_q4_0, float, float4x4, half, half2x4>;
|
|
12078
12379
|
template [[host_name("kernel_mul_mm_q4_1_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, wsp_dewsp_quantize_q4_1, float, float4x4, half, half2x4>;
|
|
12079
12380
|
template [[host_name("kernel_mul_mm_q5_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, wsp_dewsp_quantize_q5_0, float, float4x4, half, half2x4>;
|
|
@@ -12129,9 +12430,6 @@ template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mul_mm_id kernel_m
|
|
|
12129
12430
|
|
|
12130
12431
|
template [[host_name("kernel_mul_mm_id_f32_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, wsp_dewsp_quantize_f32, float, float4x4, half, half2x4>;
|
|
12131
12432
|
template [[host_name("kernel_mul_mm_id_f16_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, half4x4, 1, wsp_dewsp_quantize_f16, half, half4x4, half, half2x4>;
|
|
12132
|
-
#if defined(WSP_GGML_METAL_HAS_BF16)
|
|
12133
|
-
template [[host_name("kernel_mul_mm_id_bf16_f16")]] kernel mul_mm_id kernel_mul_mm_id<bfloat, bfloat4x4, simdgroup_bfloat8x8, half, half2x4, simdgroup_half8x8, bfloat4x4, 1, wsp_dewsp_quantize_bf16, bfloat, bfloat4x4, half, half2x4>;
|
|
12134
|
-
#endif
|
|
12135
12433
|
template [[host_name("kernel_mul_mm_id_q4_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, wsp_dewsp_quantize_q4_0, float, float4x4, half, half2x4>;
|
|
12136
12434
|
template [[host_name("kernel_mul_mm_id_q4_1_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, wsp_dewsp_quantize_q4_1, float, float4x4, half, half2x4>;
|
|
12137
12435
|
template [[host_name("kernel_mul_mm_id_q5_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, wsp_dewsp_quantize_q5_0, float, float4x4, half, half2x4>;
|
|
@@ -12434,3 +12732,75 @@ kernel void kernel_opt_step_sgd_f32(
|
|
|
12434
12732
|
|
|
12435
12733
|
x[gid] = x[gid] * (1.0f - pars[0] * pars[1]) - pars[0] * g[gid];
|
|
12436
12734
|
}
|
|
12735
|
+
|
|
12736
|
+
template<typename T>
|
|
12737
|
+
kernel void kernel_memset(
|
|
12738
|
+
constant wsp_ggml_metal_kargs_fill & args,
|
|
12739
|
+
device T * dst,
|
|
12740
|
+
uint tpig[[thread_position_in_grid]]) {
|
|
12741
|
+
dst[tpig] = args.val;
|
|
12742
|
+
}
|
|
12743
|
+
|
|
12744
|
+
typedef decltype(kernel_memset<int64_t>) kernel_memset_t;
|
|
12745
|
+
|
|
12746
|
+
template [[host_name("kernel_memset_i64")]] kernel kernel_memset_t kernel_memset<int64_t>;
|
|
12747
|
+
|
|
12748
|
+
constant short FC_count_equal_nsg [[function_constant(FC_COUNT_EQUAL + 0)]];
|
|
12749
|
+
|
|
12750
|
+
template<typename T>
|
|
12751
|
+
kernel void kernel_count_equal(
|
|
12752
|
+
constant wsp_ggml_metal_kargs_count_equal & args,
|
|
12753
|
+
device const char * src0,
|
|
12754
|
+
device const char * src1,
|
|
12755
|
+
device atomic_int * dst,
|
|
12756
|
+
threadgroup int32_t * shmem_i32 [[threadgroup(0)]],
|
|
12757
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
12758
|
+
ushort3 tpitg[[thread_position_in_threadgroup]],
|
|
12759
|
+
ushort sgitg[[simdgroup_index_in_threadgroup]],
|
|
12760
|
+
ushort tiisg[[thread_index_in_simdgroup]],
|
|
12761
|
+
ushort3 ntg[[threads_per_threadgroup]]) {
|
|
12762
|
+
const short NSG = FC_count_equal_nsg;
|
|
12763
|
+
|
|
12764
|
+
const int i3 = tgpig.z;
|
|
12765
|
+
const int i2 = tgpig.y;
|
|
12766
|
+
const int i1 = tgpig.x;
|
|
12767
|
+
|
|
12768
|
+
if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {
|
|
12769
|
+
return;
|
|
12770
|
+
}
|
|
12771
|
+
|
|
12772
|
+
int sum = 0;
|
|
12773
|
+
|
|
12774
|
+
device const char * base0 = src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03;
|
|
12775
|
+
device const char * base1 = src1 + i1*args.nb11 + i2*args.nb12 + i3*args.nb13;
|
|
12776
|
+
|
|
12777
|
+
for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
|
|
12778
|
+
const T v0 = *(device const T *)(base0 + i0*args.nb00);
|
|
12779
|
+
const T v1 = *(device const T *)(base1 + i0*args.nb10);
|
|
12780
|
+
sum += (v0 == v1);
|
|
12781
|
+
}
|
|
12782
|
+
|
|
12783
|
+
sum = simd_sum(sum);
|
|
12784
|
+
|
|
12785
|
+
if (tiisg == 0) {
|
|
12786
|
+
shmem_i32[sgitg] = sum;
|
|
12787
|
+
}
|
|
12788
|
+
|
|
12789
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
12790
|
+
|
|
12791
|
+
if (sgitg == 0) {
|
|
12792
|
+
float v = 0.0f;
|
|
12793
|
+
if (tpitg.x < NSG) {
|
|
12794
|
+
v = shmem_i32[tpitg.x];
|
|
12795
|
+
}
|
|
12796
|
+
|
|
12797
|
+
float total = simd_sum(v);
|
|
12798
|
+
if (tpitg.x == 0) {
|
|
12799
|
+
atomic_fetch_add_explicit(dst, (int32_t) total, memory_order_relaxed);
|
|
12800
|
+
}
|
|
12801
|
+
}
|
|
12802
|
+
}
|
|
12803
|
+
|
|
12804
|
+
typedef decltype(kernel_count_equal<int32_t>) kernel_count_equal_t;
|
|
12805
|
+
|
|
12806
|
+
template [[host_name("kernel_count_equal_i32")]] kernel kernel_count_equal_t kernel_count_equal<int32_t>;
|
|
Binary file
|