whisper.rn 0.4.2 → 0.5.0-rc.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +1 -3
- package/android/build.gradle +70 -11
- package/android/src/main/CMakeLists.txt +28 -1
- package/android/src/main/java/com/rnwhisper/JSCallInvokerResolver.java +40 -0
- package/android/src/main/java/com/rnwhisper/RNWhisper.java +80 -27
- package/android/src/main/java/com/rnwhisper/WhisperContext.java +21 -9
- package/android/src/main/java/com/rnwhisper/WhisperVadContext.java +1 -1
- package/android/src/main/jni.cpp +79 -2
- package/android/src/main/jniLibs/arm64-v8a/librnwhisper.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnwhisper_v8fp16_va_2.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/librnwhisper.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/librnwhisper_vfpv4.so +0 -0
- package/android/src/main/jniLibs/x86_64/librnwhisper.so +0 -0
- package/android/src/main/jniLibs/x86_64/librnwhisper_x86_64.so +0 -0
- package/android/src/newarch/java/com/rnwhisper/RNWhisperModule.java +5 -0
- package/android/src/oldarch/java/com/rnwhisper/RNWhisperModule.java +5 -0
- package/cpp/ggml-backend.cpp +36 -18
- package/cpp/ggml-backend.h +1 -1
- package/cpp/ggml-cpu/amx/mmq.cpp +10 -9
- package/cpp/ggml-cpu/arch/arm/quants.c +109 -108
- package/cpp/ggml-cpu/arch/arm/repack.cpp +13 -12
- package/cpp/ggml-cpu/arch/x86/quants.c +83 -82
- package/cpp/ggml-cpu/arch/x86/repack.cpp +20 -19
- package/cpp/ggml-cpu/common.h +3 -2
- package/cpp/ggml-cpu/ggml-cpu-impl.h +9 -3
- package/cpp/ggml-cpu/ggml-cpu.c +95 -17
- package/cpp/ggml-cpu/ggml-cpu.cpp +4 -0
- package/cpp/ggml-cpu/ops.cpp +775 -74
- package/cpp/ggml-cpu/ops.h +7 -0
- package/cpp/ggml-cpu/quants.c +25 -24
- package/cpp/ggml-cpu/repack.cpp +15 -14
- package/cpp/ggml-cpu/simd-mappings.h +211 -33
- package/cpp/ggml-cpu/vec.cpp +26 -2
- package/cpp/ggml-cpu/vec.h +99 -45
- package/cpp/ggml-cpu.h +2 -0
- package/cpp/ggml-impl.h +125 -183
- package/cpp/ggml-metal-impl.h +27 -0
- package/cpp/ggml-metal.m +298 -41
- package/cpp/ggml-quants.c +6 -6
- package/cpp/ggml-whisper-sim.metallib +0 -0
- package/cpp/ggml-whisper.metallib +0 -0
- package/cpp/ggml.c +269 -40
- package/cpp/ggml.h +122 -2
- package/cpp/gguf.cpp +5 -1
- package/cpp/jsi/RNWhisperJSI.cpp +681 -0
- package/cpp/jsi/RNWhisperJSI.h +44 -0
- package/cpp/jsi/ThreadPool.h +100 -0
- package/cpp/whisper.cpp +4 -0
- package/cpp/whisper.h +2 -0
- package/ios/RNWhisper.h +3 -0
- package/ios/RNWhisper.mm +66 -31
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend.h +1 -1
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-cpu.h +2 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-impl.h +125 -183
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-metal-impl.h +27 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml.h +122 -2
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/whisper.h +2 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/rnwhisper +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +1 -1
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpu.h +2 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +125 -183
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal-impl.h +27 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +122 -2
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/whisper.h +2 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend.h +1 -1
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-cpu.h +2 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-impl.h +125 -183
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-metal-impl.h +27 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml.h +122 -2
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/whisper.h +2 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/rnwhisper +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +1 -1
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpu.h +2 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +125 -183
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal-impl.h +27 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +122 -2
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/whisper.h +2 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
- package/jest/mock.js +1 -0
- package/lib/commonjs/NativeRNWhisper.js.map +1 -1
- package/lib/commonjs/index.js +83 -2
- package/lib/commonjs/index.js.map +1 -1
- package/lib/module/NativeRNWhisper.js.map +1 -1
- package/lib/module/index.js +83 -2
- package/lib/module/index.js.map +1 -1
- package/lib/typescript/NativeRNWhisper.d.ts +4 -0
- package/lib/typescript/NativeRNWhisper.d.ts.map +1 -1
- package/lib/typescript/index.d.ts +18 -6
- package/lib/typescript/index.d.ts.map +1 -1
- package/package.json +2 -3
- package/src/NativeRNWhisper.ts +2 -0
- package/src/index.ts +162 -33
- package/whisper-rn.podspec +6 -3
package/cpp/ggml-quants.c
CHANGED
|
@@ -568,14 +568,14 @@ static float make_qkx2_quants(int n, int nmax, const float * WSP_GGML_RESTRICT x
|
|
|
568
568
|
}
|
|
569
569
|
float iscale = nmax/(max - min);
|
|
570
570
|
float scale = 1/iscale;
|
|
571
|
-
float
|
|
571
|
+
float best_error = 0;
|
|
572
572
|
for (int i = 0; i < n; ++i) {
|
|
573
573
|
int l = nearest_int(iscale*(x[i] - min));
|
|
574
574
|
L[i] = MAX(0, MIN(nmax, l));
|
|
575
575
|
float diff = scale * L[i] + min - x[i];
|
|
576
576
|
diff = use_mad ? fabsf(diff) : diff * diff;
|
|
577
577
|
float w = weights[i];
|
|
578
|
-
|
|
578
|
+
best_error += w * diff;
|
|
579
579
|
}
|
|
580
580
|
if (nstep < 1) {
|
|
581
581
|
*the_min = -min;
|
|
@@ -601,18 +601,18 @@ static float make_qkx2_quants(int n, int nmax, const float * WSP_GGML_RESTRICT x
|
|
|
601
601
|
this_min = 0;
|
|
602
602
|
this_scale = sum_xl / sum_l2;
|
|
603
603
|
}
|
|
604
|
-
float
|
|
604
|
+
float cur_error = 0;
|
|
605
605
|
for (int i = 0; i < n; ++i) {
|
|
606
606
|
float diff = this_scale * Laux[i] + this_min - x[i];
|
|
607
607
|
diff = use_mad ? fabsf(diff) : diff * diff;
|
|
608
608
|
float w = weights[i];
|
|
609
|
-
|
|
609
|
+
cur_error += w * diff;
|
|
610
610
|
}
|
|
611
|
-
if (
|
|
611
|
+
if (cur_error < best_error) {
|
|
612
612
|
for (int i = 0; i < n; ++i) {
|
|
613
613
|
L[i] = Laux[i];
|
|
614
614
|
}
|
|
615
|
-
|
|
615
|
+
best_error = cur_error;
|
|
616
616
|
scale = this_scale;
|
|
617
617
|
min = this_min;
|
|
618
618
|
}
|
|
Binary file
|
|
Binary file
|
package/cpp/ggml.c
CHANGED
|
@@ -61,9 +61,6 @@
|
|
|
61
61
|
#define m512i(p) (__m512i)(p)
|
|
62
62
|
#endif
|
|
63
63
|
|
|
64
|
-
// precomputed f32 table for f16 (256 KB) (ggml-impl.h)
|
|
65
|
-
float wsp_ggml_table_f32_f16[1 << 16];
|
|
66
|
-
|
|
67
64
|
#if defined(__linux__) || \
|
|
68
65
|
defined(__FreeBSD__) || defined(__NetBSD__) || defined(__OpenBSD__) || \
|
|
69
66
|
(defined(__APPLE__) && !TARGET_OS_TV && !TARGET_OS_WATCH)
|
|
@@ -936,6 +933,7 @@ static const char * WSP_GGML_OP_NAME[WSP_GGML_OP_COUNT] = {
|
|
|
936
933
|
"TRANSPOSE",
|
|
937
934
|
"GET_ROWS",
|
|
938
935
|
"GET_ROWS_BACK",
|
|
936
|
+
"SET_ROWS",
|
|
939
937
|
"DIAG",
|
|
940
938
|
"DIAG_MASK_INF",
|
|
941
939
|
"DIAG_MASK_ZERO",
|
|
@@ -947,6 +945,7 @@ static const char * WSP_GGML_OP_NAME[WSP_GGML_OP_COUNT] = {
|
|
|
947
945
|
"CONV_TRANSPOSE_1D",
|
|
948
946
|
"IM2COL",
|
|
949
947
|
"IM2COL_BACK",
|
|
948
|
+
"CONV_2D",
|
|
950
949
|
"CONV_2D_DW",
|
|
951
950
|
"CONV_TRANSPOSE_2D",
|
|
952
951
|
"POOL_1D",
|
|
@@ -984,9 +983,11 @@ static const char * WSP_GGML_OP_NAME[WSP_GGML_OP_COUNT] = {
|
|
|
984
983
|
"CROSS_ENTROPY_LOSS",
|
|
985
984
|
"CROSS_ENTROPY_LOSS_BACK",
|
|
986
985
|
"OPT_STEP_ADAMW",
|
|
986
|
+
|
|
987
|
+
"GLU",
|
|
987
988
|
};
|
|
988
989
|
|
|
989
|
-
static_assert(WSP_GGML_OP_COUNT ==
|
|
990
|
+
static_assert(WSP_GGML_OP_COUNT == 86, "WSP_GGML_OP_COUNT != 86");
|
|
990
991
|
|
|
991
992
|
static const char * WSP_GGML_OP_SYMBOL[WSP_GGML_OP_COUNT] = {
|
|
992
993
|
"none",
|
|
@@ -1032,6 +1033,7 @@ static const char * WSP_GGML_OP_SYMBOL[WSP_GGML_OP_COUNT] = {
|
|
|
1032
1033
|
"transpose(x)",
|
|
1033
1034
|
"get_rows(x)",
|
|
1034
1035
|
"get_rows_back(x)",
|
|
1036
|
+
"set_rows(x)",
|
|
1035
1037
|
"diag(x)",
|
|
1036
1038
|
"diag_mask_inf(x)",
|
|
1037
1039
|
"diag_mask_zero(x)",
|
|
@@ -1043,6 +1045,7 @@ static const char * WSP_GGML_OP_SYMBOL[WSP_GGML_OP_COUNT] = {
|
|
|
1043
1045
|
"conv_transpose_1d(x)",
|
|
1044
1046
|
"im2col(x)",
|
|
1045
1047
|
"im2col_back(x)",
|
|
1048
|
+
"conv_2d(x)",
|
|
1046
1049
|
"conv_2d_dw(x)",
|
|
1047
1050
|
"conv_transpose_2d(x)",
|
|
1048
1051
|
"pool_1d(x)",
|
|
@@ -1080,9 +1083,11 @@ static const char * WSP_GGML_OP_SYMBOL[WSP_GGML_OP_COUNT] = {
|
|
|
1080
1083
|
"cross_entropy_loss(x,y)",
|
|
1081
1084
|
"cross_entropy_loss_back(x,y)",
|
|
1082
1085
|
"adamw(x)",
|
|
1086
|
+
|
|
1087
|
+
"glu(x)",
|
|
1083
1088
|
};
|
|
1084
1089
|
|
|
1085
|
-
static_assert(WSP_GGML_OP_COUNT ==
|
|
1090
|
+
static_assert(WSP_GGML_OP_COUNT == 86, "WSP_GGML_OP_COUNT != 86");
|
|
1086
1091
|
|
|
1087
1092
|
static_assert(WSP_GGML_OP_POOL_COUNT == 2, "WSP_GGML_OP_POOL_COUNT != 2");
|
|
1088
1093
|
|
|
@@ -1108,6 +1113,15 @@ static const char * WSP_GGML_UNARY_OP_NAME[WSP_GGML_UNARY_OP_COUNT] = {
|
|
|
1108
1113
|
static_assert(WSP_GGML_UNARY_OP_COUNT == 15, "WSP_GGML_UNARY_OP_COUNT != 15");
|
|
1109
1114
|
|
|
1110
1115
|
|
|
1116
|
+
static const char * WSP_GGML_GLU_OP_NAME[WSP_GGML_GLU_OP_COUNT] = {
|
|
1117
|
+
"REGLU",
|
|
1118
|
+
"GEGLU",
|
|
1119
|
+
"SWIGLU",
|
|
1120
|
+
};
|
|
1121
|
+
|
|
1122
|
+
static_assert(WSP_GGML_GLU_OP_COUNT == 3, "WSP_GGML_GLU_OP_COUNT != 3");
|
|
1123
|
+
|
|
1124
|
+
|
|
1111
1125
|
static_assert(sizeof(struct wsp_ggml_object)%WSP_GGML_MEM_ALIGN == 0, "wsp_ggml_object size must be a multiple of WSP_GGML_MEM_ALIGN");
|
|
1112
1126
|
static_assert(sizeof(struct wsp_ggml_tensor)%WSP_GGML_MEM_ALIGN == 0, "wsp_ggml_tensor size must be a multiple of WSP_GGML_MEM_ALIGN");
|
|
1113
1127
|
|
|
@@ -1210,11 +1224,19 @@ const char * wsp_ggml_unary_op_name(enum wsp_ggml_unary_op op) {
|
|
|
1210
1224
|
return WSP_GGML_UNARY_OP_NAME[op];
|
|
1211
1225
|
}
|
|
1212
1226
|
|
|
1227
|
+
const char * wsp_ggml_glu_op_name(enum wsp_ggml_glu_op op) {
|
|
1228
|
+
return WSP_GGML_GLU_OP_NAME[op];
|
|
1229
|
+
}
|
|
1230
|
+
|
|
1213
1231
|
const char * wsp_ggml_op_desc(const struct wsp_ggml_tensor * t) {
|
|
1214
1232
|
if (t->op == WSP_GGML_OP_UNARY) {
|
|
1215
1233
|
enum wsp_ggml_unary_op uop = wsp_ggml_get_unary_op(t);
|
|
1216
1234
|
return wsp_ggml_unary_op_name(uop);
|
|
1217
1235
|
}
|
|
1236
|
+
if (t->op == WSP_GGML_OP_GLU) {
|
|
1237
|
+
enum wsp_ggml_glu_op gop = wsp_ggml_get_glu_op(t);
|
|
1238
|
+
return wsp_ggml_glu_op_name(gop);
|
|
1239
|
+
}
|
|
1218
1240
|
return wsp_ggml_op_name(t->op);
|
|
1219
1241
|
}
|
|
1220
1242
|
|
|
@@ -1351,6 +1373,12 @@ bool wsp_ggml_is_contiguous_channels(const struct wsp_ggml_tensor * tensor) {
|
|
|
1351
1373
|
tensor->nb[2] == wsp_ggml_type_size(tensor->type);
|
|
1352
1374
|
}
|
|
1353
1375
|
|
|
1376
|
+
bool wsp_ggml_is_contiguous_rows(const struct wsp_ggml_tensor * tensor) {
|
|
1377
|
+
return
|
|
1378
|
+
tensor->ne[0] == wsp_ggml_blck_size(tensor->type) ||
|
|
1379
|
+
tensor->nb[0] == wsp_ggml_type_size(tensor->type);
|
|
1380
|
+
}
|
|
1381
|
+
|
|
1354
1382
|
static inline bool wsp_ggml_is_padded_1d(const struct wsp_ggml_tensor * tensor) {
|
|
1355
1383
|
static_assert(WSP_GGML_MAX_DIMS == 4, "WSP_GGML_MAX_DIMS is not 4 - update this function");
|
|
1356
1384
|
|
|
@@ -1422,14 +1450,6 @@ struct wsp_ggml_context * wsp_ggml_init(struct wsp_ggml_init_params params) {
|
|
|
1422
1450
|
// initialize time system (required on Windows)
|
|
1423
1451
|
wsp_ggml_time_init();
|
|
1424
1452
|
|
|
1425
|
-
for (int i = 0; i < (1 << 16); ++i) {
|
|
1426
|
-
union {
|
|
1427
|
-
uint16_t u16;
|
|
1428
|
-
wsp_ggml_fp16_t fp16;
|
|
1429
|
-
} u = {i};
|
|
1430
|
-
wsp_ggml_table_f32_f16[i] = WSP_GGML_COMPUTE_FP16_TO_FP32(u.fp16);
|
|
1431
|
-
}
|
|
1432
|
-
|
|
1433
1453
|
is_first_call = false;
|
|
1434
1454
|
}
|
|
1435
1455
|
|
|
@@ -1733,6 +1753,11 @@ enum wsp_ggml_unary_op wsp_ggml_get_unary_op(const struct wsp_ggml_tensor * tens
|
|
|
1733
1753
|
return (enum wsp_ggml_unary_op) wsp_ggml_get_op_params_i32(tensor, 0);
|
|
1734
1754
|
}
|
|
1735
1755
|
|
|
1756
|
+
enum wsp_ggml_glu_op wsp_ggml_get_glu_op(const struct wsp_ggml_tensor * tensor) {
|
|
1757
|
+
WSP_GGML_ASSERT(tensor->op == WSP_GGML_OP_GLU);
|
|
1758
|
+
return (enum wsp_ggml_glu_op) wsp_ggml_get_op_params_i32(tensor, 0);
|
|
1759
|
+
}
|
|
1760
|
+
|
|
1736
1761
|
const char * wsp_ggml_get_name(const struct wsp_ggml_tensor * tensor) {
|
|
1737
1762
|
return tensor->name;
|
|
1738
1763
|
}
|
|
@@ -2612,6 +2637,114 @@ struct wsp_ggml_tensor * wsp_ggml_exp_inplace(
|
|
|
2612
2637
|
return wsp_ggml_unary_inplace(ctx, a, WSP_GGML_UNARY_OP_EXP);
|
|
2613
2638
|
}
|
|
2614
2639
|
|
|
2640
|
+
// wsp_ggml_glu
|
|
2641
|
+
|
|
2642
|
+
static struct wsp_ggml_tensor * wsp_ggml_glu_impl(
|
|
2643
|
+
struct wsp_ggml_context * ctx,
|
|
2644
|
+
struct wsp_ggml_tensor * a,
|
|
2645
|
+
struct wsp_ggml_tensor * b,
|
|
2646
|
+
enum wsp_ggml_glu_op op,
|
|
2647
|
+
bool swapped) {
|
|
2648
|
+
WSP_GGML_ASSERT(wsp_ggml_is_contiguous_1(a));
|
|
2649
|
+
|
|
2650
|
+
if (b) {
|
|
2651
|
+
WSP_GGML_ASSERT(wsp_ggml_is_contiguous_1(b));
|
|
2652
|
+
WSP_GGML_ASSERT(wsp_ggml_are_same_shape(a, b));
|
|
2653
|
+
WSP_GGML_ASSERT(a->type == b->type);
|
|
2654
|
+
}
|
|
2655
|
+
|
|
2656
|
+
int64_t ne[WSP_GGML_MAX_DIMS] = { a->ne[0] / 2 }; for (int i = 1; i < WSP_GGML_MAX_DIMS; i++) ne[i] = a->ne[i];
|
|
2657
|
+
struct wsp_ggml_tensor * result = wsp_ggml_new_tensor_impl(ctx, a->type, WSP_GGML_MAX_DIMS, b ? a->ne : ne, NULL, 0);
|
|
2658
|
+
|
|
2659
|
+
wsp_ggml_set_op_params_i32(result, 0, (int32_t) op);
|
|
2660
|
+
wsp_ggml_set_op_params_i32(result, 1, (int32_t) swapped);
|
|
2661
|
+
|
|
2662
|
+
result->op = WSP_GGML_OP_GLU;
|
|
2663
|
+
result->src[0] = a;
|
|
2664
|
+
result->src[1] = b;
|
|
2665
|
+
|
|
2666
|
+
return result;
|
|
2667
|
+
}
|
|
2668
|
+
|
|
2669
|
+
struct wsp_ggml_tensor * wsp_ggml_glu(
|
|
2670
|
+
struct wsp_ggml_context * ctx,
|
|
2671
|
+
struct wsp_ggml_tensor * a,
|
|
2672
|
+
enum wsp_ggml_glu_op op,
|
|
2673
|
+
bool swapped) {
|
|
2674
|
+
return wsp_ggml_glu_impl(ctx, a, NULL, op, swapped);
|
|
2675
|
+
}
|
|
2676
|
+
|
|
2677
|
+
struct wsp_ggml_tensor * wsp_ggml_glu_split(
|
|
2678
|
+
struct wsp_ggml_context * ctx,
|
|
2679
|
+
struct wsp_ggml_tensor * a,
|
|
2680
|
+
struct wsp_ggml_tensor * b,
|
|
2681
|
+
enum wsp_ggml_glu_op op) {
|
|
2682
|
+
return wsp_ggml_glu_impl(ctx, a, b, op, false);
|
|
2683
|
+
}
|
|
2684
|
+
|
|
2685
|
+
// wsp_ggml_reglu
|
|
2686
|
+
|
|
2687
|
+
struct wsp_ggml_tensor * wsp_ggml_reglu(
|
|
2688
|
+
struct wsp_ggml_context * ctx,
|
|
2689
|
+
struct wsp_ggml_tensor * a) {
|
|
2690
|
+
return wsp_ggml_glu_impl(ctx, a, NULL, WSP_GGML_GLU_OP_REGLU, false);
|
|
2691
|
+
}
|
|
2692
|
+
|
|
2693
|
+
struct wsp_ggml_tensor * wsp_ggml_reglu_swapped(
|
|
2694
|
+
struct wsp_ggml_context * ctx,
|
|
2695
|
+
struct wsp_ggml_tensor * a) {
|
|
2696
|
+
return wsp_ggml_glu_impl(ctx, a, NULL, WSP_GGML_GLU_OP_REGLU, true);
|
|
2697
|
+
}
|
|
2698
|
+
|
|
2699
|
+
struct wsp_ggml_tensor * wsp_ggml_reglu_split(
|
|
2700
|
+
struct wsp_ggml_context * ctx,
|
|
2701
|
+
struct wsp_ggml_tensor * a,
|
|
2702
|
+
struct wsp_ggml_tensor * b) {
|
|
2703
|
+
return wsp_ggml_glu_impl(ctx, a, b, WSP_GGML_GLU_OP_REGLU, false);
|
|
2704
|
+
}
|
|
2705
|
+
|
|
2706
|
+
// wsp_ggml_geglu
|
|
2707
|
+
|
|
2708
|
+
struct wsp_ggml_tensor * wsp_ggml_geglu(
|
|
2709
|
+
struct wsp_ggml_context * ctx,
|
|
2710
|
+
struct wsp_ggml_tensor * a) {
|
|
2711
|
+
return wsp_ggml_glu_impl(ctx, a, NULL, WSP_GGML_GLU_OP_GEGLU, false);
|
|
2712
|
+
}
|
|
2713
|
+
|
|
2714
|
+
struct wsp_ggml_tensor * wsp_ggml_geglu_swapped(
|
|
2715
|
+
struct wsp_ggml_context * ctx,
|
|
2716
|
+
struct wsp_ggml_tensor * a) {
|
|
2717
|
+
return wsp_ggml_glu_impl(ctx, a, NULL, WSP_GGML_GLU_OP_GEGLU, true);
|
|
2718
|
+
}
|
|
2719
|
+
|
|
2720
|
+
struct wsp_ggml_tensor * wsp_ggml_geglu_split(
|
|
2721
|
+
struct wsp_ggml_context * ctx,
|
|
2722
|
+
struct wsp_ggml_tensor * a,
|
|
2723
|
+
struct wsp_ggml_tensor * b) {
|
|
2724
|
+
return wsp_ggml_glu_impl(ctx, a, b, WSP_GGML_GLU_OP_GEGLU, false);
|
|
2725
|
+
}
|
|
2726
|
+
|
|
2727
|
+
// wsp_ggml_swiglu
|
|
2728
|
+
|
|
2729
|
+
struct wsp_ggml_tensor * wsp_ggml_swiglu(
|
|
2730
|
+
struct wsp_ggml_context * ctx,
|
|
2731
|
+
struct wsp_ggml_tensor * a) {
|
|
2732
|
+
return wsp_ggml_glu_impl(ctx, a, NULL, WSP_GGML_GLU_OP_SWIGLU, false);
|
|
2733
|
+
}
|
|
2734
|
+
|
|
2735
|
+
struct wsp_ggml_tensor * wsp_ggml_swiglu_swapped(
|
|
2736
|
+
struct wsp_ggml_context * ctx,
|
|
2737
|
+
struct wsp_ggml_tensor * a) {
|
|
2738
|
+
return wsp_ggml_glu_impl(ctx, a, NULL, WSP_GGML_GLU_OP_SWIGLU, true);
|
|
2739
|
+
}
|
|
2740
|
+
|
|
2741
|
+
struct wsp_ggml_tensor * wsp_ggml_swiglu_split(
|
|
2742
|
+
struct wsp_ggml_context * ctx,
|
|
2743
|
+
struct wsp_ggml_tensor * a,
|
|
2744
|
+
struct wsp_ggml_tensor * b) {
|
|
2745
|
+
return wsp_ggml_glu_impl(ctx, a, b, WSP_GGML_GLU_OP_SWIGLU, false);
|
|
2746
|
+
}
|
|
2747
|
+
|
|
2615
2748
|
// wsp_ggml_norm
|
|
2616
2749
|
|
|
2617
2750
|
static struct wsp_ggml_tensor * wsp_ggml_norm_impl(
|
|
@@ -3395,6 +3528,35 @@ struct wsp_ggml_tensor * wsp_ggml_get_rows_back(
|
|
|
3395
3528
|
return result;
|
|
3396
3529
|
}
|
|
3397
3530
|
|
|
3531
|
+
// wsp_ggml_set_rows
|
|
3532
|
+
|
|
3533
|
+
struct wsp_ggml_tensor * wsp_ggml_set_rows(
|
|
3534
|
+
struct wsp_ggml_context * ctx,
|
|
3535
|
+
struct wsp_ggml_tensor * a,
|
|
3536
|
+
struct wsp_ggml_tensor * b,
|
|
3537
|
+
struct wsp_ggml_tensor * c) {
|
|
3538
|
+
WSP_GGML_ASSERT(a->ne[0] == b->ne[0]);
|
|
3539
|
+
WSP_GGML_ASSERT(a->ne[2] == b->ne[2]);
|
|
3540
|
+
WSP_GGML_ASSERT(a->ne[3] == b->ne[3]);
|
|
3541
|
+
WSP_GGML_ASSERT(b->ne[1] == c->ne[0]);
|
|
3542
|
+
WSP_GGML_ASSERT(b->ne[2] % c->ne[1] == 0);
|
|
3543
|
+
WSP_GGML_ASSERT(b->ne[3] % c->ne[2] == 0);
|
|
3544
|
+
WSP_GGML_ASSERT(c->ne[3] == 1);
|
|
3545
|
+
WSP_GGML_ASSERT(b->type == WSP_GGML_TYPE_F32);
|
|
3546
|
+
WSP_GGML_ASSERT(c->type == WSP_GGML_TYPE_I64);
|
|
3547
|
+
|
|
3548
|
+
WSP_GGML_ASSERT(wsp_ggml_is_contiguous_rows(a));
|
|
3549
|
+
WSP_GGML_ASSERT(wsp_ggml_is_contiguous_rows(b));
|
|
3550
|
+
|
|
3551
|
+
struct wsp_ggml_tensor * result = wsp_ggml_view_tensor(ctx, a);
|
|
3552
|
+
|
|
3553
|
+
result->op = WSP_GGML_OP_SET_ROWS;
|
|
3554
|
+
result->src[0] = b;
|
|
3555
|
+
result->src[1] = c;
|
|
3556
|
+
|
|
3557
|
+
return result;
|
|
3558
|
+
}
|
|
3559
|
+
|
|
3398
3560
|
// wsp_ggml_diag
|
|
3399
3561
|
|
|
3400
3562
|
struct wsp_ggml_tensor * wsp_ggml_diag(
|
|
@@ -4131,6 +4293,44 @@ struct wsp_ggml_tensor * wsp_ggml_conv_2d_dw_direct(
|
|
|
4131
4293
|
return result;
|
|
4132
4294
|
}
|
|
4133
4295
|
|
|
4296
|
+
// wsp_ggml_conv_2d_direct
|
|
4297
|
+
|
|
4298
|
+
struct wsp_ggml_tensor * wsp_ggml_conv_2d_direct(
|
|
4299
|
+
struct wsp_ggml_context * ctx,
|
|
4300
|
+
struct wsp_ggml_tensor * a, // convolution kernel [KW, KH, IC, OC]
|
|
4301
|
+
struct wsp_ggml_tensor * b, // input data [W, H, C, N]
|
|
4302
|
+
int s0, // stride dimension 0
|
|
4303
|
+
int s1, // stride dimension 1
|
|
4304
|
+
int p0, // padding dimension 0
|
|
4305
|
+
int p1, // padding dimension 1
|
|
4306
|
+
int d0, // dilation dimension 0
|
|
4307
|
+
int d1) {// dilation dimension 1
|
|
4308
|
+
|
|
4309
|
+
WSP_GGML_ASSERT(a->ne[2] == b->ne[2]);
|
|
4310
|
+
//WSP_GGML_ASSERT(a->type == b->type);
|
|
4311
|
+
|
|
4312
|
+
int64_t ne[4];
|
|
4313
|
+
ne[0] = wsp_ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0);
|
|
4314
|
+
ne[1] = wsp_ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1);
|
|
4315
|
+
ne[2] = a->ne[3];
|
|
4316
|
+
ne[3] = b->ne[3];
|
|
4317
|
+
|
|
4318
|
+
struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, b->type, 4, ne);
|
|
4319
|
+
|
|
4320
|
+
wsp_ggml_set_op_params_i32(result, 0, s0);
|
|
4321
|
+
wsp_ggml_set_op_params_i32(result, 1, s1);
|
|
4322
|
+
wsp_ggml_set_op_params_i32(result, 2, p0);
|
|
4323
|
+
wsp_ggml_set_op_params_i32(result, 3, p1);
|
|
4324
|
+
wsp_ggml_set_op_params_i32(result, 4, d0);
|
|
4325
|
+
wsp_ggml_set_op_params_i32(result, 5, d1);
|
|
4326
|
+
|
|
4327
|
+
result->op = WSP_GGML_OP_CONV_2D;
|
|
4328
|
+
result->src[0] = a;
|
|
4329
|
+
result->src[1] = b;
|
|
4330
|
+
|
|
4331
|
+
return result;
|
|
4332
|
+
}
|
|
4333
|
+
|
|
4134
4334
|
// wsp_ggml_conv_transpose_2d_p0
|
|
4135
4335
|
|
|
4136
4336
|
static int64_t wsp_ggml_calc_conv_transpose_output_size(int64_t ins, int64_t ks, int s, int p) {
|
|
@@ -4247,24 +4447,21 @@ struct wsp_ggml_tensor * wsp_ggml_pool_2d_back(
|
|
|
4247
4447
|
return result;
|
|
4248
4448
|
}
|
|
4249
4449
|
|
|
4250
|
-
// wsp_ggml_upscale
|
|
4450
|
+
// wsp_ggml_upscale / wsp_ggml_interpolate
|
|
4251
4451
|
|
|
4252
|
-
static struct wsp_ggml_tensor *
|
|
4452
|
+
static struct wsp_ggml_tensor * wsp_ggml_interpolate_impl(
|
|
4253
4453
|
struct wsp_ggml_context * ctx,
|
|
4254
4454
|
struct wsp_ggml_tensor * a,
|
|
4255
|
-
|
|
4256
|
-
|
|
4257
|
-
|
|
4258
|
-
|
|
4259
|
-
|
|
4260
|
-
WSP_GGML_ASSERT(
|
|
4261
|
-
WSP_GGML_ASSERT(a->ne[1] <= ne1);
|
|
4262
|
-
WSP_GGML_ASSERT(a->ne[2] <= ne2);
|
|
4263
|
-
WSP_GGML_ASSERT(a->ne[3] <= ne3);
|
|
4455
|
+
int64_t ne0,
|
|
4456
|
+
int64_t ne1,
|
|
4457
|
+
int64_t ne2,
|
|
4458
|
+
int64_t ne3,
|
|
4459
|
+
uint32_t mode) {
|
|
4460
|
+
WSP_GGML_ASSERT((mode & 0xFF) < WSP_GGML_SCALE_MODE_COUNT);
|
|
4264
4461
|
|
|
4265
4462
|
struct wsp_ggml_tensor * result = wsp_ggml_new_tensor_4d(ctx, a->type, ne0, ne1, ne2, ne3);
|
|
4266
4463
|
|
|
4267
|
-
wsp_ggml_set_op_params_i32(result, 0, mode);
|
|
4464
|
+
wsp_ggml_set_op_params_i32(result, 0, (int32_t)mode);
|
|
4268
4465
|
|
|
4269
4466
|
result->op = WSP_GGML_OP_UPSCALE;
|
|
4270
4467
|
result->src[0] = a;
|
|
@@ -4277,7 +4474,8 @@ struct wsp_ggml_tensor * wsp_ggml_upscale(
|
|
|
4277
4474
|
struct wsp_ggml_tensor * a,
|
|
4278
4475
|
int scale_factor,
|
|
4279
4476
|
enum wsp_ggml_scale_mode mode) {
|
|
4280
|
-
|
|
4477
|
+
WSP_GGML_ASSERT(scale_factor > 1);
|
|
4478
|
+
return wsp_ggml_interpolate_impl(ctx, a, a->ne[0] * scale_factor, a->ne[1] * scale_factor, a->ne[2], a->ne[3], mode);
|
|
4281
4479
|
}
|
|
4282
4480
|
|
|
4283
4481
|
struct wsp_ggml_tensor * wsp_ggml_upscale_ext(
|
|
@@ -4288,7 +4486,18 @@ struct wsp_ggml_tensor * wsp_ggml_upscale_ext(
|
|
|
4288
4486
|
int ne2,
|
|
4289
4487
|
int ne3,
|
|
4290
4488
|
enum wsp_ggml_scale_mode mode) {
|
|
4291
|
-
return
|
|
4489
|
+
return wsp_ggml_interpolate_impl(ctx, a, ne0, ne1, ne2, ne3, mode);
|
|
4490
|
+
}
|
|
4491
|
+
|
|
4492
|
+
struct wsp_ggml_tensor * wsp_ggml_interpolate(
|
|
4493
|
+
struct wsp_ggml_context * ctx,
|
|
4494
|
+
struct wsp_ggml_tensor * a,
|
|
4495
|
+
int64_t ne0,
|
|
4496
|
+
int64_t ne1,
|
|
4497
|
+
int64_t ne2,
|
|
4498
|
+
int64_t ne3,
|
|
4499
|
+
uint32_t mode) {
|
|
4500
|
+
return wsp_ggml_interpolate_impl(ctx, a, ne0, ne1, ne2, ne3, mode);
|
|
4292
4501
|
}
|
|
4293
4502
|
|
|
4294
4503
|
// wsp_ggml_pad
|
|
@@ -5815,19 +6024,32 @@ static void wsp_ggml_compute_backward(
|
|
|
5815
6024
|
WSP_GGML_ASSERT(!src2_needs_grads || wsp_ggml_are_same_shape(src2, cgraph->grads[isrc2]));
|
|
5816
6025
|
}
|
|
5817
6026
|
|
|
5818
|
-
static
|
|
6027
|
+
static size_t wsp_ggml_visit_parents(struct wsp_ggml_cgraph * cgraph, struct wsp_ggml_tensor * node) {
|
|
5819
6028
|
// check if already visited
|
|
5820
|
-
|
|
5821
|
-
|
|
6029
|
+
size_t node_hash_pos = wsp_ggml_hash_find(&cgraph->visited_hash_set, node);
|
|
6030
|
+
WSP_GGML_ASSERT(node_hash_pos != WSP_GGML_HASHSET_FULL);
|
|
6031
|
+
if (!wsp_ggml_bitset_get(cgraph->visited_hash_set.used, node_hash_pos)) {
|
|
6032
|
+
// This is the first time we see this node in the current graph.
|
|
6033
|
+
cgraph->visited_hash_set.keys[node_hash_pos] = node;
|
|
6034
|
+
wsp_ggml_bitset_set(cgraph->visited_hash_set.used, node_hash_pos);
|
|
6035
|
+
cgraph->use_counts[node_hash_pos] = 0;
|
|
6036
|
+
} else {
|
|
6037
|
+
// already visited
|
|
6038
|
+
return node_hash_pos;
|
|
5822
6039
|
}
|
|
5823
6040
|
|
|
5824
6041
|
for (int i = 0; i < WSP_GGML_MAX_SRC; ++i) {
|
|
5825
6042
|
const int k =
|
|
5826
6043
|
(cgraph->order == WSP_GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT) ? i :
|
|
5827
6044
|
(cgraph->order == WSP_GGML_CGRAPH_EVAL_ORDER_RIGHT_TO_LEFT) ? (WSP_GGML_MAX_SRC-1-i) :
|
|
5828
|
-
/* unknown order, just fall back to using i*/ i;
|
|
5829
|
-
|
|
5830
|
-
|
|
6045
|
+
/* unknown order, just fall back to using i */ i;
|
|
6046
|
+
|
|
6047
|
+
struct wsp_ggml_tensor * src = node->src[k];
|
|
6048
|
+
if (src) {
|
|
6049
|
+
size_t src_hash_pos = wsp_ggml_visit_parents(cgraph, src);
|
|
6050
|
+
|
|
6051
|
+
// Update the use count for this operand.
|
|
6052
|
+
cgraph->use_counts[src_hash_pos]++;
|
|
5831
6053
|
}
|
|
5832
6054
|
}
|
|
5833
6055
|
|
|
@@ -5851,6 +6073,8 @@ static void wsp_ggml_visit_parents(struct wsp_ggml_cgraph * cgraph, struct wsp_g
|
|
|
5851
6073
|
cgraph->nodes[cgraph->n_nodes] = node;
|
|
5852
6074
|
cgraph->n_nodes++;
|
|
5853
6075
|
}
|
|
6076
|
+
|
|
6077
|
+
return node_hash_pos;
|
|
5854
6078
|
}
|
|
5855
6079
|
|
|
5856
6080
|
static void wsp_ggml_build_forward_impl(struct wsp_ggml_cgraph * cgraph, struct wsp_ggml_tensor * tensor, bool expand) {
|
|
@@ -5988,6 +6212,7 @@ static size_t wsp_ggml_graph_nbytes(size_t size, bool grads) {
|
|
|
5988
6212
|
incr_ptr_aligned(&p, sizeof(struct wsp_ggml_cgraph), 1);
|
|
5989
6213
|
incr_ptr_aligned(&p, size * sizeof(struct wsp_ggml_tensor *), sizeof(struct wsp_ggml_tensor *)); // nodes
|
|
5990
6214
|
incr_ptr_aligned(&p, size * sizeof(struct wsp_ggml_tensor *), sizeof(struct wsp_ggml_tensor *)); // leafs
|
|
6215
|
+
incr_ptr_aligned(&p, hash_size * sizeof(int32_t), sizeof(int32_t)); // use_counts
|
|
5991
6216
|
incr_ptr_aligned(&p, hash_size * sizeof(struct wsp_ggml_tensor *), sizeof(struct wsp_ggml_tensor *)); // hash keys
|
|
5992
6217
|
if (grads) {
|
|
5993
6218
|
incr_ptr_aligned(&p, hash_size * sizeof(struct wsp_ggml_tensor *), sizeof(struct wsp_ggml_tensor *)); // grads
|
|
@@ -6017,11 +6242,12 @@ struct wsp_ggml_cgraph * wsp_ggml_new_graph_custom(struct wsp_ggml_context * ctx
|
|
|
6017
6242
|
|
|
6018
6243
|
void * p = cgraph + 1;
|
|
6019
6244
|
|
|
6020
|
-
struct wsp_ggml_tensor ** nodes_ptr
|
|
6021
|
-
struct wsp_ggml_tensor ** leafs_ptr
|
|
6022
|
-
|
|
6023
|
-
struct wsp_ggml_tensor **
|
|
6024
|
-
struct wsp_ggml_tensor **
|
|
6245
|
+
struct wsp_ggml_tensor ** nodes_ptr = incr_ptr_aligned(&p, size * sizeof(struct wsp_ggml_tensor *), sizeof(struct wsp_ggml_tensor *));
|
|
6246
|
+
struct wsp_ggml_tensor ** leafs_ptr = incr_ptr_aligned(&p, size * sizeof(struct wsp_ggml_tensor *), sizeof(struct wsp_ggml_tensor *));
|
|
6247
|
+
int32_t * use_counts_ptr = incr_ptr_aligned(&p, hash_size * sizeof(int32_t), sizeof(int32_t));
|
|
6248
|
+
struct wsp_ggml_tensor ** hash_keys_ptr = incr_ptr_aligned(&p, hash_size * sizeof(struct wsp_ggml_tensor *), sizeof(struct wsp_ggml_tensor *));
|
|
6249
|
+
struct wsp_ggml_tensor ** grads_ptr = grads ? incr_ptr_aligned(&p, hash_size * sizeof(struct wsp_ggml_tensor *), sizeof(struct wsp_ggml_tensor *)) : NULL;
|
|
6250
|
+
struct wsp_ggml_tensor ** grad_accs_ptr = grads ? incr_ptr_aligned(&p, hash_size * sizeof(struct wsp_ggml_tensor *), sizeof(struct wsp_ggml_tensor *)) : NULL;
|
|
6025
6251
|
|
|
6026
6252
|
wsp_ggml_bitset_t * hash_used = incr_ptr_aligned(&p, wsp_ggml_bitset_size(hash_size) * sizeof(wsp_ggml_bitset_t), sizeof(wsp_ggml_bitset_t));
|
|
6027
6253
|
|
|
@@ -6036,6 +6262,7 @@ struct wsp_ggml_cgraph * wsp_ggml_new_graph_custom(struct wsp_ggml_context * ctx
|
|
|
6036
6262
|
/*.grads =*/ grads_ptr,
|
|
6037
6263
|
/*.grad_accs =*/ grad_accs_ptr,
|
|
6038
6264
|
/*.leafs =*/ leafs_ptr,
|
|
6265
|
+
/*.use_counts =*/ use_counts_ptr,
|
|
6039
6266
|
/*.hash_table =*/ { hash_size, hash_used, hash_keys_ptr },
|
|
6040
6267
|
/*.order =*/ WSP_GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT,
|
|
6041
6268
|
};
|
|
@@ -6062,7 +6289,8 @@ struct wsp_ggml_cgraph wsp_ggml_graph_view(struct wsp_ggml_cgraph * cgraph0, int
|
|
|
6062
6289
|
/*.grads =*/ NULL, // gradients would need visited_hash_set
|
|
6063
6290
|
/*.grad_accs =*/ NULL,
|
|
6064
6291
|
/*.leafs =*/ NULL,
|
|
6065
|
-
/*.
|
|
6292
|
+
/*.use_counts =*/ cgraph0->use_counts,
|
|
6293
|
+
/*.visited_hash_set =*/ cgraph0->visited_hash_set,
|
|
6066
6294
|
/*.order =*/ cgraph0->order,
|
|
6067
6295
|
};
|
|
6068
6296
|
|
|
@@ -6089,7 +6317,8 @@ void wsp_ggml_graph_cpy(struct wsp_ggml_cgraph * src, struct wsp_ggml_cgraph * d
|
|
|
6089
6317
|
for (size_t i = 0; i < src->visited_hash_set.size; ++i) {
|
|
6090
6318
|
// copy all hashset keys (tensors) that are in use
|
|
6091
6319
|
if (wsp_ggml_bitset_get(src->visited_hash_set.used, i)) {
|
|
6092
|
-
wsp_ggml_hash_insert(&dst->visited_hash_set, src->visited_hash_set.keys[i]);
|
|
6320
|
+
size_t new_hash_pos = wsp_ggml_hash_insert(&dst->visited_hash_set, src->visited_hash_set.keys[i]);
|
|
6321
|
+
dst->use_counts[new_hash_pos] = src->use_counts[i];
|
|
6093
6322
|
}
|
|
6094
6323
|
}
|
|
6095
6324
|
|