whisper.rn 0.5.0-rc.9 → 0.5.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/build.gradle +2 -1
- package/android/gradle.properties +1 -1
- package/cpp/ggml-alloc.c +265 -141
- package/cpp/ggml-backend-impl.h +4 -1
- package/cpp/ggml-backend-reg.cpp +30 -13
- package/cpp/ggml-backend.cpp +221 -38
- package/cpp/ggml-backend.h +17 -1
- package/cpp/ggml-common.h +17 -0
- package/cpp/ggml-cpu/amx/amx.cpp +4 -2
- package/cpp/ggml-cpu/arch/arm/quants.c +132 -596
- package/cpp/ggml-cpu/arch/arm/repack.cpp +14 -286
- package/cpp/ggml-cpu/arch/x86/quants.c +184 -675
- package/cpp/ggml-cpu/arch/x86/repack.cpp +4679 -1657
- package/cpp/ggml-cpu/arch-fallback.h +32 -2
- package/cpp/ggml-cpu/common.h +14 -0
- package/cpp/ggml-cpu/ggml-cpu-impl.h +13 -6
- package/cpp/ggml-cpu/ggml-cpu.c +70 -42
- package/cpp/ggml-cpu/ggml-cpu.cpp +35 -28
- package/cpp/ggml-cpu/ops.cpp +1587 -1177
- package/cpp/ggml-cpu/ops.h +5 -8
- package/cpp/ggml-cpu/quants.c +35 -0
- package/cpp/ggml-cpu/quants.h +8 -0
- package/cpp/ggml-cpu/repack.cpp +458 -47
- package/cpp/ggml-cpu/repack.h +22 -0
- package/cpp/ggml-cpu/simd-mappings.h +89 -60
- package/cpp/ggml-cpu/traits.cpp +2 -2
- package/cpp/ggml-cpu/traits.h +1 -1
- package/cpp/ggml-cpu/vec.cpp +170 -26
- package/cpp/ggml-cpu/vec.h +506 -63
- package/cpp/ggml-cpu.h +1 -1
- package/cpp/ggml-impl.h +119 -9
- package/cpp/ggml-metal/ggml-metal-common.cpp +446 -0
- package/cpp/ggml-metal/ggml-metal-common.h +52 -0
- package/cpp/ggml-metal/ggml-metal-context.h +33 -0
- package/cpp/ggml-metal/ggml-metal-context.m +600 -0
- package/cpp/ggml-metal/ggml-metal-device.cpp +1376 -0
- package/cpp/ggml-metal/ggml-metal-device.h +226 -0
- package/cpp/ggml-metal/ggml-metal-device.m +1312 -0
- package/cpp/ggml-metal/ggml-metal-impl.h +722 -0
- package/cpp/ggml-metal/ggml-metal-ops.cpp +3158 -0
- package/cpp/ggml-metal/ggml-metal-ops.h +82 -0
- package/cpp/ggml-metal/ggml-metal.cpp +718 -0
- package/cpp/ggml-metal/ggml-whisper-sim.metallib +0 -0
- package/cpp/ggml-metal/ggml-whisper.metallib +0 -0
- package/cpp/ggml-metal-impl.h +90 -51
- package/cpp/ggml-metal.h +1 -6
- package/cpp/ggml-opt.cpp +97 -41
- package/cpp/ggml-opt.h +25 -6
- package/cpp/ggml-quants.c +111 -16
- package/cpp/ggml-quants.h +6 -0
- package/cpp/ggml.c +486 -98
- package/cpp/ggml.h +221 -16
- package/cpp/gguf.cpp +8 -1
- package/cpp/jsi/RNWhisperJSI.cpp +25 -6
- package/cpp/jsi/ThreadPool.h +3 -3
- package/cpp/whisper.cpp +100 -76
- package/cpp/whisper.h +1 -0
- package/ios/CMakeLists.txt +6 -1
- package/ios/RNWhisper.mm +6 -6
- package/ios/RNWhisperContext.mm +2 -0
- package/ios/RNWhisperVadContext.mm +16 -13
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +4 -1
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend.h +17 -1
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-common.h +17 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-cpu.h +1 -1
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-impl.h +119 -9
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-metal-impl.h +90 -51
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-metal.h +1 -6
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-opt.h +25 -6
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-quants.h +6 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml.h +221 -16
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/whisper.h +1 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Info.plist +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/rnwhisper +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +4 -1
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +17 -1
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-common.h +17 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpu.h +1 -1
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +119 -9
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal-impl.h +90 -51
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal.h +1 -6
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-opt.h +25 -6
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-quants.h +6 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +221 -16
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/whisper.h +1 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +1 -1
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +4 -1
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend.h +17 -1
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-common.h +17 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-cpu.h +1 -1
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-impl.h +119 -9
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-metal-impl.h +90 -51
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-metal.h +1 -6
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-opt.h +25 -6
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-quants.h +6 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml.h +221 -16
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/whisper.h +1 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Info.plist +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/rnwhisper +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +4 -1
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +17 -1
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-common.h +17 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpu.h +1 -1
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +119 -9
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal-impl.h +90 -51
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal.h +1 -6
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-opt.h +25 -6
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-quants.h +6 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +221 -16
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/whisper.h +1 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +1 -1
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
- package/lib/commonjs/realtime-transcription/RealtimeTranscriber.js +13 -0
- package/lib/commonjs/realtime-transcription/RealtimeTranscriber.js.map +1 -1
- package/lib/commonjs/version.json +1 -1
- package/lib/module/realtime-transcription/RealtimeTranscriber.js +13 -0
- package/lib/module/realtime-transcription/RealtimeTranscriber.js.map +1 -1
- package/lib/module/version.json +1 -1
- package/lib/typescript/realtime-transcription/RealtimeTranscriber.d.ts.map +1 -1
- package/lib/typescript/realtime-transcription/types.d.ts +6 -0
- package/lib/typescript/realtime-transcription/types.d.ts.map +1 -1
- package/package.json +1 -1
- package/src/realtime-transcription/RealtimeTranscriber.ts +17 -0
- package/src/realtime-transcription/types.ts +6 -0
- package/src/version.json +1 -1
- package/whisper-rn.podspec +8 -9
- package/cpp/ggml-metal.m +0 -6284
- package/cpp/ggml-whisper-sim.metallib +0 -0
- package/cpp/ggml-whisper.metallib +0 -0
package/cpp/ggml.c
CHANGED
|
@@ -1,6 +1,14 @@
|
|
|
1
1
|
#define _CRT_SECURE_NO_DEPRECATE // Disables "unsafe" warnings on Windows
|
|
2
2
|
#define _USE_MATH_DEFINES // For M_PI on MSVC
|
|
3
3
|
|
|
4
|
+
// GGML build info
|
|
5
|
+
#ifndef WSP_GGML_VERSION
|
|
6
|
+
#define WSP_GGML_VERSION "unknown"
|
|
7
|
+
#endif
|
|
8
|
+
#ifndef WSP_GGML_COMMIT
|
|
9
|
+
#define WSP_GGML_COMMIT "unknown"
|
|
10
|
+
#endif
|
|
11
|
+
|
|
4
12
|
#include "ggml-backend.h"
|
|
5
13
|
#include "ggml-impl.h"
|
|
6
14
|
#include "ggml-threading.h"
|
|
@@ -202,19 +210,34 @@ void wsp_ggml_print_backtrace(void) {
|
|
|
202
210
|
}
|
|
203
211
|
#endif
|
|
204
212
|
|
|
213
|
+
static wsp_ggml_abort_callback_t g_abort_callback = NULL;
|
|
214
|
+
|
|
215
|
+
// Set the abort callback (passing null will restore original abort functionality: printing a message to stdout)
|
|
216
|
+
WSP_GGML_API wsp_ggml_abort_callback_t wsp_ggml_set_abort_callback(wsp_ggml_abort_callback_t callback) {
|
|
217
|
+
wsp_ggml_abort_callback_t ret_val = g_abort_callback;
|
|
218
|
+
g_abort_callback = callback;
|
|
219
|
+
return ret_val;
|
|
220
|
+
}
|
|
221
|
+
|
|
205
222
|
void wsp_ggml_abort(const char * file, int line, const char * fmt, ...) {
|
|
206
223
|
fflush(stdout);
|
|
207
224
|
|
|
208
|
-
|
|
225
|
+
char message[2048];
|
|
226
|
+
int offset = snprintf(message, sizeof(message), "%s:%d: ", file, line);
|
|
209
227
|
|
|
210
228
|
va_list args;
|
|
211
229
|
va_start(args, fmt);
|
|
212
|
-
|
|
230
|
+
vsnprintf(message + offset, sizeof(message) - offset, fmt, args);
|
|
213
231
|
va_end(args);
|
|
214
232
|
|
|
215
|
-
|
|
233
|
+
if (g_abort_callback) {
|
|
234
|
+
g_abort_callback(message);
|
|
235
|
+
} else {
|
|
236
|
+
// default: print error and backtrace to stderr
|
|
237
|
+
fprintf(stderr, "%s\n", message);
|
|
238
|
+
wsp_ggml_print_backtrace();
|
|
239
|
+
}
|
|
216
240
|
|
|
217
|
-
wsp_ggml_print_backtrace();
|
|
218
241
|
abort();
|
|
219
242
|
}
|
|
220
243
|
|
|
@@ -458,6 +481,14 @@ bool wsp_ggml_guid_matches(wsp_ggml_guid_t guid_a, wsp_ggml_guid_t guid_b) {
|
|
|
458
481
|
return memcmp(guid_a, guid_b, sizeof(wsp_ggml_guid)) == 0;
|
|
459
482
|
}
|
|
460
483
|
|
|
484
|
+
const char * wsp_ggml_version(void) {
|
|
485
|
+
return WSP_GGML_VERSION;
|
|
486
|
+
}
|
|
487
|
+
|
|
488
|
+
const char * wsp_ggml_commit(void) {
|
|
489
|
+
return WSP_GGML_COMMIT;
|
|
490
|
+
}
|
|
491
|
+
|
|
461
492
|
//
|
|
462
493
|
// timing
|
|
463
494
|
//
|
|
@@ -559,9 +590,6 @@ FILE * wsp_ggml_fopen(const char * fname, const char * mode) {
|
|
|
559
590
|
#endif
|
|
560
591
|
|
|
561
592
|
}
|
|
562
|
-
static void wsp_ggml_vec_dot_f32(int n, float * WSP_GGML_RESTRICT s, size_t bs, const float * WSP_GGML_RESTRICT x, size_t bx, const float * WSP_GGML_RESTRICT y, size_t by, int nrc);
|
|
563
|
-
static void wsp_ggml_vec_dot_f16(int n, float * WSP_GGML_RESTRICT s, size_t bs, wsp_ggml_fp16_t * WSP_GGML_RESTRICT x, size_t bx, wsp_ggml_fp16_t * WSP_GGML_RESTRICT y, size_t by, int nrc);
|
|
564
|
-
static void wsp_ggml_vec_dot_bf16(int n, float * WSP_GGML_RESTRICT s, size_t bs, wsp_ggml_bf16_t * WSP_GGML_RESTRICT x, size_t bx, wsp_ggml_bf16_t * WSP_GGML_RESTRICT y, size_t by, int nrc);
|
|
565
593
|
|
|
566
594
|
static const struct wsp_ggml_type_traits type_traits[WSP_GGML_TYPE_COUNT] = {
|
|
567
595
|
[WSP_GGML_TYPE_I8] = {
|
|
@@ -667,6 +695,14 @@ static const struct wsp_ggml_type_traits type_traits[WSP_GGML_TYPE_COUNT] = {
|
|
|
667
695
|
.is_quantized = true,
|
|
668
696
|
.from_float_ref = (wsp_ggml_from_float_t) wsp_quantize_row_q8_1_ref,
|
|
669
697
|
},
|
|
698
|
+
[WSP_GGML_TYPE_MXFP4] = {
|
|
699
|
+
.type_name = "mxfp4",
|
|
700
|
+
.blck_size = QK_MXFP4,
|
|
701
|
+
.type_size = sizeof(block_mxfp4),
|
|
702
|
+
.is_quantized = true,
|
|
703
|
+
.to_float = (wsp_ggml_to_float_t) wsp_dewsp_quantize_row_mxfp4,
|
|
704
|
+
.from_float_ref = (wsp_ggml_from_float_t)wsp_quantize_row_mxfp4_ref,
|
|
705
|
+
},
|
|
670
706
|
[WSP_GGML_TYPE_Q2_K] = {
|
|
671
707
|
.type_name = "q2_K",
|
|
672
708
|
.blck_size = QK_K,
|
|
@@ -894,6 +930,7 @@ static const char * WSP_GGML_OP_NAME[WSP_GGML_OP_COUNT] = {
|
|
|
894
930
|
|
|
895
931
|
"DUP",
|
|
896
932
|
"ADD",
|
|
933
|
+
"ADD_ID",
|
|
897
934
|
"ADD1",
|
|
898
935
|
"ACC",
|
|
899
936
|
"SUB",
|
|
@@ -945,7 +982,9 @@ static const char * WSP_GGML_OP_NAME[WSP_GGML_OP_COUNT] = {
|
|
|
945
982
|
"CONV_TRANSPOSE_1D",
|
|
946
983
|
"IM2COL",
|
|
947
984
|
"IM2COL_BACK",
|
|
985
|
+
"IM2COL_3D",
|
|
948
986
|
"CONV_2D",
|
|
987
|
+
"CONV_3D",
|
|
949
988
|
"CONV_2D_DW",
|
|
950
989
|
"CONV_TRANSPOSE_2D",
|
|
951
990
|
"POOL_1D",
|
|
@@ -983,17 +1022,19 @@ static const char * WSP_GGML_OP_NAME[WSP_GGML_OP_COUNT] = {
|
|
|
983
1022
|
"CROSS_ENTROPY_LOSS",
|
|
984
1023
|
"CROSS_ENTROPY_LOSS_BACK",
|
|
985
1024
|
"OPT_STEP_ADAMW",
|
|
1025
|
+
"OPT_STEP_SGD",
|
|
986
1026
|
|
|
987
1027
|
"GLU",
|
|
988
1028
|
};
|
|
989
1029
|
|
|
990
|
-
static_assert(WSP_GGML_OP_COUNT ==
|
|
1030
|
+
static_assert(WSP_GGML_OP_COUNT == 90, "WSP_GGML_OP_COUNT != 90");
|
|
991
1031
|
|
|
992
1032
|
static const char * WSP_GGML_OP_SYMBOL[WSP_GGML_OP_COUNT] = {
|
|
993
1033
|
"none",
|
|
994
1034
|
|
|
995
1035
|
"x",
|
|
996
1036
|
"x+y",
|
|
1037
|
+
"x[i]+y",
|
|
997
1038
|
"x+y",
|
|
998
1039
|
"view(x,nb,offset)+=y->x",
|
|
999
1040
|
"x-y",
|
|
@@ -1045,7 +1086,9 @@ static const char * WSP_GGML_OP_SYMBOL[WSP_GGML_OP_COUNT] = {
|
|
|
1045
1086
|
"conv_transpose_1d(x)",
|
|
1046
1087
|
"im2col(x)",
|
|
1047
1088
|
"im2col_back(x)",
|
|
1089
|
+
"im2col_3d(x)",
|
|
1048
1090
|
"conv_2d(x)",
|
|
1091
|
+
"conv_3d(x)",
|
|
1049
1092
|
"conv_2d_dw(x)",
|
|
1050
1093
|
"conv_transpose_2d(x)",
|
|
1051
1094
|
"pool_1d(x)",
|
|
@@ -1083,15 +1126,15 @@ static const char * WSP_GGML_OP_SYMBOL[WSP_GGML_OP_COUNT] = {
|
|
|
1083
1126
|
"cross_entropy_loss(x,y)",
|
|
1084
1127
|
"cross_entropy_loss_back(x,y)",
|
|
1085
1128
|
"adamw(x)",
|
|
1129
|
+
"sgd(x)",
|
|
1086
1130
|
|
|
1087
1131
|
"glu(x)",
|
|
1088
1132
|
};
|
|
1089
1133
|
|
|
1090
|
-
static_assert(WSP_GGML_OP_COUNT ==
|
|
1134
|
+
static_assert(WSP_GGML_OP_COUNT == 90, "WSP_GGML_OP_COUNT != 90");
|
|
1091
1135
|
|
|
1092
1136
|
static_assert(WSP_GGML_OP_POOL_COUNT == 2, "WSP_GGML_OP_POOL_COUNT != 2");
|
|
1093
1137
|
|
|
1094
|
-
|
|
1095
1138
|
static const char * WSP_GGML_UNARY_OP_NAME[WSP_GGML_UNARY_OP_COUNT] = {
|
|
1096
1139
|
"ABS",
|
|
1097
1140
|
"SGN",
|
|
@@ -1117,9 +1160,12 @@ static const char * WSP_GGML_GLU_OP_NAME[WSP_GGML_GLU_OP_COUNT] = {
|
|
|
1117
1160
|
"REGLU",
|
|
1118
1161
|
"GEGLU",
|
|
1119
1162
|
"SWIGLU",
|
|
1163
|
+
"SWIGLU_OAI",
|
|
1164
|
+
"GEGLU_ERF",
|
|
1165
|
+
"GEGLU_QUICK",
|
|
1120
1166
|
};
|
|
1121
1167
|
|
|
1122
|
-
static_assert(WSP_GGML_GLU_OP_COUNT ==
|
|
1168
|
+
static_assert(WSP_GGML_GLU_OP_COUNT == 6, "WSP_GGML_GLU_OP_COUNT != 6");
|
|
1123
1169
|
|
|
1124
1170
|
|
|
1125
1171
|
static_assert(sizeof(struct wsp_ggml_object)%WSP_GGML_MEM_ALIGN == 0, "wsp_ggml_object size must be a multiple of WSP_GGML_MEM_ALIGN");
|
|
@@ -1287,6 +1333,7 @@ enum wsp_ggml_type wsp_ggml_ftype_to_wsp_ggml_type(enum wsp_ggml_ftype ftype) {
|
|
|
1287
1333
|
case WSP_GGML_FTYPE_MOSTLY_Q5_0: wtype = WSP_GGML_TYPE_Q5_0; break;
|
|
1288
1334
|
case WSP_GGML_FTYPE_MOSTLY_Q5_1: wtype = WSP_GGML_TYPE_Q5_1; break;
|
|
1289
1335
|
case WSP_GGML_FTYPE_MOSTLY_Q8_0: wtype = WSP_GGML_TYPE_Q8_0; break;
|
|
1336
|
+
case WSP_GGML_FTYPE_MOSTLY_MXFP4: wtype = WSP_GGML_TYPE_MXFP4; break;
|
|
1290
1337
|
case WSP_GGML_FTYPE_MOSTLY_Q2_K: wtype = WSP_GGML_TYPE_Q2_K; break;
|
|
1291
1338
|
case WSP_GGML_FTYPE_MOSTLY_Q3_K: wtype = WSP_GGML_TYPE_Q3_K; break;
|
|
1292
1339
|
case WSP_GGML_FTYPE_MOSTLY_Q4_K: wtype = WSP_GGML_TYPE_Q4_K; break;
|
|
@@ -1937,6 +1984,27 @@ struct wsp_ggml_tensor * wsp_ggml_add_cast(
|
|
|
1937
1984
|
return wsp_ggml_add_cast_impl(ctx, a, b, type);
|
|
1938
1985
|
}
|
|
1939
1986
|
|
|
1987
|
+
struct wsp_ggml_tensor * wsp_ggml_add_id(
|
|
1988
|
+
struct wsp_ggml_context * ctx,
|
|
1989
|
+
struct wsp_ggml_tensor * a,
|
|
1990
|
+
struct wsp_ggml_tensor * b,
|
|
1991
|
+
struct wsp_ggml_tensor * ids) {
|
|
1992
|
+
|
|
1993
|
+
WSP_GGML_ASSERT(a->ne[0] == b->ne[0]);
|
|
1994
|
+
WSP_GGML_ASSERT(a->ne[1] == ids->ne[0]);
|
|
1995
|
+
WSP_GGML_ASSERT(a->ne[2] == ids->ne[1]);
|
|
1996
|
+
WSP_GGML_ASSERT(ids->type == WSP_GGML_TYPE_I32);
|
|
1997
|
+
|
|
1998
|
+
struct wsp_ggml_tensor * result = wsp_ggml_dup_tensor(ctx, a);
|
|
1999
|
+
|
|
2000
|
+
result->op = WSP_GGML_OP_ADD_ID;
|
|
2001
|
+
result->src[0] = a;
|
|
2002
|
+
result->src[1] = b;
|
|
2003
|
+
result->src[2] = ids;
|
|
2004
|
+
|
|
2005
|
+
return result;
|
|
2006
|
+
}
|
|
2007
|
+
|
|
1940
2008
|
// wsp_ggml_add1
|
|
1941
2009
|
|
|
1942
2010
|
static struct wsp_ggml_tensor * wsp_ggml_add1_impl(
|
|
@@ -2745,6 +2813,61 @@ struct wsp_ggml_tensor * wsp_ggml_swiglu_split(
|
|
|
2745
2813
|
return wsp_ggml_glu_impl(ctx, a, b, WSP_GGML_GLU_OP_SWIGLU, false);
|
|
2746
2814
|
}
|
|
2747
2815
|
|
|
2816
|
+
// wsp_ggml_geglu_erf
|
|
2817
|
+
|
|
2818
|
+
struct wsp_ggml_tensor * wsp_ggml_geglu_erf(
|
|
2819
|
+
struct wsp_ggml_context * ctx,
|
|
2820
|
+
struct wsp_ggml_tensor * a) {
|
|
2821
|
+
return wsp_ggml_glu_impl(ctx, a, NULL, WSP_GGML_GLU_OP_GEGLU_ERF, false);
|
|
2822
|
+
}
|
|
2823
|
+
|
|
2824
|
+
struct wsp_ggml_tensor * wsp_ggml_geglu_erf_swapped(
|
|
2825
|
+
struct wsp_ggml_context * ctx,
|
|
2826
|
+
struct wsp_ggml_tensor * a) {
|
|
2827
|
+
return wsp_ggml_glu_impl(ctx, a, NULL, WSP_GGML_GLU_OP_GEGLU_ERF, true);
|
|
2828
|
+
}
|
|
2829
|
+
|
|
2830
|
+
struct wsp_ggml_tensor * wsp_ggml_geglu_erf_split(
|
|
2831
|
+
struct wsp_ggml_context * ctx,
|
|
2832
|
+
struct wsp_ggml_tensor * a,
|
|
2833
|
+
struct wsp_ggml_tensor * b) {
|
|
2834
|
+
return wsp_ggml_glu_impl(ctx, a, b, WSP_GGML_GLU_OP_GEGLU_ERF, false);
|
|
2835
|
+
}
|
|
2836
|
+
|
|
2837
|
+
// wsp_ggml_geglu_quick
|
|
2838
|
+
|
|
2839
|
+
struct wsp_ggml_tensor * wsp_ggml_geglu_quick(
|
|
2840
|
+
struct wsp_ggml_context * ctx,
|
|
2841
|
+
struct wsp_ggml_tensor * a) {
|
|
2842
|
+
return wsp_ggml_glu_impl(ctx, a, NULL, WSP_GGML_GLU_OP_GEGLU_QUICK, false);
|
|
2843
|
+
}
|
|
2844
|
+
|
|
2845
|
+
struct wsp_ggml_tensor * wsp_ggml_geglu_quick_swapped(
|
|
2846
|
+
struct wsp_ggml_context * ctx,
|
|
2847
|
+
struct wsp_ggml_tensor * a) {
|
|
2848
|
+
return wsp_ggml_glu_impl(ctx, a, NULL, WSP_GGML_GLU_OP_GEGLU_QUICK, true);
|
|
2849
|
+
}
|
|
2850
|
+
|
|
2851
|
+
struct wsp_ggml_tensor * wsp_ggml_geglu_quick_split(
|
|
2852
|
+
struct wsp_ggml_context * ctx,
|
|
2853
|
+
struct wsp_ggml_tensor * a,
|
|
2854
|
+
struct wsp_ggml_tensor * b) {
|
|
2855
|
+
return wsp_ggml_glu_impl(ctx, a, b, WSP_GGML_GLU_OP_GEGLU_QUICK, false);
|
|
2856
|
+
}
|
|
2857
|
+
|
|
2858
|
+
struct wsp_ggml_tensor * wsp_ggml_swiglu_oai(
|
|
2859
|
+
struct wsp_ggml_context * ctx,
|
|
2860
|
+
struct wsp_ggml_tensor * a,
|
|
2861
|
+
struct wsp_ggml_tensor * b,
|
|
2862
|
+
float alpha,
|
|
2863
|
+
float limit) {
|
|
2864
|
+
struct wsp_ggml_tensor * result = wsp_ggml_glu_impl(ctx, a, b, WSP_GGML_GLU_OP_SWIGLU_OAI, false);
|
|
2865
|
+
wsp_ggml_set_op_params_f32(result, 2, alpha);
|
|
2866
|
+
wsp_ggml_set_op_params_f32(result, 3, limit);
|
|
2867
|
+
|
|
2868
|
+
return result;
|
|
2869
|
+
}
|
|
2870
|
+
|
|
2748
2871
|
// wsp_ggml_norm
|
|
2749
2872
|
|
|
2750
2873
|
static struct wsp_ggml_tensor * wsp_ggml_norm_impl(
|
|
@@ -3002,12 +3125,14 @@ static struct wsp_ggml_tensor * wsp_ggml_scale_impl(
|
|
|
3002
3125
|
struct wsp_ggml_context * ctx,
|
|
3003
3126
|
struct wsp_ggml_tensor * a,
|
|
3004
3127
|
float s,
|
|
3128
|
+
float b,
|
|
3005
3129
|
bool inplace) {
|
|
3006
3130
|
WSP_GGML_ASSERT(wsp_ggml_is_padded_1d(a));
|
|
3007
3131
|
|
|
3008
3132
|
struct wsp_ggml_tensor * result = inplace ? wsp_ggml_view_tensor(ctx, a) : wsp_ggml_dup_tensor(ctx, a);
|
|
3009
3133
|
|
|
3010
|
-
|
|
3134
|
+
float params[2] = { s, b };
|
|
3135
|
+
wsp_ggml_set_op_params(result, ¶ms, sizeof(params));
|
|
3011
3136
|
|
|
3012
3137
|
result->op = WSP_GGML_OP_SCALE;
|
|
3013
3138
|
result->src[0] = a;
|
|
@@ -3019,14 +3144,30 @@ struct wsp_ggml_tensor * wsp_ggml_scale(
|
|
|
3019
3144
|
struct wsp_ggml_context * ctx,
|
|
3020
3145
|
struct wsp_ggml_tensor * a,
|
|
3021
3146
|
float s) {
|
|
3022
|
-
return wsp_ggml_scale_impl(ctx, a, s, false);
|
|
3147
|
+
return wsp_ggml_scale_impl(ctx, a, s, 0.0, false);
|
|
3023
3148
|
}
|
|
3024
3149
|
|
|
3025
3150
|
struct wsp_ggml_tensor * wsp_ggml_scale_inplace(
|
|
3026
3151
|
struct wsp_ggml_context * ctx,
|
|
3027
3152
|
struct wsp_ggml_tensor * a,
|
|
3028
3153
|
float s) {
|
|
3029
|
-
return wsp_ggml_scale_impl(ctx, a, s, true);
|
|
3154
|
+
return wsp_ggml_scale_impl(ctx, a, s, 0.0, true);
|
|
3155
|
+
}
|
|
3156
|
+
|
|
3157
|
+
struct wsp_ggml_tensor * wsp_ggml_scale_bias(
|
|
3158
|
+
struct wsp_ggml_context * ctx,
|
|
3159
|
+
struct wsp_ggml_tensor * a,
|
|
3160
|
+
float s,
|
|
3161
|
+
float b) {
|
|
3162
|
+
return wsp_ggml_scale_impl(ctx, a, s, b, false);
|
|
3163
|
+
}
|
|
3164
|
+
|
|
3165
|
+
struct wsp_ggml_tensor * wsp_ggml_scale_bias_inplace(
|
|
3166
|
+
struct wsp_ggml_context * ctx,
|
|
3167
|
+
struct wsp_ggml_tensor * a,
|
|
3168
|
+
float s,
|
|
3169
|
+
float b) {
|
|
3170
|
+
return wsp_ggml_scale_impl(ctx, a, s, b, true);
|
|
3030
3171
|
}
|
|
3031
3172
|
|
|
3032
3173
|
// wsp_ggml_set
|
|
@@ -3490,6 +3631,7 @@ struct wsp_ggml_tensor * wsp_ggml_get_rows(
|
|
|
3490
3631
|
struct wsp_ggml_tensor * a,
|
|
3491
3632
|
struct wsp_ggml_tensor * b) {
|
|
3492
3633
|
WSP_GGML_ASSERT(a->ne[2] == b->ne[1]);
|
|
3634
|
+
WSP_GGML_ASSERT(a->ne[3] == b->ne[2]);
|
|
3493
3635
|
WSP_GGML_ASSERT(b->ne[3] == 1);
|
|
3494
3636
|
WSP_GGML_ASSERT(b->type == WSP_GGML_TYPE_I32);
|
|
3495
3637
|
|
|
@@ -3543,7 +3685,7 @@ struct wsp_ggml_tensor * wsp_ggml_set_rows(
|
|
|
3543
3685
|
WSP_GGML_ASSERT(b->ne[3] % c->ne[2] == 0);
|
|
3544
3686
|
WSP_GGML_ASSERT(c->ne[3] == 1);
|
|
3545
3687
|
WSP_GGML_ASSERT(b->type == WSP_GGML_TYPE_F32);
|
|
3546
|
-
WSP_GGML_ASSERT(c->type == WSP_GGML_TYPE_I64);
|
|
3688
|
+
WSP_GGML_ASSERT(c->type == WSP_GGML_TYPE_I64 || c->type == WSP_GGML_TYPE_I32);
|
|
3547
3689
|
|
|
3548
3690
|
WSP_GGML_ASSERT(wsp_ggml_is_contiguous_rows(a));
|
|
3549
3691
|
WSP_GGML_ASSERT(wsp_ggml_is_contiguous_rows(b));
|
|
@@ -3553,6 +3695,7 @@ struct wsp_ggml_tensor * wsp_ggml_set_rows(
|
|
|
3553
3695
|
result->op = WSP_GGML_OP_SET_ROWS;
|
|
3554
3696
|
result->src[0] = b;
|
|
3555
3697
|
result->src[1] = c;
|
|
3698
|
+
result->src[2] = a; // note: order is weird due to legacy reasons (https://github.com/ggml-org/llama.cpp/pull/16063#discussion_r2385795931)
|
|
3556
3699
|
|
|
3557
3700
|
return result;
|
|
3558
3701
|
}
|
|
@@ -3651,9 +3794,10 @@ static struct wsp_ggml_tensor * wsp_ggml_soft_max_impl(
|
|
|
3651
3794
|
if (mask) {
|
|
3652
3795
|
WSP_GGML_ASSERT(mask->type == WSP_GGML_TYPE_F16 || mask->type == WSP_GGML_TYPE_F32);
|
|
3653
3796
|
WSP_GGML_ASSERT(wsp_ggml_is_contiguous(mask));
|
|
3654
|
-
WSP_GGML_ASSERT(wsp_ggml_is_matrix(mask));
|
|
3655
3797
|
WSP_GGML_ASSERT(mask->ne[0] == a->ne[0]);
|
|
3656
3798
|
WSP_GGML_ASSERT(mask->ne[1] >= a->ne[1]);
|
|
3799
|
+
WSP_GGML_ASSERT(a->ne[2]%mask->ne[2] == 0);
|
|
3800
|
+
WSP_GGML_ASSERT(a->ne[3]%mask->ne[3] == 0);
|
|
3657
3801
|
}
|
|
3658
3802
|
|
|
3659
3803
|
if (max_bias > 0.0f) {
|
|
@@ -3693,6 +3837,22 @@ struct wsp_ggml_tensor * wsp_ggml_soft_max_ext(
|
|
|
3693
3837
|
return wsp_ggml_soft_max_impl(ctx, a, mask, scale, max_bias, false);
|
|
3694
3838
|
}
|
|
3695
3839
|
|
|
3840
|
+
void wsp_ggml_soft_max_add_sinks(
|
|
3841
|
+
struct wsp_ggml_tensor * a,
|
|
3842
|
+
struct wsp_ggml_tensor * sinks) {
|
|
3843
|
+
if (!sinks) {
|
|
3844
|
+
a->src[2] = NULL;
|
|
3845
|
+
return;
|
|
3846
|
+
}
|
|
3847
|
+
|
|
3848
|
+
WSP_GGML_ASSERT(a->op == WSP_GGML_OP_SOFT_MAX);
|
|
3849
|
+
WSP_GGML_ASSERT(a->src[2] == NULL);
|
|
3850
|
+
WSP_GGML_ASSERT(a->src[0]->ne[2] == sinks->ne[0]);
|
|
3851
|
+
WSP_GGML_ASSERT(sinks->type == WSP_GGML_TYPE_F32);
|
|
3852
|
+
|
|
3853
|
+
a->src[2] = sinks;
|
|
3854
|
+
}
|
|
3855
|
+
|
|
3696
3856
|
// wsp_ggml_soft_max_ext_back
|
|
3697
3857
|
|
|
3698
3858
|
static struct wsp_ggml_tensor * wsp_ggml_soft_max_ext_back_impl(
|
|
@@ -3740,6 +3900,7 @@ static struct wsp_ggml_tensor * wsp_ggml_rope_impl(
|
|
|
3740
3900
|
struct wsp_ggml_tensor * b,
|
|
3741
3901
|
struct wsp_ggml_tensor * c,
|
|
3742
3902
|
int n_dims,
|
|
3903
|
+
int sections[WSP_GGML_MROPE_SECTIONS],
|
|
3743
3904
|
int mode,
|
|
3744
3905
|
int n_ctx_orig,
|
|
3745
3906
|
float freq_base,
|
|
@@ -3753,15 +3914,19 @@ static struct wsp_ggml_tensor * wsp_ggml_rope_impl(
|
|
|
3753
3914
|
|
|
3754
3915
|
WSP_GGML_ASSERT(wsp_ggml_is_vector(b));
|
|
3755
3916
|
WSP_GGML_ASSERT(b->type == WSP_GGML_TYPE_I32);
|
|
3756
|
-
|
|
3917
|
+
|
|
3918
|
+
bool mrope_used = mode & WSP_GGML_ROPE_TYPE_MROPE;
|
|
3919
|
+
if (mrope_used) {
|
|
3920
|
+
WSP_GGML_ASSERT(a->ne[2] * 4 == b->ne[0]); // mrope expecting 4 position ids per token
|
|
3921
|
+
} else {
|
|
3922
|
+
WSP_GGML_ASSERT(a->ne[2] == b->ne[0]);
|
|
3923
|
+
}
|
|
3757
3924
|
|
|
3758
3925
|
if (c) {
|
|
3759
3926
|
WSP_GGML_ASSERT(c->type == WSP_GGML_TYPE_F32);
|
|
3760
3927
|
WSP_GGML_ASSERT(c->ne[0] >= n_dims / 2);
|
|
3761
3928
|
}
|
|
3762
3929
|
|
|
3763
|
-
int sections[4] = {0, 0, 0, 0};
|
|
3764
|
-
|
|
3765
3930
|
struct wsp_ggml_tensor * result = inplace ? wsp_ggml_view_tensor(ctx, a) : wsp_ggml_dup_tensor(ctx, a);
|
|
3766
3931
|
|
|
3767
3932
|
int32_t params[15] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig };
|
|
@@ -3771,7 +3936,11 @@ static struct wsp_ggml_tensor * wsp_ggml_rope_impl(
|
|
|
3771
3936
|
memcpy(params + 8, &attn_factor, sizeof(float));
|
|
3772
3937
|
memcpy(params + 9, &beta_fast, sizeof(float));
|
|
3773
3938
|
memcpy(params + 10, &beta_slow, sizeof(float));
|
|
3774
|
-
|
|
3939
|
+
if (mrope_used && sections) {
|
|
3940
|
+
memcpy(params + 11, sections, sizeof(int32_t) * WSP_GGML_MROPE_SECTIONS);
|
|
3941
|
+
} else {
|
|
3942
|
+
memset(params + 11, 0, sizeof(int32_t) * WSP_GGML_MROPE_SECTIONS);
|
|
3943
|
+
}
|
|
3775
3944
|
wsp_ggml_set_op_params(result, params, sizeof(params));
|
|
3776
3945
|
|
|
3777
3946
|
result->op = WSP_GGML_OP_ROPE;
|
|
@@ -3789,7 +3958,7 @@ struct wsp_ggml_tensor * wsp_ggml_rope(
|
|
|
3789
3958
|
int n_dims,
|
|
3790
3959
|
int mode) {
|
|
3791
3960
|
return wsp_ggml_rope_impl(
|
|
3792
|
-
ctx, a, b, NULL, n_dims, mode, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, false
|
|
3961
|
+
ctx, a, b, NULL, n_dims, NULL, mode, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, false
|
|
3793
3962
|
);
|
|
3794
3963
|
}
|
|
3795
3964
|
|
|
@@ -3799,7 +3968,7 @@ struct wsp_ggml_tensor * wsp_ggml_rope_multi(
|
|
|
3799
3968
|
struct wsp_ggml_tensor * b,
|
|
3800
3969
|
struct wsp_ggml_tensor * c,
|
|
3801
3970
|
int n_dims,
|
|
3802
|
-
int sections[
|
|
3971
|
+
int sections[WSP_GGML_MROPE_SECTIONS],
|
|
3803
3972
|
int mode,
|
|
3804
3973
|
int n_ctx_orig,
|
|
3805
3974
|
float freq_base,
|
|
@@ -3808,36 +3977,31 @@ struct wsp_ggml_tensor * wsp_ggml_rope_multi(
|
|
|
3808
3977
|
float attn_factor,
|
|
3809
3978
|
float beta_fast,
|
|
3810
3979
|
float beta_slow) {
|
|
3811
|
-
|
|
3812
|
-
|
|
3813
|
-
|
|
3814
|
-
|
|
3815
|
-
|
|
3816
|
-
WSP_GGML_ASSERT(a->ne[2] * 4 == b->ne[0]); // mrope expecting 4 position ids per token
|
|
3817
|
-
|
|
3818
|
-
if (c) {
|
|
3819
|
-
WSP_GGML_ASSERT(c->type == WSP_GGML_TYPE_F32);
|
|
3820
|
-
WSP_GGML_ASSERT(c->ne[0] >= n_dims / 2);
|
|
3821
|
-
}
|
|
3822
|
-
|
|
3823
|
-
struct wsp_ggml_tensor * result = wsp_ggml_dup_tensor(ctx, a);
|
|
3824
|
-
|
|
3825
|
-
int32_t params[11 + 4] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig };
|
|
3826
|
-
memcpy(params + 5, &freq_base, sizeof(float));
|
|
3827
|
-
memcpy(params + 6, &freq_scale, sizeof(float));
|
|
3828
|
-
memcpy(params + 7, &ext_factor, sizeof(float));
|
|
3829
|
-
memcpy(params + 8, &attn_factor, sizeof(float));
|
|
3830
|
-
memcpy(params + 9, &beta_fast, sizeof(float));
|
|
3831
|
-
memcpy(params + 10, &beta_slow, sizeof(float));
|
|
3832
|
-
memcpy(¶ms[11], sections, sizeof(int)*4);
|
|
3833
|
-
wsp_ggml_set_op_params(result, params, sizeof(params));
|
|
3834
|
-
|
|
3835
|
-
result->op = WSP_GGML_OP_ROPE;
|
|
3836
|
-
result->src[0] = a;
|
|
3837
|
-
result->src[1] = b;
|
|
3838
|
-
result->src[2] = c;
|
|
3980
|
+
return wsp_ggml_rope_impl(
|
|
3981
|
+
ctx, a, b, c, n_dims, sections, mode, n_ctx_orig, freq_base, freq_scale,
|
|
3982
|
+
ext_factor, attn_factor, beta_fast, beta_slow, false
|
|
3983
|
+
);
|
|
3984
|
+
}
|
|
3839
3985
|
|
|
3840
|
-
|
|
3986
|
+
struct wsp_ggml_tensor * wsp_ggml_rope_multi_inplace(
|
|
3987
|
+
struct wsp_ggml_context * ctx,
|
|
3988
|
+
struct wsp_ggml_tensor * a,
|
|
3989
|
+
struct wsp_ggml_tensor * b,
|
|
3990
|
+
struct wsp_ggml_tensor * c,
|
|
3991
|
+
int n_dims,
|
|
3992
|
+
int sections[WSP_GGML_MROPE_SECTIONS],
|
|
3993
|
+
int mode,
|
|
3994
|
+
int n_ctx_orig,
|
|
3995
|
+
float freq_base,
|
|
3996
|
+
float freq_scale,
|
|
3997
|
+
float ext_factor,
|
|
3998
|
+
float attn_factor,
|
|
3999
|
+
float beta_fast,
|
|
4000
|
+
float beta_slow) {
|
|
4001
|
+
return wsp_ggml_rope_impl(
|
|
4002
|
+
ctx, a, b, c, n_dims, sections, mode, n_ctx_orig, freq_base, freq_scale,
|
|
4003
|
+
ext_factor, attn_factor, beta_fast, beta_slow, true
|
|
4004
|
+
);
|
|
3841
4005
|
}
|
|
3842
4006
|
|
|
3843
4007
|
struct wsp_ggml_tensor * wsp_ggml_rope_inplace(
|
|
@@ -3847,7 +4011,7 @@ struct wsp_ggml_tensor * wsp_ggml_rope_inplace(
|
|
|
3847
4011
|
int n_dims,
|
|
3848
4012
|
int mode) {
|
|
3849
4013
|
return wsp_ggml_rope_impl(
|
|
3850
|
-
ctx, a, b, NULL, n_dims, mode, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, true
|
|
4014
|
+
ctx, a, b, NULL, n_dims, NULL, mode, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, true
|
|
3851
4015
|
);
|
|
3852
4016
|
}
|
|
3853
4017
|
|
|
@@ -3866,7 +4030,7 @@ struct wsp_ggml_tensor * wsp_ggml_rope_ext(
|
|
|
3866
4030
|
float beta_fast,
|
|
3867
4031
|
float beta_slow) {
|
|
3868
4032
|
return wsp_ggml_rope_impl(
|
|
3869
|
-
ctx, a, b, c, n_dims, mode, n_ctx_orig, freq_base, freq_scale,
|
|
4033
|
+
ctx, a, b, c, n_dims, NULL, mode, n_ctx_orig, freq_base, freq_scale,
|
|
3870
4034
|
ext_factor, attn_factor, beta_fast, beta_slow, false
|
|
3871
4035
|
);
|
|
3872
4036
|
}
|
|
@@ -3886,7 +4050,7 @@ struct wsp_ggml_tensor * wsp_ggml_rope_ext_inplace(
|
|
|
3886
4050
|
float beta_fast,
|
|
3887
4051
|
float beta_slow) {
|
|
3888
4052
|
return wsp_ggml_rope_impl(
|
|
3889
|
-
ctx, a, b, c, n_dims, mode, n_ctx_orig, freq_base, freq_scale,
|
|
4053
|
+
ctx, a, b, c, n_dims, NULL, mode, n_ctx_orig, freq_base, freq_scale,
|
|
3890
4054
|
ext_factor, attn_factor, beta_fast, beta_slow, true
|
|
3891
4055
|
);
|
|
3892
4056
|
}
|
|
@@ -3905,7 +4069,7 @@ struct wsp_ggml_tensor * wsp_ggml_rope_custom(
|
|
|
3905
4069
|
float beta_fast,
|
|
3906
4070
|
float beta_slow) {
|
|
3907
4071
|
return wsp_ggml_rope_impl(
|
|
3908
|
-
ctx, a, b, NULL, n_dims, mode, n_ctx_orig, freq_base, freq_scale,
|
|
4072
|
+
ctx, a, b, NULL, n_dims, NULL, mode, n_ctx_orig, freq_base, freq_scale,
|
|
3909
4073
|
ext_factor, attn_factor, beta_fast, beta_slow, false
|
|
3910
4074
|
);
|
|
3911
4075
|
}
|
|
@@ -3924,7 +4088,7 @@ struct wsp_ggml_tensor * wsp_ggml_rope_custom_inplace(
|
|
|
3924
4088
|
float beta_fast,
|
|
3925
4089
|
float beta_slow) {
|
|
3926
4090
|
return wsp_ggml_rope_impl(
|
|
3927
|
-
ctx, a, b, NULL, n_dims, mode, n_ctx_orig, freq_base, freq_scale,
|
|
4091
|
+
ctx, a, b, NULL, n_dims, NULL, mode, n_ctx_orig, freq_base, freq_scale,
|
|
3928
4092
|
ext_factor, attn_factor, beta_fast, beta_slow, true
|
|
3929
4093
|
);
|
|
3930
4094
|
}
|
|
@@ -4122,14 +4286,13 @@ struct wsp_ggml_tensor * wsp_ggml_conv_1d_dw(
|
|
|
4122
4286
|
int s0,
|
|
4123
4287
|
int p0,
|
|
4124
4288
|
int d0) {
|
|
4125
|
-
struct wsp_ggml_tensor * new_a = wsp_ggml_reshape_4d(ctx, a, a->ne[0], 1, a->ne[1], a->ne[2]);
|
|
4126
4289
|
struct wsp_ggml_tensor * new_b = wsp_ggml_reshape_4d(ctx, b, b->ne[0], 1, b->ne[1], b->ne[2]);
|
|
4127
4290
|
|
|
4128
|
-
struct wsp_ggml_tensor * im2col = wsp_ggml_im2col(ctx,
|
|
4291
|
+
struct wsp_ggml_tensor * im2col = wsp_ggml_im2col(ctx, a, new_b, s0, 0, p0, 0, d0, 0, false, WSP_GGML_TYPE_F16);
|
|
4129
4292
|
|
|
4130
4293
|
struct wsp_ggml_tensor * result = wsp_ggml_mul_mat(ctx, im2col, a);
|
|
4131
4294
|
|
|
4132
|
-
result = wsp_ggml_reshape_3d(ctx, result,
|
|
4295
|
+
result = wsp_ggml_reshape_3d(ctx, result, result->ne[0], result->ne[2], 1);
|
|
4133
4296
|
|
|
4134
4297
|
return result;
|
|
4135
4298
|
}
|
|
@@ -4210,6 +4373,91 @@ struct wsp_ggml_tensor * wsp_ggml_conv_2d(
|
|
|
4210
4373
|
return result;
|
|
4211
4374
|
}
|
|
4212
4375
|
|
|
4376
|
+
// a: [OC*IC, KD, KH, KW]
|
|
4377
|
+
// b: [N*IC, ID, IH, IW]
|
|
4378
|
+
// result: [N*OD, OH, OW, IC * KD * KH * KW]
|
|
4379
|
+
struct wsp_ggml_tensor * wsp_ggml_im2col_3d(
|
|
4380
|
+
struct wsp_ggml_context * ctx,
|
|
4381
|
+
struct wsp_ggml_tensor * a,
|
|
4382
|
+
struct wsp_ggml_tensor * b,
|
|
4383
|
+
int64_t IC,
|
|
4384
|
+
int s0, // stride width
|
|
4385
|
+
int s1, // stride height
|
|
4386
|
+
int s2, // stride depth
|
|
4387
|
+
int p0, // padding width
|
|
4388
|
+
int p1, // padding height
|
|
4389
|
+
int p2, // padding depth
|
|
4390
|
+
int d0, // dilation width
|
|
4391
|
+
int d1, // dilation height
|
|
4392
|
+
int d2, // dilation depth
|
|
4393
|
+
enum wsp_ggml_type dst_type) {
|
|
4394
|
+
const int64_t N = b->ne[3] / IC;
|
|
4395
|
+
const int64_t ID = b->ne[2];
|
|
4396
|
+
const int64_t IH = b->ne[1];
|
|
4397
|
+
const int64_t IW = b->ne[0];
|
|
4398
|
+
|
|
4399
|
+
const int64_t OC = a->ne[3] / IC;
|
|
4400
|
+
UNUSED(OC);
|
|
4401
|
+
const int64_t KD = a->ne[2];
|
|
4402
|
+
const int64_t KH = a->ne[1];
|
|
4403
|
+
const int64_t KW = a->ne[0];
|
|
4404
|
+
const int64_t OD = wsp_ggml_calc_conv_output_size(ID, KD, s2, p2, d2);
|
|
4405
|
+
const int64_t OH = wsp_ggml_calc_conv_output_size(IH, KH, s1, p1, d1);
|
|
4406
|
+
const int64_t OW = wsp_ggml_calc_conv_output_size(IW, KW, s0, p0, d0);
|
|
4407
|
+
|
|
4408
|
+
WSP_GGML_ASSERT((OD > 0) && "b too small compared to a");
|
|
4409
|
+
WSP_GGML_ASSERT((OH > 0) && "b too small compared to a");
|
|
4410
|
+
WSP_GGML_ASSERT((OW > 0) && "b too small compared to a");
|
|
4411
|
+
|
|
4412
|
+
|
|
4413
|
+
const int64_t ne[4] = {KW*KH*KD*IC, OW, OH, OD*N};
|
|
4414
|
+
|
|
4415
|
+
struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, dst_type, 4, ne);
|
|
4416
|
+
int32_t params[] = { s0, s1, s2, p0, p1, p2, d0, d1, d2, (int32_t)IC};
|
|
4417
|
+
wsp_ggml_set_op_params(result, params, sizeof(params));
|
|
4418
|
+
|
|
4419
|
+
result->op = WSP_GGML_OP_IM2COL_3D;
|
|
4420
|
+
result->src[0] = a;
|
|
4421
|
+
result->src[1] = b;
|
|
4422
|
+
|
|
4423
|
+
return result;
|
|
4424
|
+
}
|
|
4425
|
+
|
|
4426
|
+
// a: [OC*IC, KD, KH, KW]
|
|
4427
|
+
// b: [N*IC, ID, IH, IW]
|
|
4428
|
+
// result: [N*OC, OD, OH, OW]
|
|
4429
|
+
struct wsp_ggml_tensor * wsp_ggml_conv_3d(
|
|
4430
|
+
struct wsp_ggml_context * ctx,
|
|
4431
|
+
struct wsp_ggml_tensor * a,
|
|
4432
|
+
struct wsp_ggml_tensor * b,
|
|
4433
|
+
int64_t IC,
|
|
4434
|
+
int s0, // stride width
|
|
4435
|
+
int s1, // stride height
|
|
4436
|
+
int s2, // stride depth
|
|
4437
|
+
int p0, // padding width
|
|
4438
|
+
int p1, // padding height
|
|
4439
|
+
int p2, // padding depth
|
|
4440
|
+
int d0, // dilation width
|
|
4441
|
+
int d1, // dilation height
|
|
4442
|
+
int d2 // dilation depth
|
|
4443
|
+
) {
|
|
4444
|
+
struct wsp_ggml_tensor * im2col = wsp_ggml_im2col_3d(ctx, a, b, IC, s0, s1, s2, p0, p1, p2, d0, d1, d2, a->type); // [N*OD, OH, OW, IC * KD * KH * KW]
|
|
4445
|
+
|
|
4446
|
+
int64_t OC = a->ne[3] / IC;
|
|
4447
|
+
int64_t N = b->ne[3] / IC;
|
|
4448
|
+
struct wsp_ggml_tensor * result =
|
|
4449
|
+
wsp_ggml_mul_mat(ctx,
|
|
4450
|
+
wsp_ggml_reshape_2d(ctx, im2col, im2col->ne[0], im2col->ne[3] * im2col->ne[2] * im2col->ne[1]), // [N*OD, OH, OW, IC * KD * KH * KW] => [N*OD*OH*OW, IC * KD * KH * KW]
|
|
4451
|
+
wsp_ggml_reshape_2d(ctx, a, (a->ne[0] * a->ne[1] * a->ne[2] * IC), OC)); // [OC*IC, KD, KH, KW] => [OC, IC * KD * KH * KW]
|
|
4452
|
+
|
|
4453
|
+
int64_t OD = im2col->ne[3] / N;
|
|
4454
|
+
result = wsp_ggml_reshape_4d(ctx, result, im2col->ne[1]*im2col->ne[2], OD, N, OC); // [OC, N*OD*OH*OW] => [OC, N, OD, OH*OW]
|
|
4455
|
+
result = wsp_ggml_cont(ctx, wsp_ggml_permute(ctx, result, 0, 1, 3, 2)); // [N, OC, OD, OH*OW]
|
|
4456
|
+
result = wsp_ggml_reshape_4d(ctx, result, im2col->ne[1], im2col->ne[2], OD, OC * N); // [N*OC, OD, OH, OW]
|
|
4457
|
+
|
|
4458
|
+
return result;
|
|
4459
|
+
}
|
|
4460
|
+
|
|
4213
4461
|
// wsp_ggml_conv_2d_sk_p0
|
|
4214
4462
|
|
|
4215
4463
|
struct wsp_ggml_tensor * wsp_ggml_conv_2d_sk_p0(
|
|
@@ -4331,6 +4579,56 @@ struct wsp_ggml_tensor * wsp_ggml_conv_2d_direct(
|
|
|
4331
4579
|
return result;
|
|
4332
4580
|
}
|
|
4333
4581
|
|
|
4582
|
+
// wsp_ggml_conv_3d_direct
|
|
4583
|
+
|
|
4584
|
+
struct wsp_ggml_tensor * wsp_ggml_conv_3d_direct(
|
|
4585
|
+
struct wsp_ggml_context * ctx,
|
|
4586
|
+
struct wsp_ggml_tensor * a,
|
|
4587
|
+
struct wsp_ggml_tensor * b,
|
|
4588
|
+
int s0,
|
|
4589
|
+
int s1,
|
|
4590
|
+
int s2,
|
|
4591
|
+
int p0,
|
|
4592
|
+
int p1,
|
|
4593
|
+
int p2,
|
|
4594
|
+
int d0,
|
|
4595
|
+
int d1,
|
|
4596
|
+
int d2,
|
|
4597
|
+
int c,
|
|
4598
|
+
int n,
|
|
4599
|
+
int oc) {
|
|
4600
|
+
|
|
4601
|
+
WSP_GGML_ASSERT(a->ne[3] == (int64_t) c * oc);
|
|
4602
|
+
WSP_GGML_ASSERT(b->ne[3] == (int64_t) c * n);
|
|
4603
|
+
|
|
4604
|
+
int64_t ne[4];
|
|
4605
|
+
ne[0] = wsp_ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0);
|
|
4606
|
+
ne[1] = wsp_ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1);
|
|
4607
|
+
ne[2] = wsp_ggml_calc_conv_output_size(b->ne[2], a->ne[2], s2, p2, d2);
|
|
4608
|
+
ne[3] = (int64_t) oc * n;
|
|
4609
|
+
|
|
4610
|
+
struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, WSP_GGML_TYPE_F32, 4, ne);
|
|
4611
|
+
|
|
4612
|
+
wsp_ggml_set_op_params_i32(result, 0, s0);
|
|
4613
|
+
wsp_ggml_set_op_params_i32(result, 1, s1);
|
|
4614
|
+
wsp_ggml_set_op_params_i32(result, 2, s2);
|
|
4615
|
+
wsp_ggml_set_op_params_i32(result, 3, p0);
|
|
4616
|
+
wsp_ggml_set_op_params_i32(result, 4, p1);
|
|
4617
|
+
wsp_ggml_set_op_params_i32(result, 5, p2);
|
|
4618
|
+
wsp_ggml_set_op_params_i32(result, 6, d0);
|
|
4619
|
+
wsp_ggml_set_op_params_i32(result, 7, d1);
|
|
4620
|
+
wsp_ggml_set_op_params_i32(result, 8, d2);
|
|
4621
|
+
wsp_ggml_set_op_params_i32(result, 9, c);
|
|
4622
|
+
wsp_ggml_set_op_params_i32(result, 10, n);
|
|
4623
|
+
wsp_ggml_set_op_params_i32(result, 11, oc);
|
|
4624
|
+
|
|
4625
|
+
result->op = WSP_GGML_OP_CONV_3D;
|
|
4626
|
+
result->src[0] = a;
|
|
4627
|
+
result->src[1] = b;
|
|
4628
|
+
|
|
4629
|
+
return result;
|
|
4630
|
+
}
|
|
4631
|
+
|
|
4334
4632
|
// wsp_ggml_conv_transpose_2d_p0
|
|
4335
4633
|
|
|
4336
4634
|
static int64_t wsp_ggml_calc_conv_transpose_output_size(int64_t ins, int64_t ks, int s, int p) {
|
|
@@ -4509,11 +4807,36 @@ struct wsp_ggml_tensor * wsp_ggml_pad(
|
|
|
4509
4807
|
int p1,
|
|
4510
4808
|
int p2,
|
|
4511
4809
|
int p3) {
|
|
4810
|
+
return wsp_ggml_pad_ext(ctx, a, 0, p0, 0, p1, 0, p2, 0, p3);
|
|
4811
|
+
}
|
|
4812
|
+
|
|
4813
|
+
struct wsp_ggml_tensor * wsp_ggml_pad_ext(
|
|
4814
|
+
struct wsp_ggml_context * ctx,
|
|
4815
|
+
struct wsp_ggml_tensor * a,
|
|
4816
|
+
int lp0,
|
|
4817
|
+
int rp0,
|
|
4818
|
+
int lp1,
|
|
4819
|
+
int rp1,
|
|
4820
|
+
int lp2,
|
|
4821
|
+
int rp2,
|
|
4822
|
+
int lp3,
|
|
4823
|
+
int rp3
|
|
4824
|
+
) {
|
|
4512
4825
|
struct wsp_ggml_tensor * result = wsp_ggml_new_tensor_4d(ctx, a->type,
|
|
4513
|
-
a->ne[0] +
|
|
4514
|
-
a->ne[1] +
|
|
4515
|
-
a->ne[2] +
|
|
4516
|
-
a->ne[3] +
|
|
4826
|
+
a->ne[0] + lp0 + rp0,
|
|
4827
|
+
a->ne[1] + lp1 + rp1,
|
|
4828
|
+
a->ne[2] + lp2 + rp2,
|
|
4829
|
+
a->ne[3] + lp3 + rp3);
|
|
4830
|
+
|
|
4831
|
+
wsp_ggml_set_op_params_i32(result, 0, lp0);
|
|
4832
|
+
wsp_ggml_set_op_params_i32(result, 1, rp0);
|
|
4833
|
+
wsp_ggml_set_op_params_i32(result, 2, lp1);
|
|
4834
|
+
wsp_ggml_set_op_params_i32(result, 3, rp1);
|
|
4835
|
+
wsp_ggml_set_op_params_i32(result, 4, lp2);
|
|
4836
|
+
wsp_ggml_set_op_params_i32(result, 5, rp2);
|
|
4837
|
+
wsp_ggml_set_op_params_i32(result, 6, lp3);
|
|
4838
|
+
wsp_ggml_set_op_params_i32(result, 7, rp3);
|
|
4839
|
+
|
|
4517
4840
|
|
|
4518
4841
|
result->op = WSP_GGML_OP_PAD;
|
|
4519
4842
|
result->src[0] = a;
|
|
@@ -4609,12 +4932,8 @@ struct wsp_ggml_tensor * wsp_ggml_timestep_embedding(
|
|
|
4609
4932
|
struct wsp_ggml_tensor * timesteps,
|
|
4610
4933
|
int dim,
|
|
4611
4934
|
int max_period) {
|
|
4612
|
-
int actual_dim = dim;
|
|
4613
|
-
if (dim % 2 != 0) {
|
|
4614
|
-
actual_dim = dim + 1;
|
|
4615
|
-
}
|
|
4616
4935
|
|
|
4617
|
-
struct wsp_ggml_tensor * result = wsp_ggml_new_tensor_2d(ctx, WSP_GGML_TYPE_F32,
|
|
4936
|
+
struct wsp_ggml_tensor * result = wsp_ggml_new_tensor_2d(ctx, WSP_GGML_TYPE_F32, dim, timesteps->ne[0]);
|
|
4618
4937
|
|
|
4619
4938
|
wsp_ggml_set_op_params_i32(result, 0, dim);
|
|
4620
4939
|
wsp_ggml_set_op_params_i32(result, 1, max_period);
|
|
@@ -4674,13 +4993,17 @@ struct wsp_ggml_tensor * wsp_ggml_flash_attn_ext(
|
|
|
4674
4993
|
WSP_GGML_ASSERT(wsp_ggml_can_mul_mat(k, q));
|
|
4675
4994
|
// TODO: check if vT can be multiplied by (k*qT)
|
|
4676
4995
|
|
|
4996
|
+
WSP_GGML_ASSERT(q->ne[3] == k->ne[3]);
|
|
4997
|
+
WSP_GGML_ASSERT(q->ne[3] == v->ne[3]);
|
|
4998
|
+
|
|
4677
4999
|
if (mask) {
|
|
4678
5000
|
WSP_GGML_ASSERT(wsp_ggml_is_contiguous(mask));
|
|
4679
|
-
WSP_GGML_ASSERT(mask->ne[2] == 1);
|
|
4680
|
-
WSP_GGML_ASSERT(mask->ne[3] == 1);
|
|
4681
5001
|
WSP_GGML_ASSERT(mask->ne[1] >= WSP_GGML_PAD(q->ne[1], WSP_GGML_KQ_MASK_PAD) &&
|
|
4682
5002
|
"the Flash-Attention kernel requires the mask to be padded to WSP_GGML_KQ_MASK_PAD and at least n_queries big");
|
|
4683
5003
|
//WSP_GGML_ASSERT(wsp_ggml_can_repeat_rows(mask, qk));
|
|
5004
|
+
|
|
5005
|
+
WSP_GGML_ASSERT(q->ne[2] % mask->ne[2] == 0);
|
|
5006
|
+
WSP_GGML_ASSERT(q->ne[3] % mask->ne[3] == 0);
|
|
4684
5007
|
}
|
|
4685
5008
|
|
|
4686
5009
|
if (max_bias > 0.0f) {
|
|
@@ -4722,6 +5045,22 @@ enum wsp_ggml_prec wsp_ggml_flash_attn_ext_get_prec(
|
|
|
4722
5045
|
return (enum wsp_ggml_prec) prec_i32;
|
|
4723
5046
|
}
|
|
4724
5047
|
|
|
5048
|
+
void wsp_ggml_flash_attn_ext_add_sinks(
|
|
5049
|
+
struct wsp_ggml_tensor * a,
|
|
5050
|
+
struct wsp_ggml_tensor * sinks) {
|
|
5051
|
+
if (!sinks) {
|
|
5052
|
+
a->src[4] = NULL;
|
|
5053
|
+
return;
|
|
5054
|
+
}
|
|
5055
|
+
|
|
5056
|
+
WSP_GGML_ASSERT(a->op == WSP_GGML_OP_FLASH_ATTN_EXT);
|
|
5057
|
+
WSP_GGML_ASSERT(a->src[4] == NULL);
|
|
5058
|
+
WSP_GGML_ASSERT(a->src[0]->ne[2] == sinks->ne[0]);
|
|
5059
|
+
WSP_GGML_ASSERT(sinks->type == WSP_GGML_TYPE_F32);
|
|
5060
|
+
|
|
5061
|
+
a->src[4] = sinks;
|
|
5062
|
+
}
|
|
5063
|
+
|
|
4725
5064
|
// wsp_ggml_flash_attn_back
|
|
4726
5065
|
|
|
4727
5066
|
struct wsp_ggml_tensor * wsp_ggml_flash_attn_back(
|
|
@@ -4808,7 +5147,6 @@ struct wsp_ggml_tensor * wsp_ggml_ssm_conv(
|
|
|
4808
5147
|
const int64_t n_s = sx->ne[2];
|
|
4809
5148
|
|
|
4810
5149
|
// TODO: maybe support other strides than 1?
|
|
4811
|
-
// FIXME: this is always true?
|
|
4812
5150
|
WSP_GGML_ASSERT(sx->ne[0] == d_conv - 1 + n_t);
|
|
4813
5151
|
WSP_GGML_ASSERT(sx->ne[1] == d_inner);
|
|
4814
5152
|
WSP_GGML_ASSERT(n_t >= 0);
|
|
@@ -4831,36 +5169,49 @@ struct wsp_ggml_tensor * wsp_ggml_ssm_scan(
|
|
|
4831
5169
|
struct wsp_ggml_tensor * dt,
|
|
4832
5170
|
struct wsp_ggml_tensor * A,
|
|
4833
5171
|
struct wsp_ggml_tensor * B,
|
|
4834
|
-
struct wsp_ggml_tensor * C
|
|
5172
|
+
struct wsp_ggml_tensor * C,
|
|
5173
|
+
struct wsp_ggml_tensor * ids) {
|
|
4835
5174
|
WSP_GGML_ASSERT(wsp_ggml_is_contiguous(s));
|
|
4836
|
-
WSP_GGML_ASSERT(wsp_ggml_is_contiguous(x));
|
|
4837
5175
|
WSP_GGML_ASSERT(wsp_ggml_is_contiguous(dt));
|
|
4838
5176
|
WSP_GGML_ASSERT(wsp_ggml_is_contiguous(A));
|
|
4839
|
-
WSP_GGML_ASSERT(
|
|
4840
|
-
WSP_GGML_ASSERT(wsp_ggml_is_3d(B));
|
|
4841
|
-
WSP_GGML_ASSERT(wsp_ggml_is_3d(s));
|
|
5177
|
+
WSP_GGML_ASSERT(x->nb[0] == wsp_ggml_type_size(x->type));
|
|
4842
5178
|
WSP_GGML_ASSERT(B->nb[0] == wsp_ggml_type_size(B->type));
|
|
4843
5179
|
WSP_GGML_ASSERT(C->nb[0] == wsp_ggml_type_size(C->type));
|
|
4844
|
-
WSP_GGML_ASSERT(
|
|
5180
|
+
WSP_GGML_ASSERT(x->nb[1] == x->ne[0]*x->nb[0]);
|
|
5181
|
+
WSP_GGML_ASSERT(B->nb[1] == B->ne[0]*B->nb[0]);
|
|
5182
|
+
WSP_GGML_ASSERT(C->nb[1] == C->ne[0]*C->nb[0]);
|
|
4845
5183
|
WSP_GGML_ASSERT(wsp_ggml_are_same_shape(B, C));
|
|
5184
|
+
WSP_GGML_ASSERT(ids->type == WSP_GGML_TYPE_I32);
|
|
4846
5185
|
|
|
4847
5186
|
{
|
|
4848
5187
|
const int64_t d_state = s->ne[0];
|
|
4849
|
-
const int64_t
|
|
4850
|
-
const int64_t
|
|
4851
|
-
const int64_t
|
|
4852
|
-
|
|
4853
|
-
|
|
4854
|
-
WSP_GGML_ASSERT(
|
|
4855
|
-
WSP_GGML_ASSERT(
|
|
4856
|
-
WSP_GGML_ASSERT(
|
|
5188
|
+
const int64_t head_dim = x->ne[0];
|
|
5189
|
+
const int64_t n_head = x->ne[1];
|
|
5190
|
+
const int64_t n_seq_tokens = x->ne[2];
|
|
5191
|
+
const int64_t n_seqs = x->ne[3];
|
|
5192
|
+
|
|
5193
|
+
WSP_GGML_ASSERT(dt->ne[0] == n_head);
|
|
5194
|
+
WSP_GGML_ASSERT(dt->ne[1] == n_seq_tokens);
|
|
5195
|
+
WSP_GGML_ASSERT(dt->ne[2] == n_seqs);
|
|
5196
|
+
WSP_GGML_ASSERT(wsp_ggml_is_3d(dt));
|
|
5197
|
+
WSP_GGML_ASSERT(s->ne[1] == head_dim);
|
|
5198
|
+
WSP_GGML_ASSERT(s->ne[2] == n_head);
|
|
4857
5199
|
WSP_GGML_ASSERT(B->ne[0] == d_state);
|
|
4858
|
-
WSP_GGML_ASSERT(B->ne[
|
|
4859
|
-
WSP_GGML_ASSERT(B->ne[
|
|
5200
|
+
WSP_GGML_ASSERT(B->ne[2] == n_seq_tokens);
|
|
5201
|
+
WSP_GGML_ASSERT(B->ne[3] == n_seqs);
|
|
5202
|
+
WSP_GGML_ASSERT(ids->ne[0] == n_seqs);
|
|
5203
|
+
WSP_GGML_ASSERT(wsp_ggml_is_vector(ids));
|
|
5204
|
+
WSP_GGML_ASSERT(A->ne[1] == n_head);
|
|
5205
|
+
WSP_GGML_ASSERT(wsp_ggml_is_matrix(A));
|
|
5206
|
+
|
|
5207
|
+
if (A->ne[0] != 1) {
|
|
5208
|
+
// Mamba-1 has more granular decay factors
|
|
5209
|
+
WSP_GGML_ASSERT(A->ne[0] == d_state);
|
|
5210
|
+
}
|
|
4860
5211
|
}
|
|
4861
5212
|
|
|
4862
5213
|
// concatenated y + ssm_states
|
|
4863
|
-
struct wsp_ggml_tensor * result = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, wsp_ggml_nelements(x) +
|
|
5214
|
+
struct wsp_ggml_tensor * result = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, wsp_ggml_nelements(x) + s->ne[0]*s->ne[1]*s->ne[2]*ids->ne[0]);
|
|
4864
5215
|
|
|
4865
5216
|
result->op = WSP_GGML_OP_SSM_SCAN;
|
|
4866
5217
|
result->src[0] = s;
|
|
@@ -4869,6 +5220,7 @@ struct wsp_ggml_tensor * wsp_ggml_ssm_scan(
|
|
|
4869
5220
|
result->src[3] = A;
|
|
4870
5221
|
result->src[4] = B;
|
|
4871
5222
|
result->src[5] = C;
|
|
5223
|
+
result->src[6] = ids;
|
|
4872
5224
|
|
|
4873
5225
|
return result;
|
|
4874
5226
|
}
|
|
@@ -5424,6 +5776,28 @@ struct wsp_ggml_tensor * wsp_ggml_opt_step_adamw(
|
|
|
5424
5776
|
return result;
|
|
5425
5777
|
}
|
|
5426
5778
|
|
|
5779
|
+
// opt_step_sgd
|
|
5780
|
+
|
|
5781
|
+
struct wsp_ggml_tensor * wsp_ggml_opt_step_sgd(
|
|
5782
|
+
struct wsp_ggml_context * ctx,
|
|
5783
|
+
struct wsp_ggml_tensor * a,
|
|
5784
|
+
struct wsp_ggml_tensor * grad,
|
|
5785
|
+
struct wsp_ggml_tensor * params) {
|
|
5786
|
+
WSP_GGML_ASSERT(a->flags & WSP_GGML_TENSOR_FLAG_PARAM);
|
|
5787
|
+
WSP_GGML_ASSERT(wsp_ggml_are_same_shape(a, grad));
|
|
5788
|
+
WSP_GGML_ASSERT(params->type == WSP_GGML_TYPE_F32);
|
|
5789
|
+
WSP_GGML_ASSERT(wsp_ggml_nelements(params) == 2);
|
|
5790
|
+
|
|
5791
|
+
struct wsp_ggml_tensor * result = wsp_ggml_view_tensor(ctx, a);
|
|
5792
|
+
|
|
5793
|
+
result->op = WSP_GGML_OP_OPT_STEP_SGD;
|
|
5794
|
+
result->src[0] = a;
|
|
5795
|
+
result->src[1] = grad;
|
|
5796
|
+
result->src[2] = params;
|
|
5797
|
+
|
|
5798
|
+
return result;
|
|
5799
|
+
}
|
|
5800
|
+
|
|
5427
5801
|
////////////////////////////////////////////////////////////////////////////////
|
|
5428
5802
|
|
|
5429
5803
|
struct wsp_ggml_hash_set wsp_ggml_hash_set_new(size_t size) {
|
|
@@ -5692,7 +6066,7 @@ static void wsp_ggml_compute_backward(
|
|
|
5692
6066
|
} break;
|
|
5693
6067
|
case WSP_GGML_OP_MEAN: {
|
|
5694
6068
|
if (src0_needs_grads) {
|
|
5695
|
-
wsp_ggml_add1_or_set(ctx, cgraph, isrc0, wsp_ggml_scale_impl(ctx, grad, 1.0f/src0->ne[0], false));
|
|
6069
|
+
wsp_ggml_add1_or_set(ctx, cgraph, isrc0, wsp_ggml_scale_impl(ctx, grad, 1.0f/src0->ne[0], 0.0, false));
|
|
5696
6070
|
}
|
|
5697
6071
|
} break;
|
|
5698
6072
|
case WSP_GGML_OP_REPEAT: {
|
|
@@ -5769,7 +6143,7 @@ static void wsp_ggml_compute_backward(
|
|
|
5769
6143
|
if (src0_needs_grads) {
|
|
5770
6144
|
float s;
|
|
5771
6145
|
memcpy(&s, tensor->op_params, sizeof(float));
|
|
5772
|
-
wsp_ggml_add_or_set(ctx, cgraph, isrc0, wsp_ggml_scale_impl(ctx, grad, s, false));
|
|
6146
|
+
wsp_ggml_add_or_set(ctx, cgraph, isrc0, wsp_ggml_scale_impl(ctx, grad, s, 0.0, false));
|
|
5773
6147
|
}
|
|
5774
6148
|
} break;
|
|
5775
6149
|
case WSP_GGML_OP_SET: {
|
|
@@ -6009,13 +6383,28 @@ static void wsp_ggml_compute_backward(
|
|
|
6009
6383
|
}
|
|
6010
6384
|
WSP_GGML_ASSERT(!src1_needs_grads && "backward pass for labels not implemented");
|
|
6011
6385
|
} break;
|
|
6386
|
+
case WSP_GGML_OP_GLU: {
|
|
6387
|
+
switch (wsp_ggml_get_glu_op(tensor)) {
|
|
6388
|
+
case WSP_GGML_GLU_OP_SWIGLU: {
|
|
6389
|
+
if (src0_needs_grads) {
|
|
6390
|
+
WSP_GGML_ASSERT(src1 && "backward pass only implemented for split swiglu");
|
|
6391
|
+
wsp_ggml_add_or_set(ctx, cgraph, isrc0, wsp_ggml_silu_back(ctx, wsp_ggml_mul(ctx, grad, src1), src0));
|
|
6392
|
+
}
|
|
6393
|
+
if (src1_needs_grads) {
|
|
6394
|
+
wsp_ggml_add_or_set(ctx, cgraph, isrc1, wsp_ggml_mul(ctx, wsp_ggml_silu(ctx, src0), grad));
|
|
6395
|
+
}
|
|
6396
|
+
} break;
|
|
6397
|
+
default: {
|
|
6398
|
+
WSP_GGML_ABORT("unsupported glu op for backward pass: %s", wsp_ggml_glu_op_name(wsp_ggml_get_glu_op(tensor)));
|
|
6399
|
+
} //break;
|
|
6400
|
+
}
|
|
6401
|
+
} break;
|
|
6012
6402
|
case WSP_GGML_OP_NONE: {
|
|
6013
6403
|
// noop
|
|
6014
6404
|
} break;
|
|
6015
6405
|
case WSP_GGML_OP_COUNT:
|
|
6016
6406
|
default: {
|
|
6017
|
-
|
|
6018
|
-
WSP_GGML_ABORT("fatal error");
|
|
6407
|
+
WSP_GGML_ABORT("%s: unsupported ggml op for backward pass: %s\n", __func__, wsp_ggml_op_name(tensor->op));
|
|
6019
6408
|
} //break;
|
|
6020
6409
|
}
|
|
6021
6410
|
|
|
@@ -6522,20 +6911,18 @@ static struct wsp_ggml_tensor * wsp_ggml_graph_get_parent(const struct wsp_ggml_
|
|
|
6522
6911
|
static void wsp_ggml_graph_dump_dot_node_edge(FILE * fp, const struct wsp_ggml_cgraph * gb, struct wsp_ggml_tensor * node, struct wsp_ggml_tensor * parent, const char * label) {
|
|
6523
6912
|
struct wsp_ggml_tensor * gparent = wsp_ggml_graph_get_parent(gb, node);
|
|
6524
6913
|
struct wsp_ggml_tensor * gparent0 = wsp_ggml_graph_get_parent(gb, parent);
|
|
6525
|
-
fprintf(fp, " \"%p\"
|
|
6914
|
+
fprintf(fp, " \"%p\" -> \"%p\" [ arrowhead = %s; style = %s; label = \"%s\"; ]\n",
|
|
6526
6915
|
gparent0 ? (void *) gparent0 : (void *) parent,
|
|
6527
|
-
gparent0 ? "g" : "x",
|
|
6528
6916
|
gparent ? (void *) gparent : (void *) node,
|
|
6529
|
-
gparent ? "g" : "x",
|
|
6530
6917
|
gparent ? "empty" : "vee",
|
|
6531
6918
|
gparent ? "dashed" : "solid",
|
|
6532
6919
|
label);
|
|
6533
6920
|
}
|
|
6534
6921
|
|
|
6535
6922
|
static void wsp_ggml_graph_dump_dot_leaf_edge(FILE * fp, struct wsp_ggml_tensor * node, struct wsp_ggml_tensor * parent, const char * label) {
|
|
6536
|
-
fprintf(fp, " \"%p\"
|
|
6537
|
-
(void *) parent,
|
|
6538
|
-
(void *) node,
|
|
6923
|
+
fprintf(fp, " \"%p\" -> \"%p\" [ label = \"%s\"; ]\n",
|
|
6924
|
+
(void *) parent,
|
|
6925
|
+
(void *) node,
|
|
6539
6926
|
label);
|
|
6540
6927
|
}
|
|
6541
6928
|
|
|
@@ -6756,6 +7143,7 @@ size_t wsp_ggml_wsp_quantize_chunk(
|
|
|
6756
7143
|
case WSP_GGML_TYPE_Q5_0: result = wsp_quantize_q5_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
|
6757
7144
|
case WSP_GGML_TYPE_Q5_1: result = wsp_quantize_q5_1(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
|
6758
7145
|
case WSP_GGML_TYPE_Q8_0: result = wsp_quantize_q8_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
|
7146
|
+
case WSP_GGML_TYPE_MXFP4: result = wsp_quantize_mxfp4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
|
6759
7147
|
case WSP_GGML_TYPE_Q2_K: result = wsp_quantize_q2_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
|
6760
7148
|
case WSP_GGML_TYPE_Q3_K: result = wsp_quantize_q3_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
|
6761
7149
|
case WSP_GGML_TYPE_Q4_K: result = wsp_quantize_q4_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|