whisper.rn 0.5.0-rc.8 → 0.5.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/cpp/ggml-alloc.c +1 -15
- package/cpp/ggml-backend-reg.cpp +17 -8
- package/cpp/ggml-backend.cpp +15 -22
- package/cpp/ggml-common.h +17 -0
- package/cpp/ggml-cpu/arch/arm/quants.c +132 -596
- package/cpp/ggml-cpu/arch/arm/repack.cpp +14 -286
- package/cpp/ggml-cpu/arch/x86/quants.c +184 -675
- package/cpp/ggml-cpu/arch/x86/repack.cpp +4679 -1657
- package/cpp/ggml-cpu/arch-fallback.h +34 -0
- package/cpp/ggml-cpu/ggml-cpu.c +22 -1
- package/cpp/ggml-cpu/ggml-cpu.cpp +21 -24
- package/cpp/ggml-cpu/ops.cpp +870 -211
- package/cpp/ggml-cpu/ops.h +3 -8
- package/cpp/ggml-cpu/quants.c +35 -0
- package/cpp/ggml-cpu/quants.h +8 -0
- package/cpp/ggml-cpu/repack.cpp +458 -47
- package/cpp/ggml-cpu/repack.h +22 -0
- package/cpp/ggml-cpu/simd-mappings.h +1 -1
- package/cpp/ggml-cpu/traits.cpp +2 -2
- package/cpp/ggml-cpu/traits.h +1 -1
- package/cpp/ggml-cpu/vec.cpp +12 -9
- package/cpp/ggml-cpu/vec.h +107 -13
- package/cpp/ggml-impl.h +77 -0
- package/cpp/ggml-metal-impl.h +51 -12
- package/cpp/ggml-metal.m +610 -115
- package/cpp/ggml-opt.cpp +97 -41
- package/cpp/ggml-opt.h +25 -6
- package/cpp/ggml-quants.c +110 -16
- package/cpp/ggml-quants.h +6 -0
- package/cpp/ggml-whisper-sim.metallib +0 -0
- package/cpp/ggml-whisper.metallib +0 -0
- package/cpp/ggml.c +314 -88
- package/cpp/ggml.h +137 -11
- package/cpp/gguf.cpp +8 -1
- package/cpp/jsi/RNWhisperJSI.cpp +23 -6
- package/cpp/whisper.cpp +15 -6
- package/ios/RNWhisper.mm +6 -6
- package/ios/RNWhisperContext.mm +2 -0
- package/ios/RNWhisperVadContext.mm +2 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-common.h +17 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-impl.h +77 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-metal-impl.h +51 -12
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-opt.h +25 -6
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-quants.h +6 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml.h +137 -11
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/rnwhisper +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-common.h +17 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +77 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal-impl.h +51 -12
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-opt.h +25 -6
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-quants.h +6 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +137 -11
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-common.h +17 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-impl.h +77 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-metal-impl.h +51 -12
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-opt.h +25 -6
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-quants.h +6 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml.h +137 -11
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/rnwhisper +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-common.h +17 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +77 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal-impl.h +51 -12
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-opt.h +25 -6
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-quants.h +6 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +137 -11
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
- package/lib/commonjs/realtime-transcription/RealtimeTranscriber.js +28 -2
- package/lib/commonjs/realtime-transcription/RealtimeTranscriber.js.map +1 -1
- package/lib/module/realtime-transcription/RealtimeTranscriber.js +28 -2
- package/lib/module/realtime-transcription/RealtimeTranscriber.js.map +1 -1
- package/lib/typescript/realtime-transcription/RealtimeTranscriber.d.ts +1 -0
- package/lib/typescript/realtime-transcription/RealtimeTranscriber.d.ts.map +1 -1
- package/lib/typescript/realtime-transcription/types.d.ts +6 -0
- package/lib/typescript/realtime-transcription/types.d.ts.map +1 -1
- package/package.json +1 -1
- package/src/realtime-transcription/RealtimeTranscriber.ts +32 -0
- package/src/realtime-transcription/types.ts +6 -0
package/cpp/ggml.c
CHANGED
|
@@ -1,6 +1,14 @@
|
|
|
1
1
|
#define _CRT_SECURE_NO_DEPRECATE // Disables "unsafe" warnings on Windows
|
|
2
2
|
#define _USE_MATH_DEFINES // For M_PI on MSVC
|
|
3
3
|
|
|
4
|
+
// GGML build info
|
|
5
|
+
#ifndef WSP_GGML_VERSION
|
|
6
|
+
#define WSP_GGML_VERSION "unknown"
|
|
7
|
+
#endif
|
|
8
|
+
#ifndef WSP_GGML_COMMIT
|
|
9
|
+
#define WSP_GGML_COMMIT "unknown"
|
|
10
|
+
#endif
|
|
11
|
+
|
|
4
12
|
#include "ggml-backend.h"
|
|
5
13
|
#include "ggml-impl.h"
|
|
6
14
|
#include "ggml-threading.h"
|
|
@@ -202,19 +210,34 @@ void wsp_ggml_print_backtrace(void) {
|
|
|
202
210
|
}
|
|
203
211
|
#endif
|
|
204
212
|
|
|
213
|
+
static wsp_ggml_abort_callback_t g_abort_callback = NULL;
|
|
214
|
+
|
|
215
|
+
// Set the abort callback (passing null will restore original abort functionality: printing a message to stdout)
|
|
216
|
+
WSP_GGML_API wsp_ggml_abort_callback_t wsp_ggml_set_abort_callback(wsp_ggml_abort_callback_t callback) {
|
|
217
|
+
wsp_ggml_abort_callback_t ret_val = g_abort_callback;
|
|
218
|
+
g_abort_callback = callback;
|
|
219
|
+
return ret_val;
|
|
220
|
+
}
|
|
221
|
+
|
|
205
222
|
void wsp_ggml_abort(const char * file, int line, const char * fmt, ...) {
|
|
206
223
|
fflush(stdout);
|
|
207
224
|
|
|
208
|
-
|
|
225
|
+
char message[2048];
|
|
226
|
+
int offset = snprintf(message, sizeof(message), "%s:%d: ", file, line);
|
|
209
227
|
|
|
210
228
|
va_list args;
|
|
211
229
|
va_start(args, fmt);
|
|
212
|
-
|
|
230
|
+
vsnprintf(message + offset, sizeof(message) - offset, fmt, args);
|
|
213
231
|
va_end(args);
|
|
214
232
|
|
|
215
|
-
|
|
233
|
+
if (g_abort_callback) {
|
|
234
|
+
g_abort_callback(message);
|
|
235
|
+
} else {
|
|
236
|
+
// default: print error and backtrace to stderr
|
|
237
|
+
fprintf(stderr, "%s\n", message);
|
|
238
|
+
wsp_ggml_print_backtrace();
|
|
239
|
+
}
|
|
216
240
|
|
|
217
|
-
wsp_ggml_print_backtrace();
|
|
218
241
|
abort();
|
|
219
242
|
}
|
|
220
243
|
|
|
@@ -458,6 +481,14 @@ bool wsp_ggml_guid_matches(wsp_ggml_guid_t guid_a, wsp_ggml_guid_t guid_b) {
|
|
|
458
481
|
return memcmp(guid_a, guid_b, sizeof(wsp_ggml_guid)) == 0;
|
|
459
482
|
}
|
|
460
483
|
|
|
484
|
+
const char * wsp_ggml_version(void) {
|
|
485
|
+
return WSP_GGML_VERSION;
|
|
486
|
+
}
|
|
487
|
+
|
|
488
|
+
const char * wsp_ggml_commit(void) {
|
|
489
|
+
return WSP_GGML_COMMIT;
|
|
490
|
+
}
|
|
491
|
+
|
|
461
492
|
//
|
|
462
493
|
// timing
|
|
463
494
|
//
|
|
@@ -559,9 +590,6 @@ FILE * wsp_ggml_fopen(const char * fname, const char * mode) {
|
|
|
559
590
|
#endif
|
|
560
591
|
|
|
561
592
|
}
|
|
562
|
-
static void wsp_ggml_vec_dot_f32(int n, float * WSP_GGML_RESTRICT s, size_t bs, const float * WSP_GGML_RESTRICT x, size_t bx, const float * WSP_GGML_RESTRICT y, size_t by, int nrc);
|
|
563
|
-
static void wsp_ggml_vec_dot_f16(int n, float * WSP_GGML_RESTRICT s, size_t bs, wsp_ggml_fp16_t * WSP_GGML_RESTRICT x, size_t bx, wsp_ggml_fp16_t * WSP_GGML_RESTRICT y, size_t by, int nrc);
|
|
564
|
-
static void wsp_ggml_vec_dot_bf16(int n, float * WSP_GGML_RESTRICT s, size_t bs, wsp_ggml_bf16_t * WSP_GGML_RESTRICT x, size_t bx, wsp_ggml_bf16_t * WSP_GGML_RESTRICT y, size_t by, int nrc);
|
|
565
593
|
|
|
566
594
|
static const struct wsp_ggml_type_traits type_traits[WSP_GGML_TYPE_COUNT] = {
|
|
567
595
|
[WSP_GGML_TYPE_I8] = {
|
|
@@ -667,6 +695,14 @@ static const struct wsp_ggml_type_traits type_traits[WSP_GGML_TYPE_COUNT] = {
|
|
|
667
695
|
.is_quantized = true,
|
|
668
696
|
.from_float_ref = (wsp_ggml_from_float_t) wsp_quantize_row_q8_1_ref,
|
|
669
697
|
},
|
|
698
|
+
[WSP_GGML_TYPE_MXFP4] = {
|
|
699
|
+
.type_name = "mxfp4",
|
|
700
|
+
.blck_size = QK_MXFP4,
|
|
701
|
+
.type_size = sizeof(block_mxfp4),
|
|
702
|
+
.is_quantized = true,
|
|
703
|
+
.to_float = (wsp_ggml_to_float_t) wsp_dewsp_quantize_row_mxfp4,
|
|
704
|
+
.from_float_ref = (wsp_ggml_from_float_t)wsp_quantize_row_mxfp4_ref,
|
|
705
|
+
},
|
|
670
706
|
[WSP_GGML_TYPE_Q2_K] = {
|
|
671
707
|
.type_name = "q2_K",
|
|
672
708
|
.blck_size = QK_K,
|
|
@@ -894,6 +930,7 @@ static const char * WSP_GGML_OP_NAME[WSP_GGML_OP_COUNT] = {
|
|
|
894
930
|
|
|
895
931
|
"DUP",
|
|
896
932
|
"ADD",
|
|
933
|
+
"ADD_ID",
|
|
897
934
|
"ADD1",
|
|
898
935
|
"ACC",
|
|
899
936
|
"SUB",
|
|
@@ -983,17 +1020,19 @@ static const char * WSP_GGML_OP_NAME[WSP_GGML_OP_COUNT] = {
|
|
|
983
1020
|
"CROSS_ENTROPY_LOSS",
|
|
984
1021
|
"CROSS_ENTROPY_LOSS_BACK",
|
|
985
1022
|
"OPT_STEP_ADAMW",
|
|
1023
|
+
"OPT_STEP_SGD",
|
|
986
1024
|
|
|
987
1025
|
"GLU",
|
|
988
1026
|
};
|
|
989
1027
|
|
|
990
|
-
static_assert(WSP_GGML_OP_COUNT ==
|
|
1028
|
+
static_assert(WSP_GGML_OP_COUNT == 88, "WSP_GGML_OP_COUNT != 88");
|
|
991
1029
|
|
|
992
1030
|
static const char * WSP_GGML_OP_SYMBOL[WSP_GGML_OP_COUNT] = {
|
|
993
1031
|
"none",
|
|
994
1032
|
|
|
995
1033
|
"x",
|
|
996
1034
|
"x+y",
|
|
1035
|
+
"x[i]+y",
|
|
997
1036
|
"x+y",
|
|
998
1037
|
"view(x,nb,offset)+=y->x",
|
|
999
1038
|
"x-y",
|
|
@@ -1083,15 +1122,15 @@ static const char * WSP_GGML_OP_SYMBOL[WSP_GGML_OP_COUNT] = {
|
|
|
1083
1122
|
"cross_entropy_loss(x,y)",
|
|
1084
1123
|
"cross_entropy_loss_back(x,y)",
|
|
1085
1124
|
"adamw(x)",
|
|
1125
|
+
"sgd(x)",
|
|
1086
1126
|
|
|
1087
1127
|
"glu(x)",
|
|
1088
1128
|
};
|
|
1089
1129
|
|
|
1090
|
-
static_assert(WSP_GGML_OP_COUNT ==
|
|
1130
|
+
static_assert(WSP_GGML_OP_COUNT == 88, "WSP_GGML_OP_COUNT != 88");
|
|
1091
1131
|
|
|
1092
1132
|
static_assert(WSP_GGML_OP_POOL_COUNT == 2, "WSP_GGML_OP_POOL_COUNT != 2");
|
|
1093
1133
|
|
|
1094
|
-
|
|
1095
1134
|
static const char * WSP_GGML_UNARY_OP_NAME[WSP_GGML_UNARY_OP_COUNT] = {
|
|
1096
1135
|
"ABS",
|
|
1097
1136
|
"SGN",
|
|
@@ -1117,9 +1156,12 @@ static const char * WSP_GGML_GLU_OP_NAME[WSP_GGML_GLU_OP_COUNT] = {
|
|
|
1117
1156
|
"REGLU",
|
|
1118
1157
|
"GEGLU",
|
|
1119
1158
|
"SWIGLU",
|
|
1159
|
+
"SWIGLU_OAI",
|
|
1160
|
+
"GEGLU_ERF",
|
|
1161
|
+
"GEGLU_QUICK",
|
|
1120
1162
|
};
|
|
1121
1163
|
|
|
1122
|
-
static_assert(WSP_GGML_GLU_OP_COUNT ==
|
|
1164
|
+
static_assert(WSP_GGML_GLU_OP_COUNT == 6, "WSP_GGML_GLU_OP_COUNT != 6");
|
|
1123
1165
|
|
|
1124
1166
|
|
|
1125
1167
|
static_assert(sizeof(struct wsp_ggml_object)%WSP_GGML_MEM_ALIGN == 0, "wsp_ggml_object size must be a multiple of WSP_GGML_MEM_ALIGN");
|
|
@@ -1287,6 +1329,7 @@ enum wsp_ggml_type wsp_ggml_ftype_to_wsp_ggml_type(enum wsp_ggml_ftype ftype) {
|
|
|
1287
1329
|
case WSP_GGML_FTYPE_MOSTLY_Q5_0: wtype = WSP_GGML_TYPE_Q5_0; break;
|
|
1288
1330
|
case WSP_GGML_FTYPE_MOSTLY_Q5_1: wtype = WSP_GGML_TYPE_Q5_1; break;
|
|
1289
1331
|
case WSP_GGML_FTYPE_MOSTLY_Q8_0: wtype = WSP_GGML_TYPE_Q8_0; break;
|
|
1332
|
+
case WSP_GGML_FTYPE_MOSTLY_MXFP4: wtype = WSP_GGML_TYPE_MXFP4; break;
|
|
1290
1333
|
case WSP_GGML_FTYPE_MOSTLY_Q2_K: wtype = WSP_GGML_TYPE_Q2_K; break;
|
|
1291
1334
|
case WSP_GGML_FTYPE_MOSTLY_Q3_K: wtype = WSP_GGML_TYPE_Q3_K; break;
|
|
1292
1335
|
case WSP_GGML_FTYPE_MOSTLY_Q4_K: wtype = WSP_GGML_TYPE_Q4_K; break;
|
|
@@ -1937,6 +1980,27 @@ struct wsp_ggml_tensor * wsp_ggml_add_cast(
|
|
|
1937
1980
|
return wsp_ggml_add_cast_impl(ctx, a, b, type);
|
|
1938
1981
|
}
|
|
1939
1982
|
|
|
1983
|
+
struct wsp_ggml_tensor * wsp_ggml_add_id(
|
|
1984
|
+
struct wsp_ggml_context * ctx,
|
|
1985
|
+
struct wsp_ggml_tensor * a,
|
|
1986
|
+
struct wsp_ggml_tensor * b,
|
|
1987
|
+
struct wsp_ggml_tensor * ids) {
|
|
1988
|
+
|
|
1989
|
+
WSP_GGML_ASSERT(a->ne[0] == b->ne[0]);
|
|
1990
|
+
WSP_GGML_ASSERT(a->ne[1] == ids->ne[0]);
|
|
1991
|
+
WSP_GGML_ASSERT(a->ne[2] == ids->ne[1]);
|
|
1992
|
+
WSP_GGML_ASSERT(ids->type == WSP_GGML_TYPE_I32);
|
|
1993
|
+
|
|
1994
|
+
struct wsp_ggml_tensor * result = wsp_ggml_dup_tensor(ctx, a);
|
|
1995
|
+
|
|
1996
|
+
result->op = WSP_GGML_OP_ADD_ID;
|
|
1997
|
+
result->src[0] = a;
|
|
1998
|
+
result->src[1] = b;
|
|
1999
|
+
result->src[2] = ids;
|
|
2000
|
+
|
|
2001
|
+
return result;
|
|
2002
|
+
}
|
|
2003
|
+
|
|
1940
2004
|
// wsp_ggml_add1
|
|
1941
2005
|
|
|
1942
2006
|
static struct wsp_ggml_tensor * wsp_ggml_add1_impl(
|
|
@@ -2745,6 +2809,61 @@ struct wsp_ggml_tensor * wsp_ggml_swiglu_split(
|
|
|
2745
2809
|
return wsp_ggml_glu_impl(ctx, a, b, WSP_GGML_GLU_OP_SWIGLU, false);
|
|
2746
2810
|
}
|
|
2747
2811
|
|
|
2812
|
+
// wsp_ggml_geglu_erf
|
|
2813
|
+
|
|
2814
|
+
struct wsp_ggml_tensor * wsp_ggml_geglu_erf(
|
|
2815
|
+
struct wsp_ggml_context * ctx,
|
|
2816
|
+
struct wsp_ggml_tensor * a) {
|
|
2817
|
+
return wsp_ggml_glu_impl(ctx, a, NULL, WSP_GGML_GLU_OP_GEGLU_ERF, false);
|
|
2818
|
+
}
|
|
2819
|
+
|
|
2820
|
+
struct wsp_ggml_tensor * wsp_ggml_geglu_erf_swapped(
|
|
2821
|
+
struct wsp_ggml_context * ctx,
|
|
2822
|
+
struct wsp_ggml_tensor * a) {
|
|
2823
|
+
return wsp_ggml_glu_impl(ctx, a, NULL, WSP_GGML_GLU_OP_GEGLU_ERF, true);
|
|
2824
|
+
}
|
|
2825
|
+
|
|
2826
|
+
struct wsp_ggml_tensor * wsp_ggml_geglu_erf_split(
|
|
2827
|
+
struct wsp_ggml_context * ctx,
|
|
2828
|
+
struct wsp_ggml_tensor * a,
|
|
2829
|
+
struct wsp_ggml_tensor * b) {
|
|
2830
|
+
return wsp_ggml_glu_impl(ctx, a, b, WSP_GGML_GLU_OP_GEGLU_ERF, false);
|
|
2831
|
+
}
|
|
2832
|
+
|
|
2833
|
+
// wsp_ggml_geglu_quick
|
|
2834
|
+
|
|
2835
|
+
struct wsp_ggml_tensor * wsp_ggml_geglu_quick(
|
|
2836
|
+
struct wsp_ggml_context * ctx,
|
|
2837
|
+
struct wsp_ggml_tensor * a) {
|
|
2838
|
+
return wsp_ggml_glu_impl(ctx, a, NULL, WSP_GGML_GLU_OP_GEGLU_QUICK, false);
|
|
2839
|
+
}
|
|
2840
|
+
|
|
2841
|
+
struct wsp_ggml_tensor * wsp_ggml_geglu_quick_swapped(
|
|
2842
|
+
struct wsp_ggml_context * ctx,
|
|
2843
|
+
struct wsp_ggml_tensor * a) {
|
|
2844
|
+
return wsp_ggml_glu_impl(ctx, a, NULL, WSP_GGML_GLU_OP_GEGLU_QUICK, true);
|
|
2845
|
+
}
|
|
2846
|
+
|
|
2847
|
+
struct wsp_ggml_tensor * wsp_ggml_geglu_quick_split(
|
|
2848
|
+
struct wsp_ggml_context * ctx,
|
|
2849
|
+
struct wsp_ggml_tensor * a,
|
|
2850
|
+
struct wsp_ggml_tensor * b) {
|
|
2851
|
+
return wsp_ggml_glu_impl(ctx, a, b, WSP_GGML_GLU_OP_GEGLU_QUICK, false);
|
|
2852
|
+
}
|
|
2853
|
+
|
|
2854
|
+
struct wsp_ggml_tensor * wsp_ggml_swiglu_oai(
|
|
2855
|
+
struct wsp_ggml_context * ctx,
|
|
2856
|
+
struct wsp_ggml_tensor * a,
|
|
2857
|
+
struct wsp_ggml_tensor * b,
|
|
2858
|
+
float alpha,
|
|
2859
|
+
float limit) {
|
|
2860
|
+
struct wsp_ggml_tensor * result = wsp_ggml_glu_impl(ctx, a, b, WSP_GGML_GLU_OP_SWIGLU_OAI, false);
|
|
2861
|
+
wsp_ggml_set_op_params_f32(result, 2, alpha);
|
|
2862
|
+
wsp_ggml_set_op_params_f32(result, 3, limit);
|
|
2863
|
+
|
|
2864
|
+
return result;
|
|
2865
|
+
}
|
|
2866
|
+
|
|
2748
2867
|
// wsp_ggml_norm
|
|
2749
2868
|
|
|
2750
2869
|
static struct wsp_ggml_tensor * wsp_ggml_norm_impl(
|
|
@@ -3002,12 +3121,14 @@ static struct wsp_ggml_tensor * wsp_ggml_scale_impl(
|
|
|
3002
3121
|
struct wsp_ggml_context * ctx,
|
|
3003
3122
|
struct wsp_ggml_tensor * a,
|
|
3004
3123
|
float s,
|
|
3124
|
+
float b,
|
|
3005
3125
|
bool inplace) {
|
|
3006
3126
|
WSP_GGML_ASSERT(wsp_ggml_is_padded_1d(a));
|
|
3007
3127
|
|
|
3008
3128
|
struct wsp_ggml_tensor * result = inplace ? wsp_ggml_view_tensor(ctx, a) : wsp_ggml_dup_tensor(ctx, a);
|
|
3009
3129
|
|
|
3010
|
-
|
|
3130
|
+
float params[2] = { s, b };
|
|
3131
|
+
wsp_ggml_set_op_params(result, ¶ms, sizeof(params));
|
|
3011
3132
|
|
|
3012
3133
|
result->op = WSP_GGML_OP_SCALE;
|
|
3013
3134
|
result->src[0] = a;
|
|
@@ -3019,14 +3140,30 @@ struct wsp_ggml_tensor * wsp_ggml_scale(
|
|
|
3019
3140
|
struct wsp_ggml_context * ctx,
|
|
3020
3141
|
struct wsp_ggml_tensor * a,
|
|
3021
3142
|
float s) {
|
|
3022
|
-
return wsp_ggml_scale_impl(ctx, a, s, false);
|
|
3143
|
+
return wsp_ggml_scale_impl(ctx, a, s, 0.0, false);
|
|
3023
3144
|
}
|
|
3024
3145
|
|
|
3025
3146
|
struct wsp_ggml_tensor * wsp_ggml_scale_inplace(
|
|
3026
3147
|
struct wsp_ggml_context * ctx,
|
|
3027
3148
|
struct wsp_ggml_tensor * a,
|
|
3028
3149
|
float s) {
|
|
3029
|
-
return wsp_ggml_scale_impl(ctx, a, s, true);
|
|
3150
|
+
return wsp_ggml_scale_impl(ctx, a, s, 0.0, true);
|
|
3151
|
+
}
|
|
3152
|
+
|
|
3153
|
+
struct wsp_ggml_tensor * wsp_ggml_scale_bias(
|
|
3154
|
+
struct wsp_ggml_context * ctx,
|
|
3155
|
+
struct wsp_ggml_tensor * a,
|
|
3156
|
+
float s,
|
|
3157
|
+
float b) {
|
|
3158
|
+
return wsp_ggml_scale_impl(ctx, a, s, b, false);
|
|
3159
|
+
}
|
|
3160
|
+
|
|
3161
|
+
struct wsp_ggml_tensor * wsp_ggml_scale_bias_inplace(
|
|
3162
|
+
struct wsp_ggml_context * ctx,
|
|
3163
|
+
struct wsp_ggml_tensor * a,
|
|
3164
|
+
float s,
|
|
3165
|
+
float b) {
|
|
3166
|
+
return wsp_ggml_scale_impl(ctx, a, s, b, true);
|
|
3030
3167
|
}
|
|
3031
3168
|
|
|
3032
3169
|
// wsp_ggml_set
|
|
@@ -3651,9 +3788,10 @@ static struct wsp_ggml_tensor * wsp_ggml_soft_max_impl(
|
|
|
3651
3788
|
if (mask) {
|
|
3652
3789
|
WSP_GGML_ASSERT(mask->type == WSP_GGML_TYPE_F16 || mask->type == WSP_GGML_TYPE_F32);
|
|
3653
3790
|
WSP_GGML_ASSERT(wsp_ggml_is_contiguous(mask));
|
|
3654
|
-
WSP_GGML_ASSERT(wsp_ggml_is_matrix(mask));
|
|
3655
3791
|
WSP_GGML_ASSERT(mask->ne[0] == a->ne[0]);
|
|
3656
3792
|
WSP_GGML_ASSERT(mask->ne[1] >= a->ne[1]);
|
|
3793
|
+
WSP_GGML_ASSERT(a->ne[2]%mask->ne[2] == 0);
|
|
3794
|
+
WSP_GGML_ASSERT(a->ne[3]%mask->ne[3] == 0);
|
|
3657
3795
|
}
|
|
3658
3796
|
|
|
3659
3797
|
if (max_bias > 0.0f) {
|
|
@@ -3693,6 +3831,22 @@ struct wsp_ggml_tensor * wsp_ggml_soft_max_ext(
|
|
|
3693
3831
|
return wsp_ggml_soft_max_impl(ctx, a, mask, scale, max_bias, false);
|
|
3694
3832
|
}
|
|
3695
3833
|
|
|
3834
|
+
void wsp_ggml_soft_max_add_sinks(
|
|
3835
|
+
struct wsp_ggml_tensor * a,
|
|
3836
|
+
struct wsp_ggml_tensor * sinks) {
|
|
3837
|
+
if (!sinks) {
|
|
3838
|
+
a->src[2] = NULL;
|
|
3839
|
+
return;
|
|
3840
|
+
}
|
|
3841
|
+
|
|
3842
|
+
WSP_GGML_ASSERT(a->op == WSP_GGML_OP_SOFT_MAX);
|
|
3843
|
+
WSP_GGML_ASSERT(a->src[2] == NULL);
|
|
3844
|
+
WSP_GGML_ASSERT(a->src[0]->ne[2] == sinks->ne[0]);
|
|
3845
|
+
WSP_GGML_ASSERT(sinks->type == WSP_GGML_TYPE_F32);
|
|
3846
|
+
|
|
3847
|
+
a->src[2] = sinks;
|
|
3848
|
+
}
|
|
3849
|
+
|
|
3696
3850
|
// wsp_ggml_soft_max_ext_back
|
|
3697
3851
|
|
|
3698
3852
|
static struct wsp_ggml_tensor * wsp_ggml_soft_max_ext_back_impl(
|
|
@@ -3740,6 +3894,7 @@ static struct wsp_ggml_tensor * wsp_ggml_rope_impl(
|
|
|
3740
3894
|
struct wsp_ggml_tensor * b,
|
|
3741
3895
|
struct wsp_ggml_tensor * c,
|
|
3742
3896
|
int n_dims,
|
|
3897
|
+
int sections[WSP_GGML_MROPE_SECTIONS],
|
|
3743
3898
|
int mode,
|
|
3744
3899
|
int n_ctx_orig,
|
|
3745
3900
|
float freq_base,
|
|
@@ -3753,15 +3908,19 @@ static struct wsp_ggml_tensor * wsp_ggml_rope_impl(
|
|
|
3753
3908
|
|
|
3754
3909
|
WSP_GGML_ASSERT(wsp_ggml_is_vector(b));
|
|
3755
3910
|
WSP_GGML_ASSERT(b->type == WSP_GGML_TYPE_I32);
|
|
3756
|
-
|
|
3911
|
+
|
|
3912
|
+
bool mrope_used = mode & WSP_GGML_ROPE_TYPE_MROPE;
|
|
3913
|
+
if (mrope_used) {
|
|
3914
|
+
WSP_GGML_ASSERT(a->ne[2] * 4 == b->ne[0]); // mrope expecting 4 position ids per token
|
|
3915
|
+
} else {
|
|
3916
|
+
WSP_GGML_ASSERT(a->ne[2] == b->ne[0]);
|
|
3917
|
+
}
|
|
3757
3918
|
|
|
3758
3919
|
if (c) {
|
|
3759
3920
|
WSP_GGML_ASSERT(c->type == WSP_GGML_TYPE_F32);
|
|
3760
3921
|
WSP_GGML_ASSERT(c->ne[0] >= n_dims / 2);
|
|
3761
3922
|
}
|
|
3762
3923
|
|
|
3763
|
-
int sections[4] = {0, 0, 0, 0};
|
|
3764
|
-
|
|
3765
3924
|
struct wsp_ggml_tensor * result = inplace ? wsp_ggml_view_tensor(ctx, a) : wsp_ggml_dup_tensor(ctx, a);
|
|
3766
3925
|
|
|
3767
3926
|
int32_t params[15] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig };
|
|
@@ -3771,7 +3930,11 @@ static struct wsp_ggml_tensor * wsp_ggml_rope_impl(
|
|
|
3771
3930
|
memcpy(params + 8, &attn_factor, sizeof(float));
|
|
3772
3931
|
memcpy(params + 9, &beta_fast, sizeof(float));
|
|
3773
3932
|
memcpy(params + 10, &beta_slow, sizeof(float));
|
|
3774
|
-
|
|
3933
|
+
if (mrope_used) {
|
|
3934
|
+
memcpy(params + 11, sections, sizeof(int32_t) * WSP_GGML_MROPE_SECTIONS);
|
|
3935
|
+
} else {
|
|
3936
|
+
memset(params + 11, 0, sizeof(int32_t) * WSP_GGML_MROPE_SECTIONS);
|
|
3937
|
+
}
|
|
3775
3938
|
wsp_ggml_set_op_params(result, params, sizeof(params));
|
|
3776
3939
|
|
|
3777
3940
|
result->op = WSP_GGML_OP_ROPE;
|
|
@@ -3789,7 +3952,7 @@ struct wsp_ggml_tensor * wsp_ggml_rope(
|
|
|
3789
3952
|
int n_dims,
|
|
3790
3953
|
int mode) {
|
|
3791
3954
|
return wsp_ggml_rope_impl(
|
|
3792
|
-
ctx, a, b, NULL, n_dims, mode, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, false
|
|
3955
|
+
ctx, a, b, NULL, n_dims, NULL, mode, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, false
|
|
3793
3956
|
);
|
|
3794
3957
|
}
|
|
3795
3958
|
|
|
@@ -3799,7 +3962,7 @@ struct wsp_ggml_tensor * wsp_ggml_rope_multi(
|
|
|
3799
3962
|
struct wsp_ggml_tensor * b,
|
|
3800
3963
|
struct wsp_ggml_tensor * c,
|
|
3801
3964
|
int n_dims,
|
|
3802
|
-
int sections[
|
|
3965
|
+
int sections[WSP_GGML_MROPE_SECTIONS],
|
|
3803
3966
|
int mode,
|
|
3804
3967
|
int n_ctx_orig,
|
|
3805
3968
|
float freq_base,
|
|
@@ -3808,36 +3971,31 @@ struct wsp_ggml_tensor * wsp_ggml_rope_multi(
|
|
|
3808
3971
|
float attn_factor,
|
|
3809
3972
|
float beta_fast,
|
|
3810
3973
|
float beta_slow) {
|
|
3811
|
-
|
|
3812
|
-
|
|
3813
|
-
|
|
3814
|
-
|
|
3815
|
-
|
|
3816
|
-
WSP_GGML_ASSERT(a->ne[2] * 4 == b->ne[0]); // mrope expecting 4 position ids per token
|
|
3817
|
-
|
|
3818
|
-
if (c) {
|
|
3819
|
-
WSP_GGML_ASSERT(c->type == WSP_GGML_TYPE_F32);
|
|
3820
|
-
WSP_GGML_ASSERT(c->ne[0] >= n_dims / 2);
|
|
3821
|
-
}
|
|
3822
|
-
|
|
3823
|
-
struct wsp_ggml_tensor * result = wsp_ggml_dup_tensor(ctx, a);
|
|
3824
|
-
|
|
3825
|
-
int32_t params[11 + 4] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig };
|
|
3826
|
-
memcpy(params + 5, &freq_base, sizeof(float));
|
|
3827
|
-
memcpy(params + 6, &freq_scale, sizeof(float));
|
|
3828
|
-
memcpy(params + 7, &ext_factor, sizeof(float));
|
|
3829
|
-
memcpy(params + 8, &attn_factor, sizeof(float));
|
|
3830
|
-
memcpy(params + 9, &beta_fast, sizeof(float));
|
|
3831
|
-
memcpy(params + 10, &beta_slow, sizeof(float));
|
|
3832
|
-
memcpy(¶ms[11], sections, sizeof(int)*4);
|
|
3833
|
-
wsp_ggml_set_op_params(result, params, sizeof(params));
|
|
3834
|
-
|
|
3835
|
-
result->op = WSP_GGML_OP_ROPE;
|
|
3836
|
-
result->src[0] = a;
|
|
3837
|
-
result->src[1] = b;
|
|
3838
|
-
result->src[2] = c;
|
|
3974
|
+
return wsp_ggml_rope_impl(
|
|
3975
|
+
ctx, a, b, c, n_dims, sections, mode, n_ctx_orig, freq_base, freq_scale,
|
|
3976
|
+
ext_factor, attn_factor, beta_fast, beta_slow, false
|
|
3977
|
+
);
|
|
3978
|
+
}
|
|
3839
3979
|
|
|
3840
|
-
|
|
3980
|
+
struct wsp_ggml_tensor * wsp_ggml_rope_multi_inplace(
|
|
3981
|
+
struct wsp_ggml_context * ctx,
|
|
3982
|
+
struct wsp_ggml_tensor * a,
|
|
3983
|
+
struct wsp_ggml_tensor * b,
|
|
3984
|
+
struct wsp_ggml_tensor * c,
|
|
3985
|
+
int n_dims,
|
|
3986
|
+
int sections[WSP_GGML_MROPE_SECTIONS],
|
|
3987
|
+
int mode,
|
|
3988
|
+
int n_ctx_orig,
|
|
3989
|
+
float freq_base,
|
|
3990
|
+
float freq_scale,
|
|
3991
|
+
float ext_factor,
|
|
3992
|
+
float attn_factor,
|
|
3993
|
+
float beta_fast,
|
|
3994
|
+
float beta_slow) {
|
|
3995
|
+
return wsp_ggml_rope_impl(
|
|
3996
|
+
ctx, a, b, c, n_dims, sections, mode, n_ctx_orig, freq_base, freq_scale,
|
|
3997
|
+
ext_factor, attn_factor, beta_fast, beta_slow, true
|
|
3998
|
+
);
|
|
3841
3999
|
}
|
|
3842
4000
|
|
|
3843
4001
|
struct wsp_ggml_tensor * wsp_ggml_rope_inplace(
|
|
@@ -3847,7 +4005,7 @@ struct wsp_ggml_tensor * wsp_ggml_rope_inplace(
|
|
|
3847
4005
|
int n_dims,
|
|
3848
4006
|
int mode) {
|
|
3849
4007
|
return wsp_ggml_rope_impl(
|
|
3850
|
-
ctx, a, b, NULL, n_dims, mode, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, true
|
|
4008
|
+
ctx, a, b, NULL, n_dims, NULL, mode, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, true
|
|
3851
4009
|
);
|
|
3852
4010
|
}
|
|
3853
4011
|
|
|
@@ -3866,7 +4024,7 @@ struct wsp_ggml_tensor * wsp_ggml_rope_ext(
|
|
|
3866
4024
|
float beta_fast,
|
|
3867
4025
|
float beta_slow) {
|
|
3868
4026
|
return wsp_ggml_rope_impl(
|
|
3869
|
-
ctx, a, b, c, n_dims, mode, n_ctx_orig, freq_base, freq_scale,
|
|
4027
|
+
ctx, a, b, c, n_dims, NULL, mode, n_ctx_orig, freq_base, freq_scale,
|
|
3870
4028
|
ext_factor, attn_factor, beta_fast, beta_slow, false
|
|
3871
4029
|
);
|
|
3872
4030
|
}
|
|
@@ -3886,7 +4044,7 @@ struct wsp_ggml_tensor * wsp_ggml_rope_ext_inplace(
|
|
|
3886
4044
|
float beta_fast,
|
|
3887
4045
|
float beta_slow) {
|
|
3888
4046
|
return wsp_ggml_rope_impl(
|
|
3889
|
-
ctx, a, b, c, n_dims, mode, n_ctx_orig, freq_base, freq_scale,
|
|
4047
|
+
ctx, a, b, c, n_dims, NULL, mode, n_ctx_orig, freq_base, freq_scale,
|
|
3890
4048
|
ext_factor, attn_factor, beta_fast, beta_slow, true
|
|
3891
4049
|
);
|
|
3892
4050
|
}
|
|
@@ -3905,7 +4063,7 @@ struct wsp_ggml_tensor * wsp_ggml_rope_custom(
|
|
|
3905
4063
|
float beta_fast,
|
|
3906
4064
|
float beta_slow) {
|
|
3907
4065
|
return wsp_ggml_rope_impl(
|
|
3908
|
-
ctx, a, b, NULL, n_dims, mode, n_ctx_orig, freq_base, freq_scale,
|
|
4066
|
+
ctx, a, b, NULL, n_dims, NULL, mode, n_ctx_orig, freq_base, freq_scale,
|
|
3909
4067
|
ext_factor, attn_factor, beta_fast, beta_slow, false
|
|
3910
4068
|
);
|
|
3911
4069
|
}
|
|
@@ -3924,7 +4082,7 @@ struct wsp_ggml_tensor * wsp_ggml_rope_custom_inplace(
|
|
|
3924
4082
|
float beta_fast,
|
|
3925
4083
|
float beta_slow) {
|
|
3926
4084
|
return wsp_ggml_rope_impl(
|
|
3927
|
-
ctx, a, b, NULL, n_dims, mode, n_ctx_orig, freq_base, freq_scale,
|
|
4085
|
+
ctx, a, b, NULL, n_dims, NULL, mode, n_ctx_orig, freq_base, freq_scale,
|
|
3928
4086
|
ext_factor, attn_factor, beta_fast, beta_slow, true
|
|
3929
4087
|
);
|
|
3930
4088
|
}
|
|
@@ -4122,14 +4280,13 @@ struct wsp_ggml_tensor * wsp_ggml_conv_1d_dw(
|
|
|
4122
4280
|
int s0,
|
|
4123
4281
|
int p0,
|
|
4124
4282
|
int d0) {
|
|
4125
|
-
struct wsp_ggml_tensor * new_a = wsp_ggml_reshape_4d(ctx, a, a->ne[0], 1, a->ne[1], a->ne[2]);
|
|
4126
4283
|
struct wsp_ggml_tensor * new_b = wsp_ggml_reshape_4d(ctx, b, b->ne[0], 1, b->ne[1], b->ne[2]);
|
|
4127
4284
|
|
|
4128
|
-
struct wsp_ggml_tensor * im2col = wsp_ggml_im2col(ctx,
|
|
4285
|
+
struct wsp_ggml_tensor * im2col = wsp_ggml_im2col(ctx, a, new_b, s0, 0, p0, 0, d0, 0, false, WSP_GGML_TYPE_F16);
|
|
4129
4286
|
|
|
4130
4287
|
struct wsp_ggml_tensor * result = wsp_ggml_mul_mat(ctx, im2col, a);
|
|
4131
4288
|
|
|
4132
|
-
result = wsp_ggml_reshape_3d(ctx, result,
|
|
4289
|
+
result = wsp_ggml_reshape_3d(ctx, result, result->ne[0], result->ne[2], 1);
|
|
4133
4290
|
|
|
4134
4291
|
return result;
|
|
4135
4292
|
}
|
|
@@ -4674,13 +4831,17 @@ struct wsp_ggml_tensor * wsp_ggml_flash_attn_ext(
|
|
|
4674
4831
|
WSP_GGML_ASSERT(wsp_ggml_can_mul_mat(k, q));
|
|
4675
4832
|
// TODO: check if vT can be multiplied by (k*qT)
|
|
4676
4833
|
|
|
4834
|
+
WSP_GGML_ASSERT(q->ne[3] == k->ne[3]);
|
|
4835
|
+
WSP_GGML_ASSERT(q->ne[3] == v->ne[3]);
|
|
4836
|
+
|
|
4677
4837
|
if (mask) {
|
|
4678
4838
|
WSP_GGML_ASSERT(wsp_ggml_is_contiguous(mask));
|
|
4679
|
-
WSP_GGML_ASSERT(mask->ne[2] == 1);
|
|
4680
|
-
WSP_GGML_ASSERT(mask->ne[3] == 1);
|
|
4681
4839
|
WSP_GGML_ASSERT(mask->ne[1] >= WSP_GGML_PAD(q->ne[1], WSP_GGML_KQ_MASK_PAD) &&
|
|
4682
4840
|
"the Flash-Attention kernel requires the mask to be padded to WSP_GGML_KQ_MASK_PAD and at least n_queries big");
|
|
4683
4841
|
//WSP_GGML_ASSERT(wsp_ggml_can_repeat_rows(mask, qk));
|
|
4842
|
+
|
|
4843
|
+
WSP_GGML_ASSERT(q->ne[2] % mask->ne[2] == 0);
|
|
4844
|
+
WSP_GGML_ASSERT(q->ne[3] % mask->ne[3] == 0);
|
|
4684
4845
|
}
|
|
4685
4846
|
|
|
4686
4847
|
if (max_bias > 0.0f) {
|
|
@@ -4722,6 +4883,22 @@ enum wsp_ggml_prec wsp_ggml_flash_attn_ext_get_prec(
|
|
|
4722
4883
|
return (enum wsp_ggml_prec) prec_i32;
|
|
4723
4884
|
}
|
|
4724
4885
|
|
|
4886
|
+
void wsp_ggml_flash_attn_ext_add_sinks(
|
|
4887
|
+
struct wsp_ggml_tensor * a,
|
|
4888
|
+
struct wsp_ggml_tensor * sinks) {
|
|
4889
|
+
if (!sinks) {
|
|
4890
|
+
a->src[4] = NULL;
|
|
4891
|
+
return;
|
|
4892
|
+
}
|
|
4893
|
+
|
|
4894
|
+
WSP_GGML_ASSERT(a->op == WSP_GGML_OP_FLASH_ATTN_EXT);
|
|
4895
|
+
WSP_GGML_ASSERT(a->src[4] == NULL);
|
|
4896
|
+
WSP_GGML_ASSERT(a->src[0]->ne[2] == sinks->ne[0]);
|
|
4897
|
+
WSP_GGML_ASSERT(sinks->type == WSP_GGML_TYPE_F32);
|
|
4898
|
+
|
|
4899
|
+
a->src[4] = sinks;
|
|
4900
|
+
}
|
|
4901
|
+
|
|
4725
4902
|
// wsp_ggml_flash_attn_back
|
|
4726
4903
|
|
|
4727
4904
|
struct wsp_ggml_tensor * wsp_ggml_flash_attn_back(
|
|
@@ -4808,7 +4985,6 @@ struct wsp_ggml_tensor * wsp_ggml_ssm_conv(
|
|
|
4808
4985
|
const int64_t n_s = sx->ne[2];
|
|
4809
4986
|
|
|
4810
4987
|
// TODO: maybe support other strides than 1?
|
|
4811
|
-
// FIXME: this is always true?
|
|
4812
4988
|
WSP_GGML_ASSERT(sx->ne[0] == d_conv - 1 + n_t);
|
|
4813
4989
|
WSP_GGML_ASSERT(sx->ne[1] == d_inner);
|
|
4814
4990
|
WSP_GGML_ASSERT(n_t >= 0);
|
|
@@ -4831,36 +5007,49 @@ struct wsp_ggml_tensor * wsp_ggml_ssm_scan(
|
|
|
4831
5007
|
struct wsp_ggml_tensor * dt,
|
|
4832
5008
|
struct wsp_ggml_tensor * A,
|
|
4833
5009
|
struct wsp_ggml_tensor * B,
|
|
4834
|
-
struct wsp_ggml_tensor * C
|
|
5010
|
+
struct wsp_ggml_tensor * C,
|
|
5011
|
+
struct wsp_ggml_tensor * ids) {
|
|
4835
5012
|
WSP_GGML_ASSERT(wsp_ggml_is_contiguous(s));
|
|
4836
|
-
WSP_GGML_ASSERT(wsp_ggml_is_contiguous(x));
|
|
4837
5013
|
WSP_GGML_ASSERT(wsp_ggml_is_contiguous(dt));
|
|
4838
5014
|
WSP_GGML_ASSERT(wsp_ggml_is_contiguous(A));
|
|
4839
|
-
WSP_GGML_ASSERT(
|
|
4840
|
-
WSP_GGML_ASSERT(wsp_ggml_is_3d(B));
|
|
4841
|
-
WSP_GGML_ASSERT(wsp_ggml_is_3d(s));
|
|
5015
|
+
WSP_GGML_ASSERT(x->nb[0] == wsp_ggml_type_size(x->type));
|
|
4842
5016
|
WSP_GGML_ASSERT(B->nb[0] == wsp_ggml_type_size(B->type));
|
|
4843
5017
|
WSP_GGML_ASSERT(C->nb[0] == wsp_ggml_type_size(C->type));
|
|
4844
|
-
WSP_GGML_ASSERT(
|
|
5018
|
+
WSP_GGML_ASSERT(x->nb[1] == x->ne[0]*x->nb[0]);
|
|
5019
|
+
WSP_GGML_ASSERT(B->nb[1] == B->ne[0]*B->nb[0]);
|
|
5020
|
+
WSP_GGML_ASSERT(C->nb[1] == C->ne[0]*C->nb[0]);
|
|
4845
5021
|
WSP_GGML_ASSERT(wsp_ggml_are_same_shape(B, C));
|
|
5022
|
+
WSP_GGML_ASSERT(ids->type == WSP_GGML_TYPE_I32);
|
|
4846
5023
|
|
|
4847
5024
|
{
|
|
4848
5025
|
const int64_t d_state = s->ne[0];
|
|
4849
|
-
const int64_t
|
|
4850
|
-
const int64_t
|
|
4851
|
-
const int64_t
|
|
4852
|
-
|
|
4853
|
-
|
|
4854
|
-
WSP_GGML_ASSERT(
|
|
4855
|
-
WSP_GGML_ASSERT(
|
|
4856
|
-
WSP_GGML_ASSERT(
|
|
5026
|
+
const int64_t head_dim = x->ne[0];
|
|
5027
|
+
const int64_t n_head = x->ne[1];
|
|
5028
|
+
const int64_t n_seq_tokens = x->ne[2];
|
|
5029
|
+
const int64_t n_seqs = x->ne[3];
|
|
5030
|
+
|
|
5031
|
+
WSP_GGML_ASSERT(dt->ne[0] == n_head);
|
|
5032
|
+
WSP_GGML_ASSERT(dt->ne[1] == n_seq_tokens);
|
|
5033
|
+
WSP_GGML_ASSERT(dt->ne[2] == n_seqs);
|
|
5034
|
+
WSP_GGML_ASSERT(wsp_ggml_is_3d(dt));
|
|
5035
|
+
WSP_GGML_ASSERT(s->ne[1] == head_dim);
|
|
5036
|
+
WSP_GGML_ASSERT(s->ne[2] == n_head);
|
|
4857
5037
|
WSP_GGML_ASSERT(B->ne[0] == d_state);
|
|
4858
|
-
WSP_GGML_ASSERT(B->ne[
|
|
4859
|
-
WSP_GGML_ASSERT(B->ne[
|
|
5038
|
+
WSP_GGML_ASSERT(B->ne[2] == n_seq_tokens);
|
|
5039
|
+
WSP_GGML_ASSERT(B->ne[3] == n_seqs);
|
|
5040
|
+
WSP_GGML_ASSERT(ids->ne[0] == n_seqs);
|
|
5041
|
+
WSP_GGML_ASSERT(wsp_ggml_is_vector(ids));
|
|
5042
|
+
WSP_GGML_ASSERT(A->ne[1] == n_head);
|
|
5043
|
+
WSP_GGML_ASSERT(wsp_ggml_is_matrix(A));
|
|
5044
|
+
|
|
5045
|
+
if (A->ne[0] != 1) {
|
|
5046
|
+
// Mamba-1 has more granular decay factors
|
|
5047
|
+
WSP_GGML_ASSERT(A->ne[0] == d_state);
|
|
5048
|
+
}
|
|
4860
5049
|
}
|
|
4861
5050
|
|
|
4862
5051
|
// concatenated y + ssm_states
|
|
4863
|
-
struct wsp_ggml_tensor * result = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, wsp_ggml_nelements(x) +
|
|
5052
|
+
struct wsp_ggml_tensor * result = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, wsp_ggml_nelements(x) + s->ne[0]*s->ne[1]*s->ne[2]*ids->ne[0]);
|
|
4864
5053
|
|
|
4865
5054
|
result->op = WSP_GGML_OP_SSM_SCAN;
|
|
4866
5055
|
result->src[0] = s;
|
|
@@ -4869,6 +5058,7 @@ struct wsp_ggml_tensor * wsp_ggml_ssm_scan(
|
|
|
4869
5058
|
result->src[3] = A;
|
|
4870
5059
|
result->src[4] = B;
|
|
4871
5060
|
result->src[5] = C;
|
|
5061
|
+
result->src[6] = ids;
|
|
4872
5062
|
|
|
4873
5063
|
return result;
|
|
4874
5064
|
}
|
|
@@ -5424,6 +5614,28 @@ struct wsp_ggml_tensor * wsp_ggml_opt_step_adamw(
|
|
|
5424
5614
|
return result;
|
|
5425
5615
|
}
|
|
5426
5616
|
|
|
5617
|
+
// opt_step_sgd
|
|
5618
|
+
|
|
5619
|
+
struct wsp_ggml_tensor * wsp_ggml_opt_step_sgd(
|
|
5620
|
+
struct wsp_ggml_context * ctx,
|
|
5621
|
+
struct wsp_ggml_tensor * a,
|
|
5622
|
+
struct wsp_ggml_tensor * grad,
|
|
5623
|
+
struct wsp_ggml_tensor * params) {
|
|
5624
|
+
WSP_GGML_ASSERT(a->flags & WSP_GGML_TENSOR_FLAG_PARAM);
|
|
5625
|
+
WSP_GGML_ASSERT(wsp_ggml_are_same_shape(a, grad));
|
|
5626
|
+
WSP_GGML_ASSERT(params->type == WSP_GGML_TYPE_F32);
|
|
5627
|
+
WSP_GGML_ASSERT(wsp_ggml_nelements(params) == 2);
|
|
5628
|
+
|
|
5629
|
+
struct wsp_ggml_tensor * result = wsp_ggml_view_tensor(ctx, a);
|
|
5630
|
+
|
|
5631
|
+
result->op = WSP_GGML_OP_OPT_STEP_SGD;
|
|
5632
|
+
result->src[0] = a;
|
|
5633
|
+
result->src[1] = grad;
|
|
5634
|
+
result->src[2] = params;
|
|
5635
|
+
|
|
5636
|
+
return result;
|
|
5637
|
+
}
|
|
5638
|
+
|
|
5427
5639
|
////////////////////////////////////////////////////////////////////////////////
|
|
5428
5640
|
|
|
5429
5641
|
struct wsp_ggml_hash_set wsp_ggml_hash_set_new(size_t size) {
|
|
@@ -5692,7 +5904,7 @@ static void wsp_ggml_compute_backward(
|
|
|
5692
5904
|
} break;
|
|
5693
5905
|
case WSP_GGML_OP_MEAN: {
|
|
5694
5906
|
if (src0_needs_grads) {
|
|
5695
|
-
wsp_ggml_add1_or_set(ctx, cgraph, isrc0, wsp_ggml_scale_impl(ctx, grad, 1.0f/src0->ne[0], false));
|
|
5907
|
+
wsp_ggml_add1_or_set(ctx, cgraph, isrc0, wsp_ggml_scale_impl(ctx, grad, 1.0f/src0->ne[0], 0.0, false));
|
|
5696
5908
|
}
|
|
5697
5909
|
} break;
|
|
5698
5910
|
case WSP_GGML_OP_REPEAT: {
|
|
@@ -5769,7 +5981,7 @@ static void wsp_ggml_compute_backward(
|
|
|
5769
5981
|
if (src0_needs_grads) {
|
|
5770
5982
|
float s;
|
|
5771
5983
|
memcpy(&s, tensor->op_params, sizeof(float));
|
|
5772
|
-
wsp_ggml_add_or_set(ctx, cgraph, isrc0, wsp_ggml_scale_impl(ctx, grad, s, false));
|
|
5984
|
+
wsp_ggml_add_or_set(ctx, cgraph, isrc0, wsp_ggml_scale_impl(ctx, grad, s, 0.0, false));
|
|
5773
5985
|
}
|
|
5774
5986
|
} break;
|
|
5775
5987
|
case WSP_GGML_OP_SET: {
|
|
@@ -6009,13 +6221,28 @@ static void wsp_ggml_compute_backward(
|
|
|
6009
6221
|
}
|
|
6010
6222
|
WSP_GGML_ASSERT(!src1_needs_grads && "backward pass for labels not implemented");
|
|
6011
6223
|
} break;
|
|
6224
|
+
case WSP_GGML_OP_GLU: {
|
|
6225
|
+
switch (wsp_ggml_get_glu_op(tensor)) {
|
|
6226
|
+
case WSP_GGML_GLU_OP_SWIGLU: {
|
|
6227
|
+
if (src0_needs_grads) {
|
|
6228
|
+
WSP_GGML_ASSERT(src1 && "backward pass only implemented for split swiglu");
|
|
6229
|
+
wsp_ggml_add_or_set(ctx, cgraph, isrc0, wsp_ggml_silu_back(ctx, wsp_ggml_mul(ctx, grad, src1), src0));
|
|
6230
|
+
}
|
|
6231
|
+
if (src1_needs_grads) {
|
|
6232
|
+
wsp_ggml_add_or_set(ctx, cgraph, isrc1, wsp_ggml_mul(ctx, wsp_ggml_silu(ctx, src0), grad));
|
|
6233
|
+
}
|
|
6234
|
+
} break;
|
|
6235
|
+
default: {
|
|
6236
|
+
WSP_GGML_ABORT("unsupported glu op for backward pass: %s", wsp_ggml_glu_op_name(wsp_ggml_get_glu_op(tensor)));
|
|
6237
|
+
} //break;
|
|
6238
|
+
}
|
|
6239
|
+
} break;
|
|
6012
6240
|
case WSP_GGML_OP_NONE: {
|
|
6013
6241
|
// noop
|
|
6014
6242
|
} break;
|
|
6015
6243
|
case WSP_GGML_OP_COUNT:
|
|
6016
6244
|
default: {
|
|
6017
|
-
|
|
6018
|
-
WSP_GGML_ABORT("fatal error");
|
|
6245
|
+
WSP_GGML_ABORT("%s: unsupported ggml op for backward pass: %s\n", __func__, wsp_ggml_op_name(tensor->op));
|
|
6019
6246
|
} //break;
|
|
6020
6247
|
}
|
|
6021
6248
|
|
|
@@ -6522,20 +6749,18 @@ static struct wsp_ggml_tensor * wsp_ggml_graph_get_parent(const struct wsp_ggml_
|
|
|
6522
6749
|
static void wsp_ggml_graph_dump_dot_node_edge(FILE * fp, const struct wsp_ggml_cgraph * gb, struct wsp_ggml_tensor * node, struct wsp_ggml_tensor * parent, const char * label) {
|
|
6523
6750
|
struct wsp_ggml_tensor * gparent = wsp_ggml_graph_get_parent(gb, node);
|
|
6524
6751
|
struct wsp_ggml_tensor * gparent0 = wsp_ggml_graph_get_parent(gb, parent);
|
|
6525
|
-
fprintf(fp, " \"%p\"
|
|
6752
|
+
fprintf(fp, " \"%p\" -> \"%p\" [ arrowhead = %s; style = %s; label = \"%s\"; ]\n",
|
|
6526
6753
|
gparent0 ? (void *) gparent0 : (void *) parent,
|
|
6527
|
-
gparent0 ? "g" : "x",
|
|
6528
6754
|
gparent ? (void *) gparent : (void *) node,
|
|
6529
|
-
gparent ? "g" : "x",
|
|
6530
6755
|
gparent ? "empty" : "vee",
|
|
6531
6756
|
gparent ? "dashed" : "solid",
|
|
6532
6757
|
label);
|
|
6533
6758
|
}
|
|
6534
6759
|
|
|
6535
6760
|
static void wsp_ggml_graph_dump_dot_leaf_edge(FILE * fp, struct wsp_ggml_tensor * node, struct wsp_ggml_tensor * parent, const char * label) {
|
|
6536
|
-
fprintf(fp, " \"%p\"
|
|
6537
|
-
(void *) parent,
|
|
6538
|
-
(void *) node,
|
|
6761
|
+
fprintf(fp, " \"%p\" -> \"%p\" [ label = \"%s\"; ]\n",
|
|
6762
|
+
(void *) parent,
|
|
6763
|
+
(void *) node,
|
|
6539
6764
|
label);
|
|
6540
6765
|
}
|
|
6541
6766
|
|
|
@@ -6756,6 +6981,7 @@ size_t wsp_ggml_wsp_quantize_chunk(
|
|
|
6756
6981
|
case WSP_GGML_TYPE_Q5_0: result = wsp_quantize_q5_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
|
6757
6982
|
case WSP_GGML_TYPE_Q5_1: result = wsp_quantize_q5_1(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
|
6758
6983
|
case WSP_GGML_TYPE_Q8_0: result = wsp_quantize_q8_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
|
6984
|
+
case WSP_GGML_TYPE_MXFP4: result = wsp_quantize_mxfp4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
|
6759
6985
|
case WSP_GGML_TYPE_Q2_K: result = wsp_quantize_q2_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
|
6760
6986
|
case WSP_GGML_TYPE_Q3_K: result = wsp_quantize_q3_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
|
6761
6987
|
case WSP_GGML_TYPE_Q4_K: result = wsp_quantize_q4_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|