whisper.rn 0.4.0-rc.9 → 0.4.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +74 -1
- package/android/build.gradle +12 -3
- package/android/src/main/CMakeLists.txt +43 -13
- package/android/src/main/java/com/rnwhisper/RNWhisper.java +211 -0
- package/android/src/main/java/com/rnwhisper/WhisperContext.java +64 -36
- package/android/src/main/java/com/rnwhisper/WhisperVadContext.java +157 -0
- package/android/src/main/jni.cpp +205 -0
- package/android/src/main/jniLibs/arm64-v8a/librnwhisper.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnwhisper_v8fp16_va_2.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/librnwhisper.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/librnwhisper_vfpv4.so +0 -0
- package/android/src/main/jniLibs/x86_64/librnwhisper.so +0 -0
- package/android/src/main/jniLibs/x86_64/librnwhisper_x86_64.so +0 -0
- package/android/src/newarch/java/com/rnwhisper/RNWhisperModule.java +26 -0
- package/android/src/oldarch/java/com/rnwhisper/RNWhisperModule.java +26 -0
- package/cpp/coreml/whisper-compat.h +10 -0
- package/cpp/coreml/whisper-compat.m +35 -0
- package/cpp/coreml/whisper-decoder-impl.h +27 -15
- package/cpp/coreml/whisper-decoder-impl.m +36 -10
- package/cpp/coreml/whisper-encoder-impl.h +21 -9
- package/cpp/coreml/whisper-encoder-impl.m +29 -3
- package/cpp/ggml-alloc.c +39 -37
- package/cpp/ggml-alloc.h +1 -1
- package/cpp/ggml-backend-impl.h +55 -27
- package/cpp/ggml-backend-reg.cpp +591 -0
- package/cpp/ggml-backend.cpp +336 -955
- package/cpp/ggml-backend.h +70 -42
- package/cpp/ggml-common.h +57 -49
- package/cpp/ggml-cpp.h +39 -0
- package/cpp/ggml-cpu/amx/amx.cpp +221 -0
- package/cpp/ggml-cpu/amx/amx.h +8 -0
- package/cpp/ggml-cpu/amx/common.h +91 -0
- package/cpp/ggml-cpu/amx/mmq.cpp +2511 -0
- package/cpp/ggml-cpu/amx/mmq.h +10 -0
- package/cpp/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
- package/cpp/ggml-cpu/arch/arm/quants.c +4113 -0
- package/cpp/ggml-cpu/arch/arm/repack.cpp +2162 -0
- package/cpp/ggml-cpu/arch/x86/cpu-feats.cpp +327 -0
- package/cpp/ggml-cpu/arch/x86/quants.c +4310 -0
- package/cpp/ggml-cpu/arch/x86/repack.cpp +3284 -0
- package/cpp/ggml-cpu/arch-fallback.h +184 -0
- package/cpp/ggml-cpu/binary-ops.cpp +158 -0
- package/cpp/ggml-cpu/binary-ops.h +16 -0
- package/cpp/ggml-cpu/common.h +72 -0
- package/cpp/ggml-cpu/ggml-cpu-impl.h +511 -0
- package/cpp/ggml-cpu/ggml-cpu.c +3473 -0
- package/cpp/ggml-cpu/ggml-cpu.cpp +671 -0
- package/cpp/ggml-cpu/ops.cpp +9085 -0
- package/cpp/ggml-cpu/ops.h +111 -0
- package/cpp/ggml-cpu/quants.c +1157 -0
- package/cpp/ggml-cpu/quants.h +89 -0
- package/cpp/ggml-cpu/repack.cpp +1570 -0
- package/cpp/ggml-cpu/repack.h +98 -0
- package/cpp/ggml-cpu/simd-mappings.h +1006 -0
- package/cpp/ggml-cpu/traits.cpp +36 -0
- package/cpp/ggml-cpu/traits.h +38 -0
- package/cpp/ggml-cpu/unary-ops.cpp +186 -0
- package/cpp/ggml-cpu/unary-ops.h +28 -0
- package/cpp/ggml-cpu/vec.cpp +321 -0
- package/cpp/ggml-cpu/vec.h +973 -0
- package/cpp/ggml-cpu.h +143 -0
- package/cpp/ggml-impl.h +417 -23
- package/cpp/ggml-metal-impl.h +622 -0
- package/cpp/ggml-metal.h +9 -9
- package/cpp/ggml-metal.m +3451 -1344
- package/cpp/ggml-opt.cpp +1037 -0
- package/cpp/ggml-opt.h +237 -0
- package/cpp/ggml-quants.c +296 -10818
- package/cpp/ggml-quants.h +78 -125
- package/cpp/ggml-threading.cpp +12 -0
- package/cpp/ggml-threading.h +14 -0
- package/cpp/ggml-whisper-sim.metallib +0 -0
- package/cpp/ggml-whisper.metallib +0 -0
- package/cpp/ggml.c +4633 -21450
- package/cpp/ggml.h +320 -661
- package/cpp/gguf.cpp +1347 -0
- package/cpp/gguf.h +202 -0
- package/cpp/rn-whisper.cpp +4 -11
- package/cpp/whisper-arch.h +197 -0
- package/cpp/whisper.cpp +2022 -495
- package/cpp/whisper.h +75 -18
- package/ios/CMakeLists.txt +95 -0
- package/ios/RNWhisper.h +5 -0
- package/ios/RNWhisper.mm +147 -0
- package/ios/RNWhisperAudioUtils.m +4 -0
- package/ios/RNWhisperContext.h +5 -0
- package/ios/RNWhisperContext.mm +22 -26
- package/ios/RNWhisperVadContext.h +29 -0
- package/ios/RNWhisperVadContext.mm +152 -0
- package/ios/rnwhisper.xcframework/Info.plist +74 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-alloc.h +76 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +255 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend.h +354 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-common.h +1861 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-cpp.h +39 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-cpu.h +143 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-impl.h +603 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-metal-impl.h +622 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-metal.h +66 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-opt.h +237 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-quants.h +100 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-threading.h +14 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml.h +2221 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/gguf.h +202 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/rn-audioutils.h +14 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/rn-whisper-log.h +11 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/rn-whisper.h +52 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/whisper-arch.h +197 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/whisper.h +739 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Info.plist +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/rnwhisper +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-alloc.h +76 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +255 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +354 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-common.h +1861 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpp.h +39 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpu.h +143 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +603 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal-impl.h +622 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal.h +66 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-opt.h +237 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-quants.h +100 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-threading.h +14 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +2221 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/gguf.h +202 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-audioutils.h +14 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-whisper-log.h +11 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-whisper.h +52 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/whisper-arch.h +197 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/whisper.h +739 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +101 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-alloc.h +76 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +255 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend.h +354 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-common.h +1861 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-cpp.h +39 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-cpu.h +143 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-impl.h +603 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-metal-impl.h +622 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-metal.h +66 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-opt.h +237 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-quants.h +100 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-threading.h +14 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml.h +2221 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/gguf.h +202 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/rn-audioutils.h +14 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/rn-whisper-log.h +11 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/rn-whisper.h +52 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/whisper-arch.h +197 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/whisper.h +739 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Info.plist +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/rnwhisper +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-alloc.h +76 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +255 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +354 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-common.h +1861 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpp.h +39 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpu.h +143 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +603 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal-impl.h +622 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal.h +66 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-opt.h +237 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-quants.h +100 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-threading.h +14 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +2221 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/gguf.h +202 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-audioutils.h +14 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-whisper-log.h +11 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-whisper.h +52 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/whisper-arch.h +197 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/whisper.h +739 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +101 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
- package/jest/mock.js +24 -0
- package/lib/commonjs/NativeRNWhisper.js.map +1 -1
- package/lib/commonjs/index.js +111 -1
- package/lib/commonjs/index.js.map +1 -1
- package/lib/commonjs/version.json +1 -1
- package/lib/module/NativeRNWhisper.js.map +1 -1
- package/lib/module/index.js +112 -0
- package/lib/module/index.js.map +1 -1
- package/lib/module/version.json +1 -1
- package/lib/typescript/NativeRNWhisper.d.ts +35 -0
- package/lib/typescript/NativeRNWhisper.d.ts.map +1 -1
- package/lib/typescript/index.d.ts +39 -3
- package/lib/typescript/index.d.ts.map +1 -1
- package/package.json +10 -6
- package/src/NativeRNWhisper.ts +48 -0
- package/src/index.ts +132 -1
- package/src/version.json +1 -1
- package/whisper-rn.podspec +11 -18
- package/cpp/README.md +0 -4
- package/cpp/ggml-aarch64.c +0 -3209
- package/cpp/ggml-aarch64.h +0 -39
- package/cpp/ggml-cpu-impl.h +0 -614
package/cpp/whisper.cpp
CHANGED
|
@@ -1,75 +1,52 @@
|
|
|
1
1
|
#include "whisper.h"
|
|
2
|
+
#include "whisper-arch.h"
|
|
3
|
+
|
|
4
|
+
#include "ggml.h"
|
|
5
|
+
#include "ggml-cpp.h"
|
|
6
|
+
#include "ggml-alloc.h"
|
|
7
|
+
#include "ggml-backend.h"
|
|
2
8
|
|
|
3
9
|
#ifdef WHISPER_USE_COREML
|
|
4
10
|
#include "coreml/whisper-encoder.h"
|
|
5
11
|
#endif
|
|
6
12
|
|
|
7
|
-
#ifdef WSP_GGML_USE_METAL
|
|
8
|
-
#include "ggml-metal.h"
|
|
9
|
-
#endif
|
|
10
|
-
|
|
11
|
-
#ifdef WSP_GGML_USE_CUDA
|
|
12
|
-
#include "ggml-cuda.h"
|
|
13
|
-
#endif
|
|
14
|
-
|
|
15
|
-
#ifdef WSP_GGML_USE_SYCL
|
|
16
|
-
#include "ggml-sycl.h"
|
|
17
|
-
#endif
|
|
18
|
-
|
|
19
|
-
#ifdef WSP_GGML_USE_VULKAN
|
|
20
|
-
#include "ggml-vulkan.h"
|
|
21
|
-
#endif
|
|
22
|
-
|
|
23
|
-
#ifdef WSP_GGML_USE_BLAS
|
|
24
|
-
#include "ggml-blas.h"
|
|
25
|
-
#endif
|
|
26
|
-
|
|
27
13
|
#ifdef WHISPER_USE_OPENVINO
|
|
28
14
|
#include "openvino/whisper-openvino-encoder.h"
|
|
29
15
|
#endif
|
|
30
16
|
|
|
31
|
-
#ifdef WSP_GGML_USE_CANN
|
|
32
|
-
#include "ggml-cann.h"
|
|
33
|
-
#endif
|
|
34
|
-
|
|
35
|
-
#include "ggml.h"
|
|
36
|
-
#include "ggml-alloc.h"
|
|
37
|
-
#include "ggml-backend.h"
|
|
38
|
-
|
|
39
17
|
#include <atomic>
|
|
40
18
|
#include <algorithm>
|
|
41
19
|
#include <cassert>
|
|
20
|
+
#include <cfloat>
|
|
42
21
|
#define _USE_MATH_DEFINES
|
|
43
22
|
#include <cmath>
|
|
44
|
-
#include <
|
|
23
|
+
#include <climits>
|
|
24
|
+
#include <codecvt>
|
|
45
25
|
#include <cstdarg>
|
|
26
|
+
#include <cstdio>
|
|
46
27
|
#include <cstring>
|
|
47
28
|
#include <fstream>
|
|
29
|
+
#include <functional>
|
|
48
30
|
#include <map>
|
|
31
|
+
#include <mutex>
|
|
32
|
+
#include <random>
|
|
33
|
+
#include <regex>
|
|
49
34
|
#include <set>
|
|
50
35
|
#include <string>
|
|
51
36
|
#include <thread>
|
|
52
37
|
#include <vector>
|
|
53
|
-
#include <regex>
|
|
54
|
-
#include <random>
|
|
55
|
-
#include <functional>
|
|
56
|
-
#include <codecvt>
|
|
57
|
-
|
|
58
|
-
#if defined(_MSC_VER)
|
|
59
|
-
#pragma warning(disable: 4244 4267) // possible loss of data
|
|
60
|
-
#endif
|
|
61
|
-
|
|
62
|
-
#if defined(WSP_GGML_BIG_ENDIAN)
|
|
63
|
-
#include <bit>
|
|
64
38
|
|
|
39
|
+
#if defined(WHISPER_BIG_ENDIAN)
|
|
65
40
|
template<typename T>
|
|
66
41
|
static T byteswap(T value) {
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
42
|
+
T value_swapped;
|
|
43
|
+
char * source = reinterpret_cast<char *>(&value);
|
|
44
|
+
char * target = reinterpret_cast<char *>(&value_swapped);
|
|
45
|
+
int size = sizeof(T);
|
|
46
|
+
for (int i = 0; i < size; i++) {
|
|
47
|
+
target[size - 1 - i] = source[i];
|
|
48
|
+
}
|
|
49
|
+
return value_swapped;
|
|
73
50
|
}
|
|
74
51
|
|
|
75
52
|
template<typename T>
|
|
@@ -105,14 +82,14 @@ static void byteswap_tensor(wsp_ggml_tensor * tensor) {
|
|
|
105
82
|
}
|
|
106
83
|
|
|
107
84
|
#define BYTESWAP_VALUE(d) d = byteswap(d)
|
|
108
|
-
#define BYTESWAP_FILTERS(f)
|
|
85
|
+
#define BYTESWAP_FILTERS(f) \
|
|
109
86
|
do { \
|
|
110
87
|
for (auto & datum : f.data) { \
|
|
111
88
|
datum = byteswap(datum); \
|
|
112
89
|
} \
|
|
113
90
|
} while (0)
|
|
114
|
-
#define BYTESWAP_TENSOR(t)
|
|
115
|
-
do {
|
|
91
|
+
#define BYTESWAP_TENSOR(t) \
|
|
92
|
+
do { \
|
|
116
93
|
byteswap_tensor(t); \
|
|
117
94
|
} while (0)
|
|
118
95
|
#else
|
|
@@ -163,51 +140,118 @@ static void whisper_log_callback_default(wsp_ggml_log_level level, const char *
|
|
|
163
140
|
#define WHISPER_MAX_DECODERS 8
|
|
164
141
|
#define WHISPER_MAX_NODES 4096
|
|
165
142
|
|
|
143
|
+
static std::string format(const char * fmt, ...) {
|
|
144
|
+
va_list ap;
|
|
145
|
+
va_list ap2;
|
|
146
|
+
va_start(ap, fmt);
|
|
147
|
+
va_copy(ap2, ap);
|
|
148
|
+
int size = vsnprintf(NULL, 0, fmt, ap);
|
|
149
|
+
WSP_GGML_ASSERT(size >= 0 && size < INT_MAX); // NOLINT
|
|
150
|
+
std::vector<char> buf(size + 1);
|
|
151
|
+
int size2 = vsnprintf(buf.data(), size + 1, fmt, ap2);
|
|
152
|
+
WSP_GGML_ASSERT(size2 == size);
|
|
153
|
+
va_end(ap2);
|
|
154
|
+
va_end(ap);
|
|
155
|
+
return std::string(buf.data(), size);
|
|
156
|
+
}
|
|
157
|
+
|
|
166
158
|
//
|
|
167
159
|
// ggml helpers
|
|
168
160
|
//
|
|
169
161
|
|
|
170
162
|
static bool wsp_ggml_graph_compute_helper(
|
|
171
163
|
struct wsp_ggml_cgraph * graph,
|
|
172
|
-
std::vector<uint8_t> & buf,
|
|
173
164
|
int n_threads,
|
|
174
165
|
wsp_ggml_abort_callback abort_callback,
|
|
175
166
|
void * abort_callback_data) {
|
|
176
|
-
|
|
167
|
+
wsp_ggml_backend_ptr backend { wsp_ggml_backend_init_by_type(WSP_GGML_BACKEND_DEVICE_TYPE_CPU, nullptr) };
|
|
177
168
|
|
|
178
|
-
|
|
179
|
-
plan.abort_callback_data = abort_callback_data;
|
|
169
|
+
auto * reg = wsp_ggml_backend_dev_backend_reg(wsp_ggml_backend_get_device(backend.get()));
|
|
180
170
|
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
171
|
+
auto * set_abort_callback_fn = (wsp_ggml_backend_set_abort_callback_t) wsp_ggml_backend_reg_get_proc_address(reg, "wsp_ggml_backend_set_abort_callback");
|
|
172
|
+
if (set_abort_callback_fn) {
|
|
173
|
+
set_abort_callback_fn(backend.get(), abort_callback, abort_callback_data);
|
|
184
174
|
}
|
|
185
175
|
|
|
186
|
-
|
|
176
|
+
auto wsp_ggml_backend_set_n_threads_fn = (wsp_ggml_backend_set_n_threads_t) wsp_ggml_backend_reg_get_proc_address(reg, "wsp_ggml_backend_set_n_threads");
|
|
177
|
+
if (wsp_ggml_backend_set_n_threads_fn) {
|
|
178
|
+
wsp_ggml_backend_set_n_threads_fn(backend.get(), n_threads);
|
|
179
|
+
}
|
|
180
|
+
|
|
181
|
+
return wsp_ggml_backend_graph_compute(backend.get(), graph) == WSP_GGML_STATUS_SUCCESS;
|
|
187
182
|
}
|
|
188
183
|
|
|
189
184
|
static bool wsp_ggml_graph_compute_helper(
|
|
190
185
|
wsp_ggml_backend_sched_t sched,
|
|
191
186
|
struct wsp_ggml_cgraph * graph,
|
|
192
|
-
int n_threads
|
|
193
|
-
|
|
187
|
+
int n_threads,
|
|
188
|
+
bool sched_reset = true) {
|
|
194
189
|
for (int i = 0; i < wsp_ggml_backend_sched_get_n_backends(sched); ++i) {
|
|
195
190
|
wsp_ggml_backend_t backend = wsp_ggml_backend_sched_get_backend(sched, i);
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
if (
|
|
201
|
-
|
|
191
|
+
wsp_ggml_backend_dev_t dev = wsp_ggml_backend_get_device(backend);
|
|
192
|
+
wsp_ggml_backend_reg_t reg = dev ? wsp_ggml_backend_dev_backend_reg(dev) : nullptr;
|
|
193
|
+
|
|
194
|
+
auto * fn_set_n_threads = (wsp_ggml_backend_set_n_threads_t) wsp_ggml_backend_reg_get_proc_address(reg, "wsp_ggml_backend_set_n_threads");
|
|
195
|
+
if (fn_set_n_threads) {
|
|
196
|
+
fn_set_n_threads(backend, n_threads);
|
|
202
197
|
}
|
|
203
|
-
#endif
|
|
204
198
|
}
|
|
205
199
|
|
|
206
|
-
bool t = wsp_ggml_backend_sched_graph_compute(sched, graph) == WSP_GGML_STATUS_SUCCESS;
|
|
207
|
-
|
|
200
|
+
const bool t = (wsp_ggml_backend_sched_graph_compute(sched, graph) == WSP_GGML_STATUS_SUCCESS);
|
|
201
|
+
|
|
202
|
+
if (!t || sched_reset) {
|
|
203
|
+
wsp_ggml_backend_sched_reset(sched);
|
|
204
|
+
}
|
|
205
|
+
|
|
206
|
+
return t;
|
|
207
|
+
}
|
|
208
|
+
|
|
209
|
+
// TODO: move these functions to ggml-base with support for ggml-backend?
|
|
210
|
+
|
|
211
|
+
static wsp_ggml_tensor * whisper_set_f32(struct wsp_ggml_tensor * t, float v) {
|
|
212
|
+
WSP_GGML_ASSERT(t->type == WSP_GGML_TYPE_F32);
|
|
213
|
+
WSP_GGML_ASSERT(wsp_ggml_is_contiguous(t));
|
|
214
|
+
size_t nels = wsp_ggml_nelements(t);
|
|
215
|
+
for (size_t i = 0; i < nels; ++i) {
|
|
216
|
+
((float *) t->data)[i] = v;
|
|
217
|
+
}
|
|
218
|
+
return t;
|
|
219
|
+
}
|
|
220
|
+
|
|
221
|
+
static wsp_ggml_tensor * whisper_set_i32(struct wsp_ggml_tensor * t, int32_t v) {
|
|
222
|
+
WSP_GGML_ASSERT(t->type == WSP_GGML_TYPE_I32);
|
|
223
|
+
WSP_GGML_ASSERT(wsp_ggml_is_contiguous(t));
|
|
224
|
+
size_t nels = wsp_ggml_nelements(t);
|
|
225
|
+
for (size_t i = 0; i < nels; ++i) {
|
|
226
|
+
((int32_t *) t->data)[i] = v;
|
|
227
|
+
}
|
|
208
228
|
return t;
|
|
209
229
|
}
|
|
210
230
|
|
|
231
|
+
static float whisper_get_f32_nd(const struct wsp_ggml_tensor * t, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
|
|
232
|
+
WSP_GGML_ASSERT(t->type == WSP_GGML_TYPE_F32);
|
|
233
|
+
void * data = (char *) t->data + i0*t->nb[0] + i1*t->nb[1] + i2*t->nb[2] + i3*t->nb[3];
|
|
234
|
+
return *(float *) data;
|
|
235
|
+
}
|
|
236
|
+
|
|
237
|
+
static void whisper_set_f32_nd(struct wsp_ggml_tensor * t, int64_t i0, int64_t i1, int64_t i2, int64_t i3, float v) {
|
|
238
|
+
WSP_GGML_ASSERT(t->type == WSP_GGML_TYPE_F32);
|
|
239
|
+
void * data = (char *) t->data + i0*t->nb[0] + i1*t->nb[1] + i2*t->nb[2] + i3*t->nb[3];
|
|
240
|
+
*(float *) data = v;
|
|
241
|
+
}
|
|
242
|
+
|
|
243
|
+
static int32_t whisper_get_i32_nd(const struct wsp_ggml_tensor * t, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
|
|
244
|
+
WSP_GGML_ASSERT(t->type == WSP_GGML_TYPE_I32);
|
|
245
|
+
void * data = (char *) t->data + i0*t->nb[0] + i1*t->nb[1] + i2*t->nb[2] + i3*t->nb[3];
|
|
246
|
+
return *(int32_t *) data;
|
|
247
|
+
}
|
|
248
|
+
|
|
249
|
+
static void whisper_set_i32_nd(struct wsp_ggml_tensor * t, int64_t i0, int64_t i1, int64_t i2, int64_t i3, int32_t v) {
|
|
250
|
+
WSP_GGML_ASSERT(t->type == WSP_GGML_TYPE_I32);
|
|
251
|
+
void * data = (char *) t->data + i0*t->nb[0] + i1*t->nb[1] + i2*t->nb[2] + i3*t->nb[3];
|
|
252
|
+
*(int32_t *) data = v;
|
|
253
|
+
}
|
|
254
|
+
|
|
211
255
|
// faster matrix multiplications for tensors that do not have dimension 0 divisible by "pad"
|
|
212
256
|
// the idea is to represent the original matrix multiplication:
|
|
213
257
|
//
|
|
@@ -451,6 +495,7 @@ struct whisper_segment {
|
|
|
451
495
|
int64_t t1;
|
|
452
496
|
|
|
453
497
|
std::string text;
|
|
498
|
+
float no_speech_prob;
|
|
454
499
|
|
|
455
500
|
std::vector<whisper_token_data> tokens;
|
|
456
501
|
|
|
@@ -543,7 +588,7 @@ static bool whisper_sched_graph_init(struct whisper_sched & allocr, std::vector<
|
|
|
543
588
|
auto & sched = allocr.sched;
|
|
544
589
|
auto & meta = allocr.meta;
|
|
545
590
|
|
|
546
|
-
sched = wsp_ggml_backend_sched_new(backends.data(), nullptr, backends.size(), WHISPER_MAX_NODES, false);
|
|
591
|
+
sched = wsp_ggml_backend_sched_new(backends.data(), nullptr, backends.size(), WHISPER_MAX_NODES, false, true);
|
|
547
592
|
|
|
548
593
|
meta.resize(wsp_ggml_tensor_overhead()*WHISPER_MAX_NODES + wsp_ggml_graph_overhead());
|
|
549
594
|
|
|
@@ -739,10 +784,10 @@ struct whisper_model {
|
|
|
739
784
|
std::vector<whisper_layer_decoder> layers_decoder;
|
|
740
785
|
|
|
741
786
|
// ggml context that contains all the meta information about the model tensors
|
|
742
|
-
|
|
787
|
+
std::vector<wsp_ggml_context *> ctxs;
|
|
743
788
|
|
|
744
789
|
// the model backend data is read-only and can be shared between processors
|
|
745
|
-
wsp_ggml_backend_buffer_t
|
|
790
|
+
std::vector<wsp_ggml_backend_buffer_t> buffers;
|
|
746
791
|
|
|
747
792
|
// tensors
|
|
748
793
|
int n_loaded;
|
|
@@ -814,6 +859,11 @@ struct whisper_aheads_masks {
|
|
|
814
859
|
wsp_ggml_backend_buffer_t buffer = nullptr;
|
|
815
860
|
};
|
|
816
861
|
|
|
862
|
+
struct vad_time_mapping {
|
|
863
|
+
int64_t processed_time; // Time in processed (VAD) audio
|
|
864
|
+
int64_t original_time; // Corresponding time in original audio
|
|
865
|
+
};
|
|
866
|
+
|
|
817
867
|
struct whisper_state {
|
|
818
868
|
int64_t t_sample_us = 0;
|
|
819
869
|
int64_t t_encode_us = 0;
|
|
@@ -890,6 +940,7 @@ struct whisper_state {
|
|
|
890
940
|
whisper_token tid_last;
|
|
891
941
|
|
|
892
942
|
std::vector<float> energy; // PCM signal energy
|
|
943
|
+
float no_speech_prob = 0.0f;
|
|
893
944
|
|
|
894
945
|
// [EXPERIMENTAL] Token-level timestamps with DTW
|
|
895
946
|
whisper_aheads_masks aheads_masks;
|
|
@@ -898,6 +949,19 @@ struct whisper_state {
|
|
|
898
949
|
|
|
899
950
|
// [EXPERIMENTAL] speed-up techniques
|
|
900
951
|
int32_t exp_n_audio_ctx = 0; // 0 - use default
|
|
952
|
+
|
|
953
|
+
whisper_vad_context * vad_context = nullptr;
|
|
954
|
+
|
|
955
|
+
struct vad_segment_info {
|
|
956
|
+
int64_t orig_start;
|
|
957
|
+
int64_t orig_end;
|
|
958
|
+
int64_t vad_start;
|
|
959
|
+
int64_t vad_end;
|
|
960
|
+
};
|
|
961
|
+
std::vector<vad_segment_info> vad_segments;
|
|
962
|
+
bool has_vad_segments = false;
|
|
963
|
+
|
|
964
|
+
std::vector<vad_time_mapping> vad_mapping_table;
|
|
901
965
|
};
|
|
902
966
|
|
|
903
967
|
struct whisper_context {
|
|
@@ -1254,65 +1318,36 @@ static size_t aheads_masks_nbytes(struct whisper_aheads_masks & aheads_masks) {
|
|
|
1254
1318
|
}
|
|
1255
1319
|
|
|
1256
1320
|
static wsp_ggml_backend_t whisper_backend_init_gpu(const whisper_context_params & params) {
|
|
1257
|
-
wsp_ggml_backend_t result = NULL;
|
|
1258
|
-
|
|
1259
1321
|
wsp_ggml_log_set(g_state.log_callback, g_state.log_callback_user_data);
|
|
1260
1322
|
|
|
1261
|
-
|
|
1262
|
-
if (params.use_gpu) {
|
|
1263
|
-
WHISPER_LOG_INFO("%s: using CUDA backend\n", __func__);
|
|
1264
|
-
result = wsp_ggml_backend_cuda_init(params.gpu_device);
|
|
1265
|
-
if (!result) {
|
|
1266
|
-
WHISPER_LOG_ERROR("%s: wsp_ggml_backend_cuda_init() failed\n", __func__);
|
|
1267
|
-
}
|
|
1268
|
-
}
|
|
1269
|
-
#endif
|
|
1323
|
+
wsp_ggml_backend_dev_t dev = nullptr;
|
|
1270
1324
|
|
|
1271
|
-
|
|
1325
|
+
int cnt = 0;
|
|
1272
1326
|
if (params.use_gpu) {
|
|
1273
|
-
|
|
1274
|
-
|
|
1275
|
-
|
|
1276
|
-
|
|
1277
|
-
|
|
1278
|
-
|
|
1279
|
-
wsp_ggml_backend_free(result);
|
|
1280
|
-
result = NULL;
|
|
1281
|
-
}
|
|
1282
|
-
}
|
|
1283
|
-
#endif
|
|
1327
|
+
for (size_t i = 0; i < wsp_ggml_backend_dev_count(); ++i) {
|
|
1328
|
+
wsp_ggml_backend_dev_t dev_cur = wsp_ggml_backend_dev_get(i);
|
|
1329
|
+
if (wsp_ggml_backend_dev_type(dev_cur) == WSP_GGML_BACKEND_DEVICE_TYPE_GPU) {
|
|
1330
|
+
if (cnt == 0 || cnt == params.gpu_device) {
|
|
1331
|
+
dev = dev_cur;
|
|
1332
|
+
}
|
|
1284
1333
|
|
|
1285
|
-
|
|
1286
|
-
|
|
1287
|
-
|
|
1288
|
-
|
|
1289
|
-
if (!result) {
|
|
1290
|
-
WHISPER_LOG_ERROR("%s: wsp_ggml_backend_sycl_init() failed\n", __func__);
|
|
1334
|
+
if (++cnt > params.gpu_device) {
|
|
1335
|
+
break;
|
|
1336
|
+
}
|
|
1337
|
+
}
|
|
1291
1338
|
}
|
|
1292
1339
|
}
|
|
1293
|
-
#endif
|
|
1294
1340
|
|
|
1295
|
-
|
|
1296
|
-
|
|
1297
|
-
|
|
1298
|
-
result = wsp_ggml_backend_vk_init(params.gpu_device);
|
|
1299
|
-
if (!result) {
|
|
1300
|
-
WHISPER_LOG_ERROR("%s: wsp_ggml_backend_vk_init() failed\n", __func__);
|
|
1301
|
-
}
|
|
1341
|
+
if (dev == nullptr) {
|
|
1342
|
+
WHISPER_LOG_INFO("%s: no GPU found\n", __func__);
|
|
1343
|
+
return nullptr;
|
|
1302
1344
|
}
|
|
1303
|
-
#endif
|
|
1304
1345
|
|
|
1305
|
-
|
|
1306
|
-
|
|
1307
|
-
|
|
1308
|
-
|
|
1309
|
-
if (!result) {
|
|
1310
|
-
WHISPER_LOG_ERROR("%s: wsp_ggml_backend_cann_init() failed\n", __func__);
|
|
1311
|
-
}
|
|
1346
|
+
WHISPER_LOG_INFO("%s: using %s backend\n", __func__, wsp_ggml_backend_dev_name(dev));
|
|
1347
|
+
wsp_ggml_backend_t result = wsp_ggml_backend_dev_init(dev, nullptr);
|
|
1348
|
+
if (!result) {
|
|
1349
|
+
WHISPER_LOG_ERROR("%s: failed to initialize %s backend\n", __func__, wsp_ggml_backend_dev_name(dev));
|
|
1312
1350
|
}
|
|
1313
|
-
#endif
|
|
1314
|
-
|
|
1315
|
-
WSP_GGML_UNUSED(params);
|
|
1316
1351
|
|
|
1317
1352
|
return result;
|
|
1318
1353
|
}
|
|
@@ -1326,53 +1361,132 @@ static std::vector<wsp_ggml_backend_t> whisper_backend_init(const whisper_contex
|
|
|
1326
1361
|
result.push_back(backend_gpu);
|
|
1327
1362
|
}
|
|
1328
1363
|
|
|
1329
|
-
|
|
1330
|
-
{
|
|
1331
|
-
|
|
1332
|
-
|
|
1333
|
-
|
|
1334
|
-
|
|
1335
|
-
|
|
1336
|
-
|
|
1364
|
+
// ACCEL backends
|
|
1365
|
+
for (size_t i = 0; i < wsp_ggml_backend_dev_count(); ++i) {
|
|
1366
|
+
wsp_ggml_backend_dev_t dev = wsp_ggml_backend_dev_get(i);
|
|
1367
|
+
if (wsp_ggml_backend_dev_type(dev) == WSP_GGML_BACKEND_DEVICE_TYPE_ACCEL) {
|
|
1368
|
+
WHISPER_LOG_INFO("%s: using %s backend\n", __func__, wsp_ggml_backend_dev_name(dev));
|
|
1369
|
+
wsp_ggml_backend_t backend = wsp_ggml_backend_dev_init(dev, nullptr);
|
|
1370
|
+
if (!backend) {
|
|
1371
|
+
WHISPER_LOG_ERROR("%s: failed to initialize %s backend\n", __func__, wsp_ggml_backend_dev_name(dev));
|
|
1372
|
+
continue;
|
|
1373
|
+
}
|
|
1374
|
+
result.push_back(backend);
|
|
1337
1375
|
}
|
|
1338
1376
|
}
|
|
1339
|
-
#endif
|
|
1340
1377
|
|
|
1341
|
-
|
|
1342
|
-
|
|
1343
|
-
|
|
1378
|
+
wsp_ggml_backend_t backend_cpu = wsp_ggml_backend_init_by_type(WSP_GGML_BACKEND_DEVICE_TYPE_CPU, nullptr);
|
|
1379
|
+
if (backend_cpu == nullptr) {
|
|
1380
|
+
throw std::runtime_error("failed to initialize CPU backend");
|
|
1381
|
+
}
|
|
1382
|
+
result.push_back(backend_cpu);
|
|
1344
1383
|
|
|
1345
1384
|
return result;
|
|
1346
1385
|
}
|
|
1347
1386
|
|
|
1348
|
-
|
|
1349
|
-
wsp_ggml_backend_buffer_type_t result = nullptr;
|
|
1387
|
+
using buft_list_t = std::vector<std::pair<wsp_ggml_backend_dev_t, wsp_ggml_backend_buffer_type_t>>;
|
|
1350
1388
|
|
|
1351
|
-
|
|
1389
|
+
static buft_list_t make_buft_list(whisper_context_params & params) {
|
|
1390
|
+
// Prio order: GPU -> CPU Extra -> CPU
|
|
1391
|
+
buft_list_t buft_list;
|
|
1352
1392
|
|
|
1353
|
-
|
|
1354
|
-
|
|
1355
|
-
|
|
1393
|
+
// GPU
|
|
1394
|
+
if (params.use_gpu) {
|
|
1395
|
+
int cnt = 0;
|
|
1396
|
+
for (size_t i = 0; i < wsp_ggml_backend_dev_count(); ++i) {
|
|
1397
|
+
wsp_ggml_backend_dev_t dev = wsp_ggml_backend_dev_get(i);
|
|
1398
|
+
if (wsp_ggml_backend_dev_type(dev) == WSP_GGML_BACKEND_DEVICE_TYPE_GPU) {
|
|
1399
|
+
if (cnt == 0 || cnt == params.gpu_device) {
|
|
1400
|
+
auto * buft = wsp_ggml_backend_dev_buffer_type(dev);
|
|
1401
|
+
if (buft) {
|
|
1402
|
+
buft_list.emplace_back(dev, buft);
|
|
1403
|
+
}
|
|
1404
|
+
}
|
|
1356
1405
|
|
|
1357
|
-
|
|
1358
|
-
|
|
1359
|
-
|
|
1406
|
+
if (++cnt > params.gpu_device) {
|
|
1407
|
+
break;
|
|
1408
|
+
}
|
|
1409
|
+
}
|
|
1410
|
+
}
|
|
1411
|
+
}
|
|
1360
1412
|
|
|
1361
|
-
|
|
1362
|
-
|
|
1363
|
-
|
|
1413
|
+
// CPU Extra
|
|
1414
|
+
auto * cpu_dev = wsp_ggml_backend_dev_by_type(WSP_GGML_BACKEND_DEVICE_TYPE_CPU);
|
|
1415
|
+
auto * cpu_reg = wsp_ggml_backend_dev_backend_reg(cpu_dev);
|
|
1416
|
+
auto get_extra_bufts_fn = (wsp_ggml_backend_dev_get_extra_bufts_t)
|
|
1417
|
+
wsp_ggml_backend_reg_get_proc_address(cpu_reg, "wsp_ggml_backend_dev_get_extra_bufts");
|
|
1418
|
+
if (get_extra_bufts_fn) {
|
|
1419
|
+
wsp_ggml_backend_buffer_type_t * extra_bufts = get_extra_bufts_fn(cpu_dev);
|
|
1420
|
+
while (extra_bufts && *extra_bufts) {
|
|
1421
|
+
buft_list.emplace_back(cpu_dev, *extra_bufts);
|
|
1422
|
+
++extra_bufts;
|
|
1423
|
+
}
|
|
1424
|
+
}
|
|
1364
1425
|
|
|
1365
|
-
|
|
1366
|
-
|
|
1367
|
-
#endif
|
|
1426
|
+
// CPU
|
|
1427
|
+
buft_list.emplace_back(cpu_dev, wsp_ggml_backend_cpu_buffer_type());
|
|
1368
1428
|
|
|
1369
|
-
|
|
1370
|
-
|
|
1371
|
-
#endif
|
|
1429
|
+
return buft_list;
|
|
1430
|
+
}
|
|
1372
1431
|
|
|
1373
|
-
|
|
1432
|
+
static bool weight_buft_supported(const whisper_hparams & hparams, wsp_ggml_tensor * w, wsp_ggml_op op, wsp_ggml_backend_buffer_type_t buft, wsp_ggml_backend_dev_t dev) {
|
|
1433
|
+
bool op_supported = true;
|
|
1374
1434
|
|
|
1375
|
-
|
|
1435
|
+
if (wsp_ggml_backend_dev_type(dev) == WSP_GGML_BACKEND_DEVICE_TYPE_GPU ||
|
|
1436
|
+
(wsp_ggml_backend_dev_type(dev) == WSP_GGML_BACKEND_DEVICE_TYPE_CPU && buft == wsp_ggml_backend_cpu_buffer_type())) {
|
|
1437
|
+
// GPU and default CPU backend support all operators
|
|
1438
|
+
op_supported = true;
|
|
1439
|
+
} else {
|
|
1440
|
+
switch (op) {
|
|
1441
|
+
// The current extra_buffer_type implementations only support WSP_GGML_OP_MUL_MAT
|
|
1442
|
+
case WSP_GGML_OP_MUL_MAT: {
|
|
1443
|
+
wsp_ggml_init_params params = {
|
|
1444
|
+
/*.mem_size =*/ 2 * wsp_ggml_tensor_overhead(),
|
|
1445
|
+
/*.mem_buffer =*/ nullptr,
|
|
1446
|
+
/*.no_alloc =*/ true,
|
|
1447
|
+
};
|
|
1448
|
+
|
|
1449
|
+
wsp_ggml_context_ptr ctx_ptr { wsp_ggml_init(params) };
|
|
1450
|
+
if (!ctx_ptr) {
|
|
1451
|
+
throw std::runtime_error("failed to create ggml context");
|
|
1452
|
+
}
|
|
1453
|
+
wsp_ggml_context * ctx = ctx_ptr.get();
|
|
1454
|
+
|
|
1455
|
+
wsp_ggml_tensor * op_tensor = nullptr;
|
|
1456
|
+
|
|
1457
|
+
int64_t n_ctx = hparams.n_audio_ctx;
|
|
1458
|
+
wsp_ggml_tensor * b = wsp_ggml_new_tensor_4d(ctx, WSP_GGML_TYPE_F32, w->ne[0], n_ctx, w->ne[2], w->ne[3]);
|
|
1459
|
+
op_tensor = wsp_ggml_mul_mat(ctx, w, b);
|
|
1460
|
+
|
|
1461
|
+
// create a temporary dummy buffer for the weight so that supports_op can check the buffer type
|
|
1462
|
+
WSP_GGML_ASSERT(w->buffer == nullptr);
|
|
1463
|
+
w->buffer = wsp_ggml_backend_buft_alloc_buffer(buft, 0);
|
|
1464
|
+
op_supported = wsp_ggml_backend_dev_supports_op(dev, op_tensor);
|
|
1465
|
+
wsp_ggml_backend_buffer_free(w->buffer);
|
|
1466
|
+
w->buffer = nullptr;
|
|
1467
|
+
break;
|
|
1468
|
+
}
|
|
1469
|
+
default: {
|
|
1470
|
+
op_supported = false;
|
|
1471
|
+
break;
|
|
1472
|
+
}
|
|
1473
|
+
};
|
|
1474
|
+
}
|
|
1475
|
+
|
|
1476
|
+
return op_supported;
|
|
1477
|
+
}
|
|
1478
|
+
|
|
1479
|
+
static wsp_ggml_backend_buffer_type_t select_weight_buft(const whisper_hparams & hparams, wsp_ggml_tensor * w, wsp_ggml_op op, buft_list_t buft_list) {
|
|
1480
|
+
WSP_GGML_ASSERT(!buft_list.empty());
|
|
1481
|
+
for (const auto & p : buft_list) {
|
|
1482
|
+
wsp_ggml_backend_dev_t dev = p.first;
|
|
1483
|
+
wsp_ggml_backend_buffer_type_t buft = p.second;
|
|
1484
|
+
if (weight_buft_supported(hparams, w, op, buft, dev)) {
|
|
1485
|
+
return buft;
|
|
1486
|
+
}
|
|
1487
|
+
}
|
|
1488
|
+
|
|
1489
|
+
return nullptr;
|
|
1376
1490
|
}
|
|
1377
1491
|
|
|
1378
1492
|
// load the model from a ggml file
|
|
@@ -1581,31 +1695,65 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
|
1581
1695
|
const wsp_ggml_type wtype = wctx.wtype;
|
|
1582
1696
|
const wsp_ggml_type vtype = wctx.wtype == WSP_GGML_TYPE_F32 ? WSP_GGML_TYPE_F32 : WSP_GGML_TYPE_F16; // conv type
|
|
1583
1697
|
|
|
1584
|
-
|
|
1585
|
-
{
|
|
1586
|
-
const auto & hparams = model.hparams;
|
|
1698
|
+
const auto & hparams = model.hparams;
|
|
1587
1699
|
|
|
1588
|
-
|
|
1589
|
-
|
|
1700
|
+
const int n_audio_layer = hparams.n_audio_layer;
|
|
1701
|
+
const int n_text_layer = hparams.n_text_layer;
|
|
1590
1702
|
|
|
1591
|
-
|
|
1703
|
+
const size_t n_tensors = 10 /* input */ + 15 + 15*n_audio_layer + 24*n_text_layer;
|
|
1592
1704
|
|
|
1593
|
-
|
|
1594
|
-
|
|
1595
|
-
|
|
1596
|
-
|
|
1597
|
-
|
|
1705
|
+
std::map<wsp_ggml_backend_buffer_type_t, wsp_ggml_context *> ctx_map;
|
|
1706
|
+
auto get_ctx = [&](wsp_ggml_backend_buffer_type_t buft) -> wsp_ggml_context * {
|
|
1707
|
+
auto it = ctx_map.find(buft);
|
|
1708
|
+
if (it == ctx_map.end()) {
|
|
1709
|
+
wsp_ggml_init_params params = {
|
|
1710
|
+
/*.mem_size =*/ n_tensors * wsp_ggml_tensor_overhead(),
|
|
1711
|
+
/*.mem_buffer =*/ nullptr,
|
|
1712
|
+
/*.no_alloc =*/ true,
|
|
1713
|
+
};
|
|
1598
1714
|
|
|
1599
|
-
|
|
1600
|
-
|
|
1601
|
-
|
|
1602
|
-
|
|
1715
|
+
wsp_ggml_context * ctx = wsp_ggml_init(params);
|
|
1716
|
+
if (!ctx) {
|
|
1717
|
+
throw std::runtime_error("failed to create ggml context");
|
|
1718
|
+
}
|
|
1719
|
+
|
|
1720
|
+
ctx_map[buft] = ctx;
|
|
1721
|
+
model.ctxs.emplace_back(ctx);
|
|
1722
|
+
|
|
1723
|
+
return ctx;
|
|
1603
1724
|
}
|
|
1604
|
-
|
|
1725
|
+
|
|
1726
|
+
return it->second;
|
|
1727
|
+
};
|
|
1728
|
+
|
|
1729
|
+
// Create a list of available bufts, in priority order
|
|
1730
|
+
buft_list_t buft_list = make_buft_list(wctx.params);
|
|
1731
|
+
|
|
1732
|
+
auto create_tensor = [&](asr_tensor type, asr_system system, wsp_ggml_tensor * meta, int layer = 0) -> wsp_ggml_tensor * {
|
|
1733
|
+
wsp_ggml_op op = ASR_TENSOR_INFO.at(type);
|
|
1734
|
+
wsp_ggml_backend_buffer_type_t buft = select_weight_buft(hparams, meta, op, buft_list);
|
|
1735
|
+
if (!buft) {
|
|
1736
|
+
throw std::runtime_error(format("failed to find a compatible buffer type for tensor %s", ASR_TENSOR_NAMES.at(system).at(type)));
|
|
1737
|
+
}
|
|
1738
|
+
|
|
1739
|
+
wsp_ggml_context * ctx = get_ctx(buft);
|
|
1740
|
+
wsp_ggml_tensor * tensor = wsp_ggml_dup_tensor(ctx, meta);
|
|
1741
|
+
|
|
1742
|
+
model.tensors[format(ASR_TENSOR_NAMES.at(system).at(type), layer)] = tensor;
|
|
1743
|
+
|
|
1744
|
+
return tensor;
|
|
1745
|
+
};
|
|
1746
|
+
|
|
1605
1747
|
|
|
1606
1748
|
// prepare tensors for the weights
|
|
1607
1749
|
{
|
|
1608
|
-
|
|
1750
|
+
wsp_ggml_init_params params = {
|
|
1751
|
+
/*.mem_size =*/ n_tensors * wsp_ggml_tensor_overhead(),
|
|
1752
|
+
/*.mem_buffer =*/ nullptr,
|
|
1753
|
+
/*.no_alloc =*/ true,
|
|
1754
|
+
};
|
|
1755
|
+
|
|
1756
|
+
wsp_ggml_context * ctx = wsp_ggml_init(params);
|
|
1609
1757
|
|
|
1610
1758
|
const auto & hparams = model.hparams;
|
|
1611
1759
|
|
|
@@ -1625,189 +1773,108 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
|
1625
1773
|
model.layers_decoder.resize(n_text_layer);
|
|
1626
1774
|
|
|
1627
1775
|
// encoder
|
|
1628
|
-
|
|
1629
|
-
model.e_pe = wsp_ggml_new_tensor_2d(ctx, WSP_GGML_TYPE_F32, n_audio_state, n_audio_ctx);
|
|
1630
|
-
|
|
1631
|
-
model.e_conv_1_w = wsp_ggml_new_tensor_3d(ctx, vtype, 3, n_mels, n_audio_state);
|
|
1632
|
-
model.e_conv_1_b = wsp_ggml_new_tensor_2d(ctx, WSP_GGML_TYPE_F32, 1, n_audio_state);
|
|
1633
|
-
|
|
1634
|
-
model.e_conv_2_w = wsp_ggml_new_tensor_3d(ctx, vtype, 3, n_audio_state, n_audio_state);
|
|
1635
|
-
model.e_conv_2_b = wsp_ggml_new_tensor_2d(ctx, WSP_GGML_TYPE_F32, 1, n_audio_state);
|
|
1776
|
+
model.e_pe = create_tensor(ASR_TENSOR_ENC_POS_EMBD, ASR_SYSTEM_ENCODER, wsp_ggml_new_tensor_2d(ctx, WSP_GGML_TYPE_F32, n_audio_state, n_audio_ctx));
|
|
1636
1777
|
|
|
1637
|
-
|
|
1638
|
-
|
|
1778
|
+
model.e_conv_1_w = create_tensor(ASR_TENSOR_CONV1_WEIGHT, ASR_SYSTEM_ENCODER, wsp_ggml_new_tensor_3d(ctx, vtype, 3, n_mels, n_audio_state));
|
|
1779
|
+
model.e_conv_1_b = create_tensor(ASR_TENSOR_CONV1_BIAS, ASR_SYSTEM_ENCODER, wsp_ggml_new_tensor_2d(ctx, WSP_GGML_TYPE_F32, 1, n_audio_state));
|
|
1639
1780
|
|
|
1640
|
-
|
|
1641
|
-
|
|
1781
|
+
model.e_conv_2_w = create_tensor(ASR_TENSOR_CONV2_WEIGHT, ASR_SYSTEM_ENCODER, wsp_ggml_new_tensor_3d(ctx, vtype, 3, n_audio_state, n_audio_state));
|
|
1782
|
+
model.e_conv_2_b = create_tensor(ASR_TENSOR_CONV2_BIAS, ASR_SYSTEM_ENCODER, wsp_ggml_new_tensor_2d(ctx, WSP_GGML_TYPE_F32, 1, n_audio_state));
|
|
1642
1783
|
|
|
1643
|
-
|
|
1644
|
-
|
|
1784
|
+
model.e_ln_w = create_tensor(ASR_TENSOR_LN_WEIGHT, ASR_SYSTEM_ENCODER, wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_audio_state));
|
|
1785
|
+
model.e_ln_b = create_tensor(ASR_TENSOR_LN_POST_BIAS, ASR_SYSTEM_ENCODER, wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_audio_state));
|
|
1645
1786
|
|
|
1646
|
-
|
|
1647
|
-
model.
|
|
1787
|
+
for (int i = 0; i < n_audio_layer; ++i) {
|
|
1788
|
+
auto & layer = model.layers_encoder[i];
|
|
1648
1789
|
|
|
1649
|
-
|
|
1650
|
-
|
|
1790
|
+
layer.mlp_ln_w = create_tensor(ASR_TENSOR_MLP_LN_WEIGHT, ASR_SYSTEM_ENCODER, wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_audio_state), i);
|
|
1791
|
+
layer.mlp_ln_b = create_tensor(ASR_TENSOR_MLP_LN_BIAS, ASR_SYSTEM_ENCODER, wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_audio_state), i);
|
|
1651
1792
|
|
|
1652
|
-
|
|
1653
|
-
|
|
1793
|
+
layer.mlp_0_w = create_tensor(ASR_TENSOR_MLP_0_WEIGHT, ASR_SYSTEM_ENCODER, wsp_ggml_new_tensor_2d(ctx, wtype, n_audio_state, 4*n_audio_state), i);
|
|
1794
|
+
layer.mlp_0_b = create_tensor(ASR_TENSOR_MLP_0_BIAS, ASR_SYSTEM_ENCODER, wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, 4*n_audio_state), i);
|
|
1654
1795
|
|
|
1655
|
-
|
|
1656
|
-
|
|
1796
|
+
layer.mlp_1_w = create_tensor(ASR_TENSOR_MLP_2_WEIGHT, ASR_SYSTEM_ENCODER, wsp_ggml_new_tensor_2d(ctx, wtype, 4*n_audio_state, n_audio_state), i);
|
|
1797
|
+
layer.mlp_1_b = create_tensor(ASR_TENSOR_MLP_2_BIAS, ASR_SYSTEM_ENCODER, wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_audio_state), i);
|
|
1657
1798
|
|
|
1658
|
-
|
|
1659
|
-
|
|
1799
|
+
layer.attn_ln_0_w = create_tensor(ASR_TENSOR_ATTN_LN_WEIGHT, ASR_SYSTEM_ENCODER, wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_audio_state), i);
|
|
1800
|
+
layer.attn_ln_0_b = create_tensor(ASR_TENSOR_ATTN_LN_BIAS, ASR_SYSTEM_ENCODER, wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_audio_state), i);
|
|
1660
1801
|
|
|
1661
|
-
|
|
1662
|
-
|
|
1802
|
+
layer.attn_q_w = create_tensor(ASR_TENSOR_ATTN_QUERY_WEIGHT, ASR_SYSTEM_ENCODER, wsp_ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state), i);
|
|
1803
|
+
layer.attn_q_b = create_tensor(ASR_TENSOR_ATTN_QUERY_BIAS, ASR_SYSTEM_ENCODER, wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_audio_state), i);
|
|
1663
1804
|
|
|
1664
|
-
|
|
1665
|
-
layer.attn_ln_0_b = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_audio_state);
|
|
1805
|
+
layer.attn_k_w = create_tensor(ASR_TENSOR_ATTN_KEY_WEIGHT, ASR_SYSTEM_ENCODER, wsp_ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state), i);
|
|
1666
1806
|
|
|
1667
|
-
|
|
1668
|
-
|
|
1807
|
+
layer.attn_v_w = create_tensor(ASR_TENSOR_ATTN_VALUE_WEIGHT, ASR_SYSTEM_ENCODER, wsp_ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state), i);
|
|
1808
|
+
layer.attn_v_b = create_tensor(ASR_TENSOR_ATTN_VALUE_BIAS, ASR_SYSTEM_ENCODER, wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_audio_state), i);
|
|
1669
1809
|
|
|
1670
|
-
|
|
1671
|
-
|
|
1672
|
-
layer.attn_v_w = wsp_ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state);
|
|
1673
|
-
layer.attn_v_b = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_audio_state);
|
|
1674
|
-
|
|
1675
|
-
layer.attn_ln_1_w = wsp_ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state);
|
|
1676
|
-
layer.attn_ln_1_b = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_audio_state);
|
|
1677
|
-
|
|
1678
|
-
// map by name
|
|
1679
|
-
model.tensors["encoder.blocks." + std::to_string(i) + ".mlp_ln.weight"] = layer.mlp_ln_w;
|
|
1680
|
-
model.tensors["encoder.blocks." + std::to_string(i) + ".mlp_ln.bias"] = layer.mlp_ln_b;
|
|
1681
|
-
|
|
1682
|
-
model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.0.weight"] = layer.mlp_0_w;
|
|
1683
|
-
model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.0.bias"] = layer.mlp_0_b;
|
|
1684
|
-
|
|
1685
|
-
model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.2.weight"] = layer.mlp_1_w;
|
|
1686
|
-
model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.2.bias"] = layer.mlp_1_b;
|
|
1687
|
-
|
|
1688
|
-
model.tensors["encoder.blocks." + std::to_string(i) + ".attn_ln.weight"] = layer.attn_ln_0_w;
|
|
1689
|
-
model.tensors["encoder.blocks." + std::to_string(i) + ".attn_ln.bias"] = layer.attn_ln_0_b;
|
|
1690
|
-
|
|
1691
|
-
model.tensors["encoder.blocks." + std::to_string(i) + ".attn.query.weight"] = layer.attn_q_w;
|
|
1692
|
-
model.tensors["encoder.blocks." + std::to_string(i) + ".attn.query.bias"] = layer.attn_q_b;
|
|
1693
|
-
|
|
1694
|
-
model.tensors["encoder.blocks." + std::to_string(i) + ".attn.key.weight"] = layer.attn_k_w;
|
|
1695
|
-
|
|
1696
|
-
model.tensors["encoder.blocks." + std::to_string(i) + ".attn.value.weight"] = layer.attn_v_w;
|
|
1697
|
-
model.tensors["encoder.blocks." + std::to_string(i) + ".attn.value.bias"] = layer.attn_v_b;
|
|
1698
|
-
|
|
1699
|
-
model.tensors["encoder.blocks." + std::to_string(i) + ".attn.out.weight"] = layer.attn_ln_1_w;
|
|
1700
|
-
model.tensors["encoder.blocks." + std::to_string(i) + ".attn.out.bias"] = layer.attn_ln_1_b;
|
|
1701
|
-
}
|
|
1810
|
+
layer.attn_ln_1_w = create_tensor(ASR_TENSOR_ATTN_OUT_WEIGHT, ASR_SYSTEM_ENCODER, wsp_ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state), i);
|
|
1811
|
+
layer.attn_ln_1_b = create_tensor(ASR_TENSOR_ATTN_OUT_BIAS, ASR_SYSTEM_ENCODER, wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_audio_state), i);
|
|
1702
1812
|
}
|
|
1703
1813
|
|
|
1704
1814
|
// decoder
|
|
1705
|
-
|
|
1706
|
-
model.d_pe = wsp_ggml_new_tensor_2d(ctx, WSP_GGML_TYPE_F32, n_text_state, n_text_ctx);
|
|
1707
|
-
|
|
1708
|
-
model.d_te = wsp_ggml_new_tensor_2d(ctx, wtype, n_text_state, n_vocab);
|
|
1709
|
-
|
|
1710
|
-
model.d_ln_w = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_text_state);
|
|
1711
|
-
model.d_ln_b = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_text_state);
|
|
1712
|
-
|
|
1713
|
-
// map by name
|
|
1714
|
-
model.tensors["decoder.positional_embedding"] = model.d_pe;
|
|
1715
|
-
|
|
1716
|
-
model.tensors["decoder.token_embedding.weight"] = model.d_te;
|
|
1815
|
+
model.d_pe = create_tensor(ASR_TENSOR_DEC_POS_EMBD, ASR_SYSTEM_DECODER, wsp_ggml_new_tensor_2d(ctx, WSP_GGML_TYPE_F32, n_text_state, n_text_ctx));
|
|
1717
1816
|
|
|
1718
|
-
|
|
1719
|
-
model.tensors["decoder.ln.bias"] = model.d_ln_b;
|
|
1817
|
+
model.d_te = create_tensor(ASR_TENSOR_DEC_TOKEN_EMBD_WEIGHT, ASR_SYSTEM_DECODER, wsp_ggml_new_tensor_2d(ctx, wtype, n_text_state, n_vocab));
|
|
1720
1818
|
|
|
1721
|
-
|
|
1722
|
-
|
|
1819
|
+
model.d_ln_w = create_tensor(ASR_TENSOR_LN_WEIGHT, ASR_SYSTEM_DECODER, wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_text_state));
|
|
1820
|
+
model.d_ln_b = create_tensor(ASR_TENSOR_LN_BIAS, ASR_SYSTEM_DECODER, wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_text_state));
|
|
1723
1821
|
|
|
1724
|
-
|
|
1725
|
-
|
|
1822
|
+
for (int i = 0; i < n_text_layer; ++i) {
|
|
1823
|
+
auto & layer = model.layers_decoder[i];
|
|
1726
1824
|
|
|
1727
|
-
|
|
1728
|
-
|
|
1825
|
+
layer.mlp_ln_w = create_tensor(ASR_TENSOR_MLP_LN_WEIGHT, ASR_SYSTEM_DECODER, wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_text_state), i);
|
|
1826
|
+
layer.mlp_ln_b = create_tensor(ASR_TENSOR_MLP_LN_BIAS, ASR_SYSTEM_DECODER, wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_text_state), i);
|
|
1729
1827
|
|
|
1730
|
-
|
|
1731
|
-
|
|
1828
|
+
layer.mlp_0_w = create_tensor(ASR_TENSOR_MLP_0_WEIGHT, ASR_SYSTEM_DECODER, wsp_ggml_new_tensor_2d(ctx, wtype, n_text_state, 4*n_text_state), i);
|
|
1829
|
+
layer.mlp_0_b = create_tensor(ASR_TENSOR_MLP_0_BIAS, ASR_SYSTEM_DECODER, wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, 4*n_text_state), i);
|
|
1732
1830
|
|
|
1733
|
-
|
|
1734
|
-
|
|
1831
|
+
layer.mlp_1_w = create_tensor(ASR_TENSOR_MLP_2_WEIGHT, ASR_SYSTEM_DECODER, wsp_ggml_new_tensor_2d(ctx, wtype, 4*n_text_state, n_text_state), i);
|
|
1832
|
+
layer.mlp_1_b = create_tensor(ASR_TENSOR_MLP_2_BIAS, ASR_SYSTEM_DECODER, wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_text_state), i);
|
|
1735
1833
|
|
|
1736
|
-
|
|
1737
|
-
|
|
1834
|
+
layer.attn_ln_0_w = create_tensor(ASR_TENSOR_ATTN_LN_WEIGHT, ASR_SYSTEM_DECODER, wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_text_state), i);
|
|
1835
|
+
layer.attn_ln_0_b = create_tensor(ASR_TENSOR_ATTN_LN_BIAS, ASR_SYSTEM_DECODER, wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_text_state), i);
|
|
1738
1836
|
|
|
1739
|
-
|
|
1837
|
+
layer.attn_q_w = create_tensor(ASR_TENSOR_ATTN_QUERY_WEIGHT, ASR_SYSTEM_DECODER, wsp_ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state), i);
|
|
1838
|
+
layer.attn_q_b = create_tensor(ASR_TENSOR_ATTN_QUERY_BIAS, ASR_SYSTEM_DECODER, wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_text_state), i);
|
|
1740
1839
|
|
|
1741
|
-
|
|
1742
|
-
layer.attn_v_b = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_text_state);
|
|
1840
|
+
layer.attn_k_w = create_tensor(ASR_TENSOR_ATTN_KEY_WEIGHT, ASR_SYSTEM_DECODER, wsp_ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state), i);
|
|
1743
1841
|
|
|
1744
|
-
|
|
1745
|
-
|
|
1842
|
+
layer.attn_v_w = create_tensor(ASR_TENSOR_ATTN_VALUE_WEIGHT, ASR_SYSTEM_DECODER, wsp_ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state), i);
|
|
1843
|
+
layer.attn_v_b = create_tensor(ASR_TENSOR_ATTN_VALUE_BIAS, ASR_SYSTEM_DECODER, wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_text_state), i);
|
|
1746
1844
|
|
|
1747
|
-
|
|
1748
|
-
|
|
1845
|
+
layer.attn_ln_1_w = create_tensor(ASR_TENSOR_ATTN_OUT_WEIGHT, ASR_SYSTEM_DECODER, wsp_ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state), i);
|
|
1846
|
+
layer.attn_ln_1_b = create_tensor(ASR_TENSOR_ATTN_OUT_BIAS, ASR_SYSTEM_DECODER, wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_text_state), i);
|
|
1749
1847
|
|
|
1750
|
-
|
|
1751
|
-
|
|
1848
|
+
layer.cross_attn_ln_0_w = create_tensor(ASR_TENSOR_ATTN_LN_WEIGHT, ASR_SYSTEM_CROSS, wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_text_state), i);
|
|
1849
|
+
layer.cross_attn_ln_0_b = create_tensor(ASR_TENSOR_ATTN_LN_BIAS, ASR_SYSTEM_CROSS, wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_text_state), i);
|
|
1752
1850
|
|
|
1753
|
-
|
|
1851
|
+
layer.cross_attn_q_w = create_tensor(ASR_TENSOR_ATTN_QUERY_WEIGHT, ASR_SYSTEM_CROSS, wsp_ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state), i);
|
|
1852
|
+
layer.cross_attn_q_b = create_tensor(ASR_TENSOR_ATTN_QUERY_BIAS, ASR_SYSTEM_CROSS, wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_text_state), i);
|
|
1754
1853
|
|
|
1755
|
-
|
|
1756
|
-
layer.cross_attn_v_b = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_text_state);
|
|
1854
|
+
layer.cross_attn_k_w = create_tensor(ASR_TENSOR_ATTN_KEY_WEIGHT, ASR_SYSTEM_CROSS, wsp_ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state), i);
|
|
1757
1855
|
|
|
1758
|
-
|
|
1759
|
-
|
|
1856
|
+
layer.cross_attn_v_w = create_tensor(ASR_TENSOR_ATTN_VALUE_WEIGHT, ASR_SYSTEM_CROSS, wsp_ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state), i);
|
|
1857
|
+
layer.cross_attn_v_b = create_tensor(ASR_TENSOR_ATTN_VALUE_BIAS, ASR_SYSTEM_CROSS, wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_text_state), i);
|
|
1760
1858
|
|
|
1761
|
-
|
|
1762
|
-
|
|
1763
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".mlp_ln.bias"] = layer.mlp_ln_b;
|
|
1764
|
-
|
|
1765
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.0.weight"] = layer.mlp_0_w;
|
|
1766
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.0.bias"] = layer.mlp_0_b;
|
|
1767
|
-
|
|
1768
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.2.weight"] = layer.mlp_1_w;
|
|
1769
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.2.bias"] = layer.mlp_1_b;
|
|
1770
|
-
|
|
1771
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".attn_ln.weight"] = layer.attn_ln_0_w;
|
|
1772
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".attn_ln.bias"] = layer.attn_ln_0_b;
|
|
1773
|
-
|
|
1774
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".attn.query.weight"] = layer.attn_q_w;
|
|
1775
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".attn.query.bias"] = layer.attn_q_b;
|
|
1776
|
-
|
|
1777
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".attn.key.weight"] = layer.attn_k_w;
|
|
1778
|
-
|
|
1779
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".attn.value.weight"] = layer.attn_v_w;
|
|
1780
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".attn.value.bias"] = layer.attn_v_b;
|
|
1781
|
-
|
|
1782
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".attn.out.weight"] = layer.attn_ln_1_w;
|
|
1783
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".attn.out.bias"] = layer.attn_ln_1_b;
|
|
1784
|
-
|
|
1785
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn_ln.weight"] = layer.cross_attn_ln_0_w;
|
|
1786
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn_ln.bias"] = layer.cross_attn_ln_0_b;
|
|
1787
|
-
|
|
1788
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.query.weight"] = layer.cross_attn_q_w;
|
|
1789
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.query.bias"] = layer.cross_attn_q_b;
|
|
1790
|
-
|
|
1791
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.key.weight"] = layer.cross_attn_k_w;
|
|
1792
|
-
|
|
1793
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.value.weight"] = layer.cross_attn_v_w;
|
|
1794
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.value.bias"] = layer.cross_attn_v_b;
|
|
1795
|
-
|
|
1796
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.out.weight"] = layer.cross_attn_ln_1_w;
|
|
1797
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.out.bias"] = layer.cross_attn_ln_1_b;
|
|
1798
|
-
}
|
|
1859
|
+
layer.cross_attn_ln_1_w = create_tensor(ASR_TENSOR_ATTN_OUT_WEIGHT, ASR_SYSTEM_CROSS, wsp_ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state), i);
|
|
1860
|
+
layer.cross_attn_ln_1_b = create_tensor(ASR_TENSOR_ATTN_OUT_BIAS, ASR_SYSTEM_CROSS, wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_text_state), i);
|
|
1799
1861
|
}
|
|
1862
|
+
|
|
1863
|
+
wsp_ggml_free(ctx);
|
|
1800
1864
|
}
|
|
1801
1865
|
|
|
1802
1866
|
// allocate tensors in the backend buffers
|
|
1803
|
-
|
|
1804
|
-
|
|
1805
|
-
|
|
1806
|
-
|
|
1807
|
-
|
|
1867
|
+
for (auto & p : ctx_map) {
|
|
1868
|
+
wsp_ggml_backend_buffer_type_t buft = p.first;
|
|
1869
|
+
wsp_ggml_context * ctx = p.second;
|
|
1870
|
+
wsp_ggml_backend_buffer_t buf = wsp_ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
|
|
1871
|
+
if (buf) {
|
|
1872
|
+
model.buffers.emplace_back(buf);
|
|
1808
1873
|
|
|
1809
|
-
|
|
1810
|
-
|
|
1874
|
+
size_t size_main = wsp_ggml_backend_buffer_get_size(buf);
|
|
1875
|
+
WHISPER_LOG_INFO("%s: %12s total size = %8.2f MB\n", __func__, wsp_ggml_backend_buffer_name(buf), size_main / 1e6);
|
|
1876
|
+
}
|
|
1877
|
+
}
|
|
1811
1878
|
|
|
1812
1879
|
// load weights
|
|
1813
1880
|
{
|
|
@@ -1870,11 +1937,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
|
1870
1937
|
return false;
|
|
1871
1938
|
}
|
|
1872
1939
|
|
|
1873
|
-
|
|
1874
|
-
|
|
1875
|
-
//printf("%s: [%5.5s] %s\n", __func__, wsp_ggml_backend_name(backend), name.c_str());
|
|
1876
|
-
|
|
1877
|
-
if (wsp_ggml_backend_buffer_is_host(model.buffer)) {
|
|
1940
|
+
if (wsp_ggml_backend_buffer_is_host(tensor->buffer)) {
|
|
1878
1941
|
// for the CPU and Metal backend, we can read directly into the tensor
|
|
1879
1942
|
loader->read(loader->context, tensor->data, wsp_ggml_nbytes(tensor));
|
|
1880
1943
|
BYTESWAP_TENSOR(tensor);
|
|
@@ -1887,7 +1950,6 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
|
1887
1950
|
wsp_ggml_backend_tensor_set(tensor, read_buf.data(), 0, wsp_ggml_nbytes(tensor));
|
|
1888
1951
|
}
|
|
1889
1952
|
|
|
1890
|
-
//printf("%48s - [%5d, %5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ne[2], wsp_ggml_type_name((wsp_ggml_type) ttype), wsp_ggml_nbytes(tensor)/1e6);
|
|
1891
1953
|
total_size += wsp_ggml_nbytes(tensor);
|
|
1892
1954
|
model.n_loaded++;
|
|
1893
1955
|
}
|
|
@@ -1902,7 +1964,9 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
|
1902
1964
|
}
|
|
1903
1965
|
}
|
|
1904
1966
|
|
|
1905
|
-
|
|
1967
|
+
for (auto & buf : model.buffers) {
|
|
1968
|
+
wsp_ggml_backend_buffer_set_usage(buf, WSP_GGML_BACKEND_BUFFER_USAGE_WEIGHTS);
|
|
1969
|
+
}
|
|
1906
1970
|
|
|
1907
1971
|
wctx.t_load_us = wsp_ggml_time_us() - t_start_us;
|
|
1908
1972
|
|
|
@@ -3164,7 +3228,7 @@ static bool log_mel_spectrogram(
|
|
|
3164
3228
|
std::vector<std::thread> workers(n_threads - 1);
|
|
3165
3229
|
for (int iw = 0; iw < n_threads - 1; ++iw) {
|
|
3166
3230
|
workers[iw] = std::thread(
|
|
3167
|
-
log_mel_spectrogram_worker_thread, iw + 1, hann, samples_padded,
|
|
3231
|
+
log_mel_spectrogram_worker_thread, iw + 1, hann, std::cref(samples_padded),
|
|
3168
3232
|
n_samples + stage_2_pad, frame_size, frame_step, n_threads,
|
|
3169
3233
|
std::cref(filters), std::ref(mel));
|
|
3170
3234
|
}
|
|
@@ -3670,8 +3734,7 @@ struct whisper_context * whisper_init_with_params_no_state(struct whisper_model_
|
|
|
3670
3734
|
WHISPER_LOG_INFO("%s: flash attn = %d\n", __func__, params.flash_attn);
|
|
3671
3735
|
WHISPER_LOG_INFO("%s: gpu_device = %d\n", __func__, params.gpu_device);
|
|
3672
3736
|
WHISPER_LOG_INFO("%s: dtw = %d\n", __func__, params.dtw_token_timestamps);
|
|
3673
|
-
|
|
3674
|
-
// TODO: temporary call to force backend registry initialization
|
|
3737
|
+
WHISPER_LOG_INFO("%s: devices = %zu\n", __func__, wsp_ggml_backend_dev_count());
|
|
3675
3738
|
WHISPER_LOG_INFO("%s: backends = %zu\n", __func__, wsp_ggml_backend_reg_count());
|
|
3676
3739
|
|
|
3677
3740
|
whisper_context * ctx = new whisper_context;
|
|
@@ -3792,15 +3855,24 @@ void whisper_free_state(struct whisper_state * state) {
|
|
|
3792
3855
|
// [EXPERIMENTAL] Token-level timestamps with DTW
|
|
3793
3856
|
aheads_masks_free(state->aheads_masks);
|
|
3794
3857
|
|
|
3858
|
+
if (state->vad_context != nullptr) {
|
|
3859
|
+
whisper_vad_free(state->vad_context);
|
|
3860
|
+
state->vad_context = nullptr;
|
|
3861
|
+
}
|
|
3862
|
+
|
|
3795
3863
|
delete state;
|
|
3796
3864
|
}
|
|
3797
3865
|
}
|
|
3798
3866
|
|
|
3799
3867
|
void whisper_free(struct whisper_context * ctx) {
|
|
3800
3868
|
if (ctx) {
|
|
3801
|
-
|
|
3869
|
+
for (wsp_ggml_context * context : ctx->model.ctxs) {
|
|
3870
|
+
wsp_ggml_free(context);
|
|
3871
|
+
}
|
|
3802
3872
|
|
|
3803
|
-
|
|
3873
|
+
for (wsp_ggml_backend_buffer_t buf : ctx->model.buffers) {
|
|
3874
|
+
wsp_ggml_backend_buffer_free(buf);
|
|
3875
|
+
}
|
|
3804
3876
|
|
|
3805
3877
|
whisper_free_state(ctx->state);
|
|
3806
3878
|
|
|
@@ -4194,47 +4266,37 @@ struct whisper_timings * whisper_get_timings(struct whisper_context * ctx) {
|
|
|
4194
4266
|
if (ctx->state == nullptr) {
|
|
4195
4267
|
return nullptr;
|
|
4196
4268
|
}
|
|
4197
|
-
|
|
4198
|
-
|
|
4199
|
-
|
|
4200
|
-
|
|
4201
|
-
|
|
4202
|
-
|
|
4203
|
-
|
|
4204
|
-
.n_encode = ctx->state->n_encode,
|
|
4205
|
-
.n_decode = ctx->state->n_decode,
|
|
4206
|
-
.n_batchd = ctx->state->n_batchd,
|
|
4207
|
-
.n_prompt = ctx->state->n_prompt,
|
|
4208
|
-
.t_sample_us = ctx->state->t_sample_us,
|
|
4209
|
-
.t_encode_us = ctx->state->t_encode_us,
|
|
4210
|
-
.t_decode_us = ctx->state->t_decode_us,
|
|
4211
|
-
.t_batchd_us = ctx->state->t_batchd_us,
|
|
4212
|
-
.t_prompt_us = ctx->state->t_prompt_us,
|
|
4213
|
-
};
|
|
4269
|
+
whisper_timings * timings = new whisper_timings;
|
|
4270
|
+
timings->sample_ms = 1e-3f * ctx->state->t_sample_us / std::max(1, ctx->state->n_sample);
|
|
4271
|
+
timings->encode_ms = 1e-3f * ctx->state->t_encode_us / std::max(1, ctx->state->n_encode);
|
|
4272
|
+
timings->decode_ms = 1e-3f * ctx->state->t_decode_us / std::max(1, ctx->state->n_decode);
|
|
4273
|
+
timings->batchd_ms = 1e-3f * ctx->state->t_batchd_us / std::max(1, ctx->state->n_batchd);
|
|
4274
|
+
timings->prompt_ms = 1e-3f * ctx->state->t_prompt_us / std::max(1, ctx->state->n_prompt);
|
|
4275
|
+
return timings;
|
|
4214
4276
|
}
|
|
4215
4277
|
|
|
4216
4278
|
void whisper_print_timings(struct whisper_context * ctx) {
|
|
4217
4279
|
const int64_t t_end_us = wsp_ggml_time_us();
|
|
4218
|
-
const struct whisper_timings * timings = whisper_get_timings(ctx);
|
|
4219
4280
|
|
|
4220
4281
|
WHISPER_LOG_INFO("\n");
|
|
4221
|
-
WHISPER_LOG_INFO("%s: load time = %8.2f ms\n", __func__,
|
|
4282
|
+
WHISPER_LOG_INFO("%s: load time = %8.2f ms\n", __func__, ctx->t_load_us / 1000.0f);
|
|
4222
4283
|
if (ctx->state != nullptr) {
|
|
4284
|
+
|
|
4223
4285
|
const int32_t n_sample = std::max(1, ctx->state->n_sample);
|
|
4224
4286
|
const int32_t n_encode = std::max(1, ctx->state->n_encode);
|
|
4225
4287
|
const int32_t n_decode = std::max(1, ctx->state->n_decode);
|
|
4226
4288
|
const int32_t n_batchd = std::max(1, ctx->state->n_batchd);
|
|
4227
4289
|
const int32_t n_prompt = std::max(1, ctx->state->n_prompt);
|
|
4228
4290
|
|
|
4229
|
-
WHISPER_LOG_INFO("%s: fallbacks = %3d p / %3d h\n", __func__,
|
|
4230
|
-
WHISPER_LOG_INFO("%s: mel time = %8.2f ms\n", __func__,
|
|
4231
|
-
WHISPER_LOG_INFO("%s: sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f *
|
|
4232
|
-
WHISPER_LOG_INFO("%s: encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f *
|
|
4233
|
-
WHISPER_LOG_INFO("%s: decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f *
|
|
4234
|
-
WHISPER_LOG_INFO("%s: batchd time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f *
|
|
4235
|
-
WHISPER_LOG_INFO("%s: prompt time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f *
|
|
4291
|
+
WHISPER_LOG_INFO("%s: fallbacks = %3d p / %3d h\n", __func__, ctx->state->n_fail_p, ctx->state->n_fail_h);
|
|
4292
|
+
WHISPER_LOG_INFO("%s: mel time = %8.2f ms\n", __func__, ctx->state->t_mel_us / 1000.0f);
|
|
4293
|
+
WHISPER_LOG_INFO("%s: sample time = %8.2f ms / %5d runs ( %8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_sample_us, n_sample, 1e-3f * ctx->state->t_sample_us / n_sample);
|
|
4294
|
+
WHISPER_LOG_INFO("%s: encode time = %8.2f ms / %5d runs ( %8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_encode_us, n_encode, 1e-3f * ctx->state->t_encode_us / n_encode);
|
|
4295
|
+
WHISPER_LOG_INFO("%s: decode time = %8.2f ms / %5d runs ( %8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_decode_us, n_decode, 1e-3f * ctx->state->t_decode_us / n_decode);
|
|
4296
|
+
WHISPER_LOG_INFO("%s: batchd time = %8.2f ms / %5d runs ( %8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_batchd_us, n_batchd, 1e-3f * ctx->state->t_batchd_us / n_batchd);
|
|
4297
|
+
WHISPER_LOG_INFO("%s: prompt time = %8.2f ms / %5d runs ( %8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_prompt_us, n_prompt, 1e-3f * ctx->state->t_prompt_us / n_prompt);
|
|
4236
4298
|
}
|
|
4237
|
-
WHISPER_LOG_INFO("%s: total time = %8.2f ms\n", __func__, (t_end_us -
|
|
4299
|
+
WHISPER_LOG_INFO("%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f);
|
|
4238
4300
|
}
|
|
4239
4301
|
|
|
4240
4302
|
void whisper_reset_timings(struct whisper_context * ctx) {
|
|
@@ -4274,64 +4336,1186 @@ const char * whisper_print_system_info(void) {
|
|
|
4274
4336
|
static std::string s;
|
|
4275
4337
|
|
|
4276
4338
|
s = "";
|
|
4277
|
-
s += "
|
|
4278
|
-
s += "AVX2 = " + std::to_string(wsp_ggml_cpu_has_avx2()) + " | ";
|
|
4279
|
-
s += "AVX512 = " + std::to_string(wsp_ggml_cpu_has_avx512()) + " | ";
|
|
4280
|
-
s += "FMA = " + std::to_string(wsp_ggml_cpu_has_fma()) + " | ";
|
|
4281
|
-
s += "NEON = " + std::to_string(wsp_ggml_cpu_has_neon()) + " | ";
|
|
4282
|
-
s += "ARM_FMA = " + std::to_string(wsp_ggml_cpu_has_arm_fma()) + " | ";
|
|
4283
|
-
s += "METAL = " + std::to_string(wsp_ggml_cpu_has_metal()) + " | ";
|
|
4284
|
-
s += "F16C = " + std::to_string(wsp_ggml_cpu_has_f16c()) + " | ";
|
|
4285
|
-
s += "FP16_VA = " + std::to_string(wsp_ggml_cpu_has_fp16_va()) + " | ";
|
|
4286
|
-
s += "WASM_SIMD = " + std::to_string(wsp_ggml_cpu_has_wasm_simd()) + " | ";
|
|
4287
|
-
s += "BLAS = " + std::to_string(wsp_ggml_cpu_has_blas()) + " | ";
|
|
4288
|
-
s += "SSE3 = " + std::to_string(wsp_ggml_cpu_has_sse3()) + " | ";
|
|
4289
|
-
s += "SSSE3 = " + std::to_string(wsp_ggml_cpu_has_ssse3()) + " | ";
|
|
4290
|
-
s += "VSX = " + std::to_string(wsp_ggml_cpu_has_vsx()) + " | ";
|
|
4291
|
-
s += "CUDA = " + std::to_string(wsp_ggml_cpu_has_cuda()) + " | ";
|
|
4339
|
+
s += "WHISPER : ";
|
|
4292
4340
|
s += "COREML = " + std::to_string(whisper_has_coreml()) + " | ";
|
|
4293
4341
|
s += "OPENVINO = " + std::to_string(whisper_has_openvino()) + " | ";
|
|
4294
|
-
|
|
4342
|
+
|
|
4343
|
+
for (size_t i = 0; i < wsp_ggml_backend_reg_count(); i++) {
|
|
4344
|
+
auto * reg = wsp_ggml_backend_reg_get(i);
|
|
4345
|
+
auto * get_features_fn = (wsp_ggml_backend_get_features_t) wsp_ggml_backend_reg_get_proc_address(reg, "wsp_ggml_backend_get_features");
|
|
4346
|
+
if (get_features_fn) {
|
|
4347
|
+
wsp_ggml_backend_feature * features = get_features_fn(reg);
|
|
4348
|
+
s += wsp_ggml_backend_reg_name(reg);
|
|
4349
|
+
s += " : ";
|
|
4350
|
+
for (; features->name; features++) {
|
|
4351
|
+
s += features->name;
|
|
4352
|
+
s += " = ";
|
|
4353
|
+
s += features->value;
|
|
4354
|
+
s += " | ";
|
|
4355
|
+
}
|
|
4356
|
+
}
|
|
4357
|
+
}
|
|
4295
4358
|
return s.c_str();
|
|
4296
4359
|
}
|
|
4297
4360
|
|
|
4298
4361
|
//////////////////////////////////
|
|
4299
|
-
//
|
|
4362
|
+
// Voice Activity Detection (VAD)
|
|
4300
4363
|
//////////////////////////////////
|
|
4301
4364
|
|
|
4302
|
-
|
|
4303
|
-
|
|
4304
|
-
|
|
4305
|
-
|
|
4306
|
-
|
|
4307
|
-
|
|
4308
|
-
|
|
4309
|
-
|
|
4310
|
-
|
|
4311
|
-
|
|
4365
|
+
struct whisper_vad_hparams {
|
|
4366
|
+
int32_t n_encoder_layers;
|
|
4367
|
+
int32_t * encoder_in_channels;
|
|
4368
|
+
int32_t * encoder_out_channels;
|
|
4369
|
+
int32_t * kernel_sizes;
|
|
4370
|
+
int32_t lstm_input_size;
|
|
4371
|
+
int32_t lstm_hidden_size;
|
|
4372
|
+
int32_t final_conv_in;
|
|
4373
|
+
int32_t final_conv_out;
|
|
4374
|
+
};
|
|
4312
4375
|
|
|
4313
|
-
|
|
4314
|
-
|
|
4315
|
-
|
|
4316
|
-
|
|
4317
|
-
// invalid sequence, abort
|
|
4318
|
-
code_points.push_back(0);
|
|
4319
|
-
return std::make_pair(std::move(code_points), whisper_partial_utf8{ 0, -1 });
|
|
4320
|
-
}
|
|
4321
|
-
value = (value << 6) + (next_byte & 0x3F);
|
|
4322
|
-
++pos;
|
|
4323
|
-
--n_remain;
|
|
4324
|
-
}
|
|
4376
|
+
struct whisper_vad_model {
|
|
4377
|
+
std::string type;
|
|
4378
|
+
std::string version;
|
|
4379
|
+
whisper_vad_hparams hparams;
|
|
4325
4380
|
|
|
4326
|
-
|
|
4327
|
-
code_points.push_back(value);
|
|
4328
|
-
}
|
|
4381
|
+
struct wsp_ggml_tensor * stft_forward_basis; // [256, 1, 258]
|
|
4329
4382
|
|
|
4330
|
-
//
|
|
4331
|
-
|
|
4332
|
-
|
|
4333
|
-
|
|
4334
|
-
|
|
4383
|
+
// Encoder tensors - 4 convolutional layers
|
|
4384
|
+
struct wsp_ggml_tensor * encoder_0_weight; // [3, 129, 128]
|
|
4385
|
+
struct wsp_ggml_tensor * encoder_0_bias; // [128]
|
|
4386
|
+
|
|
4387
|
+
// Second encoder layer
|
|
4388
|
+
struct wsp_ggml_tensor * encoder_1_weight; // [3, 128, 64]
|
|
4389
|
+
struct wsp_ggml_tensor * encoder_1_bias; // [64]
|
|
4390
|
+
|
|
4391
|
+
// Third encoder layer
|
|
4392
|
+
struct wsp_ggml_tensor * encoder_2_weight; // [3, 64, 64]
|
|
4393
|
+
struct wsp_ggml_tensor * encoder_2_bias; // [64]
|
|
4394
|
+
|
|
4395
|
+
// Fourth encoder layer
|
|
4396
|
+
struct wsp_ggml_tensor * encoder_3_weight; // [3, 64, 128]
|
|
4397
|
+
struct wsp_ggml_tensor * encoder_3_bias; // [128]
|
|
4398
|
+
|
|
4399
|
+
// LSTM decoder tensors
|
|
4400
|
+
struct wsp_ggml_tensor * lstm_ih_weight; // [128, 512] input-to-hidden
|
|
4401
|
+
struct wsp_ggml_tensor * lstm_ih_bias; // [512]
|
|
4402
|
+
struct wsp_ggml_tensor * lstm_hh_weight; // [128, 512] hidden-to-hidden
|
|
4403
|
+
struct wsp_ggml_tensor * lstm_hh_bias; // [512]
|
|
4404
|
+
|
|
4405
|
+
// Final conv layer
|
|
4406
|
+
struct wsp_ggml_tensor * final_conv_weight; // [128]
|
|
4407
|
+
struct wsp_ggml_tensor * final_conv_bias; // [1]
|
|
4408
|
+
|
|
4409
|
+
// ggml contexts
|
|
4410
|
+
std::vector<wsp_ggml_context *> ctxs;
|
|
4411
|
+
|
|
4412
|
+
// buffer for the model tensors
|
|
4413
|
+
std::vector<wsp_ggml_backend_buffer_t> buffers;
|
|
4414
|
+
|
|
4415
|
+
// tensors
|
|
4416
|
+
int n_loaded;
|
|
4417
|
+
std::map<std::string, struct wsp_ggml_tensor *> tensors;
|
|
4418
|
+
};
|
|
4419
|
+
|
|
4420
|
+
struct whisper_vad_segment {
|
|
4421
|
+
int64_t start;
|
|
4422
|
+
int64_t end;
|
|
4423
|
+
};
|
|
4424
|
+
|
|
4425
|
+
struct whisper_vad_segments {
|
|
4426
|
+
std::vector<whisper_vad_segment> data;
|
|
4427
|
+
};
|
|
4428
|
+
|
|
4429
|
+
struct whisper_vad_context {
|
|
4430
|
+
int64_t t_vad_us = 0;
|
|
4431
|
+
|
|
4432
|
+
int n_window;
|
|
4433
|
+
int n_context;
|
|
4434
|
+
int n_threads;
|
|
4435
|
+
|
|
4436
|
+
std::vector<wsp_ggml_backend_t> backends;
|
|
4437
|
+
wsp_ggml_backend_buffer_t buffer = nullptr;
|
|
4438
|
+
whisper_context_params params;
|
|
4439
|
+
std::vector<uint8_t> ctx_buf;
|
|
4440
|
+
whisper_sched sched;
|
|
4441
|
+
|
|
4442
|
+
whisper_vad_model model;
|
|
4443
|
+
std::string path_model;
|
|
4444
|
+
struct wsp_ggml_tensor * h_state;
|
|
4445
|
+
struct wsp_ggml_tensor * c_state;
|
|
4446
|
+
std::vector<float> probs;
|
|
4447
|
+
};
|
|
4448
|
+
|
|
4449
|
+
struct whisper_vad_context_params whisper_vad_default_context_params(void) {
|
|
4450
|
+
whisper_vad_context_params result = {
|
|
4451
|
+
/*.n_thread = */ 4,
|
|
4452
|
+
/*.use_gpu = */ false,
|
|
4453
|
+
/*.gpu_device = */ 0,
|
|
4454
|
+
};
|
|
4455
|
+
return result;
|
|
4456
|
+
}
|
|
4457
|
+
|
|
4458
|
+
struct whisper_vad_params whisper_vad_default_params(void) {
|
|
4459
|
+
whisper_vad_params result = {
|
|
4460
|
+
/* threshold = */ 0.5f,
|
|
4461
|
+
/* min_speech_duration_ms = */ 250,
|
|
4462
|
+
/* min_silence_duration_ms = */ 100,
|
|
4463
|
+
/* max_speech_duration_s = */ FLT_MAX,
|
|
4464
|
+
/* speech_pad_ms = */ 30,
|
|
4465
|
+
/* samples_overlap = */ 0.1,
|
|
4466
|
+
};
|
|
4467
|
+
return result;
|
|
4468
|
+
}
|
|
4469
|
+
|
|
4470
|
+
// Time conversion utility functions for whisper VAD
|
|
4471
|
+
static int cs_to_samples(int64_t cs) {
|
|
4472
|
+
return (int)((cs / 100.0) * WHISPER_SAMPLE_RATE + 0.5);
|
|
4473
|
+
}
|
|
4474
|
+
|
|
4475
|
+
static int64_t samples_to_cs(int samples) {
|
|
4476
|
+
return (int64_t)((samples / (double)WHISPER_SAMPLE_RATE) * 100.0 + 0.5);
|
|
4477
|
+
}
|
|
4478
|
+
|
|
4479
|
+
static bool weight_buft_supported(const whisper_vad_hparams & hparams, wsp_ggml_tensor * w, wsp_ggml_op op, wsp_ggml_backend_buffer_type_t buft, wsp_ggml_backend_dev_t dev) {
|
|
4480
|
+
bool op_supported = true;
|
|
4481
|
+
|
|
4482
|
+
if (wsp_ggml_backend_dev_type(dev) == WSP_GGML_BACKEND_DEVICE_TYPE_GPU ||
|
|
4483
|
+
(wsp_ggml_backend_dev_type(dev) == WSP_GGML_BACKEND_DEVICE_TYPE_CPU && buft == wsp_ggml_backend_cpu_buffer_type())) {
|
|
4484
|
+
// GPU and default CPU backend support all operators
|
|
4485
|
+
op_supported = true;
|
|
4486
|
+
} else {
|
|
4487
|
+
switch (op) {
|
|
4488
|
+
// The current extra_buffer_type implementations only support WSP_GGML_OP_MUL_MAT
|
|
4489
|
+
case WSP_GGML_OP_MUL_MAT: {
|
|
4490
|
+
wsp_ggml_init_params params = {
|
|
4491
|
+
/*.mem_size =*/ 2 * wsp_ggml_tensor_overhead(),
|
|
4492
|
+
/*.mem_buffer =*/ nullptr,
|
|
4493
|
+
/*.no_alloc =*/ true,
|
|
4494
|
+
};
|
|
4495
|
+
|
|
4496
|
+
wsp_ggml_context_ptr ctx_ptr { wsp_ggml_init(params) };
|
|
4497
|
+
if (!ctx_ptr) {
|
|
4498
|
+
throw std::runtime_error("failed to create ggml context");
|
|
4499
|
+
}
|
|
4500
|
+
wsp_ggml_context * ctx = ctx_ptr.get();
|
|
4501
|
+
|
|
4502
|
+
wsp_ggml_tensor * op_tensor = nullptr;
|
|
4503
|
+
|
|
4504
|
+
int64_t n_ctx = hparams.lstm_hidden_size;
|
|
4505
|
+
wsp_ggml_tensor * b = wsp_ggml_new_tensor_4d(ctx, WSP_GGML_TYPE_F32, w->ne[0], n_ctx, w->ne[2], w->ne[3]);
|
|
4506
|
+
op_tensor = wsp_ggml_mul_mat(ctx, w, b);
|
|
4507
|
+
|
|
4508
|
+
// create a temporary dummy buffer for the weight so that supports_op can check the buffer type
|
|
4509
|
+
WSP_GGML_ASSERT(w->buffer == nullptr);
|
|
4510
|
+
w->buffer = wsp_ggml_backend_buft_alloc_buffer(buft, 0);
|
|
4511
|
+
op_supported = wsp_ggml_backend_dev_supports_op(dev, op_tensor);
|
|
4512
|
+
wsp_ggml_backend_buffer_free(w->buffer);
|
|
4513
|
+
w->buffer = nullptr;
|
|
4514
|
+
break;
|
|
4515
|
+
}
|
|
4516
|
+
default: {
|
|
4517
|
+
op_supported = false;
|
|
4518
|
+
break;
|
|
4519
|
+
}
|
|
4520
|
+
};
|
|
4521
|
+
}
|
|
4522
|
+
return op_supported;
|
|
4523
|
+
}
|
|
4524
|
+
|
|
4525
|
+
static wsp_ggml_backend_buffer_type_t select_weight_buft(const whisper_vad_hparams & hparams, wsp_ggml_tensor * w, wsp_ggml_op op, buft_list_t buft_list) {
|
|
4526
|
+
WSP_GGML_ASSERT(!buft_list.empty());
|
|
4527
|
+
for (const auto & p : buft_list) {
|
|
4528
|
+
wsp_ggml_backend_dev_t dev = p.first;
|
|
4529
|
+
wsp_ggml_backend_buffer_type_t buft = p.second;
|
|
4530
|
+
if (weight_buft_supported(hparams, w, op, buft, dev)) {
|
|
4531
|
+
return buft;
|
|
4532
|
+
}
|
|
4533
|
+
}
|
|
4534
|
+
|
|
4535
|
+
return nullptr;
|
|
4536
|
+
}
|
|
4537
|
+
|
|
4538
|
+
static wsp_ggml_tensor * whisper_vad_build_stft_layer(wsp_ggml_context * ctx0,
|
|
4539
|
+
const whisper_vad_model & model, wsp_ggml_tensor * cur) {
|
|
4540
|
+
// Apply reflective padding to the input tensor
|
|
4541
|
+
wsp_ggml_tensor * padded = wsp_ggml_pad_reflect_1d(ctx0, cur, 64, 64);
|
|
4542
|
+
|
|
4543
|
+
struct wsp_ggml_tensor * stft = wsp_ggml_conv_1d(ctx0, model.stft_forward_basis, padded, model.hparams.lstm_input_size, 0, 1);
|
|
4544
|
+
|
|
4545
|
+
// Calculate cutoff for real/imaginary parts
|
|
4546
|
+
int cutoff = model.stft_forward_basis->ne[2] / 2;
|
|
4547
|
+
|
|
4548
|
+
// Extract real part (first half of the STFT output).
|
|
4549
|
+
struct wsp_ggml_tensor * real_part = wsp_ggml_view_2d(ctx0, stft, 4, cutoff, stft->nb[1], 0);
|
|
4550
|
+
// Extract imaginary part (second half of the STFT output).
|
|
4551
|
+
struct wsp_ggml_tensor * img_part = wsp_ggml_view_2d(ctx0, stft, 4, cutoff, stft->nb[1], cutoff * stft->nb[1]);
|
|
4552
|
+
|
|
4553
|
+
// Calculate magnitude: sqrt(real^2 + imag^2)
|
|
4554
|
+
struct wsp_ggml_tensor * real_squared = wsp_ggml_mul(ctx0, real_part, real_part);
|
|
4555
|
+
struct wsp_ggml_tensor * img_squared = wsp_ggml_mul(ctx0, img_part, img_part);
|
|
4556
|
+
struct wsp_ggml_tensor * sum_squares = wsp_ggml_add(ctx0, real_squared, img_squared);
|
|
4557
|
+
struct wsp_ggml_tensor * magnitude = wsp_ggml_sqrt(ctx0, sum_squares);
|
|
4558
|
+
return magnitude;
|
|
4559
|
+
}
|
|
4560
|
+
|
|
4561
|
+
static wsp_ggml_tensor * whisper_vad_build_encoder_layer(wsp_ggml_context * ctx0,
|
|
4562
|
+
const whisper_vad_model & model, wsp_ggml_tensor * cur) {
|
|
4563
|
+
// First Conv1D: expands to 128 channels.
|
|
4564
|
+
cur = wsp_ggml_conv_1d(ctx0, model.encoder_0_weight, cur, 1, 1, 1);
|
|
4565
|
+
cur = wsp_ggml_add(ctx0, cur, wsp_ggml_reshape_3d(ctx0, model.encoder_0_bias, 1, 128, 1));
|
|
4566
|
+
cur = wsp_ggml_relu(ctx0, cur);
|
|
4567
|
+
|
|
4568
|
+
// Second Conv1D: reduces to 64 channels.
|
|
4569
|
+
cur = wsp_ggml_conv_1d(ctx0, model.encoder_1_weight, cur, 2, 1, 1);
|
|
4570
|
+
cur = wsp_ggml_add(ctx0, cur, wsp_ggml_reshape_3d(ctx0, model.encoder_1_bias, 1, 64, 1));
|
|
4571
|
+
cur = wsp_ggml_relu(ctx0, cur);
|
|
4572
|
+
|
|
4573
|
+
// Third Conv1D: maintains 64 channels
|
|
4574
|
+
cur = wsp_ggml_conv_1d(ctx0, model.encoder_2_weight, cur, 2, 1, 1);
|
|
4575
|
+
cur = wsp_ggml_add(ctx0, cur, wsp_ggml_reshape_3d(ctx0, model.encoder_2_bias, 1, 64, 1));
|
|
4576
|
+
cur = wsp_ggml_relu(ctx0, cur);
|
|
4577
|
+
|
|
4578
|
+
// Fourth Conv1D: expands to 128 channels
|
|
4579
|
+
cur = wsp_ggml_conv_1d(ctx0, model.encoder_3_weight, cur, 1, 1, 1);
|
|
4580
|
+
cur = wsp_ggml_add(ctx0, cur, wsp_ggml_reshape_3d(ctx0, model.encoder_3_bias, 1, 128, 1));
|
|
4581
|
+
cur = wsp_ggml_relu(ctx0, cur);
|
|
4582
|
+
|
|
4583
|
+
return cur;
|
|
4584
|
+
}
|
|
4585
|
+
|
|
4586
|
+
static wsp_ggml_tensor * whisper_vad_build_lstm_layer(wsp_ggml_context * ctx0,
|
|
4587
|
+
const whisper_vad_context & vctx, wsp_ggml_tensor * cur, wsp_ggml_cgraph * gf) {
|
|
4588
|
+
const whisper_vad_model & model = vctx.model;
|
|
4589
|
+
const int hdim = model.hparams.lstm_hidden_size;
|
|
4590
|
+
|
|
4591
|
+
struct wsp_ggml_tensor * x_t = wsp_ggml_transpose(ctx0, cur);
|
|
4592
|
+
|
|
4593
|
+
// Create operations using the input-to-hidden weights.
|
|
4594
|
+
struct wsp_ggml_tensor * inp_gate = wsp_ggml_mul_mat(ctx0, model.lstm_ih_weight, x_t);
|
|
4595
|
+
inp_gate = wsp_ggml_add(ctx0, inp_gate, model.lstm_ih_bias);
|
|
4596
|
+
|
|
4597
|
+
// Create operations using the hidden-to-hidden weights.
|
|
4598
|
+
struct wsp_ggml_tensor * hid_gate = wsp_ggml_mul_mat(ctx0, model.lstm_hh_weight, vctx.h_state);
|
|
4599
|
+
hid_gate = wsp_ggml_add(ctx0, hid_gate, model.lstm_hh_bias);
|
|
4600
|
+
|
|
4601
|
+
// Create add operation to get preactivations for all gates.
|
|
4602
|
+
struct wsp_ggml_tensor * out_gate = wsp_ggml_add(ctx0, inp_gate, hid_gate);
|
|
4603
|
+
|
|
4604
|
+
const size_t hdim_size = wsp_ggml_row_size(out_gate->type, hdim);
|
|
4605
|
+
|
|
4606
|
+
// Create sigmoid for input gate (using the first 128 bytes from the preactivations).
|
|
4607
|
+
struct wsp_ggml_tensor * i_t = wsp_ggml_sigmoid(ctx0, wsp_ggml_view_1d(ctx0, out_gate, hdim, 0 * hdim_size));
|
|
4608
|
+
|
|
4609
|
+
// Create sigmoid for the forget gate (using the second 128 bytes from the preactivations).
|
|
4610
|
+
struct wsp_ggml_tensor * f_t = wsp_ggml_sigmoid(ctx0, wsp_ggml_view_1d(ctx0, out_gate, hdim, 1 * hdim_size));
|
|
4611
|
+
|
|
4612
|
+
// Create sigmoid for the cell gate (using the third 128 bytes from the preactivations).
|
|
4613
|
+
struct wsp_ggml_tensor * g_t = wsp_ggml_tanh(ctx0, wsp_ggml_view_1d(ctx0, out_gate, hdim, 2 * hdim_size));
|
|
4614
|
+
|
|
4615
|
+
// Create sigmoid for the output gate (using the fourth 128 bytes from the preactivations).
|
|
4616
|
+
struct wsp_ggml_tensor * o_t = wsp_ggml_sigmoid(ctx0, wsp_ggml_view_1d(ctx0, out_gate, hdim, 3 * hdim_size));
|
|
4617
|
+
|
|
4618
|
+
// Update cell state
|
|
4619
|
+
struct wsp_ggml_tensor * c_out = wsp_ggml_add(ctx0,
|
|
4620
|
+
wsp_ggml_mul(ctx0, f_t, vctx.c_state),
|
|
4621
|
+
wsp_ggml_mul(ctx0, i_t, g_t));
|
|
4622
|
+
wsp_ggml_build_forward_expand(gf, wsp_ggml_cpy(ctx0, c_out, vctx.c_state));
|
|
4623
|
+
|
|
4624
|
+
// Update hidden state
|
|
4625
|
+
struct wsp_ggml_tensor * out = wsp_ggml_mul(ctx0, o_t, wsp_ggml_tanh(ctx0, c_out));
|
|
4626
|
+
wsp_ggml_build_forward_expand(gf, wsp_ggml_cpy(ctx0, out, vctx.h_state));
|
|
4627
|
+
|
|
4628
|
+
return out;
|
|
4629
|
+
}
|
|
4630
|
+
|
|
4631
|
+
static struct wsp_ggml_cgraph * whisper_vad_build_graph(whisper_vad_context & vctx) {
|
|
4632
|
+
const auto & model = vctx.model;
|
|
4633
|
+
|
|
4634
|
+
struct wsp_ggml_init_params params = {
|
|
4635
|
+
/*.mem_size =*/ vctx.sched.meta.size(),
|
|
4636
|
+
/*.mem_buffer =*/ vctx.sched.meta.data(),
|
|
4637
|
+
/*.no_alloc =*/ true,
|
|
4638
|
+
};
|
|
4639
|
+
|
|
4640
|
+
struct wsp_ggml_context * ctx0 = wsp_ggml_init(params);
|
|
4641
|
+
|
|
4642
|
+
wsp_ggml_cgraph * gf = wsp_ggml_new_graph(ctx0);
|
|
4643
|
+
|
|
4644
|
+
struct wsp_ggml_tensor * frame = wsp_ggml_new_tensor_2d(ctx0, WSP_GGML_TYPE_F32, vctx.n_window, 1);
|
|
4645
|
+
wsp_ggml_set_name(frame, "frame");
|
|
4646
|
+
wsp_ggml_set_input(frame);
|
|
4647
|
+
|
|
4648
|
+
struct wsp_ggml_tensor * cur = nullptr;
|
|
4649
|
+
{
|
|
4650
|
+
cur = whisper_vad_build_stft_layer(ctx0, model, frame);
|
|
4651
|
+
|
|
4652
|
+
cur = whisper_vad_build_encoder_layer(ctx0, model, cur);
|
|
4653
|
+
|
|
4654
|
+
// Extract the first element of the first dimension
|
|
4655
|
+
// (equivalent to pytorch's [:, :, 0])
|
|
4656
|
+
cur = wsp_ggml_view_2d(ctx0, cur, 1, 128, cur->nb[1], 0);
|
|
4657
|
+
|
|
4658
|
+
cur = whisper_vad_build_lstm_layer(ctx0, vctx, cur, gf);
|
|
4659
|
+
cur = wsp_ggml_relu(ctx0, cur);
|
|
4660
|
+
cur = wsp_ggml_conv_1d(ctx0, model.final_conv_weight, cur, 1, 0, 1);
|
|
4661
|
+
cur = wsp_ggml_add(ctx0, cur, model.final_conv_bias);
|
|
4662
|
+
cur = wsp_ggml_sigmoid(ctx0, cur);
|
|
4663
|
+
wsp_ggml_set_name(cur, "prob");
|
|
4664
|
+
wsp_ggml_set_output(cur);
|
|
4665
|
+
}
|
|
4666
|
+
|
|
4667
|
+
wsp_ggml_build_forward_expand(gf, cur);
|
|
4668
|
+
|
|
4669
|
+
wsp_ggml_free(ctx0);
|
|
4670
|
+
|
|
4671
|
+
return gf;
|
|
4672
|
+
}
|
|
4673
|
+
|
|
4674
|
+
static bool whisper_vad_init_context(whisper_vad_context * vctx) {
|
|
4675
|
+
|
|
4676
|
+
auto whisper_context_params = whisper_context_default_params();
|
|
4677
|
+
// TODO: GPU VAD is forced disabled until the performance is improved
|
|
4678
|
+
//whisper_context_params.use_gpu = vctx->params.use_gpu;
|
|
4679
|
+
whisper_context_params.use_gpu = false;
|
|
4680
|
+
whisper_context_params.gpu_device = vctx->params.gpu_device;
|
|
4681
|
+
|
|
4682
|
+
vctx->backends = whisper_backend_init(whisper_context_params);
|
|
4683
|
+
if (vctx->backends.empty()) {
|
|
4684
|
+
WHISPER_LOG_ERROR("%s: whisper_backend_init() failed\n", __func__);
|
|
4685
|
+
return false;
|
|
4686
|
+
}
|
|
4687
|
+
|
|
4688
|
+
const int32_t lstm_hidden_size = vctx->model.hparams.lstm_hidden_size;
|
|
4689
|
+
|
|
4690
|
+
vctx->ctx_buf.resize(2u*wsp_ggml_tensor_overhead());
|
|
4691
|
+
|
|
4692
|
+
struct wsp_ggml_init_params params = {
|
|
4693
|
+
/*.mem_size =*/ vctx->ctx_buf.size(),
|
|
4694
|
+
/*.mem_buffer =*/ vctx->ctx_buf.data(),
|
|
4695
|
+
/*.no_alloc =*/ true,
|
|
4696
|
+
};
|
|
4697
|
+
|
|
4698
|
+
wsp_ggml_context * ctx = wsp_ggml_init(params);
|
|
4699
|
+
if (!ctx) {
|
|
4700
|
+
WHISPER_LOG_ERROR("%s: failed to init LSTM state ggml context\n", __func__);
|
|
4701
|
+
return false;
|
|
4702
|
+
}
|
|
4703
|
+
|
|
4704
|
+
// LSTM Hidden state
|
|
4705
|
+
vctx->h_state = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, lstm_hidden_size);
|
|
4706
|
+
wsp_ggml_set_name(vctx->h_state, "h_state");
|
|
4707
|
+
|
|
4708
|
+
// LSTM Cell state
|
|
4709
|
+
vctx->c_state = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, lstm_hidden_size);
|
|
4710
|
+
wsp_ggml_set_name(vctx->c_state, "c_state");
|
|
4711
|
+
|
|
4712
|
+
vctx->buffer = wsp_ggml_backend_alloc_ctx_tensors(ctx, vctx->backends[0]);
|
|
4713
|
+
if (!vctx->buffer) {
|
|
4714
|
+
WHISPER_LOG_ERROR("%s: failed to allocate memory for the VAD state\n", __func__);
|
|
4715
|
+
return false;
|
|
4716
|
+
}
|
|
4717
|
+
|
|
4718
|
+
{
|
|
4719
|
+
bool ok = whisper_sched_graph_init(vctx->sched, vctx->backends,
|
|
4720
|
+
[&]() {
|
|
4721
|
+
return whisper_vad_build_graph(*vctx);
|
|
4722
|
+
});
|
|
4723
|
+
|
|
4724
|
+
if (!ok) {
|
|
4725
|
+
WHISPER_LOG_ERROR("%s: failed to init VAD allocator\n", __func__);
|
|
4726
|
+
return false;
|
|
4727
|
+
}
|
|
4728
|
+
|
|
4729
|
+
WHISPER_LOG_INFO("%s: compute buffer (VAD) = %7.2f MB\n", __func__, whisper_sched_size(vctx->sched) / 1e6);
|
|
4730
|
+
}
|
|
4731
|
+
|
|
4732
|
+
return true;
|
|
4733
|
+
}
|
|
4734
|
+
|
|
4735
|
+
struct whisper_vad_context * whisper_vad_init_from_file_with_params(
|
|
4736
|
+
const char * path_model,
|
|
4737
|
+
struct whisper_vad_context_params params) {
|
|
4738
|
+
WHISPER_LOG_INFO("%s: loading VAD model from '%s'\n", __func__, path_model);
|
|
4739
|
+
#ifdef _MSC_VER
|
|
4740
|
+
std::wstring_convert<std::codecvt_utf8<wchar_t>> converter;
|
|
4741
|
+
std::wstring path_model_wide = converter.from_bytes(path_model);
|
|
4742
|
+
auto fin = std::ifstream(path_model_wide, std::ios::binary);
|
|
4743
|
+
#else
|
|
4744
|
+
auto fin = std::ifstream(path_model, std::ios::binary);
|
|
4745
|
+
#endif
|
|
4746
|
+
if (!fin) {
|
|
4747
|
+
WHISPER_LOG_ERROR("%s: failed to open VAD model '%s'\n", __func__, path_model);
|
|
4748
|
+
return nullptr;
|
|
4749
|
+
}
|
|
4750
|
+
|
|
4751
|
+
whisper_model_loader loader = {};
|
|
4752
|
+
loader.context = &fin;
|
|
4753
|
+
|
|
4754
|
+
loader.read = [](void * ctx, void * output, size_t read_size) {
|
|
4755
|
+
std::ifstream * fin = (std::ifstream*)ctx;
|
|
4756
|
+
fin->read((char *)output, read_size);
|
|
4757
|
+
return read_size;
|
|
4758
|
+
};
|
|
4759
|
+
|
|
4760
|
+
loader.eof = [](void * ctx) {
|
|
4761
|
+
std::ifstream * fin = (std::ifstream*)ctx;
|
|
4762
|
+
return fin->eof();
|
|
4763
|
+
};
|
|
4764
|
+
|
|
4765
|
+
loader.close = [](void * ctx) {
|
|
4766
|
+
std::ifstream * fin = (std::ifstream*)ctx;
|
|
4767
|
+
fin->close();
|
|
4768
|
+
};
|
|
4769
|
+
|
|
4770
|
+
auto ctx = whisper_vad_init_with_params(&loader, params);
|
|
4771
|
+
if (!ctx) {
|
|
4772
|
+
whisper_vad_free(ctx);
|
|
4773
|
+
return nullptr;
|
|
4774
|
+
}
|
|
4775
|
+
ctx->path_model = path_model;
|
|
4776
|
+
return ctx;
|
|
4777
|
+
}
|
|
4778
|
+
|
|
4779
|
+
struct whisper_vad_context * whisper_vad_init_with_params(
|
|
4780
|
+
struct whisper_model_loader * loader,
|
|
4781
|
+
struct whisper_vad_context_params params) {
|
|
4782
|
+
// Read the VAD model
|
|
4783
|
+
{
|
|
4784
|
+
uint32_t magic;
|
|
4785
|
+
read_safe(loader, magic);
|
|
4786
|
+
if (magic != WSP_GGML_FILE_MAGIC) {
|
|
4787
|
+
WHISPER_LOG_ERROR("%s: invalid model data (bad magic)\n", __func__);
|
|
4788
|
+
return nullptr;
|
|
4789
|
+
}
|
|
4790
|
+
}
|
|
4791
|
+
|
|
4792
|
+
whisper_vad_context * vctx = new whisper_vad_context;
|
|
4793
|
+
vctx->n_threads = params.n_threads;
|
|
4794
|
+
vctx->params.use_gpu = params.use_gpu;
|
|
4795
|
+
vctx->params.gpu_device = params.gpu_device;
|
|
4796
|
+
|
|
4797
|
+
auto & model = vctx->model;
|
|
4798
|
+
auto & hparams = model.hparams;
|
|
4799
|
+
|
|
4800
|
+
// load model context params.
|
|
4801
|
+
{
|
|
4802
|
+
int32_t str_len;
|
|
4803
|
+
read_safe(loader, str_len);
|
|
4804
|
+
std::vector<char> buffer(str_len + 1, 0);
|
|
4805
|
+
loader->read(loader->context, buffer.data(), str_len);
|
|
4806
|
+
std::string model_type(buffer.data(), str_len);
|
|
4807
|
+
model.type = model_type;
|
|
4808
|
+
WHISPER_LOG_INFO("%s: model type: %s\n", __func__, model.type.c_str());
|
|
4809
|
+
|
|
4810
|
+
int32_t major, minor, patch;
|
|
4811
|
+
read_safe(loader, major);
|
|
4812
|
+
read_safe(loader, minor);
|
|
4813
|
+
read_safe(loader, patch);
|
|
4814
|
+
std::string version_str = std::to_string(major) + "." +
|
|
4815
|
+
std::to_string(minor) + "." +
|
|
4816
|
+
std::to_string(patch);
|
|
4817
|
+
model.version = version_str;
|
|
4818
|
+
WHISPER_LOG_INFO("%s: model version: %s\n", __func__, model.version.c_str());
|
|
4819
|
+
|
|
4820
|
+
read_safe(loader, vctx->n_window);
|
|
4821
|
+
read_safe(loader, vctx->n_context);
|
|
4822
|
+
}
|
|
4823
|
+
|
|
4824
|
+
// load model hyper params (hparams).
|
|
4825
|
+
{
|
|
4826
|
+
read_safe(loader, hparams.n_encoder_layers);
|
|
4827
|
+
|
|
4828
|
+
hparams.encoder_in_channels = new int32_t[hparams.n_encoder_layers];
|
|
4829
|
+
hparams.encoder_out_channels = new int32_t[hparams.n_encoder_layers];
|
|
4830
|
+
hparams.kernel_sizes = new int32_t[hparams.n_encoder_layers];
|
|
4831
|
+
|
|
4832
|
+
for (int32_t i = 0; i < hparams.n_encoder_layers; i++) {
|
|
4833
|
+
read_safe(loader, hparams.encoder_in_channels[i]);
|
|
4834
|
+
read_safe(loader, hparams.encoder_out_channels[i]);
|
|
4835
|
+
read_safe(loader, hparams.kernel_sizes[i]);
|
|
4836
|
+
}
|
|
4837
|
+
|
|
4838
|
+
read_safe(loader, hparams.lstm_input_size);
|
|
4839
|
+
read_safe(loader, hparams.lstm_hidden_size);
|
|
4840
|
+
read_safe(loader, hparams.final_conv_in);
|
|
4841
|
+
read_safe(loader, hparams.final_conv_out);
|
|
4842
|
+
|
|
4843
|
+
WHISPER_LOG_INFO("%s: n_encoder_layers = %d\n", __func__, hparams.n_encoder_layers);
|
|
4844
|
+
for (int32_t i = 0; i < hparams.n_encoder_layers; i++) {
|
|
4845
|
+
WHISPER_LOG_INFO("%s: encoder_in_channels[%d] = %d\n", __func__, i, hparams.encoder_in_channels[i]);
|
|
4846
|
+
}
|
|
4847
|
+
for (int32_t i = 0; i < hparams.n_encoder_layers; i++) {
|
|
4848
|
+
WHISPER_LOG_INFO("%s: encoder_out_channels[%d] = %d\n", __func__, i, hparams.encoder_out_channels[i]);
|
|
4849
|
+
}
|
|
4850
|
+
WHISPER_LOG_INFO("%s: lstm_input_size = %d\n", __func__, hparams.lstm_input_size);
|
|
4851
|
+
WHISPER_LOG_INFO("%s: lstm_hidden_size = %d\n", __func__, hparams.lstm_hidden_size);
|
|
4852
|
+
WHISPER_LOG_INFO("%s: final_conv_in = %d\n", __func__, hparams.final_conv_in);
|
|
4853
|
+
WHISPER_LOG_INFO("%s: final_conv_out = %d\n", __func__, hparams.final_conv_out);
|
|
4854
|
+
}
|
|
4855
|
+
|
|
4856
|
+
// 1 STFT tensor, 4*2 encoder tensors, 4 LSTM tensors, 2 final output tensors
|
|
4857
|
+
const size_t n_tensors = hparams.n_encoder_layers * 2 + 4 + 2 + 1;
|
|
4858
|
+
|
|
4859
|
+
std::map<wsp_ggml_backend_buffer_type_t, wsp_ggml_context *> ctx_map;
|
|
4860
|
+
auto get_ctx = [&](wsp_ggml_backend_buffer_type_t buft) -> wsp_ggml_context * {
|
|
4861
|
+
auto it = ctx_map.find(buft);
|
|
4862
|
+
if (it == ctx_map.end()) {
|
|
4863
|
+
wsp_ggml_init_params params = {
|
|
4864
|
+
/*.mem_size =*/ n_tensors * wsp_ggml_tensor_overhead(),
|
|
4865
|
+
/*.mem_buffer =*/ nullptr,
|
|
4866
|
+
/*.no_alloc =*/ true,
|
|
4867
|
+
};
|
|
4868
|
+
|
|
4869
|
+
wsp_ggml_context * ctx = wsp_ggml_init(params);
|
|
4870
|
+
if (!ctx) {
|
|
4871
|
+
throw std::runtime_error("failed to create ggml context");
|
|
4872
|
+
}
|
|
4873
|
+
|
|
4874
|
+
ctx_map[buft] = ctx;
|
|
4875
|
+
model.ctxs.emplace_back(ctx);
|
|
4876
|
+
|
|
4877
|
+
return ctx;
|
|
4878
|
+
}
|
|
4879
|
+
|
|
4880
|
+
return it->second;
|
|
4881
|
+
};
|
|
4882
|
+
|
|
4883
|
+
whisper_context_params wparams = whisper_context_default_params();
|
|
4884
|
+
wparams.use_gpu = params.use_gpu;
|
|
4885
|
+
wparams.gpu_device = params.gpu_device;
|
|
4886
|
+
buft_list_t buft_list = make_buft_list(wparams);
|
|
4887
|
+
|
|
4888
|
+
auto create_tensor = [&](vad_tensor type, wsp_ggml_tensor * meta) -> wsp_ggml_tensor * {
|
|
4889
|
+
wsp_ggml_op op = VAD_TENSOR_OPS.at(type);
|
|
4890
|
+
wsp_ggml_backend_buffer_type_t buft = select_weight_buft(hparams, meta, op, buft_list);
|
|
4891
|
+
if (!buft) {
|
|
4892
|
+
throw std::runtime_error(format("failed to find a compatible buffer type for tensor %s", VAD_TENSOR_NAMES.at(type)));
|
|
4893
|
+
}
|
|
4894
|
+
wsp_ggml_context * ctx = get_ctx(buft);
|
|
4895
|
+
wsp_ggml_tensor * tensor = wsp_ggml_dup_tensor(ctx, meta);
|
|
4896
|
+
model.tensors[VAD_TENSOR_NAMES.at(type)] = tensor;
|
|
4897
|
+
|
|
4898
|
+
return tensor;
|
|
4899
|
+
};
|
|
4900
|
+
|
|
4901
|
+
// create tensors
|
|
4902
|
+
{
|
|
4903
|
+
wsp_ggml_init_params params = {
|
|
4904
|
+
/*.mem_size =*/ n_tensors * wsp_ggml_tensor_overhead(),
|
|
4905
|
+
/*.mem_buffer =*/ nullptr,
|
|
4906
|
+
/*.no_alloc =*/ true,
|
|
4907
|
+
};
|
|
4908
|
+
|
|
4909
|
+
wsp_ggml_context * ctx = wsp_ggml_init(params);
|
|
4910
|
+
const auto & hparams = model.hparams;
|
|
4911
|
+
|
|
4912
|
+
// SFTF precomputed basis matrix
|
|
4913
|
+
model.stft_forward_basis = create_tensor(VAD_TENSOR_STFT_BASIS,
|
|
4914
|
+
wsp_ggml_new_tensor_3d(ctx, WSP_GGML_TYPE_F16, 256, 1, 258));
|
|
4915
|
+
|
|
4916
|
+
model.encoder_0_weight = create_tensor(VAD_TENSOR_ENC_0_WEIGHT,
|
|
4917
|
+
wsp_ggml_new_tensor_3d(
|
|
4918
|
+
ctx,
|
|
4919
|
+
WSP_GGML_TYPE_F16,
|
|
4920
|
+
hparams.kernel_sizes[0],
|
|
4921
|
+
hparams.encoder_in_channels[0],
|
|
4922
|
+
hparams.encoder_out_channels[0]
|
|
4923
|
+
));
|
|
4924
|
+
model.encoder_0_bias = create_tensor(VAD_TENSOR_ENC_0_BIAS,
|
|
4925
|
+
wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, hparams.encoder_out_channels[0]));
|
|
4926
|
+
|
|
4927
|
+
model.encoder_1_weight = create_tensor(VAD_TENSOR_ENC_1_WEIGHT,
|
|
4928
|
+
wsp_ggml_new_tensor_3d(
|
|
4929
|
+
ctx,
|
|
4930
|
+
WSP_GGML_TYPE_F16,
|
|
4931
|
+
hparams.kernel_sizes[1],
|
|
4932
|
+
hparams.encoder_in_channels[1],
|
|
4933
|
+
hparams.encoder_out_channels[1]
|
|
4934
|
+
));
|
|
4935
|
+
model.encoder_1_bias = create_tensor(VAD_TENSOR_ENC_1_BIAS,
|
|
4936
|
+
wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, hparams.encoder_out_channels[1]));
|
|
4937
|
+
|
|
4938
|
+
model.encoder_2_weight = create_tensor(VAD_TENSOR_ENC_2_WEIGHT,
|
|
4939
|
+
wsp_ggml_new_tensor_3d(
|
|
4940
|
+
ctx,
|
|
4941
|
+
WSP_GGML_TYPE_F16,
|
|
4942
|
+
hparams.kernel_sizes[2],
|
|
4943
|
+
hparams.encoder_in_channels[2],
|
|
4944
|
+
hparams.encoder_out_channels[2]
|
|
4945
|
+
));
|
|
4946
|
+
model.encoder_2_bias = create_tensor(VAD_TENSOR_ENC_2_BIAS,
|
|
4947
|
+
wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, hparams.encoder_out_channels[2]));
|
|
4948
|
+
|
|
4949
|
+
model.encoder_3_weight = create_tensor(VAD_TENSOR_ENC_3_WEIGHT,
|
|
4950
|
+
wsp_ggml_new_tensor_3d(
|
|
4951
|
+
ctx,
|
|
4952
|
+
WSP_GGML_TYPE_F16,
|
|
4953
|
+
hparams.kernel_sizes[3],
|
|
4954
|
+
hparams.encoder_in_channels[3],
|
|
4955
|
+
hparams.encoder_out_channels[3]
|
|
4956
|
+
));
|
|
4957
|
+
model.encoder_3_bias = create_tensor(VAD_TENSOR_ENC_3_BIAS,
|
|
4958
|
+
wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, hparams.encoder_out_channels[3]));
|
|
4959
|
+
|
|
4960
|
+
// Hidden State dimension (input gate, forget gate, cell gate, output gate)
|
|
4961
|
+
const int hstate_dim = hparams.lstm_hidden_size * 4;
|
|
4962
|
+
|
|
4963
|
+
// LSTM weights - input to hidden
|
|
4964
|
+
model.lstm_ih_weight = create_tensor(
|
|
4965
|
+
VAD_TENSOR_LSTM_WEIGHT_IH,
|
|
4966
|
+
wsp_ggml_new_tensor_2d(ctx, WSP_GGML_TYPE_F32, hparams.lstm_hidden_size, hstate_dim)
|
|
4967
|
+
);
|
|
4968
|
+
model.lstm_ih_bias = create_tensor(
|
|
4969
|
+
VAD_TENSOR_LSTM_BIAS_IH,
|
|
4970
|
+
wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, hstate_dim)
|
|
4971
|
+
);
|
|
4972
|
+
|
|
4973
|
+
// LSTM weights - hidden to hidden
|
|
4974
|
+
model.lstm_hh_weight = create_tensor(
|
|
4975
|
+
VAD_TENSOR_LSTM_WEIGHT_HH,
|
|
4976
|
+
wsp_ggml_new_tensor_2d(ctx, WSP_GGML_TYPE_F32, hparams.lstm_hidden_size, hstate_dim)
|
|
4977
|
+
);
|
|
4978
|
+
model.lstm_hh_bias = create_tensor(
|
|
4979
|
+
VAD_TENSOR_LSTM_BIAS_HH,
|
|
4980
|
+
wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, hstate_dim)
|
|
4981
|
+
);
|
|
4982
|
+
|
|
4983
|
+
// Final conv layer weight
|
|
4984
|
+
model.final_conv_weight = create_tensor(
|
|
4985
|
+
VAD_TENSOR_FINAL_CONV_WEIGHT,
|
|
4986
|
+
wsp_ggml_new_tensor_2d(ctx, WSP_GGML_TYPE_F16, hparams.final_conv_in, 1)
|
|
4987
|
+
);
|
|
4988
|
+
model.final_conv_bias = create_tensor(
|
|
4989
|
+
VAD_TENSOR_FINAL_CONV_BIAS,
|
|
4990
|
+
wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, 1)
|
|
4991
|
+
);
|
|
4992
|
+
|
|
4993
|
+
wsp_ggml_free(ctx);
|
|
4994
|
+
}
|
|
4995
|
+
|
|
4996
|
+
// allocate tensors in the backend buffers
|
|
4997
|
+
for (auto & p : ctx_map) {
|
|
4998
|
+
wsp_ggml_backend_buffer_type_t buft = p.first;
|
|
4999
|
+
wsp_ggml_context * ctx = p.second;
|
|
5000
|
+
wsp_ggml_backend_buffer_t buf = wsp_ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
|
|
5001
|
+
if (buf) {
|
|
5002
|
+
model.buffers.emplace_back(buf);
|
|
5003
|
+
|
|
5004
|
+
size_t size_main = wsp_ggml_backend_buffer_get_size(buf);
|
|
5005
|
+
WHISPER_LOG_INFO("%s: %12s total size = %8.2f MB\n", __func__, wsp_ggml_backend_buffer_name(buf), size_main / 1e6);
|
|
5006
|
+
}
|
|
5007
|
+
}
|
|
5008
|
+
|
|
5009
|
+
// load weights
|
|
5010
|
+
{
|
|
5011
|
+
size_t total_size = 0;
|
|
5012
|
+
model.n_loaded = 0;
|
|
5013
|
+
std::vector<char> read_buf;
|
|
5014
|
+
|
|
5015
|
+
while (true) {
|
|
5016
|
+
int32_t n_dims;
|
|
5017
|
+
int32_t length;
|
|
5018
|
+
int32_t ttype;
|
|
5019
|
+
|
|
5020
|
+
read_safe(loader, n_dims);
|
|
5021
|
+
read_safe(loader, length);
|
|
5022
|
+
read_safe(loader, ttype);
|
|
5023
|
+
|
|
5024
|
+
if (loader->eof(loader->context)) {
|
|
5025
|
+
break;
|
|
5026
|
+
}
|
|
5027
|
+
|
|
5028
|
+
int32_t nelements = 1;
|
|
5029
|
+
int32_t ne[4] = { 1, 1, 1, 1 };
|
|
5030
|
+
for (int i = 0; i < n_dims; ++i) {
|
|
5031
|
+
read_safe(loader, ne[i]);
|
|
5032
|
+
nelements *= ne[i];
|
|
5033
|
+
}
|
|
5034
|
+
|
|
5035
|
+
std::string name;
|
|
5036
|
+
std::vector<char> tmp(length);
|
|
5037
|
+
loader->read(loader->context, &tmp[0], tmp.size());
|
|
5038
|
+
name.assign(&tmp[0], tmp.size());
|
|
5039
|
+
|
|
5040
|
+
if (model.tensors.find(name) == model.tensors.end()) {
|
|
5041
|
+
WHISPER_LOG_ERROR("%s: unknown tensor '%s' in model file\n", __func__, name.data());
|
|
5042
|
+
return nullptr;
|
|
5043
|
+
}
|
|
5044
|
+
|
|
5045
|
+
auto tensor = model.tensors[name.data()];
|
|
5046
|
+
|
|
5047
|
+
if (wsp_ggml_nelements(tensor) != nelements) {
|
|
5048
|
+
WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
|
|
5049
|
+
WHISPER_LOG_ERROR("%s: shape: [%d, %d, %d], expected: [%d, %d, %d]\n",
|
|
5050
|
+
__func__, ne[0], ne[1], ne[2], (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2]);
|
|
5051
|
+
return nullptr;
|
|
5052
|
+
}
|
|
5053
|
+
|
|
5054
|
+
if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1] || tensor->ne[2] != ne[2]) {
|
|
5055
|
+
WHISPER_LOG_ERROR("%s: tensor '%s' has wrong shape in model file: got [%d, %d, %d], expected [%d, %d, %d]\n",
|
|
5056
|
+
__func__, name.data(), (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2], ne[0], ne[1], ne[2]);
|
|
5057
|
+
return nullptr;
|
|
5058
|
+
}
|
|
5059
|
+
|
|
5060
|
+
const size_t bpe = wsp_ggml_type_size(wsp_ggml_type(ttype));
|
|
5061
|
+
|
|
5062
|
+
if ((nelements*bpe)/wsp_ggml_blck_size(tensor->type) != wsp_ggml_nbytes(tensor)) {
|
|
5063
|
+
WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
|
|
5064
|
+
__func__, name.data(), wsp_ggml_nbytes(tensor), nelements*bpe);
|
|
5065
|
+
return nullptr;
|
|
5066
|
+
}
|
|
5067
|
+
|
|
5068
|
+
if (wsp_ggml_backend_buffer_is_host(tensor->buffer)) {
|
|
5069
|
+
// for the CPU and Metal backend, we can read directly into the tensor
|
|
5070
|
+
loader->read(loader->context, tensor->data, wsp_ggml_nbytes(tensor));
|
|
5071
|
+
BYTESWAP_TENSOR(tensor);
|
|
5072
|
+
} else {
|
|
5073
|
+
// read into a temporary buffer first, then copy to device memory
|
|
5074
|
+
read_buf.resize(wsp_ggml_nbytes(tensor));
|
|
5075
|
+
|
|
5076
|
+
loader->read(loader->context, read_buf.data(), read_buf.size());
|
|
5077
|
+
|
|
5078
|
+
wsp_ggml_backend_tensor_set(tensor, read_buf.data(), 0, wsp_ggml_nbytes(tensor));
|
|
5079
|
+
}
|
|
5080
|
+
|
|
5081
|
+
total_size += wsp_ggml_nbytes(tensor);
|
|
5082
|
+
model.n_loaded++;
|
|
5083
|
+
}
|
|
5084
|
+
|
|
5085
|
+
WHISPER_LOG_INFO("%s: model size = %7.2f MB\n", __func__, total_size/1e6);
|
|
5086
|
+
|
|
5087
|
+
if (model.n_loaded == 0) {
|
|
5088
|
+
WHISPER_LOG_WARN("%s: WARN no tensors loaded from model file - assuming empty model for testing\n", __func__);
|
|
5089
|
+
} else if (model.n_loaded != (int) model.tensors.size()) {
|
|
5090
|
+
WHISPER_LOG_ERROR("%s: ERROR not all tensors loaded from model file - expected %zu, got %d\n", __func__, model.tensors.size(), model.n_loaded);
|
|
5091
|
+
return nullptr;
|
|
5092
|
+
}
|
|
5093
|
+
|
|
5094
|
+
}
|
|
5095
|
+
|
|
5096
|
+
if (!whisper_vad_init_context(vctx)) {
|
|
5097
|
+
whisper_vad_free(vctx);
|
|
5098
|
+
return nullptr;
|
|
5099
|
+
}
|
|
5100
|
+
|
|
5101
|
+
return vctx;
|
|
5102
|
+
}
|
|
5103
|
+
|
|
5104
|
+
bool whisper_vad_detect_speech(
|
|
5105
|
+
struct whisper_vad_context * vctx,
|
|
5106
|
+
const float * samples,
|
|
5107
|
+
int n_samples) {
|
|
5108
|
+
int n_chunks = n_samples / vctx->n_window;
|
|
5109
|
+
if (n_samples % vctx->n_window != 0) {
|
|
5110
|
+
n_chunks += 1; // Add one more chunk for remaining samples.
|
|
5111
|
+
}
|
|
5112
|
+
|
|
5113
|
+
WHISPER_LOG_INFO("%s: detecting speech in %d samples\n", __func__, n_samples);
|
|
5114
|
+
WHISPER_LOG_INFO("%s: n_chunks: %d\n", __func__, n_chunks);
|
|
5115
|
+
|
|
5116
|
+
// Reset LSTM hidden/cell states
|
|
5117
|
+
wsp_ggml_backend_buffer_clear(vctx->buffer, 0);
|
|
5118
|
+
|
|
5119
|
+
vctx->probs.resize(n_chunks);
|
|
5120
|
+
WHISPER_LOG_INFO("%s: props size: %u\n", __func__, n_chunks);
|
|
5121
|
+
|
|
5122
|
+
std::vector<float> window(vctx->n_window, 0.0f);
|
|
5123
|
+
|
|
5124
|
+
auto & sched = vctx->sched.sched;
|
|
5125
|
+
|
|
5126
|
+
wsp_ggml_cgraph * gf = whisper_vad_build_graph(*vctx);
|
|
5127
|
+
|
|
5128
|
+
if (!wsp_ggml_backend_sched_alloc_graph(sched, gf)) {
|
|
5129
|
+
WHISPER_LOG_ERROR("%s: failed to allocate the compute buffer\n", __func__);
|
|
5130
|
+
return false;
|
|
5131
|
+
}
|
|
5132
|
+
|
|
5133
|
+
struct wsp_ggml_tensor * frame = wsp_ggml_graph_get_tensor(gf, "frame");
|
|
5134
|
+
struct wsp_ggml_tensor * prob = wsp_ggml_graph_get_tensor(gf, "prob");
|
|
5135
|
+
|
|
5136
|
+
// we are going to reuse the graph multiple times for each chunk
|
|
5137
|
+
const int64_t t_start_vad_us = wsp_ggml_time_us();
|
|
5138
|
+
|
|
5139
|
+
for (int i = 0; i < n_chunks; i++) {
|
|
5140
|
+
const int idx_start = i * vctx->n_window;
|
|
5141
|
+
const int idx_end = std::min(idx_start + vctx->n_window, n_samples);
|
|
5142
|
+
|
|
5143
|
+
const int chunk_len = idx_end - idx_start;
|
|
5144
|
+
|
|
5145
|
+
if (chunk_len < vctx->n_window) {
|
|
5146
|
+
WHISPER_LOG_INFO("%s: chunk_len: %d < n_window: %d\n", __func__, chunk_len, vctx->n_window);
|
|
5147
|
+
std::vector<float> partial_chunk(vctx->n_window, 0.0f);
|
|
5148
|
+
std::copy(samples + idx_start, samples + idx_end, partial_chunk.begin());
|
|
5149
|
+
|
|
5150
|
+
// Copy the zero-padded chunk to the window.
|
|
5151
|
+
const int samples_to_copy_max = vctx->n_window;
|
|
5152
|
+
const int samples_to_copy_cur = std::min(samples_to_copy_max, (int)partial_chunk.size());
|
|
5153
|
+
std::copy(partial_chunk.begin(), partial_chunk.begin() + samples_to_copy_cur, window.begin());
|
|
5154
|
+
if (samples_to_copy_cur < samples_to_copy_max) {
|
|
5155
|
+
std::fill(window.begin() + samples_to_copy_cur, window.end(), 0.0f);
|
|
5156
|
+
}
|
|
5157
|
+
} else {
|
|
5158
|
+
// Copy current frame samples to the window.
|
|
5159
|
+
const int samples_to_copy = std::min(idx_end - idx_start, vctx->n_window);
|
|
5160
|
+
std::copy(samples + idx_start, samples + idx_start + samples_to_copy, window.begin());
|
|
5161
|
+
}
|
|
5162
|
+
|
|
5163
|
+
// Set the frame tensor data with the samples.
|
|
5164
|
+
wsp_ggml_backend_tensor_set(frame, window.data(), 0, wsp_ggml_nelements(frame) * sizeof(float));
|
|
5165
|
+
|
|
5166
|
+
// do not reset the scheduler - we will reuse the graph in the next chunk
|
|
5167
|
+
if (!wsp_ggml_graph_compute_helper(sched, gf, vctx->n_threads, false)) {
|
|
5168
|
+
WHISPER_LOG_ERROR("%s: failed to compute VAD graph\n", __func__);
|
|
5169
|
+
break;
|
|
5170
|
+
}
|
|
5171
|
+
|
|
5172
|
+
// Get the probability for this chunk.
|
|
5173
|
+
wsp_ggml_backend_tensor_get(prob, &vctx->probs[i], 0, sizeof(float));
|
|
5174
|
+
|
|
5175
|
+
//WHISPER_LOG_DEBUG("chunk %d: p = %7.3f\n", i, probs[i]);
|
|
5176
|
+
}
|
|
5177
|
+
|
|
5178
|
+
vctx->t_vad_us += wsp_ggml_time_us() - t_start_vad_us;
|
|
5179
|
+
WHISPER_LOG_INFO("%s: vad time = %.2f ms processing %d samples\n", __func__, 1e-3f * vctx->t_vad_us, n_samples);
|
|
5180
|
+
|
|
5181
|
+
wsp_ggml_backend_sched_reset(sched);
|
|
5182
|
+
|
|
5183
|
+
return true;
|
|
5184
|
+
}
|
|
5185
|
+
|
|
5186
|
+
int whisper_vad_segments_n_segments(struct whisper_vad_segments * segments) {
|
|
5187
|
+
return segments->data.size();
|
|
5188
|
+
}
|
|
5189
|
+
|
|
5190
|
+
float whisper_vad_segments_get_segment_t0(struct whisper_vad_segments * segments, int i_segment) {
|
|
5191
|
+
return segments->data[i_segment].start;
|
|
5192
|
+
}
|
|
5193
|
+
|
|
5194
|
+
float whisper_vad_segments_get_segment_t1(struct whisper_vad_segments * segments, int i_segment) {
|
|
5195
|
+
return segments->data[i_segment].end;
|
|
5196
|
+
}
|
|
5197
|
+
|
|
5198
|
+
int whisper_vad_n_probs(struct whisper_vad_context * vctx) {
|
|
5199
|
+
return vctx->probs.size();
|
|
5200
|
+
}
|
|
5201
|
+
|
|
5202
|
+
float * whisper_vad_probs(struct whisper_vad_context * vctx) {
|
|
5203
|
+
return vctx->probs.data();
|
|
5204
|
+
}
|
|
5205
|
+
|
|
5206
|
+
struct whisper_vad_segments * whisper_vad_segments_from_probs(
|
|
5207
|
+
struct whisper_vad_context * vctx,
|
|
5208
|
+
whisper_vad_params params) {
|
|
5209
|
+
WHISPER_LOG_INFO("%s: detecting speech timestamps using %d probabilities\n", __func__, whisper_vad_n_probs(vctx));
|
|
5210
|
+
|
|
5211
|
+
int n_probs = whisper_vad_n_probs(vctx);
|
|
5212
|
+
float * probs = whisper_vad_probs(vctx);
|
|
5213
|
+
float threshold = params.threshold;
|
|
5214
|
+
int min_speech_duration_ms = params.min_speech_duration_ms;
|
|
5215
|
+
int min_silence_duration_ms = params.min_silence_duration_ms;
|
|
5216
|
+
float max_speech_duration_s = params.max_speech_duration_s;
|
|
5217
|
+
int speech_pad_ms = params.speech_pad_ms;
|
|
5218
|
+
int n_window = vctx->n_window;
|
|
5219
|
+
int sample_rate = WHISPER_SAMPLE_RATE;
|
|
5220
|
+
int min_silence_samples = sample_rate * min_silence_duration_ms / 1000;
|
|
5221
|
+
int audio_length_samples = n_probs * n_window;
|
|
5222
|
+
|
|
5223
|
+
// Min number of samples to be considered valid speech.
|
|
5224
|
+
int min_speech_samples = sample_rate * min_speech_duration_ms / 1000;
|
|
5225
|
+
int speech_pad_samples = sample_rate * speech_pad_ms / 1000;
|
|
5226
|
+
|
|
5227
|
+
// Max number of samples that a speech segment can contain before it is
|
|
5228
|
+
// split into multiple segments.
|
|
5229
|
+
int max_speech_samples;
|
|
5230
|
+
if (max_speech_duration_s > 100000.0f) {
|
|
5231
|
+
max_speech_samples = INT_MAX / 2;
|
|
5232
|
+
} else {
|
|
5233
|
+
int64_t temp = (int64_t)sample_rate * (int64_t)(max_speech_duration_s) - n_window - 2 * speech_pad_samples;
|
|
5234
|
+
max_speech_samples = (temp > INT_MAX) ? INT_MAX / 2 : (int)temp;
|
|
5235
|
+
if (max_speech_samples < 0) {
|
|
5236
|
+
max_speech_samples = INT_MAX / 2;
|
|
5237
|
+
}
|
|
5238
|
+
}
|
|
5239
|
+
// Detect silence period that exceeds this value, then that location (sample)
|
|
5240
|
+
// is marked as a potential place where the segment could be split if
|
|
5241
|
+
// max_speech_samples is reached. The value 98 was taken from the original
|
|
5242
|
+
// silaro-vad python implementation:
|
|
5243
|
+
//https://github.com/snakers4/silero-vad/blob/0dd45f0bcd7271463c234f3bae5ad25181f9df8b/src/silero_vad/utils_vad.py#L291
|
|
5244
|
+
int min_silence_samples_at_max_speech = sample_rate * 98 / 1000;
|
|
5245
|
+
|
|
5246
|
+
// Calculate lower threshold for detecting end of speech segments.
|
|
5247
|
+
float neg_threshold = threshold - 0.15f;
|
|
5248
|
+
if (neg_threshold < 0.01f) {
|
|
5249
|
+
neg_threshold = 0.01f;
|
|
5250
|
+
}
|
|
5251
|
+
|
|
5252
|
+
struct speech_segment_t {
|
|
5253
|
+
int start;
|
|
5254
|
+
int end;
|
|
5255
|
+
};
|
|
5256
|
+
|
|
5257
|
+
std::vector<speech_segment_t> speeches;
|
|
5258
|
+
speeches.reserve(256);
|
|
5259
|
+
|
|
5260
|
+
bool is_speech_segment = false;
|
|
5261
|
+
int temp_end = 0;
|
|
5262
|
+
int prev_end = 0;
|
|
5263
|
+
int next_start = 0;
|
|
5264
|
+
int curr_speech_start = 0;
|
|
5265
|
+
bool has_curr_speech = false;
|
|
5266
|
+
|
|
5267
|
+
for (int i = 0; i < n_probs; i++) {
|
|
5268
|
+
float curr_prob = probs[i];
|
|
5269
|
+
int curr_sample = n_window * i;
|
|
5270
|
+
|
|
5271
|
+
// Reset temp_end when we get back to speech
|
|
5272
|
+
if ((curr_prob >= threshold) && temp_end) {
|
|
5273
|
+
temp_end = 0;
|
|
5274
|
+
if (next_start < prev_end) {
|
|
5275
|
+
next_start = curr_sample;
|
|
5276
|
+
}
|
|
5277
|
+
}
|
|
5278
|
+
|
|
5279
|
+
// Start a new speech segment when probability exceeds threshold and not already in speech
|
|
5280
|
+
if ((curr_prob >= threshold) && !is_speech_segment) {
|
|
5281
|
+
is_speech_segment = true;
|
|
5282
|
+
curr_speech_start = curr_sample;
|
|
5283
|
+
has_curr_speech = true;
|
|
5284
|
+
continue;
|
|
5285
|
+
}
|
|
5286
|
+
|
|
5287
|
+
// Handle maximum speech duration
|
|
5288
|
+
if (is_speech_segment && (curr_sample - curr_speech_start) > max_speech_samples) {
|
|
5289
|
+
if (prev_end) {
|
|
5290
|
+
speeches.push_back({ curr_speech_start, prev_end });
|
|
5291
|
+
has_curr_speech = true;
|
|
5292
|
+
|
|
5293
|
+
if (next_start < prev_end) { // Previously reached silence and is still not speech
|
|
5294
|
+
is_speech_segment = false;
|
|
5295
|
+
has_curr_speech = false;
|
|
5296
|
+
} else {
|
|
5297
|
+
curr_speech_start = next_start;
|
|
5298
|
+
}
|
|
5299
|
+
prev_end = next_start = temp_end = 0;
|
|
5300
|
+
} else {
|
|
5301
|
+
speeches.push_back({ curr_speech_start, curr_sample });
|
|
5302
|
+
|
|
5303
|
+
prev_end = next_start = temp_end = 0;
|
|
5304
|
+
is_speech_segment = false;
|
|
5305
|
+
has_curr_speech = false;
|
|
5306
|
+
continue;
|
|
5307
|
+
}
|
|
5308
|
+
}
|
|
5309
|
+
|
|
5310
|
+
// Handle silence after speech
|
|
5311
|
+
if ((curr_prob < neg_threshold) && is_speech_segment) {
|
|
5312
|
+
if (!temp_end) {
|
|
5313
|
+
temp_end = curr_sample;
|
|
5314
|
+
}
|
|
5315
|
+
|
|
5316
|
+
// Track potential segment ends for max_speech handling
|
|
5317
|
+
if ((curr_sample - temp_end) > min_silence_samples_at_max_speech) {
|
|
5318
|
+
prev_end = temp_end;
|
|
5319
|
+
}
|
|
5320
|
+
|
|
5321
|
+
// Check if silence is long enough to end the segment
|
|
5322
|
+
if ((curr_sample - temp_end) < min_silence_samples) {
|
|
5323
|
+
continue;
|
|
5324
|
+
} else {
|
|
5325
|
+
// End the segment if it's long enough
|
|
5326
|
+
if ((temp_end - curr_speech_start) > min_speech_samples) {
|
|
5327
|
+
speeches.push_back({ curr_speech_start, temp_end });
|
|
5328
|
+
}
|
|
5329
|
+
|
|
5330
|
+
prev_end = next_start = temp_end = 0;
|
|
5331
|
+
is_speech_segment = false;
|
|
5332
|
+
has_curr_speech = false;
|
|
5333
|
+
continue;
|
|
5334
|
+
}
|
|
5335
|
+
}
|
|
5336
|
+
}
|
|
5337
|
+
|
|
5338
|
+
// Handle the case if we're still in a speech segment at the end
|
|
5339
|
+
if (has_curr_speech && (audio_length_samples - curr_speech_start) > min_speech_samples) {
|
|
5340
|
+
speeches.push_back({ curr_speech_start, audio_length_samples });
|
|
5341
|
+
}
|
|
5342
|
+
|
|
5343
|
+
// Merge adjacent segments with small gaps in between (post-processing)
|
|
5344
|
+
if (speeches.size() > 1) {
|
|
5345
|
+
int merged_count = 0;
|
|
5346
|
+
for (int i = 0; i < (int) speeches.size() - 1; i++) {
|
|
5347
|
+
// Define maximum gap allowed for merging (e.g., 200ms converted to samples)
|
|
5348
|
+
int max_merge_gap_samples = sample_rate * 200 / 1000;
|
|
5349
|
+
|
|
5350
|
+
// If the gap between this segment and the next is small enough
|
|
5351
|
+
if (speeches[i+1].start - speeches[i].end < max_merge_gap_samples) {
|
|
5352
|
+
// Merge by extending current segment to the end of next segment
|
|
5353
|
+
speeches[i].end = speeches[i+1].end;
|
|
5354
|
+
speeches.erase(speeches.begin() + i + 1);
|
|
5355
|
+
|
|
5356
|
+
i--;
|
|
5357
|
+
merged_count++;
|
|
5358
|
+
}
|
|
5359
|
+
}
|
|
5360
|
+
WHISPER_LOG_INFO("%s: Merged %d adjacent segments, now have %d segments\n",
|
|
5361
|
+
__func__, merged_count, (int) speeches.size());
|
|
5362
|
+
}
|
|
5363
|
+
|
|
5364
|
+
// Double-check for minimum speech duration
|
|
5365
|
+
for (int i = 0; i < (int) speeches.size(); i++) {
|
|
5366
|
+
if (speeches[i].end - speeches[i].start < min_speech_samples) {
|
|
5367
|
+
WHISPER_LOG_INFO("%s: Removing segment %d (too short: %d samples)\n",
|
|
5368
|
+
__func__, i, speeches[i].end - speeches[i].start);
|
|
5369
|
+
|
|
5370
|
+
speeches.erase(speeches.begin() + i);
|
|
5371
|
+
i--;
|
|
5372
|
+
}
|
|
5373
|
+
}
|
|
5374
|
+
|
|
5375
|
+
WHISPER_LOG_INFO("%s: Final speech segments after filtering: %d\n", __func__, (int) speeches.size());
|
|
5376
|
+
|
|
5377
|
+
// Allocate final segments
|
|
5378
|
+
std::vector<whisper_vad_segment> segments;
|
|
5379
|
+
if (speeches.size() > 0) {
|
|
5380
|
+
try {
|
|
5381
|
+
segments.resize(speeches.size());
|
|
5382
|
+
} catch (const std::bad_alloc &) {
|
|
5383
|
+
WHISPER_LOG_ERROR("%s: failed to allocate memory for final segments\n", __func__);
|
|
5384
|
+
return nullptr;
|
|
5385
|
+
}
|
|
5386
|
+
}
|
|
5387
|
+
|
|
5388
|
+
// Apply padding to segments and copy to final segments
|
|
5389
|
+
for (int i = 0; i < (int) speeches.size(); i++) {
|
|
5390
|
+
// Apply padding to the start of the first segment
|
|
5391
|
+
if (i == 0) {
|
|
5392
|
+
speeches[i].start =
|
|
5393
|
+
(speeches[i].start > speech_pad_samples) ?
|
|
5394
|
+
(speeches[i].start - speech_pad_samples) : 0;
|
|
5395
|
+
}
|
|
5396
|
+
|
|
5397
|
+
// Handle spacing between segments
|
|
5398
|
+
if (i < (int) speeches.size() - 1) {
|
|
5399
|
+
int silence_duration = speeches[i+1].start - speeches[i].end;
|
|
5400
|
+
|
|
5401
|
+
if (silence_duration < 2 * speech_pad_samples) {
|
|
5402
|
+
// If segments are close, split the difference
|
|
5403
|
+
speeches[i].end += silence_duration / 2;
|
|
5404
|
+
speeches[i+1].start =
|
|
5405
|
+
(speeches[i+1].start > silence_duration / 2) ?
|
|
5406
|
+
(speeches[i+1].start - silence_duration / 2) : 0;
|
|
5407
|
+
} else {
|
|
5408
|
+
// Otherwise, apply full padding to both
|
|
5409
|
+
speeches[i].end =
|
|
5410
|
+
(speeches[i].end + speech_pad_samples < audio_length_samples) ?
|
|
5411
|
+
(speeches[i].end + speech_pad_samples) : audio_length_samples;
|
|
5412
|
+
speeches[i+1].start =
|
|
5413
|
+
(speeches[i+1].start > speech_pad_samples) ?
|
|
5414
|
+
(speeches[i+1].start - speech_pad_samples) : 0;
|
|
5415
|
+
}
|
|
5416
|
+
} else {
|
|
5417
|
+
// Apply padding to the end of the last segment
|
|
5418
|
+
speeches[i].end =
|
|
5419
|
+
(speeches[i].end + speech_pad_samples < audio_length_samples) ?
|
|
5420
|
+
(speeches[i].end + speech_pad_samples) : audio_length_samples;
|
|
5421
|
+
}
|
|
5422
|
+
|
|
5423
|
+
// Convert from samples to centiseconds
|
|
5424
|
+
segments[i].start = samples_to_cs(speeches[i].start);
|
|
5425
|
+
segments[i].end = samples_to_cs(speeches[i].end);
|
|
5426
|
+
|
|
5427
|
+
WHISPER_LOG_INFO("%s: VAD segment %d: start = %.2f, end = %.2f (duration: %.2f)\n",
|
|
5428
|
+
__func__, i, segments[i].start/100.0, segments[i].end/100.0, (segments[i].end - segments[i].start)/100.0);
|
|
5429
|
+
}
|
|
5430
|
+
|
|
5431
|
+
whisper_vad_segments * vad_segments = new whisper_vad_segments;
|
|
5432
|
+
if (vad_segments == NULL) {
|
|
5433
|
+
WHISPER_LOG_ERROR("%s: failed to allocate memory for whisper_vad_segments\n", __func__);
|
|
5434
|
+
return nullptr;
|
|
5435
|
+
}
|
|
5436
|
+
|
|
5437
|
+
vad_segments->data = std::move(segments);
|
|
5438
|
+
|
|
5439
|
+
return vad_segments;
|
|
5440
|
+
}
|
|
5441
|
+
|
|
5442
|
+
struct whisper_vad_segments * whisper_vad_segments_from_samples(
|
|
5443
|
+
whisper_vad_context * vctx,
|
|
5444
|
+
whisper_vad_params params,
|
|
5445
|
+
const float * samples,
|
|
5446
|
+
int n_samples) {
|
|
5447
|
+
WHISPER_LOG_INFO("%s: detecting speech timestamps in %d samples\n", __func__, n_samples);
|
|
5448
|
+
if (!whisper_vad_detect_speech(vctx, samples, n_samples)) {
|
|
5449
|
+
WHISPER_LOG_ERROR("%s: failed to detect speech\n", __func__);
|
|
5450
|
+
return nullptr;
|
|
5451
|
+
}
|
|
5452
|
+
return whisper_vad_segments_from_probs(vctx, params);
|
|
5453
|
+
}
|
|
5454
|
+
|
|
5455
|
+
void whisper_vad_free(whisper_vad_context * ctx) {
|
|
5456
|
+
if (ctx) {
|
|
5457
|
+
for (wsp_ggml_context * context : ctx->model.ctxs) {
|
|
5458
|
+
wsp_ggml_free(context);
|
|
5459
|
+
}
|
|
5460
|
+
|
|
5461
|
+
for (wsp_ggml_backend_buffer_t buf : ctx->model.buffers) {
|
|
5462
|
+
wsp_ggml_backend_buffer_free(buf);
|
|
5463
|
+
}
|
|
5464
|
+
|
|
5465
|
+
wsp_ggml_backend_sched_free(ctx->sched.sched);
|
|
5466
|
+
|
|
5467
|
+
for (auto & backend : ctx->backends) {
|
|
5468
|
+
wsp_ggml_backend_free(backend);
|
|
5469
|
+
}
|
|
5470
|
+
|
|
5471
|
+
|
|
5472
|
+
delete ctx;
|
|
5473
|
+
}
|
|
5474
|
+
}
|
|
5475
|
+
|
|
5476
|
+
void whisper_vad_free_segments(whisper_vad_segments * segments) {
|
|
5477
|
+
if (segments) {
|
|
5478
|
+
delete segments;
|
|
5479
|
+
}
|
|
5480
|
+
}
|
|
5481
|
+
|
|
5482
|
+
//////////////////////////////////
|
|
5483
|
+
// Grammar - ported from llama.cpp
|
|
5484
|
+
//////////////////////////////////
|
|
5485
|
+
|
|
5486
|
+
// Decodes a UTF-8 string which may end in an incomplete sequence. Adds a terminating 0 for use as
|
|
5487
|
+
// pointer. If an invalid sequence is encountered, returns `whisper_partial_utf8.n_remain == -1`.
|
|
5488
|
+
static std::pair<std::vector<uint32_t>, whisper_partial_utf8> decode_utf8(
|
|
5489
|
+
const char * src,
|
|
5490
|
+
whisper_partial_utf8 partial_start) {
|
|
5491
|
+
static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4 };
|
|
5492
|
+
const char * pos = src;
|
|
5493
|
+
std::vector<uint32_t> code_points;
|
|
5494
|
+
uint32_t value = partial_start.value;
|
|
5495
|
+
int n_remain = partial_start.n_remain;
|
|
5496
|
+
|
|
5497
|
+
// continue previous decode, if applicable
|
|
5498
|
+
while (*pos != 0 && n_remain > 0) {
|
|
5499
|
+
uint8_t next_byte = static_cast<uint8_t>(*pos);
|
|
5500
|
+
if ((next_byte >> 6) != 2) {
|
|
5501
|
+
// invalid sequence, abort
|
|
5502
|
+
code_points.push_back(0);
|
|
5503
|
+
return std::make_pair(std::move(code_points), whisper_partial_utf8{ 0, -1 });
|
|
5504
|
+
}
|
|
5505
|
+
value = (value << 6) + (next_byte & 0x3F);
|
|
5506
|
+
++pos;
|
|
5507
|
+
--n_remain;
|
|
5508
|
+
}
|
|
5509
|
+
|
|
5510
|
+
if (partial_start.n_remain > 0 && n_remain == 0) {
|
|
5511
|
+
code_points.push_back(value);
|
|
5512
|
+
}
|
|
5513
|
+
|
|
5514
|
+
// decode any subsequent utf-8 sequences, which may end in an incomplete one
|
|
5515
|
+
while (*pos != 0) {
|
|
5516
|
+
uint8_t first_byte = static_cast<uint8_t>(*pos);
|
|
5517
|
+
uint8_t highbits = first_byte >> 4;
|
|
5518
|
+
n_remain = lookup[highbits] - 1;
|
|
4335
5519
|
|
|
4336
5520
|
if (n_remain < 0) {
|
|
4337
5521
|
// invalid sequence, abort
|
|
@@ -4450,7 +5634,7 @@ static void whisper_grammar_advance_stack(
|
|
|
4450
5634
|
std::vector<std::vector<const whisper_grammar_element *>> & new_stacks) {
|
|
4451
5635
|
|
|
4452
5636
|
if (stack.empty()) {
|
|
4453
|
-
new_stacks.
|
|
5637
|
+
new_stacks.emplace_back();
|
|
4454
5638
|
return;
|
|
4455
5639
|
}
|
|
4456
5640
|
|
|
@@ -4771,7 +5955,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
|
|
|
4771
5955
|
/*.detect_language =*/ false,
|
|
4772
5956
|
|
|
4773
5957
|
/*.suppress_blank =*/ true,
|
|
4774
|
-
/*.
|
|
5958
|
+
/*.suppress_nst =*/ false,
|
|
4775
5959
|
|
|
4776
5960
|
/*.temperature =*/ 0.0f,
|
|
4777
5961
|
/*.max_initial_ts =*/ 1.0f,
|
|
@@ -4811,6 +5995,11 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
|
|
|
4811
5995
|
/*.n_grammar_rules =*/ 0,
|
|
4812
5996
|
/*.i_start_rule =*/ 0,
|
|
4813
5997
|
/*.grammar_penalty =*/ 100.0f,
|
|
5998
|
+
|
|
5999
|
+
/*.vad =*/ false,
|
|
6000
|
+
/*.vad_model_path =*/ nullptr,
|
|
6001
|
+
|
|
6002
|
+
/* vad_params =*/ whisper_vad_default_params(),
|
|
4814
6003
|
};
|
|
4815
6004
|
|
|
4816
6005
|
switch (strategy) {
|
|
@@ -4921,6 +6110,42 @@ static const std::vector<std::string> non_speech_tokens = {
|
|
|
4921
6110
|
"♪♪♪","♩", "♪", "♫", "♬", "♭", "♮", "♯"
|
|
4922
6111
|
};
|
|
4923
6112
|
|
|
6113
|
+
static void whisper_compute_logprobs(
|
|
6114
|
+
const std::vector<float> & logits,
|
|
6115
|
+
const int n_logits,
|
|
6116
|
+
std::vector<float> & logprobs) {
|
|
6117
|
+
const float logit_max = *std::max_element(logits.begin(), logits.end());
|
|
6118
|
+
float logsumexp = 0.0f;
|
|
6119
|
+
for (int i = 0; i < n_logits; ++i) {
|
|
6120
|
+
if (logits[i] > -INFINITY) {
|
|
6121
|
+
logsumexp += expf(logits[i] - logit_max);
|
|
6122
|
+
}
|
|
6123
|
+
}
|
|
6124
|
+
logsumexp = logf(logsumexp) + logit_max;
|
|
6125
|
+
|
|
6126
|
+
for (int i = 0; i < n_logits; ++i) {
|
|
6127
|
+
if (logits[i] > -INFINITY) {
|
|
6128
|
+
logprobs[i] = logits[i] - logsumexp;
|
|
6129
|
+
} else {
|
|
6130
|
+
logprobs[i] = -INFINITY;
|
|
6131
|
+
}
|
|
6132
|
+
}
|
|
6133
|
+
}
|
|
6134
|
+
|
|
6135
|
+
static void whisper_compute_probs(
|
|
6136
|
+
const std::vector<float> & logits,
|
|
6137
|
+
const int n_logits,
|
|
6138
|
+
const std::vector<float> & logprobs,
|
|
6139
|
+
std::vector<float> & probs) {
|
|
6140
|
+
for (int i = 0; i < n_logits; ++i) {
|
|
6141
|
+
if (logits[i] == -INFINITY) {
|
|
6142
|
+
probs[i] = 0.0f;
|
|
6143
|
+
} else {
|
|
6144
|
+
probs[i] = expf(logprobs[i]);
|
|
6145
|
+
}
|
|
6146
|
+
}
|
|
6147
|
+
}
|
|
6148
|
+
|
|
4924
6149
|
// process the logits for the selected decoder
|
|
4925
6150
|
// - applies logit filters
|
|
4926
6151
|
// - computes logprobs and probs
|
|
@@ -4982,7 +6207,7 @@ static void whisper_process_logits(
|
|
|
4982
6207
|
|
|
4983
6208
|
// suppress sot and nosp tokens
|
|
4984
6209
|
logits[vocab.token_sot] = -INFINITY;
|
|
4985
|
-
logits[vocab.token_nosp] = -INFINITY;
|
|
6210
|
+
logits[vocab.token_nosp] = -INFINITY;
|
|
4986
6211
|
|
|
4987
6212
|
// [TDRZ] when tinydiarize is disabled, suppress solm token
|
|
4988
6213
|
if (params.tdrz_enable == false) {
|
|
@@ -5019,7 +6244,7 @@ static void whisper_process_logits(
|
|
|
5019
6244
|
|
|
5020
6245
|
// suppress non-speech tokens
|
|
5021
6246
|
// ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253
|
|
5022
|
-
if (params.
|
|
6247
|
+
if (params.suppress_nst) {
|
|
5023
6248
|
for (const std::string & token : non_speech_tokens) {
|
|
5024
6249
|
const std::string suppress_tokens[] = {token, " " + token};
|
|
5025
6250
|
for (const std::string & suppress_token : suppress_tokens) {
|
|
@@ -5081,24 +6306,7 @@ static void whisper_process_logits(
|
|
|
5081
6306
|
}
|
|
5082
6307
|
|
|
5083
6308
|
// populate the logprobs array (log_softmax)
|
|
5084
|
-
|
|
5085
|
-
const float logit_max = *std::max_element(logits.begin(), logits.end());
|
|
5086
|
-
float logsumexp = 0.0f;
|
|
5087
|
-
for (int i = 0; i < n_logits; ++i) {
|
|
5088
|
-
if (logits[i] > -INFINITY) {
|
|
5089
|
-
logsumexp += expf(logits[i] - logit_max);
|
|
5090
|
-
}
|
|
5091
|
-
}
|
|
5092
|
-
logsumexp = logf(logsumexp) + logit_max;
|
|
5093
|
-
|
|
5094
|
-
for (int i = 0; i < n_logits; ++i) {
|
|
5095
|
-
if (logits[i] > -INFINITY) {
|
|
5096
|
-
logprobs[i] = logits[i] - logsumexp;
|
|
5097
|
-
} else {
|
|
5098
|
-
logprobs[i] = -INFINITY;
|
|
5099
|
-
}
|
|
5100
|
-
}
|
|
5101
|
-
}
|
|
6309
|
+
whisper_compute_logprobs(logits, n_logits, logprobs);
|
|
5102
6310
|
|
|
5103
6311
|
// if sum of probability over timestamps is above any other token, sample timestamp
|
|
5104
6312
|
// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L431-L437
|
|
@@ -5156,15 +6364,7 @@ static void whisper_process_logits(
|
|
|
5156
6364
|
}
|
|
5157
6365
|
|
|
5158
6366
|
// compute probs
|
|
5159
|
-
|
|
5160
|
-
for (int i = 0; i < n_logits; ++i) {
|
|
5161
|
-
if (logits[i] == -INFINITY) {
|
|
5162
|
-
probs[i] = 0.0f;
|
|
5163
|
-
} else {
|
|
5164
|
-
probs[i] = expf(logprobs[i]);
|
|
5165
|
-
}
|
|
5166
|
-
}
|
|
5167
|
-
}
|
|
6367
|
+
whisper_compute_probs(logits, n_logits, logprobs, probs);
|
|
5168
6368
|
|
|
5169
6369
|
#if 0
|
|
5170
6370
|
// print first 100 logits - token string : logit
|
|
@@ -5416,6 +6616,186 @@ static void whisper_sequence_score(
|
|
|
5416
6616
|
}
|
|
5417
6617
|
}
|
|
5418
6618
|
|
|
6619
|
+
static bool whisper_vad(
|
|
6620
|
+
struct whisper_context * ctx,
|
|
6621
|
+
struct whisper_state * state,
|
|
6622
|
+
struct whisper_full_params params,
|
|
6623
|
+
const float * samples,
|
|
6624
|
+
int n_samples,
|
|
6625
|
+
std::vector<float> & filtered_samples) {
|
|
6626
|
+
WHISPER_LOG_INFO("%s: VAD is enabled, processing speech segments only\n", __func__);
|
|
6627
|
+
int filtered_n_samples = 0;
|
|
6628
|
+
|
|
6629
|
+
// Clear any existing mapping table
|
|
6630
|
+
state->vad_mapping_table.clear();
|
|
6631
|
+
state->has_vad_segments = false;
|
|
6632
|
+
|
|
6633
|
+
if (state->vad_context == nullptr) {
|
|
6634
|
+
struct whisper_vad_context_params vad_ctx_params = whisper_vad_default_context_params();
|
|
6635
|
+
struct whisper_vad_context * vctx = whisper_vad_init_from_file_with_params(params.vad_model_path, vad_ctx_params);
|
|
6636
|
+
if (vctx == nullptr) {
|
|
6637
|
+
WHISPER_LOG_ERROR("%s: failed to initialize VAD context\n", __func__);
|
|
6638
|
+
return false;
|
|
6639
|
+
}
|
|
6640
|
+
state->vad_context = vctx;
|
|
6641
|
+
}
|
|
6642
|
+
auto vctx = state->vad_context;
|
|
6643
|
+
|
|
6644
|
+
const whisper_vad_params & vad_params = params.vad_params;
|
|
6645
|
+
|
|
6646
|
+
whisper_vad_segments * vad_segments = whisper_vad_segments_from_samples(vctx, vad_params, samples, n_samples);
|
|
6647
|
+
|
|
6648
|
+
if (vad_segments->data.size() > 0) {
|
|
6649
|
+
state->has_vad_segments = true;
|
|
6650
|
+
ctx->state->vad_segments.clear();
|
|
6651
|
+
ctx->state->vad_segments.reserve(vad_segments->data.size());
|
|
6652
|
+
|
|
6653
|
+
// Initialize the time mapping table
|
|
6654
|
+
state->vad_mapping_table.clear();
|
|
6655
|
+
state->vad_mapping_table.reserve(vad_segments->data.size() * 4);
|
|
6656
|
+
|
|
6657
|
+
WHISPER_LOG_INFO("%s: detected %d speech segments\n", __func__, (int)vad_segments->data.size());
|
|
6658
|
+
float overlap_seconds = vad_params.samples_overlap;
|
|
6659
|
+
int overlap_samples = overlap_seconds * WHISPER_SAMPLE_RATE;
|
|
6660
|
+
|
|
6661
|
+
for (int i = 0; i < (int)vad_segments->data.size(); i++) {
|
|
6662
|
+
int segment_start_samples = cs_to_samples(vad_segments->data[i].start);
|
|
6663
|
+
int segment_end_samples = cs_to_samples(vad_segments->data[i].end);
|
|
6664
|
+
|
|
6665
|
+
if (i < (int)vad_segments->data.size() - 1) {
|
|
6666
|
+
segment_end_samples += overlap_samples;
|
|
6667
|
+
}
|
|
6668
|
+
segment_end_samples = std::min(segment_end_samples, n_samples - 1);
|
|
6669
|
+
filtered_n_samples += (segment_end_samples - segment_start_samples);
|
|
6670
|
+
|
|
6671
|
+
WHISPER_LOG_INFO("%s: Including segment %d: %.2f - %.2f (duration: %.2f)\n",
|
|
6672
|
+
__func__, i, vad_segments->data[i].start/100.0,
|
|
6673
|
+
(vad_segments->data[i].end/100.0 + (i < (int)vad_segments->data.size() - 1 ? overlap_seconds : 0)),
|
|
6674
|
+
(vad_segments->data[i].end - vad_segments->data[i].start)/100.0 +
|
|
6675
|
+
(i < (int)vad_segments->data.size() - 1 ? overlap_seconds : 0));
|
|
6676
|
+
}
|
|
6677
|
+
|
|
6678
|
+
int silence_samples = 0.1 * WHISPER_SAMPLE_RATE;
|
|
6679
|
+
int total_silence_samples = (vad_segments->data.size() > 1) ? (vad_segments->data.size() - 1) * silence_samples : 0;
|
|
6680
|
+
int total_samples_needed = filtered_n_samples + total_silence_samples;
|
|
6681
|
+
|
|
6682
|
+
WHISPER_LOG_INFO("%s: total duration of speech segments: %.2f seconds\n",
|
|
6683
|
+
__func__, (float)filtered_n_samples / WHISPER_SAMPLE_RATE);
|
|
6684
|
+
|
|
6685
|
+
try {
|
|
6686
|
+
filtered_samples.resize(total_samples_needed);
|
|
6687
|
+
} catch (const std::bad_alloc & /* e */) {
|
|
6688
|
+
WHISPER_LOG_ERROR("%s: failed to allocate memory for filtered samples\n", __func__);
|
|
6689
|
+
whisper_vad_free_segments(vad_segments);
|
|
6690
|
+
whisper_vad_free(vctx);
|
|
6691
|
+
return false;
|
|
6692
|
+
}
|
|
6693
|
+
|
|
6694
|
+
int offset = 0;
|
|
6695
|
+
for (int i = 0; i < (int)vad_segments->data.size(); i++) {
|
|
6696
|
+
int segment_start_samples = cs_to_samples(vad_segments->data[i].start);
|
|
6697
|
+
int segment_end_samples = cs_to_samples(vad_segments->data[i].end);
|
|
6698
|
+
|
|
6699
|
+
if (i < (int)vad_segments->data.size() - 1) {
|
|
6700
|
+
segment_end_samples += overlap_samples;
|
|
6701
|
+
}
|
|
6702
|
+
|
|
6703
|
+
segment_start_samples = std::min(segment_start_samples, n_samples - 1);
|
|
6704
|
+
segment_end_samples = std::min(segment_end_samples, n_samples);
|
|
6705
|
+
int segment_length = segment_end_samples - segment_start_samples;
|
|
6706
|
+
if (segment_length > 0) {
|
|
6707
|
+
whisper_state::vad_segment_info segment;
|
|
6708
|
+
|
|
6709
|
+
segment.orig_start = vad_segments->data[i].start;
|
|
6710
|
+
segment.orig_end = vad_segments->data[i].end;
|
|
6711
|
+
|
|
6712
|
+
segment.vad_start = samples_to_cs(offset);
|
|
6713
|
+
segment.vad_end = samples_to_cs(offset + segment_length);
|
|
6714
|
+
|
|
6715
|
+
// Add segment boundaries to mapping table
|
|
6716
|
+
vad_time_mapping start_mapping = {segment.vad_start, segment.orig_start};
|
|
6717
|
+
vad_time_mapping end_mapping = {segment.vad_end, segment.orig_end};
|
|
6718
|
+
|
|
6719
|
+
state->vad_mapping_table.push_back(start_mapping);
|
|
6720
|
+
state->vad_mapping_table.push_back(end_mapping);
|
|
6721
|
+
|
|
6722
|
+
// Add intermediate points for longer segments to improve interpolation accuracy
|
|
6723
|
+
const int64_t min_segment_length = 100; // 1 second
|
|
6724
|
+
const int64_t point_interval = 20; // Add a point every 200ms
|
|
6725
|
+
|
|
6726
|
+
if (segment.vad_end - segment.vad_start > min_segment_length) {
|
|
6727
|
+
int64_t segment_duration = segment.vad_end - segment.vad_start;
|
|
6728
|
+
int num_points = (int)(segment_duration / point_interval) - 1;
|
|
6729
|
+
|
|
6730
|
+
for (int j = 1; j <= num_points; j++) {
|
|
6731
|
+
int64_t vad_time = segment.vad_start + j * point_interval;
|
|
6732
|
+
|
|
6733
|
+
if (vad_time >= segment.vad_end) continue;
|
|
6734
|
+
|
|
6735
|
+
int64_t vad_elapsed = vad_time - segment.vad_start;
|
|
6736
|
+
int64_t vad_total = segment.vad_end - segment.vad_start;
|
|
6737
|
+
int64_t orig_total = segment.orig_end - segment.orig_start;
|
|
6738
|
+
int64_t orig_time = segment.orig_start + (vad_elapsed * orig_total) / vad_total;
|
|
6739
|
+
|
|
6740
|
+
vad_time_mapping intermediate_mapping = {vad_time, orig_time};
|
|
6741
|
+
state->vad_mapping_table.push_back(intermediate_mapping);
|
|
6742
|
+
}
|
|
6743
|
+
}
|
|
6744
|
+
|
|
6745
|
+
WHISPER_LOG_INFO("%s: vad_segment_info: orig_start: %.2f, orig_end: %.2f, vad_start: %.2f, vad_end: %.2f\n",
|
|
6746
|
+
__func__, segment.orig_start/100.0, segment.orig_end/100.0, segment.vad_start/100.0, segment.vad_end/100.0);
|
|
6747
|
+
ctx->state->vad_segments.push_back(segment);
|
|
6748
|
+
|
|
6749
|
+
// Copy this speech segment
|
|
6750
|
+
memcpy(filtered_samples.data() + offset, samples + segment_start_samples, segment_length * sizeof(float));
|
|
6751
|
+
offset += segment_length;
|
|
6752
|
+
|
|
6753
|
+
// Add silence after this segment (except after the last segment)
|
|
6754
|
+
if (i < (int)vad_segments->data.size() - 1) {
|
|
6755
|
+
// Calculate the start and end time of the silence gap in processed audio
|
|
6756
|
+
int64_t silence_start_vad = samples_to_cs(offset);
|
|
6757
|
+
int64_t silence_end_vad = samples_to_cs(offset + silence_samples);
|
|
6758
|
+
// Calculate the corresponding original times
|
|
6759
|
+
int64_t orig_silence_start = segment.orig_end;
|
|
6760
|
+
int64_t orig_silence_end = vad_segments->data[i+1].start;
|
|
6761
|
+
|
|
6762
|
+
// Add mapping points for silence boundaries
|
|
6763
|
+
state->vad_mapping_table.push_back({silence_start_vad, orig_silence_start});
|
|
6764
|
+
state->vad_mapping_table.push_back({silence_end_vad, orig_silence_end});
|
|
6765
|
+
|
|
6766
|
+
// Fill with zeros (silence)
|
|
6767
|
+
memset(filtered_samples.data() + offset, 0, silence_samples * sizeof(float));
|
|
6768
|
+
offset += silence_samples;
|
|
6769
|
+
}
|
|
6770
|
+
}
|
|
6771
|
+
}
|
|
6772
|
+
|
|
6773
|
+
// Sort the mapping table by processed time
|
|
6774
|
+
std::sort(state->vad_mapping_table.begin(), state->vad_mapping_table.end(),
|
|
6775
|
+
[](const vad_time_mapping& a, const vad_time_mapping& b) {
|
|
6776
|
+
return a.processed_time < b.processed_time;
|
|
6777
|
+
});
|
|
6778
|
+
|
|
6779
|
+
// Remove any duplicate processed times to ensure monotonicity which is
|
|
6780
|
+
// needed for binary search and interpolation later.
|
|
6781
|
+
if (!state->vad_mapping_table.empty()) {
|
|
6782
|
+
auto last = std::unique(state->vad_mapping_table.begin(), state->vad_mapping_table.end(),
|
|
6783
|
+
[](const vad_time_mapping& a, const vad_time_mapping& b) {
|
|
6784
|
+
return a.processed_time == b.processed_time;
|
|
6785
|
+
});
|
|
6786
|
+
state->vad_mapping_table.erase(last, state->vad_mapping_table.end());
|
|
6787
|
+
}
|
|
6788
|
+
|
|
6789
|
+
WHISPER_LOG_INFO("%s: Created time mapping table with %d points\n", __func__, (int)state->vad_mapping_table.size());
|
|
6790
|
+
|
|
6791
|
+
filtered_n_samples = offset;
|
|
6792
|
+
WHISPER_LOG_INFO("%s: Reduced audio from %d to %d samples (%.1f%% reduction)\n",
|
|
6793
|
+
__func__, n_samples, filtered_n_samples, 100.0f * (1.0f - (float)filtered_n_samples / n_samples));
|
|
6794
|
+
}
|
|
6795
|
+
|
|
6796
|
+
return true;
|
|
6797
|
+
}
|
|
6798
|
+
|
|
5419
6799
|
int whisper_full_with_state(
|
|
5420
6800
|
struct whisper_context * ctx,
|
|
5421
6801
|
struct whisper_state * state,
|
|
@@ -5465,11 +6845,13 @@ int whisper_full_with_state(
|
|
|
5465
6845
|
const int seek_start = params.offset_ms/10;
|
|
5466
6846
|
const int seek_end = params.duration_ms == 0 ? whisper_n_len_from_state(state) : seek_start + params.duration_ms/10;
|
|
5467
6847
|
|
|
5468
|
-
// if length of spectrogram is less than
|
|
5469
|
-
// basically don't process anything that is less than
|
|
5470
|
-
//
|
|
5471
|
-
|
|
5472
|
-
|
|
6848
|
+
// if length of spectrogram is less than 100ms (10 frames), then return
|
|
6849
|
+
// basically don't process anything that is less than 100ms
|
|
6850
|
+
// ref: https://github.com/ggml-org/whisper.cpp/issues/2065
|
|
6851
|
+
const int delta_min = 10;
|
|
6852
|
+
|
|
6853
|
+
if (seek_end < seek_start + delta_min) {
|
|
6854
|
+
WHISPER_LOG_WARN("%s: input is too short - %d ms < 100 ms. consider padding the input audio with silence\n", __func__, (seek_end - seek_start)*10);
|
|
5473
6855
|
return 0;
|
|
5474
6856
|
}
|
|
5475
6857
|
|
|
@@ -5516,7 +6898,7 @@ int whisper_full_with_state(
|
|
|
5516
6898
|
decoder.logprobs.resize(ctx->vocab.n_vocab);
|
|
5517
6899
|
decoder.logits_id.reserve(ctx->model.hparams.n_vocab);
|
|
5518
6900
|
|
|
5519
|
-
decoder.rng = std::mt19937(
|
|
6901
|
+
decoder.rng = std::mt19937(j);
|
|
5520
6902
|
}
|
|
5521
6903
|
|
|
5522
6904
|
// the accumulated text context so far
|
|
@@ -5613,8 +6995,8 @@ int whisper_full_with_state(
|
|
|
5613
6995
|
ctx, state, progress_cur, params.progress_callback_user_data);
|
|
5614
6996
|
}
|
|
5615
6997
|
|
|
5616
|
-
// if only
|
|
5617
|
-
if (seek +
|
|
6998
|
+
// if only 100ms left, then stop
|
|
6999
|
+
if (seek + delta_min >= seek_end) {
|
|
5618
7000
|
break;
|
|
5619
7001
|
}
|
|
5620
7002
|
|
|
@@ -5743,6 +7125,18 @@ int whisper_full_with_state(
|
|
|
5743
7125
|
return -8;
|
|
5744
7126
|
}
|
|
5745
7127
|
|
|
7128
|
+
// Calculate no_speech probability after first decode.
|
|
7129
|
+
// This has to be done before any logit filtering. Hence we cannot use the probs from the whisper_process_logits.
|
|
7130
|
+
{
|
|
7131
|
+
const int n_logits = ctx->vocab.id_to_token.size();
|
|
7132
|
+
std::vector<float> logprobs(n_logits);
|
|
7133
|
+
std::vector<float> probs(n_logits);
|
|
7134
|
+
|
|
7135
|
+
whisper_compute_logprobs(state->logits, n_logits, logprobs);
|
|
7136
|
+
whisper_compute_probs(state->logits, n_logits, logprobs, probs);
|
|
7137
|
+
state->no_speech_prob = probs[whisper_token_nosp(ctx)];
|
|
7138
|
+
}
|
|
7139
|
+
|
|
5746
7140
|
{
|
|
5747
7141
|
const int64_t t_start_sample_us = wsp_ggml_time_us();
|
|
5748
7142
|
|
|
@@ -5949,10 +7343,10 @@ int whisper_full_with_state(
|
|
|
5949
7343
|
// end of segment
|
|
5950
7344
|
if (token.id == whisper_token_eot(ctx) || // end of text token
|
|
5951
7345
|
(params.max_tokens > 0 && i >= params.max_tokens) || // max tokens per segment reached
|
|
5952
|
-
(has_ts && seek + seek_delta +
|
|
7346
|
+
(has_ts && seek + seek_delta + delta_min >= seek_end) // end of audio reached (100ms)
|
|
5953
7347
|
) {
|
|
5954
7348
|
if (result_len == 0 && !params.no_timestamps) {
|
|
5955
|
-
if (seek + seek_delta +
|
|
7349
|
+
if (seek + seek_delta + delta_min >= seek_end) {
|
|
5956
7350
|
result_len = i + 1;
|
|
5957
7351
|
} else {
|
|
5958
7352
|
WHISPER_LOG_DEBUG("%s: decoder %d failed (result_len = 0)\n", __func__, j);
|
|
@@ -6134,8 +7528,9 @@ int whisper_full_with_state(
|
|
|
6134
7528
|
if (it != (int) temperatures.size() - 1) {
|
|
6135
7529
|
const auto & decoder = state->decoders[best_decoder_id];
|
|
6136
7530
|
|
|
6137
|
-
if (decoder.failed ||
|
|
6138
|
-
|
|
7531
|
+
if (decoder.failed ||
|
|
7532
|
+
(decoder.sequence.avg_logprobs < params.logprob_thold && state->no_speech_prob < params.no_speech_thold)) {
|
|
7533
|
+
WHISPER_LOG_DEBUG("%s: failed due to avg_logprobs %8.5f < %8.5f and no_speech_prob %8.5f < %8.5f\n", __func__, decoder.sequence.avg_logprobs, params.logprob_thold, state->no_speech_prob, params.no_speech_thold);
|
|
6139
7534
|
success = false;
|
|
6140
7535
|
state->n_fail_p++;
|
|
6141
7536
|
}
|
|
@@ -6156,7 +7551,7 @@ int whisper_full_with_state(
|
|
|
6156
7551
|
{
|
|
6157
7552
|
const auto & best_decoder = state->decoders[best_decoder_id];
|
|
6158
7553
|
|
|
6159
|
-
|
|
7554
|
+
auto seek_delta = best_decoder.seek_delta;
|
|
6160
7555
|
const auto result_len = best_decoder.sequence.result_len;
|
|
6161
7556
|
|
|
6162
7557
|
const auto & tokens_cur = best_decoder.sequence.tokens;
|
|
@@ -6164,6 +7559,9 @@ int whisper_full_with_state(
|
|
|
6164
7559
|
// [EXPERIMENTAL] Token-level timestamps with DTW
|
|
6165
7560
|
const auto n_segments_before = state->result_all.size();
|
|
6166
7561
|
|
|
7562
|
+
const bool is_no_speech = (state->no_speech_prob > params.no_speech_thold &&
|
|
7563
|
+
best_decoder.sequence.avg_logprobs < params.logprob_thold);
|
|
7564
|
+
|
|
6167
7565
|
//WHISPER_LOG_DEBUG("prompt_init.size() = %d, prompt.size() = %d, result_len = %d, seek_delta = %d\n", prompt_init.size(), prompt.size(), result_len, seek_delta);
|
|
6168
7566
|
|
|
6169
7567
|
// update prompt_past
|
|
@@ -6172,11 +7570,11 @@ int whisper_full_with_state(
|
|
|
6172
7570
|
prompt_past.insert(prompt_past.end(), prompt.begin() + 1, prompt.end() - prompt_init.size());
|
|
6173
7571
|
}
|
|
6174
7572
|
|
|
6175
|
-
for (int i = 0; i < result_len; ++i) {
|
|
7573
|
+
for (int i = 0; i < result_len && !is_no_speech; ++i) {
|
|
6176
7574
|
prompt_past.push_back(tokens_cur[i].id);
|
|
6177
7575
|
}
|
|
6178
7576
|
|
|
6179
|
-
if (!tokens_cur.empty() && ctx->model.n_loaded > 0) {
|
|
7577
|
+
if (!tokens_cur.empty() && ctx->model.n_loaded > 0 && !is_no_speech) {
|
|
6180
7578
|
int i0 = 0;
|
|
6181
7579
|
auto t0 = seek + 2*(tokens_cur.front().tid - whisper_token_beg(ctx));
|
|
6182
7580
|
|
|
@@ -6215,7 +7613,7 @@ int whisper_full_with_state(
|
|
|
6215
7613
|
|
|
6216
7614
|
//printf("tt0 = %d, tt1 = %d, text = %s, token = %s, token_id = %d, tid = %d\n", tt0, tt1, text.c_str(), ctx->vocab.id_to_token[tokens_cur[i].id].c_str(), tokens_cur[i].id, tokens_cur[i].tid);
|
|
6217
7615
|
|
|
6218
|
-
result_all.push_back({ tt0, tt1, text, {}, speaker_turn_next });
|
|
7616
|
+
result_all.push_back({ tt0, tt1, text, state->no_speech_prob, {}, speaker_turn_next });
|
|
6219
7617
|
for (int j = i0; j <= i; j++) {
|
|
6220
7618
|
result_all.back().tokens.push_back(tokens_cur[j]);
|
|
6221
7619
|
}
|
|
@@ -6260,7 +7658,7 @@ int whisper_full_with_state(
|
|
|
6260
7658
|
}
|
|
6261
7659
|
}
|
|
6262
7660
|
|
|
6263
|
-
result_all.push_back({ tt0, tt1, text, {}
|
|
7661
|
+
result_all.push_back({ tt0, tt1, text, state->no_speech_prob, {}, speaker_turn_next });
|
|
6264
7662
|
for (int j = i0; j < (int) tokens_cur.size(); j++) {
|
|
6265
7663
|
result_all.back().tokens.push_back(tokens_cur[j]);
|
|
6266
7664
|
}
|
|
@@ -6297,6 +7695,15 @@ int whisper_full_with_state(
|
|
|
6297
7695
|
}
|
|
6298
7696
|
}
|
|
6299
7697
|
|
|
7698
|
+
// ref: https://github.com/ggml-org/whisper.cpp/pull/2629
|
|
7699
|
+
const bool single_timestamp_ending = tokens_cur.size() > 1 &&
|
|
7700
|
+
tokens_cur[tokens_cur.size() - 2].id < whisper_token_beg(ctx) &&
|
|
7701
|
+
tokens_cur[tokens_cur.size() - 1].id > whisper_token_beg(ctx);
|
|
7702
|
+
if (single_timestamp_ending) {
|
|
7703
|
+
WHISPER_LOG_DEBUG("single timestamp ending - skip entire chunk\n");
|
|
7704
|
+
seek_delta = std::min(seek_end - seek, WHISPER_CHUNK_SIZE * 100);
|
|
7705
|
+
}
|
|
7706
|
+
|
|
6300
7707
|
// update audio window
|
|
6301
7708
|
seek += seek_delta;
|
|
6302
7709
|
|
|
@@ -6312,6 +7719,21 @@ int whisper_full(
|
|
|
6312
7719
|
struct whisper_full_params params,
|
|
6313
7720
|
const float * samples,
|
|
6314
7721
|
int n_samples) {
|
|
7722
|
+
|
|
7723
|
+
std::vector<float> vad_samples;
|
|
7724
|
+
if (params.vad) {
|
|
7725
|
+
WHISPER_LOG_INFO("%s: VAD is enabled, processing speech segments only\n", __func__);
|
|
7726
|
+
if (!whisper_vad(ctx, ctx->state, params, samples, n_samples, vad_samples)) {
|
|
7727
|
+
WHISPER_LOG_ERROR("%s: failed to compute VAD\n", __func__);
|
|
7728
|
+
return -1;
|
|
7729
|
+
}
|
|
7730
|
+
if (vad_samples.empty()) {
|
|
7731
|
+
ctx->state->result_all.clear();
|
|
7732
|
+
return 0;
|
|
7733
|
+
}
|
|
7734
|
+
samples = vad_samples.data();
|
|
7735
|
+
n_samples = vad_samples.size();
|
|
7736
|
+
}
|
|
6315
7737
|
return whisper_full_with_state(ctx, ctx->state, params, samples, n_samples);
|
|
6316
7738
|
}
|
|
6317
7739
|
|
|
@@ -6321,9 +7743,24 @@ int whisper_full_parallel(
|
|
|
6321
7743
|
const float * samples,
|
|
6322
7744
|
int n_samples,
|
|
6323
7745
|
int n_processors) {
|
|
7746
|
+
|
|
6324
7747
|
if (n_processors == 1) {
|
|
6325
7748
|
return whisper_full(ctx, params, samples, n_samples);
|
|
6326
7749
|
}
|
|
7750
|
+
|
|
7751
|
+
std::vector<float> vad_samples;
|
|
7752
|
+
if (params.vad) {
|
|
7753
|
+
WHISPER_LOG_INFO("%s: VAD is enabled, processing speech segments only\n", __func__);
|
|
7754
|
+
if (!whisper_vad(ctx, ctx->state, params, samples, n_samples, vad_samples)) {
|
|
7755
|
+
WHISPER_LOG_ERROR("%s: failed to compute VAD\n", __func__);
|
|
7756
|
+
return -1;
|
|
7757
|
+
}
|
|
7758
|
+
if (vad_samples.empty()) {
|
|
7759
|
+
return 0;
|
|
7760
|
+
}
|
|
7761
|
+
samples = vad_samples.data();
|
|
7762
|
+
n_samples = vad_samples.size();
|
|
7763
|
+
}
|
|
6327
7764
|
int ret = 0;
|
|
6328
7765
|
|
|
6329
7766
|
// prepare separate states for each thread
|
|
@@ -6446,20 +7883,93 @@ int whisper_full_lang_id(struct whisper_context * ctx) {
|
|
|
6446
7883
|
return ctx->state->lang_id;
|
|
6447
7884
|
}
|
|
6448
7885
|
|
|
6449
|
-
int64_t
|
|
6450
|
-
|
|
7886
|
+
static int64_t map_processed_to_original_time(int64_t processed_time, const std::vector<vad_time_mapping> & mapping_table) {
|
|
7887
|
+
if (mapping_table.empty()) {
|
|
7888
|
+
return processed_time;
|
|
7889
|
+
}
|
|
7890
|
+
|
|
7891
|
+
if (processed_time <= mapping_table.front().processed_time) {
|
|
7892
|
+
return mapping_table.front().original_time; // Before first mapping point
|
|
7893
|
+
}
|
|
7894
|
+
|
|
7895
|
+
if (processed_time >= mapping_table.back().processed_time) {
|
|
7896
|
+
return mapping_table.back().original_time; // After last mapping point
|
|
7897
|
+
}
|
|
7898
|
+
|
|
7899
|
+
// Binary search over the time map that finds the first entry that has a
|
|
7900
|
+
// processed time greater than or equal to the current processed time.
|
|
7901
|
+
auto upper = std::lower_bound(mapping_table.begin(), mapping_table.end(), processed_time,
|
|
7902
|
+
[](const vad_time_mapping & entry, int64_t time) {
|
|
7903
|
+
return entry.processed_time < time;
|
|
7904
|
+
}
|
|
7905
|
+
);
|
|
7906
|
+
|
|
7907
|
+
// If exact match found
|
|
7908
|
+
if (upper->processed_time == processed_time) {
|
|
7909
|
+
return upper->original_time;
|
|
7910
|
+
}
|
|
7911
|
+
|
|
7912
|
+
// Need to interpolate between two points
|
|
7913
|
+
auto lower = upper - 1;
|
|
7914
|
+
|
|
7915
|
+
int64_t processed_diff = upper->processed_time - lower->processed_time;
|
|
7916
|
+
int64_t original_diff = upper->original_time - lower->original_time;
|
|
7917
|
+
int64_t offset = processed_time - lower->processed_time;
|
|
7918
|
+
|
|
7919
|
+
if (processed_diff == 0) {
|
|
7920
|
+
return lower->original_time;
|
|
7921
|
+
}
|
|
7922
|
+
|
|
7923
|
+
// Perform linear interpolation
|
|
7924
|
+
return lower->original_time + (offset * original_diff) / processed_diff;
|
|
6451
7925
|
}
|
|
6452
7926
|
|
|
6453
|
-
|
|
6454
|
-
|
|
7927
|
+
// Function to get the starting timestamp of a segment
|
|
7928
|
+
int64_t whisper_full_get_segment_t0_from_state(struct whisper_state * state, int i_segment) {
|
|
7929
|
+
// If VAD wasn't used, return the original timestamp
|
|
7930
|
+
if (!state->has_vad_segments || state->vad_mapping_table.empty()) {
|
|
7931
|
+
return state->result_all[i_segment].t0;
|
|
7932
|
+
}
|
|
7933
|
+
|
|
7934
|
+
// Get the processed timestamp
|
|
7935
|
+
int64_t t0 = state->result_all[i_segment].t0;
|
|
7936
|
+
|
|
7937
|
+
// Map to original time using the mapping table
|
|
7938
|
+
return map_processed_to_original_time(t0, state->vad_mapping_table);
|
|
6455
7939
|
}
|
|
6456
7940
|
|
|
7941
|
+
// Function to get the ending timestamp of a segment
|
|
6457
7942
|
int64_t whisper_full_get_segment_t1_from_state(struct whisper_state * state, int i_segment) {
|
|
6458
|
-
return
|
|
7943
|
+
// If VAD wasn't used, return the original timestamp
|
|
7944
|
+
if (!state->has_vad_segments || state->vad_mapping_table.empty()) {
|
|
7945
|
+
return state->result_all[i_segment].t1;
|
|
7946
|
+
}
|
|
7947
|
+
|
|
7948
|
+
// Get the processed timestamp
|
|
7949
|
+
int64_t t1 = state->result_all[i_segment].t1;
|
|
7950
|
+
|
|
7951
|
+
// Map to original time using the mapping table
|
|
7952
|
+
int64_t orig_t1 = map_processed_to_original_time(t1, state->vad_mapping_table);
|
|
7953
|
+
|
|
7954
|
+
// Get the corresponding t0 for this segment
|
|
7955
|
+
int64_t orig_t0 = whisper_full_get_segment_t0_from_state(state, i_segment);
|
|
7956
|
+
|
|
7957
|
+
// Ensure minimum duration to prevent zero-length segments
|
|
7958
|
+
const int64_t min_duration = 10; // 10ms minimum
|
|
7959
|
+
if (orig_t1 - orig_t0 < min_duration) {
|
|
7960
|
+
orig_t1 = orig_t0 + min_duration;
|
|
7961
|
+
}
|
|
7962
|
+
|
|
7963
|
+
return orig_t1;
|
|
7964
|
+
}
|
|
7965
|
+
|
|
7966
|
+
|
|
7967
|
+
int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment) {
|
|
7968
|
+
return whisper_full_get_segment_t0_from_state(ctx->state, i_segment);
|
|
6459
7969
|
}
|
|
6460
7970
|
|
|
6461
7971
|
int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment) {
|
|
6462
|
-
return ctx->state
|
|
7972
|
+
return whisper_full_get_segment_t1_from_state(ctx->state, i_segment);
|
|
6463
7973
|
}
|
|
6464
7974
|
|
|
6465
7975
|
bool whisper_full_get_segment_speaker_turn_next_from_state(struct whisper_state * state, int i_segment) {
|
|
@@ -6518,6 +8028,14 @@ float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int
|
|
|
6518
8028
|
return ctx->state->result_all[i_segment].tokens[i_token].p;
|
|
6519
8029
|
}
|
|
6520
8030
|
|
|
8031
|
+
float whisper_full_get_segment_no_speech_prob(struct whisper_context * ctx, int i_segment) {
|
|
8032
|
+
return ctx->state->result_all[i_segment].no_speech_prob;
|
|
8033
|
+
}
|
|
8034
|
+
|
|
8035
|
+
float whisper_full_get_segment_no_speech_prob_from_state(struct whisper_state * state, int i_segment) {
|
|
8036
|
+
return state->result_all[i_segment].no_speech_prob;
|
|
8037
|
+
}
|
|
8038
|
+
|
|
6521
8039
|
// =================================================================================================
|
|
6522
8040
|
|
|
6523
8041
|
//
|
|
@@ -6698,7 +8216,6 @@ WHISPER_API const char * whisper_bench_wsp_ggml_mul_mat_str(int n_threads) {
|
|
|
6698
8216
|
// c: N*N*sizeof(float)
|
|
6699
8217
|
// when F16 is used, there is an extra work buffer of size N*N*sizeof(float)
|
|
6700
8218
|
std::vector<uint8_t> buf(3llu*N_max*N_max*sizeof(float) + 3*wsp_ggml_tensor_overhead() + wsp_ggml_graph_overhead());
|
|
6701
|
-
std::vector<uint8_t> work;
|
|
6702
8219
|
|
|
6703
8220
|
// put a bunch of random data in the buffer
|
|
6704
8221
|
for (size_t i = 0; i < buf.size(); i++) buf[i] = i;
|
|
@@ -6755,12 +8272,12 @@ WHISPER_API const char * whisper_bench_wsp_ggml_mul_mat_str(int n_threads) {
|
|
|
6755
8272
|
double tsum = 0.0;
|
|
6756
8273
|
|
|
6757
8274
|
// heat-up
|
|
6758
|
-
wsp_ggml_graph_compute_helper(gf,
|
|
8275
|
+
wsp_ggml_graph_compute_helper(gf, n_threads, nullptr, nullptr);
|
|
6759
8276
|
|
|
6760
8277
|
for (int i = 0; i < n_max; ++i) {
|
|
6761
8278
|
const int64_t t0 = wsp_ggml_time_us();
|
|
6762
8279
|
|
|
6763
|
-
wsp_ggml_graph_compute_helper(gf,
|
|
8280
|
+
wsp_ggml_graph_compute_helper(gf, n_threads, nullptr, nullptr);
|
|
6764
8281
|
|
|
6765
8282
|
const int64_t t1 = wsp_ggml_time_us();
|
|
6766
8283
|
|
|
@@ -6813,10 +8330,6 @@ WHISPER_API const char * whisper_bench_wsp_ggml_mul_mat_str(int n_threads) {
|
|
|
6813
8330
|
// token-level timestamps
|
|
6814
8331
|
//
|
|
6815
8332
|
|
|
6816
|
-
static int timestamp_to_sample(int64_t t, int n_samples) {
|
|
6817
|
-
return std::max(0, std::min((int) n_samples - 1, (int) ((t*WHISPER_SAMPLE_RATE)/100)));
|
|
6818
|
-
}
|
|
6819
|
-
|
|
6820
8333
|
static int64_t sample_to_timestamp(int i_sample) {
|
|
6821
8334
|
return (100ll*i_sample)/WHISPER_SAMPLE_RATE;
|
|
6822
8335
|
}
|
|
@@ -6866,6 +8379,18 @@ static std::vector<float> get_signal_energy(const float * signal, int n_samples,
|
|
|
6866
8379
|
return result;
|
|
6867
8380
|
}
|
|
6868
8381
|
|
|
8382
|
+
static int timestamp_to_sample(int64_t t, int64_t segment_t0, int n_samples) {
|
|
8383
|
+
// Convert absolute timestamp to segment-relative timestamp
|
|
8384
|
+
int64_t relative_t = t - segment_t0;
|
|
8385
|
+
int sample = (int)((relative_t * WHISPER_SAMPLE_RATE) / 100);
|
|
8386
|
+
return std::max(0, std::min(n_samples - 1, sample));
|
|
8387
|
+
}
|
|
8388
|
+
|
|
8389
|
+
static int64_t sample_to_timestamp(int i_sample, int64_t segment_t0) {
|
|
8390
|
+
int64_t relative_timestamp = (100ll * i_sample) / WHISPER_SAMPLE_RATE;
|
|
8391
|
+
return relative_timestamp + segment_t0;
|
|
8392
|
+
}
|
|
8393
|
+
|
|
6869
8394
|
static void whisper_exp_compute_token_level_timestamps(
|
|
6870
8395
|
struct whisper_context & ctx,
|
|
6871
8396
|
struct whisper_state & state,
|
|
@@ -6921,12 +8446,6 @@ static void whisper_exp_compute_token_level_timestamps(
|
|
|
6921
8446
|
|
|
6922
8447
|
const int64_t tt = t_beg + 2*(token.tid - whisper_token_beg(&ctx));
|
|
6923
8448
|
|
|
6924
|
-
tokens[j].id = token.id;
|
|
6925
|
-
tokens[j].tid = token.tid;
|
|
6926
|
-
tokens[j].p = token.p;
|
|
6927
|
-
tokens[j].pt = token.pt;
|
|
6928
|
-
tokens[j].ptsum = token.ptsum;
|
|
6929
|
-
|
|
6930
8449
|
tokens[j].vlen = voice_length(whisper_token_to_str(&ctx, token.id));
|
|
6931
8450
|
|
|
6932
8451
|
if (token.pt > thold_pt && token.ptsum > thold_ptsum && token.tid > tid_last && tt <= t1) {
|
|
@@ -7012,8 +8531,8 @@ static void whisper_exp_compute_token_level_timestamps(
|
|
|
7012
8531
|
continue;
|
|
7013
8532
|
}
|
|
7014
8533
|
|
|
7015
|
-
int s0 = timestamp_to_sample(tokens[j].t0, n_samples);
|
|
7016
|
-
int s1 = timestamp_to_sample(tokens[j].t1, n_samples);
|
|
8534
|
+
int s0 = timestamp_to_sample(tokens[j].t0, segment.t0, n_samples);
|
|
8535
|
+
int s1 = timestamp_to_sample(tokens[j].t1, segment.t0, n_samples);
|
|
7017
8536
|
|
|
7018
8537
|
const int ss0 = std::max(s0 - hw, 0);
|
|
7019
8538
|
const int ss1 = std::min(s1 + hw, n_samples);
|
|
@@ -7034,7 +8553,7 @@ static void whisper_exp_compute_token_level_timestamps(
|
|
|
7034
8553
|
while (k > 0 && state.energy[k] > thold) {
|
|
7035
8554
|
k--;
|
|
7036
8555
|
}
|
|
7037
|
-
tokens[j].t0 = sample_to_timestamp(k);
|
|
8556
|
+
tokens[j].t0 = sample_to_timestamp(k, segment.t0);
|
|
7038
8557
|
if (tokens[j].t0 < tokens[j - 1].t1) {
|
|
7039
8558
|
tokens[j].t0 = tokens[j - 1].t1;
|
|
7040
8559
|
} else {
|
|
@@ -7045,7 +8564,7 @@ static void whisper_exp_compute_token_level_timestamps(
|
|
|
7045
8564
|
k++;
|
|
7046
8565
|
}
|
|
7047
8566
|
s0 = k;
|
|
7048
|
-
tokens[j].t0 = sample_to_timestamp(k);
|
|
8567
|
+
tokens[j].t0 = sample_to_timestamp(k, segment.t0);
|
|
7049
8568
|
}
|
|
7050
8569
|
}
|
|
7051
8570
|
|
|
@@ -7055,7 +8574,7 @@ static void whisper_exp_compute_token_level_timestamps(
|
|
|
7055
8574
|
while (k < n_samples - 1 && state.energy[k] > thold) {
|
|
7056
8575
|
k++;
|
|
7057
8576
|
}
|
|
7058
|
-
tokens[j].t1 = sample_to_timestamp(k);
|
|
8577
|
+
tokens[j].t1 = sample_to_timestamp(k, segment.t0);
|
|
7059
8578
|
if (j < n - 1 && tokens[j].t1 > tokens[j + 1].t0) {
|
|
7060
8579
|
tokens[j].t1 = tokens[j + 1].t0;
|
|
7061
8580
|
} else {
|
|
@@ -7066,7 +8585,7 @@ static void whisper_exp_compute_token_level_timestamps(
|
|
|
7066
8585
|
k--;
|
|
7067
8586
|
}
|
|
7068
8587
|
s1 = k;
|
|
7069
|
-
tokens[j].t1 = sample_to_timestamp(k);
|
|
8588
|
+
tokens[j].t1 = sample_to_timestamp(k, segment.t0);
|
|
7070
8589
|
}
|
|
7071
8590
|
}
|
|
7072
8591
|
}
|
|
@@ -7137,18 +8656,18 @@ static wsp_ggml_tensor * dtw_and_backtrace(wsp_ggml_context * ctx, wsp_ggml_tens
|
|
|
7137
8656
|
struct wsp_ggml_tensor * cost = wsp_ggml_new_tensor_2d(ctx, WSP_GGML_TYPE_F32, N + 1, M + 1);
|
|
7138
8657
|
struct wsp_ggml_tensor * trace = wsp_ggml_new_tensor_2d(ctx, WSP_GGML_TYPE_I32, N + 1, M + 1);
|
|
7139
8658
|
|
|
7140
|
-
cost =
|
|
7141
|
-
trace =
|
|
7142
|
-
|
|
8659
|
+
cost = whisper_set_f32(cost, INFINITY);
|
|
8660
|
+
trace = whisper_set_i32(trace, -1);
|
|
8661
|
+
whisper_set_f32_nd(cost, 0, 0, 0, 0, 0.0);
|
|
7143
8662
|
|
|
7144
8663
|
// dtw
|
|
7145
8664
|
// supposedly can be optmized by computing diagonals in parallel ?
|
|
7146
8665
|
// Not sure it is worth it since x will be GENERATED_TOKENS*1500 size at most.
|
|
7147
8666
|
for (int64_t j = 1; j < M + 1; ++j) {
|
|
7148
8667
|
for (int64_t i = 1; i < N + 1; ++i) {
|
|
7149
|
-
float c0 =
|
|
7150
|
-
float c1 =
|
|
7151
|
-
float c2 =
|
|
8668
|
+
float c0 = whisper_get_f32_nd(cost, i - 1, j - 1, 0, 0);
|
|
8669
|
+
float c1 = whisper_get_f32_nd(cost, i - 1, j, 0, 0);
|
|
8670
|
+
float c2 = whisper_get_f32_nd(cost, i, j - 1, 0, 0);
|
|
7152
8671
|
|
|
7153
8672
|
float c;
|
|
7154
8673
|
int32_t t;
|
|
@@ -7163,9 +8682,9 @@ static wsp_ggml_tensor * dtw_and_backtrace(wsp_ggml_context * ctx, wsp_ggml_tens
|
|
|
7163
8682
|
t = 2;
|
|
7164
8683
|
}
|
|
7165
8684
|
|
|
7166
|
-
c =
|
|
7167
|
-
|
|
7168
|
-
|
|
8685
|
+
c = whisper_get_f32_nd(x, i - 1, j - 1, 0, 0) + c;
|
|
8686
|
+
whisper_set_f32_nd(cost, i, j, 0, 0, c);
|
|
8687
|
+
whisper_set_i32_nd(trace, i, j, 0, 0, t);
|
|
7169
8688
|
}
|
|
7170
8689
|
}
|
|
7171
8690
|
|
|
@@ -7174,19 +8693,19 @@ static wsp_ggml_tensor * dtw_and_backtrace(wsp_ggml_context * ctx, wsp_ggml_tens
|
|
|
7174
8693
|
struct wsp_ggml_tensor * bt = wsp_ggml_new_tensor_2d(ctx, WSP_GGML_TYPE_I32, BT_MAX_ROWS, 2);
|
|
7175
8694
|
// trace[0, :] = 2;
|
|
7176
8695
|
for (int64_t i = 0; i < M + 1; ++i)
|
|
7177
|
-
|
|
8696
|
+
whisper_set_i32_nd(trace, 0, i, 0, 0, 2);
|
|
7178
8697
|
//trace[:, 0] = 1;
|
|
7179
8698
|
for (int64_t i = 0; i < N + 1; ++i)
|
|
7180
|
-
|
|
8699
|
+
whisper_set_i32_nd(trace, i, 0, 0, 0, 1);
|
|
7181
8700
|
int bt_row_idx = BT_MAX_ROWS - 1;
|
|
7182
8701
|
int64_t i = N;
|
|
7183
8702
|
int64_t j = M;
|
|
7184
8703
|
while (i > 0 || j > 0) {
|
|
7185
|
-
|
|
7186
|
-
|
|
8704
|
+
whisper_set_i32_nd(bt, bt_row_idx, 0, 0, 0, i - 1);
|
|
8705
|
+
whisper_set_i32_nd(bt, bt_row_idx, 1, 0, 0, j - 1);
|
|
7187
8706
|
--bt_row_idx;
|
|
7188
8707
|
|
|
7189
|
-
int32_t t =
|
|
8708
|
+
int32_t t = whisper_get_i32_nd(trace, i, j, 0, 0);
|
|
7190
8709
|
if (t == 0) {
|
|
7191
8710
|
--i;
|
|
7192
8711
|
--j;
|
|
@@ -7207,8 +8726,8 @@ static wsp_ggml_tensor * dtw_and_backtrace(wsp_ggml_context * ctx, wsp_ggml_tens
|
|
|
7207
8726
|
wsp_ggml_tensor * r = wsp_ggml_new_tensor_2d(ctx, WSP_GGML_TYPE_I32, 2, result_n_cols);
|
|
7208
8727
|
for (int64_t i = 0; i < 2; ++i) {
|
|
7209
8728
|
for (int64_t j = 0; j < result_n_cols; ++j) {
|
|
7210
|
-
int32_t v =
|
|
7211
|
-
|
|
8729
|
+
int32_t v = whisper_get_i32_nd(bt, j+bt_row_idx+1, i, 0, 0);
|
|
8730
|
+
whisper_set_i32_nd(r, i, j, 0, 0, v);
|
|
7212
8731
|
}
|
|
7213
8732
|
}
|
|
7214
8733
|
|
|
@@ -7243,11 +8762,11 @@ static void median_filter(struct wsp_ggml_tensor * dst , const struct wsp_ggml_t
|
|
|
7243
8762
|
idx = 2*(a->ne[2] - 1) - idx;
|
|
7244
8763
|
}
|
|
7245
8764
|
|
|
7246
|
-
filter.push_back(
|
|
8765
|
+
filter.push_back(whisper_get_f32_nd(a, i, j, idx, 0));
|
|
7247
8766
|
}
|
|
7248
8767
|
std::sort(filter.begin(), filter.end());
|
|
7249
8768
|
const float v = filter[filter.size()/2];
|
|
7250
|
-
|
|
8769
|
+
whisper_set_f32_nd(dst, i, j, k, 0, v);
|
|
7251
8770
|
filter.clear();
|
|
7252
8771
|
}
|
|
7253
8772
|
}
|
|
@@ -7369,7 +8888,9 @@ static void whisper_exp_compute_token_level_timestamps_dtw(
|
|
|
7369
8888
|
// Compute
|
|
7370
8889
|
struct wsp_ggml_cgraph * gf = wsp_ggml_new_graph(gctx);
|
|
7371
8890
|
wsp_ggml_build_forward_expand(gf, w);
|
|
7372
|
-
|
|
8891
|
+
|
|
8892
|
+
wsp_ggml_backend_ptr backend { wsp_ggml_backend_init_by_type(WSP_GGML_BACKEND_DEVICE_TYPE_CPU, nullptr) };
|
|
8893
|
+
wsp_ggml_backend_graph_compute(backend.get(), gf);
|
|
7373
8894
|
|
|
7374
8895
|
wsp_ggml_tensor * alignment = dtw_and_backtrace(gctx, w);
|
|
7375
8896
|
|
|
@@ -7378,9 +8899,9 @@ static void whisper_exp_compute_token_level_timestamps_dtw(
|
|
|
7378
8899
|
auto seg_i = state->result_all.begin() + i_segment;
|
|
7379
8900
|
auto tok_i = seg_i->tokens.begin();
|
|
7380
8901
|
for (int i = 0; i < alignment->ne[1]; ++i) {
|
|
7381
|
-
int32_t v =
|
|
8902
|
+
int32_t v = whisper_get_i32_nd(alignment, 0, i, 0, 0);
|
|
7382
8903
|
if (v != last_v) {
|
|
7383
|
-
int32_t time_index =
|
|
8904
|
+
int32_t time_index = whisper_get_i32_nd(alignment, 1, i, 0, 0);
|
|
7384
8905
|
int64_t timestamp = (time_index * 2) + seek; // Each index on DTW result = 20mS audio
|
|
7385
8906
|
last_v = v;
|
|
7386
8907
|
|
|
@@ -7418,6 +8939,7 @@ static void whisper_exp_compute_token_level_timestamps_dtw(
|
|
|
7418
8939
|
void whisper_log_set(wsp_ggml_log_callback log_callback, void * user_data) {
|
|
7419
8940
|
g_state.log_callback = log_callback ? log_callback : whisper_log_callback_default;
|
|
7420
8941
|
g_state.log_callback_user_data = user_data;
|
|
8942
|
+
wsp_ggml_log_set(g_state.log_callback, g_state.log_callback_user_data);
|
|
7421
8943
|
}
|
|
7422
8944
|
|
|
7423
8945
|
WSP_GGML_ATTRIBUTE_FORMAT(2, 3)
|
|
@@ -7441,6 +8963,11 @@ static void whisper_log_internal(wsp_ggml_log_level level, const char * format,
|
|
|
7441
8963
|
static void whisper_log_callback_default(wsp_ggml_log_level level, const char * text, void * user_data) {
|
|
7442
8964
|
(void) level;
|
|
7443
8965
|
(void) user_data;
|
|
8966
|
+
#ifndef WHISPER_DEBUG
|
|
8967
|
+
if (level == WSP_GGML_LOG_LEVEL_DEBUG) {
|
|
8968
|
+
return;
|
|
8969
|
+
}
|
|
8970
|
+
#endif
|
|
7444
8971
|
fputs(text, stderr);
|
|
7445
8972
|
fflush(stderr);
|
|
7446
8973
|
}
|