whisper.rn 0.4.0-rc.5 → 0.4.0-rc.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/src/main/java/com/rnwhisper/RNWhisper.java +5 -5
- package/android/src/main/java/com/rnwhisper/WhisperContext.java +7 -2
- package/android/src/main/jni.cpp +3 -2
- package/cpp/ggml-alloc.h +1 -1
- package/cpp/ggml-metal-whisper.metal +1497 -169
- package/cpp/ggml-metal.m +530 -53
- package/cpp/ggml-quants.c +2 -2
- package/cpp/ggml.c +264 -99
- package/cpp/ggml.h +21 -7
- package/cpp/rn-whisper.cpp +3 -0
- package/cpp/rn-whisper.h +3 -2
- package/ios/RNWhisperContext.mm +10 -6
- package/lib/commonjs/index.js.map +1 -1
- package/lib/commonjs/version.json +1 -1
- package/lib/module/index.js.map +1 -1
- package/lib/module/version.json +1 -1
- package/lib/typescript/index.d.ts +5 -0
- package/lib/typescript/index.d.ts.map +1 -1
- package/package.json +1 -1
- package/src/index.ts +5 -0
- package/src/version.json +1 -1
- package/ios/RNWhisper.xcodeproj/project.xcworkspace/contents.xcworkspacedata +0 -4
- package/ios/RNWhisper.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist +0 -8
- package/ios/RNWhisper.xcodeproj/project.xcworkspace/xcuserdata/jhen.xcuserdatad/UserInterfaceState.xcuserstate +0 -0
- package/ios/RNWhisper.xcodeproj/xcuserdata/jhen.xcuserdatad/xcschemes/xcschememanagement.plist +0 -19
package/cpp/ggml.c
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
#define _CRT_SECURE_NO_DEPRECATE // Disables ridiculous "unsafe"
|
|
1
|
+
#define _CRT_SECURE_NO_DEPRECATE // Disables ridiculous "unsafe" warnings on Windows
|
|
2
2
|
#define _USE_MATH_DEFINES // For M_PI on MSVC
|
|
3
3
|
|
|
4
4
|
#include "ggml-impl.h"
|
|
@@ -33,7 +33,7 @@
|
|
|
33
33
|
// we should just be careful :)
|
|
34
34
|
#pragma warning(disable: 4244 4267)
|
|
35
35
|
|
|
36
|
-
// disable POSIX deprecation
|
|
36
|
+
// disable POSIX deprecation warnings
|
|
37
37
|
// these functions are never going away, anyway
|
|
38
38
|
#pragma warning(disable: 4996)
|
|
39
39
|
#endif
|
|
@@ -1395,7 +1395,7 @@ inline static void wsp_ggml_vec_step_f32 (const int n, float * y, const float *
|
|
|
1395
1395
|
inline static void wsp_ggml_vec_tanh_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = tanhf(x[i]); }
|
|
1396
1396
|
inline static void wsp_ggml_vec_elu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : expf(x[i])-1; }
|
|
1397
1397
|
inline static void wsp_ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; }
|
|
1398
|
-
inline static void
|
|
1398
|
+
inline static void wsp_ggml_vec_leaky_relu_f32 (const int n, float * y, const float * x, const float ns) { for (int i = 0; i < n; ++i) y[i] = ((x[i] > 0.f) ? x[i] : 0.f) + ns * ((x[i] < 0.0f) ? x[i] : 0.f); }
|
|
1399
1399
|
|
|
1400
1400
|
static const float GELU_COEF_A = 0.044715f;
|
|
1401
1401
|
static const float GELU_QUICK_COEF = -1.702f;
|
|
@@ -1623,7 +1623,9 @@ static const char * WSP_GGML_OP_NAME[WSP_GGML_OP_COUNT] = {
|
|
|
1623
1623
|
"POOL_1D",
|
|
1624
1624
|
"POOL_2D",
|
|
1625
1625
|
"UPSCALE",
|
|
1626
|
+
"PAD",
|
|
1626
1627
|
"ARGSORT",
|
|
1628
|
+
"LEAKY_RELU",
|
|
1627
1629
|
|
|
1628
1630
|
"FLASH_ATTN",
|
|
1629
1631
|
"FLASH_FF",
|
|
@@ -1650,7 +1652,7 @@ static const char * WSP_GGML_OP_NAME[WSP_GGML_OP_COUNT] = {
|
|
|
1650
1652
|
"CROSS_ENTROPY_LOSS_BACK",
|
|
1651
1653
|
};
|
|
1652
1654
|
|
|
1653
|
-
static_assert(WSP_GGML_OP_COUNT ==
|
|
1655
|
+
static_assert(WSP_GGML_OP_COUNT == 72, "WSP_GGML_OP_COUNT != 72");
|
|
1654
1656
|
|
|
1655
1657
|
static const char * WSP_GGML_OP_SYMBOL[WSP_GGML_OP_COUNT] = {
|
|
1656
1658
|
"none",
|
|
@@ -1707,7 +1709,9 @@ static const char * WSP_GGML_OP_SYMBOL[WSP_GGML_OP_COUNT] = {
|
|
|
1707
1709
|
"pool_1d(x)",
|
|
1708
1710
|
"pool_2d(x)",
|
|
1709
1711
|
"upscale(x)",
|
|
1712
|
+
"pad(x)",
|
|
1710
1713
|
"argsort(x)",
|
|
1714
|
+
"leaky_relu(x)",
|
|
1711
1715
|
|
|
1712
1716
|
"flash_attn(x)",
|
|
1713
1717
|
"flash_ff(x)",
|
|
@@ -1734,7 +1738,7 @@ static const char * WSP_GGML_OP_SYMBOL[WSP_GGML_OP_COUNT] = {
|
|
|
1734
1738
|
"cross_entropy_loss_back(x,y)",
|
|
1735
1739
|
};
|
|
1736
1740
|
|
|
1737
|
-
static_assert(WSP_GGML_OP_COUNT ==
|
|
1741
|
+
static_assert(WSP_GGML_OP_COUNT == 72, "WSP_GGML_OP_COUNT != 72");
|
|
1738
1742
|
|
|
1739
1743
|
static_assert(WSP_GGML_OP_POOL_COUNT == 2, "WSP_GGML_OP_POOL_COUNT != 2");
|
|
1740
1744
|
|
|
@@ -1750,17 +1754,16 @@ static const char * WSP_GGML_UNARY_OP_NAME[WSP_GGML_UNARY_OP_COUNT] = {
|
|
|
1750
1754
|
"GELU",
|
|
1751
1755
|
"GELU_QUICK",
|
|
1752
1756
|
"SILU",
|
|
1753
|
-
"LEAKY",
|
|
1754
1757
|
};
|
|
1755
1758
|
|
|
1756
|
-
static_assert(WSP_GGML_UNARY_OP_COUNT ==
|
|
1759
|
+
static_assert(WSP_GGML_UNARY_OP_COUNT == 10, "WSP_GGML_UNARY_OP_COUNT != 10");
|
|
1757
1760
|
|
|
1758
1761
|
|
|
1759
1762
|
static_assert(sizeof(struct wsp_ggml_object)%WSP_GGML_MEM_ALIGN == 0, "wsp_ggml_object size must be a multiple of WSP_GGML_MEM_ALIGN");
|
|
1760
1763
|
static_assert(sizeof(struct wsp_ggml_tensor)%WSP_GGML_MEM_ALIGN == 0, "wsp_ggml_tensor size must be a multiple of WSP_GGML_MEM_ALIGN");
|
|
1761
1764
|
|
|
1762
1765
|
// WARN:
|
|
1763
|
-
// Mis-
|
|
1766
|
+
// Mis-configuration can lead to problem that's hard to reason about:
|
|
1764
1767
|
// * At best it crash or talks nosense.
|
|
1765
1768
|
// * At worst it talks slightly difference but hard to perceive.
|
|
1766
1769
|
//
|
|
@@ -3830,12 +3833,25 @@ struct wsp_ggml_tensor * wsp_ggml_relu_inplace(
|
|
|
3830
3833
|
return wsp_ggml_unary_inplace(ctx, a, WSP_GGML_UNARY_OP_RELU);
|
|
3831
3834
|
}
|
|
3832
3835
|
|
|
3833
|
-
//
|
|
3836
|
+
// wsp_ggml_leaky_relu
|
|
3834
3837
|
|
|
3835
|
-
struct wsp_ggml_tensor *
|
|
3838
|
+
struct wsp_ggml_tensor * wsp_ggml_leaky_relu(
|
|
3836
3839
|
struct wsp_ggml_context * ctx,
|
|
3837
|
-
struct wsp_ggml_tensor * a) {
|
|
3838
|
-
|
|
3840
|
+
struct wsp_ggml_tensor * a, float negative_slope, bool inplace) {
|
|
3841
|
+
bool is_node = false;
|
|
3842
|
+
|
|
3843
|
+
if (!inplace && (a->grad)) {
|
|
3844
|
+
is_node = true;
|
|
3845
|
+
}
|
|
3846
|
+
|
|
3847
|
+
struct wsp_ggml_tensor * result = inplace ? wsp_ggml_view_tensor(ctx, a) : wsp_ggml_dup_tensor(ctx, a);
|
|
3848
|
+
wsp_ggml_set_op_params(result, &negative_slope, sizeof(negative_slope));
|
|
3849
|
+
|
|
3850
|
+
result->op = WSP_GGML_OP_LEAKY_RELU;
|
|
3851
|
+
result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL;
|
|
3852
|
+
result->src[0] = a;
|
|
3853
|
+
|
|
3854
|
+
return result;
|
|
3839
3855
|
}
|
|
3840
3856
|
|
|
3841
3857
|
// wsp_ggml_gelu
|
|
@@ -4022,8 +4038,9 @@ static struct wsp_ggml_tensor * wsp_ggml_group_norm_impl(
|
|
|
4022
4038
|
|
|
4023
4039
|
struct wsp_ggml_tensor * result = inplace ? wsp_ggml_view_tensor(ctx, a) : wsp_ggml_dup_tensor(ctx, a);
|
|
4024
4040
|
|
|
4025
|
-
result->op = WSP_GGML_OP_GROUP_NORM;
|
|
4026
4041
|
result->op_params[0] = n_groups;
|
|
4042
|
+
|
|
4043
|
+
result->op = WSP_GGML_OP_GROUP_NORM;
|
|
4027
4044
|
result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL;
|
|
4028
4045
|
result->src[0] = a;
|
|
4029
4046
|
result->src[1] = NULL; // TODO: maybe store epsilon here?
|
|
@@ -4075,17 +4092,18 @@ struct wsp_ggml_tensor * wsp_ggml_mul_mat(
|
|
|
4075
4092
|
|
|
4076
4093
|
struct wsp_ggml_tensor * wsp_ggml_mul_mat_id(
|
|
4077
4094
|
struct wsp_ggml_context * ctx,
|
|
4078
|
-
struct wsp_ggml_tensor * as[],
|
|
4095
|
+
struct wsp_ggml_tensor * const as[],
|
|
4096
|
+
int n_as,
|
|
4079
4097
|
struct wsp_ggml_tensor * ids,
|
|
4080
4098
|
int id,
|
|
4081
4099
|
struct wsp_ggml_tensor * b) {
|
|
4082
4100
|
|
|
4083
|
-
int64_t n_as = ids->ne[0];
|
|
4084
|
-
|
|
4085
4101
|
WSP_GGML_ASSERT(ids->type == WSP_GGML_TYPE_I32);
|
|
4086
|
-
WSP_GGML_ASSERT(
|
|
4102
|
+
WSP_GGML_ASSERT(ids->ne[2] == 1 && ids->ne[3] == 1);
|
|
4103
|
+
WSP_GGML_ASSERT(ids->ne[1] == b->ne[1]);
|
|
4104
|
+
WSP_GGML_ASSERT(ids->ne[2] == b->ne[2] && ids->ne[3] == b->ne[3]);
|
|
4087
4105
|
WSP_GGML_ASSERT(n_as > 0 && n_as <= WSP_GGML_MAX_SRC - 2);
|
|
4088
|
-
WSP_GGML_ASSERT(id >= 0 && id <
|
|
4106
|
+
WSP_GGML_ASSERT(id >= 0 && id < ids->ne[0]);
|
|
4089
4107
|
|
|
4090
4108
|
bool is_node = false;
|
|
4091
4109
|
|
|
@@ -4097,13 +4115,14 @@ struct wsp_ggml_tensor * wsp_ggml_mul_mat_id(
|
|
|
4097
4115
|
struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, WSP_GGML_TYPE_F32, MAX(as[0]->n_dims, b->n_dims), ne);
|
|
4098
4116
|
|
|
4099
4117
|
wsp_ggml_set_op_params_i32(result, 0, id);
|
|
4118
|
+
wsp_ggml_set_op_params_i32(result, 1, n_as);
|
|
4100
4119
|
|
|
4101
4120
|
result->op = WSP_GGML_OP_MUL_MAT_ID;
|
|
4102
4121
|
result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL;
|
|
4103
4122
|
result->src[0] = ids;
|
|
4104
4123
|
result->src[1] = b;
|
|
4105
4124
|
|
|
4106
|
-
for (
|
|
4125
|
+
for (int i = 0; i < n_as; i++) {
|
|
4107
4126
|
struct wsp_ggml_tensor * a = as[i];
|
|
4108
4127
|
WSP_GGML_ASSERT(wsp_ggml_are_same_shape(as[0], a));
|
|
4109
4128
|
WSP_GGML_ASSERT(wsp_ggml_can_mul_mat(a, b));
|
|
@@ -4731,7 +4750,9 @@ struct wsp_ggml_tensor * wsp_ggml_get_rows(
|
|
|
4731
4750
|
struct wsp_ggml_context * ctx,
|
|
4732
4751
|
struct wsp_ggml_tensor * a,
|
|
4733
4752
|
struct wsp_ggml_tensor * b) {
|
|
4734
|
-
WSP_GGML_ASSERT(
|
|
4753
|
+
WSP_GGML_ASSERT(a->ne[2] == b->ne[1]);
|
|
4754
|
+
WSP_GGML_ASSERT(b->ne[3] == 1);
|
|
4755
|
+
WSP_GGML_ASSERT(b->type == WSP_GGML_TYPE_I32);
|
|
4735
4756
|
|
|
4736
4757
|
bool is_node = false;
|
|
4737
4758
|
|
|
@@ -4741,7 +4762,7 @@ struct wsp_ggml_tensor * wsp_ggml_get_rows(
|
|
|
4741
4762
|
|
|
4742
4763
|
// TODO: implement non F32 return
|
|
4743
4764
|
//struct wsp_ggml_tensor * result = wsp_ggml_new_tensor_2d(ctx, a->type, a->ne[0], b->ne[0]);
|
|
4744
|
-
struct wsp_ggml_tensor * result =
|
|
4765
|
+
struct wsp_ggml_tensor * result = wsp_ggml_new_tensor_4d(ctx, WSP_GGML_TYPE_F32, a->ne[0], b->ne[0], b->ne[1], b->ne[2]);
|
|
4745
4766
|
|
|
4746
4767
|
result->op = WSP_GGML_OP_GET_ROWS;
|
|
4747
4768
|
result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL;
|
|
@@ -5519,6 +5540,30 @@ static struct wsp_ggml_tensor * wsp_ggml_upscale_impl(
|
|
|
5519
5540
|
return result;
|
|
5520
5541
|
}
|
|
5521
5542
|
|
|
5543
|
+
struct wsp_ggml_tensor * wsp_ggml_pad(
|
|
5544
|
+
struct wsp_ggml_context * ctx,
|
|
5545
|
+
struct wsp_ggml_tensor * a,
|
|
5546
|
+
int p0, int p1, int p2, int p3) {
|
|
5547
|
+
bool is_node = false;
|
|
5548
|
+
|
|
5549
|
+
if (a->grad) {
|
|
5550
|
+
WSP_GGML_ASSERT(false); // TODO: implement backward
|
|
5551
|
+
is_node = true;
|
|
5552
|
+
}
|
|
5553
|
+
|
|
5554
|
+
struct wsp_ggml_tensor * result = wsp_ggml_new_tensor_4d(ctx, a->type,
|
|
5555
|
+
a->ne[0] + p0,
|
|
5556
|
+
a->ne[1] + p1,
|
|
5557
|
+
a->ne[2] + p2,
|
|
5558
|
+
a->ne[3] + p3);
|
|
5559
|
+
|
|
5560
|
+
result->op = WSP_GGML_OP_PAD;
|
|
5561
|
+
result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL;
|
|
5562
|
+
result->src[0] = a;
|
|
5563
|
+
|
|
5564
|
+
return result;
|
|
5565
|
+
}
|
|
5566
|
+
|
|
5522
5567
|
struct wsp_ggml_tensor * wsp_ggml_upscale(
|
|
5523
5568
|
struct wsp_ggml_context * ctx,
|
|
5524
5569
|
struct wsp_ggml_tensor * a,
|
|
@@ -7520,7 +7565,7 @@ static void wsp_ggml_compute_forward_acc_f32(
|
|
|
7520
7565
|
WSP_GGML_ASSERT(wsp_ggml_is_contiguous(dst) && wsp_ggml_is_contiguous(src0));
|
|
7521
7566
|
|
|
7522
7567
|
// view src0 and dst with these strides and data offset inbytes during acc
|
|
7523
|
-
// nb0 is
|
|
7568
|
+
// nb0 is implicitly element_size because src0 and dst are contiguous
|
|
7524
7569
|
size_t nb1 = ((int32_t *) dst->op_params)[0];
|
|
7525
7570
|
size_t nb2 = ((int32_t *) dst->op_params)[1];
|
|
7526
7571
|
size_t nb3 = ((int32_t *) dst->op_params)[2];
|
|
@@ -7714,8 +7759,10 @@ static void wsp_ggml_compute_forward_mul_f32(
|
|
|
7714
7759
|
const int ith = params->ith;
|
|
7715
7760
|
const int nth = params->nth;
|
|
7716
7761
|
|
|
7762
|
+
// TODO: OpenCL kernel support broadcast
|
|
7717
7763
|
#ifdef WSP_GGML_USE_CLBLAST
|
|
7718
7764
|
if (src1->backend == WSP_GGML_BACKEND_GPU) {
|
|
7765
|
+
WSP_GGML_ASSERT(wsp_ggml_are_same_shape(src0, src1));
|
|
7719
7766
|
if (ith == 0) {
|
|
7720
7767
|
wsp_ggml_cl_mul(src0, src1, dst);
|
|
7721
7768
|
}
|
|
@@ -8981,10 +9028,9 @@ static void wsp_ggml_compute_forward_silu(
|
|
|
8981
9028
|
} break;
|
|
8982
9029
|
}
|
|
8983
9030
|
}
|
|
9031
|
+
// wsp_ggml_compute_forward_leaky_relu
|
|
8984
9032
|
|
|
8985
|
-
|
|
8986
|
-
|
|
8987
|
-
static void wsp_ggml_compute_forward_leaky_f32(
|
|
9033
|
+
static void wsp_ggml_compute_forward_leaky_relu_f32(
|
|
8988
9034
|
const struct wsp_ggml_compute_params * params,
|
|
8989
9035
|
const struct wsp_ggml_tensor * src0,
|
|
8990
9036
|
struct wsp_ggml_tensor * dst) {
|
|
@@ -8998,24 +9044,27 @@ static void wsp_ggml_compute_forward_leaky_f32(
|
|
|
8998
9044
|
const int n = wsp_ggml_nrows(src0);
|
|
8999
9045
|
const int nc = src0->ne[0];
|
|
9000
9046
|
|
|
9047
|
+
float negative_slope;
|
|
9048
|
+
memcpy(&negative_slope, dst->op_params, sizeof(float));
|
|
9049
|
+
|
|
9001
9050
|
assert(dst->nb[0] == sizeof(float));
|
|
9002
9051
|
assert(src0->nb[0] == sizeof(float));
|
|
9003
9052
|
|
|
9004
9053
|
for (int i = 0; i < n; i++) {
|
|
9005
|
-
|
|
9054
|
+
wsp_ggml_vec_leaky_relu_f32(nc,
|
|
9006
9055
|
(float *) ((char *) dst->data + i*( dst->nb[1])),
|
|
9007
|
-
(float *) ((char *) src0->data + i*(src0->nb[1])));
|
|
9056
|
+
(float *) ((char *) src0->data + i*(src0->nb[1])), negative_slope);
|
|
9008
9057
|
}
|
|
9009
9058
|
}
|
|
9010
9059
|
|
|
9011
|
-
static void
|
|
9060
|
+
static void wsp_ggml_compute_forward_leaky_relu(
|
|
9012
9061
|
const struct wsp_ggml_compute_params * params,
|
|
9013
9062
|
const struct wsp_ggml_tensor * src0,
|
|
9014
9063
|
struct wsp_ggml_tensor * dst) {
|
|
9015
9064
|
switch (src0->type) {
|
|
9016
9065
|
case WSP_GGML_TYPE_F32:
|
|
9017
9066
|
{
|
|
9018
|
-
|
|
9067
|
+
wsp_ggml_compute_forward_leaky_relu_f32(params, src0, dst);
|
|
9019
9068
|
} break;
|
|
9020
9069
|
default:
|
|
9021
9070
|
{
|
|
@@ -9504,8 +9553,11 @@ static bool wsp_ggml_compute_forward_mul_mat_use_blas(
|
|
|
9504
9553
|
const int64_t ne0 = dst->ne[0];
|
|
9505
9554
|
const int64_t ne1 = dst->ne[1];
|
|
9506
9555
|
|
|
9556
|
+
// NOTE: with WSP_GGML_OP_MUL_MAT_ID we don't want to go through the BLAS branch because it will dequantize (to_float)
|
|
9557
|
+
// all the experts for each batch element and the processing would become incredibly slow
|
|
9507
9558
|
// TODO: find the optimal values for these
|
|
9508
|
-
if (
|
|
9559
|
+
if (dst->op != WSP_GGML_OP_MUL_MAT_ID &&
|
|
9560
|
+
wsp_ggml_is_contiguous(src0) &&
|
|
9509
9561
|
wsp_ggml_is_contiguous(src1) &&
|
|
9510
9562
|
//src0->type == WSP_GGML_TYPE_F32 &&
|
|
9511
9563
|
src1->type == WSP_GGML_TYPE_F32 &&
|
|
@@ -9519,11 +9571,16 @@ static bool wsp_ggml_compute_forward_mul_mat_use_blas(
|
|
|
9519
9571
|
}
|
|
9520
9572
|
#endif
|
|
9521
9573
|
|
|
9574
|
+
// off1 = offset in i11 and i1
|
|
9575
|
+
// cne1 = ne11 and ne1
|
|
9576
|
+
// in a normal matrix multiplication, off1 = 0 and cne1 = ne1
|
|
9577
|
+
// during WSP_GGML_TASK_INIT, the full src1 is converted regardless of off1 and cne1
|
|
9522
9578
|
static void wsp_ggml_compute_forward_mul_mat(
|
|
9523
9579
|
const struct wsp_ggml_compute_params * params,
|
|
9524
9580
|
const struct wsp_ggml_tensor * src0,
|
|
9525
9581
|
const struct wsp_ggml_tensor * src1,
|
|
9526
|
-
struct wsp_ggml_tensor * dst
|
|
9582
|
+
struct wsp_ggml_tensor * dst,
|
|
9583
|
+
int64_t off1, int64_t cne1) {
|
|
9527
9584
|
int64_t t0 = wsp_ggml_perf_time_us();
|
|
9528
9585
|
UNUSED(t0);
|
|
9529
9586
|
|
|
@@ -9591,10 +9648,9 @@ static void wsp_ggml_compute_forward_mul_mat(
|
|
|
9591
9648
|
const int64_t i03 = i13/r3;
|
|
9592
9649
|
const int64_t i02 = i12/r2;
|
|
9593
9650
|
|
|
9594
|
-
const void * x = (char *) src0->data +
|
|
9595
|
-
const float * y = (float *) ((char *) src1->data + i12*nb12 + i13*nb13);
|
|
9596
|
-
|
|
9597
|
-
float * d = (float *) ((char *) dst->data + i12*nb2 + i13*nb3);
|
|
9651
|
+
const void * x = (char *) src0->data + i02*nb02 + i03*nb03;
|
|
9652
|
+
const float * y = (float *) ((char *) src1->data + off1*nb11 + i12*nb12 + i13*nb13);
|
|
9653
|
+
float * d = (float *) ((char *) dst->data + off1*nb1 + i12*nb2 + i13*nb3);
|
|
9598
9654
|
|
|
9599
9655
|
if (type != WSP_GGML_TYPE_F32) {
|
|
9600
9656
|
float * const wdata = params->wdata;
|
|
@@ -9611,10 +9667,10 @@ static void wsp_ggml_compute_forward_mul_mat(
|
|
|
9611
9667
|
}
|
|
9612
9668
|
|
|
9613
9669
|
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
|
|
9614
|
-
|
|
9615
|
-
|
|
9616
|
-
|
|
9617
|
-
|
|
9670
|
+
cne1, ne01, ne10,
|
|
9671
|
+
1.0f, y, ne10,
|
|
9672
|
+
x, ne00,
|
|
9673
|
+
0.0f, d, ne01);
|
|
9618
9674
|
}
|
|
9619
9675
|
}
|
|
9620
9676
|
|
|
@@ -9630,6 +9686,7 @@ static void wsp_ggml_compute_forward_mul_mat(
|
|
|
9630
9686
|
const size_t row_size = ne10*wsp_ggml_type_size(vec_dot_type)/wsp_ggml_blck_size(vec_dot_type);
|
|
9631
9687
|
|
|
9632
9688
|
assert(params->wsize >= ne11*ne12*ne13*row_size);
|
|
9689
|
+
assert(src1->type == WSP_GGML_TYPE_F32);
|
|
9633
9690
|
|
|
9634
9691
|
for (int64_t i13 = 0; i13 < ne13; ++i13) {
|
|
9635
9692
|
for (int64_t i12 = 0; i12 < ne12; ++i12) {
|
|
@@ -9652,7 +9709,7 @@ static void wsp_ggml_compute_forward_mul_mat(
|
|
|
9652
9709
|
const size_t row_size = ne10*wsp_ggml_type_size(vec_dot_type)/wsp_ggml_blck_size(vec_dot_type);
|
|
9653
9710
|
|
|
9654
9711
|
const int64_t nr0 = ne01; // src0 rows
|
|
9655
|
-
const int64_t nr1 =
|
|
9712
|
+
const int64_t nr1 = cne1*ne12*ne13; // src1 rows
|
|
9656
9713
|
|
|
9657
9714
|
//printf("nr0 = %lld, nr1 = %lld\n", nr0, nr1);
|
|
9658
9715
|
|
|
@@ -9694,9 +9751,9 @@ static void wsp_ggml_compute_forward_mul_mat(
|
|
|
9694
9751
|
for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) {
|
|
9695
9752
|
for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) {
|
|
9696
9753
|
for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ++ir1) {
|
|
9697
|
-
const int64_t i13 = (ir1/(ne12*
|
|
9698
|
-
const int64_t i12 = (ir1 - i13*ne12*
|
|
9699
|
-
const int64_t i11 = (ir1 - i13*ne12*
|
|
9754
|
+
const int64_t i13 = (ir1/(ne12*cne1));
|
|
9755
|
+
const int64_t i12 = (ir1 - i13*ne12*cne1)/cne1;
|
|
9756
|
+
const int64_t i11 = (ir1 - i13*ne12*cne1 - i12*cne1) + off1;
|
|
9700
9757
|
|
|
9701
9758
|
// broadcast src0 into src1
|
|
9702
9759
|
const int64_t i03 = i13/r3;
|
|
@@ -9736,20 +9793,28 @@ static void wsp_ggml_compute_forward_mul_mat(
|
|
|
9736
9793
|
|
|
9737
9794
|
static void wsp_ggml_compute_forward_mul_mat_id(
|
|
9738
9795
|
const struct wsp_ggml_compute_params * params,
|
|
9796
|
+
const struct wsp_ggml_tensor * src0,
|
|
9797
|
+
const struct wsp_ggml_tensor * src1,
|
|
9739
9798
|
struct wsp_ggml_tensor * dst) {
|
|
9740
9799
|
|
|
9741
|
-
|
|
9742
|
-
|
|
9743
|
-
|
|
9744
|
-
|
|
9800
|
+
if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) {
|
|
9801
|
+
// during WSP_GGML_TASK_INIT the entire src1 is converted to vec_dot_type
|
|
9802
|
+
wsp_ggml_compute_forward_mul_mat(params, dst->src[2], src1, dst, 0, dst->ne[1]);
|
|
9803
|
+
return;
|
|
9804
|
+
}
|
|
9745
9805
|
|
|
9746
|
-
const
|
|
9806
|
+
const struct wsp_ggml_tensor * ids = src0;
|
|
9807
|
+
const int id = wsp_ggml_get_op_params_i32(dst, 0);
|
|
9808
|
+
const int n_as = wsp_ggml_get_op_params_i32(dst, 1);
|
|
9747
9809
|
|
|
9748
|
-
|
|
9810
|
+
for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
|
|
9811
|
+
const int32_t row_id = *(const int32_t *) ((const char *) ids->data + i01*ids->nb[1] + id*ids->nb[0]);
|
|
9749
9812
|
|
|
9750
|
-
|
|
9813
|
+
WSP_GGML_ASSERT(row_id >= 0 && row_id < n_as);
|
|
9751
9814
|
|
|
9752
|
-
|
|
9815
|
+
const struct wsp_ggml_tensor * src0_row = dst->src[row_id + 2];
|
|
9816
|
+
wsp_ggml_compute_forward_mul_mat(params, src0_row, src1, dst, i01, 1);
|
|
9817
|
+
}
|
|
9753
9818
|
}
|
|
9754
9819
|
|
|
9755
9820
|
// wsp_ggml_compute_forward_out_prod
|
|
@@ -10161,7 +10226,7 @@ static void wsp_ggml_compute_forward_set_f32(
|
|
|
10161
10226
|
WSP_GGML_ASSERT(wsp_ggml_is_contiguous(dst) && wsp_ggml_is_contiguous(src0));
|
|
10162
10227
|
|
|
10163
10228
|
// view src0 and dst with these strides and data offset inbytes during set
|
|
10164
|
-
// nb0 is
|
|
10229
|
+
// nb0 is implicitly element_size because src0 and dst are contiguous
|
|
10165
10230
|
size_t nb1 = ((int32_t *) dst->op_params)[0];
|
|
10166
10231
|
size_t nb2 = ((int32_t *) dst->op_params)[1];
|
|
10167
10232
|
size_t nb3 = ((int32_t *) dst->op_params)[2];
|
|
@@ -10325,21 +10390,30 @@ static void wsp_ggml_compute_forward_get_rows_q(
|
|
|
10325
10390
|
return;
|
|
10326
10391
|
}
|
|
10327
10392
|
|
|
10328
|
-
|
|
10329
|
-
|
|
10393
|
+
WSP_GGML_TENSOR_BINARY_OP_LOCALS
|
|
10394
|
+
|
|
10395
|
+
const int64_t nc = ne00;
|
|
10396
|
+
const int64_t nr = wsp_ggml_nelements(src1); WSP_GGML_UNUSED(nr);
|
|
10397
|
+
|
|
10330
10398
|
const enum wsp_ggml_type type = src0->type;
|
|
10331
10399
|
wsp_ggml_to_float_t const wsp_dewsp_quantize_row_q = type_traits[type].to_float;
|
|
10332
10400
|
|
|
10333
|
-
assert(
|
|
10334
|
-
assert(
|
|
10335
|
-
assert(
|
|
10401
|
+
assert(ne0 == nc);
|
|
10402
|
+
assert(ne02 == ne11);
|
|
10403
|
+
assert(nb00 == wsp_ggml_type_size(type));
|
|
10404
|
+
assert(wsp_ggml_nrows(dst) == nr);
|
|
10336
10405
|
|
|
10337
|
-
|
|
10338
|
-
|
|
10406
|
+
// TODO: multi-thread
|
|
10407
|
+
for (int64_t i12 = 0; i12 < ne12; ++i12) {
|
|
10408
|
+
for (int64_t i11 = 0; i11 < ne11; ++i11) {
|
|
10409
|
+
for (int64_t i10 = 0; i10 < ne10; ++i10) {
|
|
10410
|
+
const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
|
|
10339
10411
|
|
|
10340
|
-
|
|
10341
|
-
|
|
10342
|
-
|
|
10412
|
+
wsp_dewsp_quantize_row_q(
|
|
10413
|
+
(const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
|
|
10414
|
+
(float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc);
|
|
10415
|
+
}
|
|
10416
|
+
}
|
|
10343
10417
|
}
|
|
10344
10418
|
}
|
|
10345
10419
|
|
|
@@ -10354,19 +10428,26 @@ static void wsp_ggml_compute_forward_get_rows_f16(
|
|
|
10354
10428
|
return;
|
|
10355
10429
|
}
|
|
10356
10430
|
|
|
10357
|
-
|
|
10358
|
-
const int nr = wsp_ggml_nelements(src1);
|
|
10431
|
+
WSP_GGML_TENSOR_BINARY_OP_LOCALS
|
|
10359
10432
|
|
|
10360
|
-
|
|
10361
|
-
|
|
10362
|
-
assert(src0->nb[0] == sizeof(wsp_ggml_fp16_t));
|
|
10433
|
+
const int64_t nc = ne00;
|
|
10434
|
+
const int64_t nr = wsp_ggml_nelements(src1); WSP_GGML_UNUSED(nr);
|
|
10363
10435
|
|
|
10364
|
-
|
|
10365
|
-
|
|
10436
|
+
assert(ne0 == nc);
|
|
10437
|
+
assert(ne02 == ne11);
|
|
10438
|
+
assert(nb00 == sizeof(wsp_ggml_fp16_t));
|
|
10439
|
+
assert(wsp_ggml_nrows(dst) == nr);
|
|
10366
10440
|
|
|
10367
|
-
|
|
10368
|
-
|
|
10369
|
-
|
|
10441
|
+
// TODO: multi-thread
|
|
10442
|
+
for (int64_t i12 = 0; i12 < ne12; ++i12) {
|
|
10443
|
+
for (int64_t i11 = 0; i11 < ne11; ++i11) {
|
|
10444
|
+
for (int64_t i10 = 0; i10 < ne10; ++i10) {
|
|
10445
|
+
const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
|
|
10446
|
+
|
|
10447
|
+
wsp_ggml_fp16_to_fp32_row(
|
|
10448
|
+
(const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
|
|
10449
|
+
(float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc);
|
|
10450
|
+
}
|
|
10370
10451
|
}
|
|
10371
10452
|
}
|
|
10372
10453
|
}
|
|
@@ -10382,19 +10463,27 @@ static void wsp_ggml_compute_forward_get_rows_f32(
|
|
|
10382
10463
|
return;
|
|
10383
10464
|
}
|
|
10384
10465
|
|
|
10385
|
-
|
|
10386
|
-
const int nr = wsp_ggml_nelements(src1);
|
|
10466
|
+
WSP_GGML_TENSOR_BINARY_OP_LOCALS
|
|
10387
10467
|
|
|
10388
|
-
|
|
10389
|
-
|
|
10390
|
-
assert(src0->nb[0] == sizeof(float));
|
|
10468
|
+
const int64_t nc = ne00;
|
|
10469
|
+
const int64_t nr = wsp_ggml_nelements(src1); WSP_GGML_UNUSED(nr);
|
|
10391
10470
|
|
|
10392
|
-
|
|
10393
|
-
|
|
10471
|
+
assert(ne0 == nc);
|
|
10472
|
+
assert(ne02 == ne11);
|
|
10473
|
+
assert(nb00 == sizeof(float));
|
|
10474
|
+
assert(wsp_ggml_nrows(dst) == nr);
|
|
10394
10475
|
|
|
10395
|
-
|
|
10396
|
-
|
|
10397
|
-
|
|
10476
|
+
// TODO: multi-thread
|
|
10477
|
+
for (int64_t i12 = 0; i12 < ne12; ++i12) {
|
|
10478
|
+
for (int64_t i11 = 0; i11 < ne11; ++i11) {
|
|
10479
|
+
for (int64_t i10 = 0; i10 < ne10; ++i10) {
|
|
10480
|
+
const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
|
|
10481
|
+
|
|
10482
|
+
wsp_ggml_vec_cpy_f32(nc,
|
|
10483
|
+
(float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3),
|
|
10484
|
+
(float *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03));
|
|
10485
|
+
}
|
|
10486
|
+
}
|
|
10398
10487
|
}
|
|
10399
10488
|
}
|
|
10400
10489
|
|
|
@@ -12114,6 +12203,7 @@ static void wsp_ggml_compute_forward_upscale_f32(
|
|
|
12114
12203
|
WSP_GGML_ASSERT(src0->nb[0] == sizeof(float));
|
|
12115
12204
|
|
|
12116
12205
|
const int ith = params->ith;
|
|
12206
|
+
const int nth = params->nth;
|
|
12117
12207
|
|
|
12118
12208
|
WSP_GGML_TENSOR_UNARY_OP_LOCALS
|
|
12119
12209
|
|
|
@@ -12121,16 +12211,17 @@ static void wsp_ggml_compute_forward_upscale_f32(
|
|
|
12121
12211
|
|
|
12122
12212
|
// TODO: optimize
|
|
12123
12213
|
|
|
12124
|
-
for (
|
|
12125
|
-
|
|
12126
|
-
|
|
12127
|
-
|
|
12128
|
-
|
|
12129
|
-
|
|
12130
|
-
|
|
12131
|
-
const
|
|
12214
|
+
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
|
12215
|
+
const int64_t i03 = i3;
|
|
12216
|
+
for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
|
|
12217
|
+
const int64_t i02 = i2;
|
|
12218
|
+
for (int64_t i1 = 0; i1 < ne1; i1++) {
|
|
12219
|
+
const int64_t i01 = i1 / scale_factor;
|
|
12220
|
+
for (int64_t i0 = 0; i0 < ne0; i0++) {
|
|
12221
|
+
const int64_t i00 = i0 / scale_factor;
|
|
12132
12222
|
|
|
12133
|
-
float *
|
|
12223
|
+
const float * x = (float *)((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
|
12224
|
+
float * y = (float *)((char *) dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
|
|
12134
12225
|
|
|
12135
12226
|
*y = *x;
|
|
12136
12227
|
}
|
|
@@ -12155,6 +12246,64 @@ static void wsp_ggml_compute_forward_upscale(
|
|
|
12155
12246
|
}
|
|
12156
12247
|
}
|
|
12157
12248
|
|
|
12249
|
+
// wsp_ggml_compute_forward_pad
|
|
12250
|
+
|
|
12251
|
+
static void wsp_ggml_compute_forward_pad_f32(
|
|
12252
|
+
const struct wsp_ggml_compute_params * params,
|
|
12253
|
+
const struct wsp_ggml_tensor * src0,
|
|
12254
|
+
struct wsp_ggml_tensor * dst) {
|
|
12255
|
+
|
|
12256
|
+
if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) {
|
|
12257
|
+
return;
|
|
12258
|
+
}
|
|
12259
|
+
|
|
12260
|
+
WSP_GGML_ASSERT(src0->nb[0] == sizeof(float));
|
|
12261
|
+
WSP_GGML_ASSERT( dst->nb[0] == sizeof(float));
|
|
12262
|
+
|
|
12263
|
+
const int ith = params->ith;
|
|
12264
|
+
const int nth = params->nth;
|
|
12265
|
+
|
|
12266
|
+
WSP_GGML_TENSOR_UNARY_OP_LOCALS
|
|
12267
|
+
|
|
12268
|
+
float * dst_ptr = (float *) dst->data;
|
|
12269
|
+
|
|
12270
|
+
// TODO: optimize
|
|
12271
|
+
|
|
12272
|
+
for (int64_t i2 = 0; i2 < ne2; ++i2) {
|
|
12273
|
+
for (int64_t i1 = ith; i1 < ne1; i1 += nth) {
|
|
12274
|
+
for (int64_t i0 = 0; i0 < ne0; ++i0) {
|
|
12275
|
+
for (int64_t i3 = 0; i3 < ne3; ++i3) {
|
|
12276
|
+
const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
|
|
12277
|
+
|
|
12278
|
+
const float * src_ptr = (const float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
|
12279
|
+
|
|
12280
|
+
if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
|
|
12281
|
+
dst_ptr[dst_idx] = *src_ptr;
|
|
12282
|
+
} else {
|
|
12283
|
+
dst_ptr[dst_idx] = 0;
|
|
12284
|
+
}
|
|
12285
|
+
}
|
|
12286
|
+
}
|
|
12287
|
+
}
|
|
12288
|
+
}
|
|
12289
|
+
}
|
|
12290
|
+
|
|
12291
|
+
static void wsp_ggml_compute_forward_pad(
|
|
12292
|
+
const struct wsp_ggml_compute_params * params,
|
|
12293
|
+
const struct wsp_ggml_tensor * src0,
|
|
12294
|
+
struct wsp_ggml_tensor * dst) {
|
|
12295
|
+
switch (src0->type) {
|
|
12296
|
+
case WSP_GGML_TYPE_F32:
|
|
12297
|
+
{
|
|
12298
|
+
wsp_ggml_compute_forward_pad_f32(params, src0, dst);
|
|
12299
|
+
} break;
|
|
12300
|
+
default:
|
|
12301
|
+
{
|
|
12302
|
+
WSP_GGML_ASSERT(false);
|
|
12303
|
+
} break;
|
|
12304
|
+
}
|
|
12305
|
+
}
|
|
12306
|
+
|
|
12158
12307
|
// wsp_ggml_compute_forward_argsort
|
|
12159
12308
|
|
|
12160
12309
|
static void wsp_ggml_compute_forward_argsort_f32(
|
|
@@ -13362,10 +13511,6 @@ static void wsp_ggml_compute_forward_unary(
|
|
|
13362
13511
|
{
|
|
13363
13512
|
wsp_ggml_compute_forward_silu(params, src0, dst);
|
|
13364
13513
|
} break;
|
|
13365
|
-
case WSP_GGML_UNARY_OP_LEAKY:
|
|
13366
|
-
{
|
|
13367
|
-
wsp_ggml_compute_forward_leaky(params, src0, dst);
|
|
13368
|
-
} break;
|
|
13369
13514
|
default:
|
|
13370
13515
|
{
|
|
13371
13516
|
WSP_GGML_ASSERT(false);
|
|
@@ -14037,11 +14182,11 @@ static void wsp_ggml_compute_forward(struct wsp_ggml_compute_params * params, st
|
|
|
14037
14182
|
} break;
|
|
14038
14183
|
case WSP_GGML_OP_MUL_MAT:
|
|
14039
14184
|
{
|
|
14040
|
-
wsp_ggml_compute_forward_mul_mat(params, tensor->src[0], tensor->src[1], tensor);
|
|
14185
|
+
wsp_ggml_compute_forward_mul_mat(params, tensor->src[0], tensor->src[1], tensor, 0, tensor->ne[1]);
|
|
14041
14186
|
} break;
|
|
14042
14187
|
case WSP_GGML_OP_MUL_MAT_ID:
|
|
14043
14188
|
{
|
|
14044
|
-
wsp_ggml_compute_forward_mul_mat_id(params, tensor);
|
|
14189
|
+
wsp_ggml_compute_forward_mul_mat_id(params, tensor->src[0], tensor->src[1], tensor);
|
|
14045
14190
|
} break;
|
|
14046
14191
|
case WSP_GGML_OP_OUT_PROD:
|
|
14047
14192
|
{
|
|
@@ -14147,10 +14292,18 @@ static void wsp_ggml_compute_forward(struct wsp_ggml_compute_params * params, st
|
|
|
14147
14292
|
{
|
|
14148
14293
|
wsp_ggml_compute_forward_upscale(params, tensor->src[0], tensor);
|
|
14149
14294
|
} break;
|
|
14295
|
+
case WSP_GGML_OP_PAD:
|
|
14296
|
+
{
|
|
14297
|
+
wsp_ggml_compute_forward_pad(params, tensor->src[0], tensor);
|
|
14298
|
+
} break;
|
|
14150
14299
|
case WSP_GGML_OP_ARGSORT:
|
|
14151
14300
|
{
|
|
14152
14301
|
wsp_ggml_compute_forward_argsort(params, tensor->src[0], tensor);
|
|
14153
14302
|
} break;
|
|
14303
|
+
case WSP_GGML_OP_LEAKY_RELU:
|
|
14304
|
+
{
|
|
14305
|
+
wsp_ggml_compute_forward_leaky_relu(params, tensor->src[0], tensor);
|
|
14306
|
+
} break;
|
|
14154
14307
|
case WSP_GGML_OP_FLASH_ATTN:
|
|
14155
14308
|
{
|
|
14156
14309
|
const int32_t t = wsp_ggml_get_op_params_i32(tensor, 0);
|
|
@@ -14475,7 +14628,7 @@ void wsp_ggml_build_backward_gradient_checkpointing(
|
|
|
14475
14628
|
// insert new tensors recomputing src, reusing already made replacements,
|
|
14476
14629
|
// remember replacements: remember new tensors with mapping from corresponding gf nodes
|
|
14477
14630
|
// recurse for input tensors,
|
|
14478
|
-
// unless (i.e. terminating when) input tensors are
|
|
14631
|
+
// unless (i.e. terminating when) input tensors are replacements (like checkpoints)
|
|
14479
14632
|
node->src[k] = wsp_ggml_recompute_graph_node(ctx, gf, replacements, node->src[k]);
|
|
14480
14633
|
}
|
|
14481
14634
|
// insert rewritten backward node with replacements made into resulting backward graph gb
|
|
@@ -15143,10 +15296,18 @@ static void wsp_ggml_compute_backward(struct wsp_ggml_context * ctx, struct wsp_
|
|
|
15143
15296
|
{
|
|
15144
15297
|
WSP_GGML_ASSERT(false); // TODO: not implemented
|
|
15145
15298
|
} break;
|
|
15299
|
+
case WSP_GGML_OP_PAD:
|
|
15300
|
+
{
|
|
15301
|
+
WSP_GGML_ASSERT(false); // TODO: not implemented
|
|
15302
|
+
} break;
|
|
15146
15303
|
case WSP_GGML_OP_ARGSORT:
|
|
15147
15304
|
{
|
|
15148
15305
|
WSP_GGML_ASSERT(false); // TODO: not implemented
|
|
15149
15306
|
} break;
|
|
15307
|
+
case WSP_GGML_OP_LEAKY_RELU:
|
|
15308
|
+
{
|
|
15309
|
+
WSP_GGML_ASSERT(false); // TODO: not implemented
|
|
15310
|
+
} break;
|
|
15150
15311
|
case WSP_GGML_OP_FLASH_ATTN:
|
|
15151
15312
|
{
|
|
15152
15313
|
struct wsp_ggml_tensor * flash_grad = NULL;
|
|
@@ -15752,6 +15913,7 @@ static int wsp_ggml_get_n_tasks(struct wsp_ggml_tensor * node, int n_threads) {
|
|
|
15752
15913
|
case WSP_GGML_OP_ARGMAX:
|
|
15753
15914
|
case WSP_GGML_OP_REPEAT:
|
|
15754
15915
|
case WSP_GGML_OP_REPEAT_BACK:
|
|
15916
|
+
case WSP_GGML_OP_LEAKY_RELU:
|
|
15755
15917
|
{
|
|
15756
15918
|
n_tasks = 1;
|
|
15757
15919
|
} break;
|
|
@@ -15764,7 +15926,6 @@ static int wsp_ggml_get_n_tasks(struct wsp_ggml_tensor * node, int n_threads) {
|
|
|
15764
15926
|
case WSP_GGML_UNARY_OP_TANH:
|
|
15765
15927
|
case WSP_GGML_UNARY_OP_ELU:
|
|
15766
15928
|
case WSP_GGML_UNARY_OP_RELU:
|
|
15767
|
-
case WSP_GGML_UNARY_OP_LEAKY:
|
|
15768
15929
|
{
|
|
15769
15930
|
n_tasks = 1;
|
|
15770
15931
|
} break;
|
|
@@ -15883,6 +16044,10 @@ static int wsp_ggml_get_n_tasks(struct wsp_ggml_tensor * node, int n_threads) {
|
|
|
15883
16044
|
{
|
|
15884
16045
|
n_tasks = n_threads;
|
|
15885
16046
|
} break;
|
|
16047
|
+
case WSP_GGML_OP_PAD:
|
|
16048
|
+
{
|
|
16049
|
+
n_tasks = n_threads;
|
|
16050
|
+
} break;
|
|
15886
16051
|
case WSP_GGML_OP_ARGSORT:
|
|
15887
16052
|
{
|
|
15888
16053
|
n_tasks = n_threads;
|