whisper.rn 0.4.0-rc.9 → 0.4.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (183) hide show
  1. package/README.md +5 -1
  2. package/android/build.gradle +12 -3
  3. package/android/src/main/CMakeLists.txt +43 -13
  4. package/android/src/main/java/com/rnwhisper/WhisperContext.java +33 -35
  5. package/android/src/main/jni.cpp +9 -0
  6. package/android/src/main/jniLibs/arm64-v8a/librnwhisper.so +0 -0
  7. package/android/src/main/jniLibs/arm64-v8a/librnwhisper_v8fp16_va_2.so +0 -0
  8. package/android/src/main/jniLibs/armeabi-v7a/librnwhisper.so +0 -0
  9. package/android/src/main/jniLibs/armeabi-v7a/librnwhisper_vfpv4.so +0 -0
  10. package/android/src/main/jniLibs/x86_64/librnwhisper.so +0 -0
  11. package/android/src/main/jniLibs/x86_64/librnwhisper_x86_64.so +0 -0
  12. package/cpp/coreml/whisper-compat.h +10 -0
  13. package/cpp/coreml/whisper-compat.m +35 -0
  14. package/cpp/coreml/whisper-decoder-impl.h +27 -15
  15. package/cpp/coreml/whisper-decoder-impl.m +36 -10
  16. package/cpp/coreml/whisper-encoder-impl.h +21 -9
  17. package/cpp/coreml/whisper-encoder-impl.m +29 -3
  18. package/cpp/ggml-alloc.c +39 -37
  19. package/cpp/ggml-alloc.h +1 -1
  20. package/cpp/ggml-backend-impl.h +55 -27
  21. package/cpp/ggml-backend-reg.cpp +591 -0
  22. package/cpp/ggml-backend.cpp +336 -955
  23. package/cpp/ggml-backend.h +70 -42
  24. package/cpp/ggml-common.h +57 -49
  25. package/cpp/ggml-cpp.h +39 -0
  26. package/cpp/ggml-cpu/amx/amx.cpp +221 -0
  27. package/cpp/ggml-cpu/amx/amx.h +8 -0
  28. package/cpp/ggml-cpu/amx/common.h +91 -0
  29. package/cpp/ggml-cpu/amx/mmq.cpp +2511 -0
  30. package/cpp/ggml-cpu/amx/mmq.h +10 -0
  31. package/cpp/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
  32. package/cpp/ggml-cpu/arch/arm/quants.c +4113 -0
  33. package/cpp/ggml-cpu/arch/arm/repack.cpp +2162 -0
  34. package/cpp/ggml-cpu/arch/x86/cpu-feats.cpp +327 -0
  35. package/cpp/ggml-cpu/arch/x86/quants.c +4310 -0
  36. package/cpp/ggml-cpu/arch/x86/repack.cpp +3284 -0
  37. package/cpp/ggml-cpu/arch-fallback.h +184 -0
  38. package/cpp/ggml-cpu/binary-ops.cpp +158 -0
  39. package/cpp/ggml-cpu/binary-ops.h +16 -0
  40. package/cpp/ggml-cpu/common.h +72 -0
  41. package/cpp/ggml-cpu/ggml-cpu-impl.h +511 -0
  42. package/cpp/ggml-cpu/ggml-cpu.c +3473 -0
  43. package/cpp/ggml-cpu/ggml-cpu.cpp +671 -0
  44. package/cpp/ggml-cpu/ops.cpp +9085 -0
  45. package/cpp/ggml-cpu/ops.h +111 -0
  46. package/cpp/ggml-cpu/quants.c +1157 -0
  47. package/cpp/ggml-cpu/quants.h +89 -0
  48. package/cpp/ggml-cpu/repack.cpp +1570 -0
  49. package/cpp/ggml-cpu/repack.h +98 -0
  50. package/cpp/ggml-cpu/simd-mappings.h +1006 -0
  51. package/cpp/ggml-cpu/traits.cpp +36 -0
  52. package/cpp/ggml-cpu/traits.h +38 -0
  53. package/cpp/ggml-cpu/unary-ops.cpp +186 -0
  54. package/cpp/ggml-cpu/unary-ops.h +28 -0
  55. package/cpp/ggml-cpu/vec.cpp +321 -0
  56. package/cpp/ggml-cpu/vec.h +973 -0
  57. package/cpp/ggml-cpu.h +143 -0
  58. package/cpp/ggml-impl.h +417 -23
  59. package/cpp/ggml-metal-impl.h +622 -0
  60. package/cpp/ggml-metal.h +9 -9
  61. package/cpp/ggml-metal.m +3451 -1344
  62. package/cpp/ggml-opt.cpp +1037 -0
  63. package/cpp/ggml-opt.h +237 -0
  64. package/cpp/ggml-quants.c +296 -10818
  65. package/cpp/ggml-quants.h +78 -125
  66. package/cpp/ggml-threading.cpp +12 -0
  67. package/cpp/ggml-threading.h +14 -0
  68. package/cpp/ggml-whisper-sim.metallib +0 -0
  69. package/cpp/ggml-whisper.metallib +0 -0
  70. package/cpp/ggml.c +4633 -21450
  71. package/cpp/ggml.h +320 -661
  72. package/cpp/gguf.cpp +1347 -0
  73. package/cpp/gguf.h +202 -0
  74. package/cpp/rn-whisper.cpp +4 -11
  75. package/cpp/whisper-arch.h +197 -0
  76. package/cpp/whisper.cpp +2022 -495
  77. package/cpp/whisper.h +75 -18
  78. package/ios/CMakeLists.txt +95 -0
  79. package/ios/RNWhisper.h +5 -0
  80. package/ios/RNWhisperAudioUtils.m +4 -0
  81. package/ios/RNWhisperContext.h +5 -0
  82. package/ios/RNWhisperContext.mm +4 -2
  83. package/ios/rnwhisper.xcframework/Info.plist +74 -0
  84. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-alloc.h +76 -0
  85. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +255 -0
  86. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend.h +354 -0
  87. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-common.h +1861 -0
  88. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-cpp.h +39 -0
  89. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-cpu.h +143 -0
  90. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-impl.h +603 -0
  91. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-metal-impl.h +622 -0
  92. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-metal.h +66 -0
  93. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-opt.h +237 -0
  94. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-quants.h +100 -0
  95. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-threading.h +14 -0
  96. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml.h +2221 -0
  97. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/gguf.h +202 -0
  98. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/rn-audioutils.h +14 -0
  99. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/rn-whisper-log.h +11 -0
  100. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/rn-whisper.h +52 -0
  101. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/whisper-arch.h +197 -0
  102. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/whisper.h +739 -0
  103. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Info.plist +0 -0
  104. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
  105. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/rnwhisper +0 -0
  106. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-alloc.h +76 -0
  107. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +255 -0
  108. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +354 -0
  109. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-common.h +1861 -0
  110. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpp.h +39 -0
  111. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpu.h +143 -0
  112. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +603 -0
  113. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal-impl.h +622 -0
  114. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal.h +66 -0
  115. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-opt.h +237 -0
  116. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-quants.h +100 -0
  117. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-threading.h +14 -0
  118. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +2221 -0
  119. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/gguf.h +202 -0
  120. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-audioutils.h +14 -0
  121. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-whisper-log.h +11 -0
  122. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-whisper.h +52 -0
  123. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/whisper-arch.h +197 -0
  124. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/whisper.h +739 -0
  125. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
  126. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +101 -0
  127. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
  128. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
  129. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-alloc.h +76 -0
  130. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +255 -0
  131. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend.h +354 -0
  132. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-common.h +1861 -0
  133. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-cpp.h +39 -0
  134. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-cpu.h +143 -0
  135. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-impl.h +603 -0
  136. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-metal-impl.h +622 -0
  137. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-metal.h +66 -0
  138. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-opt.h +237 -0
  139. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-quants.h +100 -0
  140. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-threading.h +14 -0
  141. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml.h +2221 -0
  142. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/gguf.h +202 -0
  143. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/rn-audioutils.h +14 -0
  144. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/rn-whisper-log.h +11 -0
  145. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/rn-whisper.h +52 -0
  146. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/whisper-arch.h +197 -0
  147. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/whisper.h +739 -0
  148. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Info.plist +0 -0
  149. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
  150. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/rnwhisper +0 -0
  151. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-alloc.h +76 -0
  152. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +255 -0
  153. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +354 -0
  154. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-common.h +1861 -0
  155. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpp.h +39 -0
  156. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpu.h +143 -0
  157. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +603 -0
  158. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal-impl.h +622 -0
  159. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal.h +66 -0
  160. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-opt.h +237 -0
  161. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-quants.h +100 -0
  162. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-threading.h +14 -0
  163. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +2221 -0
  164. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/gguf.h +202 -0
  165. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-audioutils.h +14 -0
  166. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-whisper-log.h +11 -0
  167. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-whisper.h +52 -0
  168. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/whisper-arch.h +197 -0
  169. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/whisper.h +739 -0
  170. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
  171. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +101 -0
  172. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
  173. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
  174. package/jest/mock.js +5 -0
  175. package/lib/commonjs/version.json +1 -1
  176. package/lib/module/version.json +1 -1
  177. package/package.json +10 -6
  178. package/src/version.json +1 -1
  179. package/whisper-rn.podspec +11 -18
  180. package/cpp/README.md +0 -4
  181. package/cpp/ggml-aarch64.c +0 -3209
  182. package/cpp/ggml-aarch64.h +0 -39
  183. package/cpp/ggml-cpu-impl.h +0 -614
package/cpp/whisper.cpp CHANGED
@@ -1,75 +1,52 @@
1
1
  #include "whisper.h"
2
+ #include "whisper-arch.h"
3
+
4
+ #include "ggml.h"
5
+ #include "ggml-cpp.h"
6
+ #include "ggml-alloc.h"
7
+ #include "ggml-backend.h"
2
8
 
3
9
  #ifdef WHISPER_USE_COREML
4
10
  #include "coreml/whisper-encoder.h"
5
11
  #endif
6
12
 
7
- #ifdef WSP_GGML_USE_METAL
8
- #include "ggml-metal.h"
9
- #endif
10
-
11
- #ifdef WSP_GGML_USE_CUDA
12
- #include "ggml-cuda.h"
13
- #endif
14
-
15
- #ifdef WSP_GGML_USE_SYCL
16
- #include "ggml-sycl.h"
17
- #endif
18
-
19
- #ifdef WSP_GGML_USE_VULKAN
20
- #include "ggml-vulkan.h"
21
- #endif
22
-
23
- #ifdef WSP_GGML_USE_BLAS
24
- #include "ggml-blas.h"
25
- #endif
26
-
27
13
  #ifdef WHISPER_USE_OPENVINO
28
14
  #include "openvino/whisper-openvino-encoder.h"
29
15
  #endif
30
16
 
31
- #ifdef WSP_GGML_USE_CANN
32
- #include "ggml-cann.h"
33
- #endif
34
-
35
- #include "ggml.h"
36
- #include "ggml-alloc.h"
37
- #include "ggml-backend.h"
38
-
39
17
  #include <atomic>
40
18
  #include <algorithm>
41
19
  #include <cassert>
20
+ #include <cfloat>
42
21
  #define _USE_MATH_DEFINES
43
22
  #include <cmath>
44
- #include <cstdio>
23
+ #include <climits>
24
+ #include <codecvt>
45
25
  #include <cstdarg>
26
+ #include <cstdio>
46
27
  #include <cstring>
47
28
  #include <fstream>
29
+ #include <functional>
48
30
  #include <map>
31
+ #include <mutex>
32
+ #include <random>
33
+ #include <regex>
49
34
  #include <set>
50
35
  #include <string>
51
36
  #include <thread>
52
37
  #include <vector>
53
- #include <regex>
54
- #include <random>
55
- #include <functional>
56
- #include <codecvt>
57
-
58
- #if defined(_MSC_VER)
59
- #pragma warning(disable: 4244 4267) // possible loss of data
60
- #endif
61
-
62
- #if defined(WSP_GGML_BIG_ENDIAN)
63
- #include <bit>
64
38
 
39
+ #if defined(WHISPER_BIG_ENDIAN)
65
40
  template<typename T>
66
41
  static T byteswap(T value) {
67
- return std::byteswap(value);
68
- }
69
-
70
- template<>
71
- float byteswap(float value) {
72
- return std::bit_cast<float>(byteswap(std::bit_cast<std::uint32_t>(value)));
42
+ T value_swapped;
43
+ char * source = reinterpret_cast<char *>(&value);
44
+ char * target = reinterpret_cast<char *>(&value_swapped);
45
+ int size = sizeof(T);
46
+ for (int i = 0; i < size; i++) {
47
+ target[size - 1 - i] = source[i];
48
+ }
49
+ return value_swapped;
73
50
  }
74
51
 
75
52
  template<typename T>
@@ -105,14 +82,14 @@ static void byteswap_tensor(wsp_ggml_tensor * tensor) {
105
82
  }
106
83
 
107
84
  #define BYTESWAP_VALUE(d) d = byteswap(d)
108
- #define BYTESWAP_FILTERS(f) \
85
+ #define BYTESWAP_FILTERS(f) \
109
86
  do { \
110
87
  for (auto & datum : f.data) { \
111
88
  datum = byteswap(datum); \
112
89
  } \
113
90
  } while (0)
114
- #define BYTESWAP_TENSOR(t) \
115
- do { \
91
+ #define BYTESWAP_TENSOR(t) \
92
+ do { \
116
93
  byteswap_tensor(t); \
117
94
  } while (0)
118
95
  #else
@@ -163,51 +140,118 @@ static void whisper_log_callback_default(wsp_ggml_log_level level, const char *
163
140
  #define WHISPER_MAX_DECODERS 8
164
141
  #define WHISPER_MAX_NODES 4096
165
142
 
143
+ static std::string format(const char * fmt, ...) {
144
+ va_list ap;
145
+ va_list ap2;
146
+ va_start(ap, fmt);
147
+ va_copy(ap2, ap);
148
+ int size = vsnprintf(NULL, 0, fmt, ap);
149
+ WSP_GGML_ASSERT(size >= 0 && size < INT_MAX); // NOLINT
150
+ std::vector<char> buf(size + 1);
151
+ int size2 = vsnprintf(buf.data(), size + 1, fmt, ap2);
152
+ WSP_GGML_ASSERT(size2 == size);
153
+ va_end(ap2);
154
+ va_end(ap);
155
+ return std::string(buf.data(), size);
156
+ }
157
+
166
158
  //
167
159
  // ggml helpers
168
160
  //
169
161
 
170
162
  static bool wsp_ggml_graph_compute_helper(
171
163
  struct wsp_ggml_cgraph * graph,
172
- std::vector<uint8_t> & buf,
173
164
  int n_threads,
174
165
  wsp_ggml_abort_callback abort_callback,
175
166
  void * abort_callback_data) {
176
- struct wsp_ggml_cplan plan = wsp_ggml_graph_plan(graph, n_threads, nullptr);
167
+ wsp_ggml_backend_ptr backend { wsp_ggml_backend_init_by_type(WSP_GGML_BACKEND_DEVICE_TYPE_CPU, nullptr) };
177
168
 
178
- plan.abort_callback = abort_callback;
179
- plan.abort_callback_data = abort_callback_data;
169
+ auto * reg = wsp_ggml_backend_dev_backend_reg(wsp_ggml_backend_get_device(backend.get()));
180
170
 
181
- if (plan.work_size > 0) {
182
- buf.resize(plan.work_size);
183
- plan.work_data = buf.data();
171
+ auto * set_abort_callback_fn = (wsp_ggml_backend_set_abort_callback_t) wsp_ggml_backend_reg_get_proc_address(reg, "wsp_ggml_backend_set_abort_callback");
172
+ if (set_abort_callback_fn) {
173
+ set_abort_callback_fn(backend.get(), abort_callback, abort_callback_data);
184
174
  }
185
175
 
186
- return wsp_ggml_graph_compute(graph, &plan);
176
+ auto wsp_ggml_backend_set_n_threads_fn = (wsp_ggml_backend_set_n_threads_t) wsp_ggml_backend_reg_get_proc_address(reg, "wsp_ggml_backend_set_n_threads");
177
+ if (wsp_ggml_backend_set_n_threads_fn) {
178
+ wsp_ggml_backend_set_n_threads_fn(backend.get(), n_threads);
179
+ }
180
+
181
+ return wsp_ggml_backend_graph_compute(backend.get(), graph) == WSP_GGML_STATUS_SUCCESS;
187
182
  }
188
183
 
189
184
  static bool wsp_ggml_graph_compute_helper(
190
185
  wsp_ggml_backend_sched_t sched,
191
186
  struct wsp_ggml_cgraph * graph,
192
- int n_threads) {
193
-
187
+ int n_threads,
188
+ bool sched_reset = true) {
194
189
  for (int i = 0; i < wsp_ggml_backend_sched_get_n_backends(sched); ++i) {
195
190
  wsp_ggml_backend_t backend = wsp_ggml_backend_sched_get_backend(sched, i);
196
- if (wsp_ggml_backend_is_cpu(backend)) {
197
- wsp_ggml_backend_cpu_set_n_threads(backend, n_threads);
198
- }
199
- #ifdef WSP_GGML_USE_BLAS
200
- if (wsp_ggml_backend_is_blas(backend)) {
201
- wsp_ggml_backend_blas_set_n_threads(backend, n_threads);
191
+ wsp_ggml_backend_dev_t dev = wsp_ggml_backend_get_device(backend);
192
+ wsp_ggml_backend_reg_t reg = dev ? wsp_ggml_backend_dev_backend_reg(dev) : nullptr;
193
+
194
+ auto * fn_set_n_threads = (wsp_ggml_backend_set_n_threads_t) wsp_ggml_backend_reg_get_proc_address(reg, "wsp_ggml_backend_set_n_threads");
195
+ if (fn_set_n_threads) {
196
+ fn_set_n_threads(backend, n_threads);
202
197
  }
203
- #endif
204
198
  }
205
199
 
206
- bool t = wsp_ggml_backend_sched_graph_compute(sched, graph) == WSP_GGML_STATUS_SUCCESS;
207
- wsp_ggml_backend_sched_reset(sched);
200
+ const bool t = (wsp_ggml_backend_sched_graph_compute(sched, graph) == WSP_GGML_STATUS_SUCCESS);
201
+
202
+ if (!t || sched_reset) {
203
+ wsp_ggml_backend_sched_reset(sched);
204
+ }
205
+
206
+ return t;
207
+ }
208
+
209
+ // TODO: move these functions to ggml-base with support for ggml-backend?
210
+
211
+ static wsp_ggml_tensor * whisper_set_f32(struct wsp_ggml_tensor * t, float v) {
212
+ WSP_GGML_ASSERT(t->type == WSP_GGML_TYPE_F32);
213
+ WSP_GGML_ASSERT(wsp_ggml_is_contiguous(t));
214
+ size_t nels = wsp_ggml_nelements(t);
215
+ for (size_t i = 0; i < nels; ++i) {
216
+ ((float *) t->data)[i] = v;
217
+ }
218
+ return t;
219
+ }
220
+
221
+ static wsp_ggml_tensor * whisper_set_i32(struct wsp_ggml_tensor * t, int32_t v) {
222
+ WSP_GGML_ASSERT(t->type == WSP_GGML_TYPE_I32);
223
+ WSP_GGML_ASSERT(wsp_ggml_is_contiguous(t));
224
+ size_t nels = wsp_ggml_nelements(t);
225
+ for (size_t i = 0; i < nels; ++i) {
226
+ ((int32_t *) t->data)[i] = v;
227
+ }
208
228
  return t;
209
229
  }
210
230
 
231
+ static float whisper_get_f32_nd(const struct wsp_ggml_tensor * t, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
232
+ WSP_GGML_ASSERT(t->type == WSP_GGML_TYPE_F32);
233
+ void * data = (char *) t->data + i0*t->nb[0] + i1*t->nb[1] + i2*t->nb[2] + i3*t->nb[3];
234
+ return *(float *) data;
235
+ }
236
+
237
+ static void whisper_set_f32_nd(struct wsp_ggml_tensor * t, int64_t i0, int64_t i1, int64_t i2, int64_t i3, float v) {
238
+ WSP_GGML_ASSERT(t->type == WSP_GGML_TYPE_F32);
239
+ void * data = (char *) t->data + i0*t->nb[0] + i1*t->nb[1] + i2*t->nb[2] + i3*t->nb[3];
240
+ *(float *) data = v;
241
+ }
242
+
243
+ static int32_t whisper_get_i32_nd(const struct wsp_ggml_tensor * t, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
244
+ WSP_GGML_ASSERT(t->type == WSP_GGML_TYPE_I32);
245
+ void * data = (char *) t->data + i0*t->nb[0] + i1*t->nb[1] + i2*t->nb[2] + i3*t->nb[3];
246
+ return *(int32_t *) data;
247
+ }
248
+
249
+ static void whisper_set_i32_nd(struct wsp_ggml_tensor * t, int64_t i0, int64_t i1, int64_t i2, int64_t i3, int32_t v) {
250
+ WSP_GGML_ASSERT(t->type == WSP_GGML_TYPE_I32);
251
+ void * data = (char *) t->data + i0*t->nb[0] + i1*t->nb[1] + i2*t->nb[2] + i3*t->nb[3];
252
+ *(int32_t *) data = v;
253
+ }
254
+
211
255
  // faster matrix multiplications for tensors that do not have dimension 0 divisible by "pad"
212
256
  // the idea is to represent the original matrix multiplication:
213
257
  //
@@ -451,6 +495,7 @@ struct whisper_segment {
451
495
  int64_t t1;
452
496
 
453
497
  std::string text;
498
+ float no_speech_prob;
454
499
 
455
500
  std::vector<whisper_token_data> tokens;
456
501
 
@@ -543,7 +588,7 @@ static bool whisper_sched_graph_init(struct whisper_sched & allocr, std::vector<
543
588
  auto & sched = allocr.sched;
544
589
  auto & meta = allocr.meta;
545
590
 
546
- sched = wsp_ggml_backend_sched_new(backends.data(), nullptr, backends.size(), WHISPER_MAX_NODES, false);
591
+ sched = wsp_ggml_backend_sched_new(backends.data(), nullptr, backends.size(), WHISPER_MAX_NODES, false, true);
547
592
 
548
593
  meta.resize(wsp_ggml_tensor_overhead()*WHISPER_MAX_NODES + wsp_ggml_graph_overhead());
549
594
 
@@ -739,10 +784,10 @@ struct whisper_model {
739
784
  std::vector<whisper_layer_decoder> layers_decoder;
740
785
 
741
786
  // ggml context that contains all the meta information about the model tensors
742
- struct wsp_ggml_context * ctx = nullptr;
787
+ std::vector<wsp_ggml_context *> ctxs;
743
788
 
744
789
  // the model backend data is read-only and can be shared between processors
745
- wsp_ggml_backend_buffer_t buffer = nullptr;
790
+ std::vector<wsp_ggml_backend_buffer_t> buffers;
746
791
 
747
792
  // tensors
748
793
  int n_loaded;
@@ -814,6 +859,11 @@ struct whisper_aheads_masks {
814
859
  wsp_ggml_backend_buffer_t buffer = nullptr;
815
860
  };
816
861
 
862
+ struct vad_time_mapping {
863
+ int64_t processed_time; // Time in processed (VAD) audio
864
+ int64_t original_time; // Corresponding time in original audio
865
+ };
866
+
817
867
  struct whisper_state {
818
868
  int64_t t_sample_us = 0;
819
869
  int64_t t_encode_us = 0;
@@ -890,6 +940,7 @@ struct whisper_state {
890
940
  whisper_token tid_last;
891
941
 
892
942
  std::vector<float> energy; // PCM signal energy
943
+ float no_speech_prob = 0.0f;
893
944
 
894
945
  // [EXPERIMENTAL] Token-level timestamps with DTW
895
946
  whisper_aheads_masks aheads_masks;
@@ -898,6 +949,19 @@ struct whisper_state {
898
949
 
899
950
  // [EXPERIMENTAL] speed-up techniques
900
951
  int32_t exp_n_audio_ctx = 0; // 0 - use default
952
+
953
+ whisper_vad_context * vad_context = nullptr;
954
+
955
+ struct vad_segment_info {
956
+ int64_t orig_start;
957
+ int64_t orig_end;
958
+ int64_t vad_start;
959
+ int64_t vad_end;
960
+ };
961
+ std::vector<vad_segment_info> vad_segments;
962
+ bool has_vad_segments = false;
963
+
964
+ std::vector<vad_time_mapping> vad_mapping_table;
901
965
  };
902
966
 
903
967
  struct whisper_context {
@@ -1254,65 +1318,36 @@ static size_t aheads_masks_nbytes(struct whisper_aheads_masks & aheads_masks) {
1254
1318
  }
1255
1319
 
1256
1320
  static wsp_ggml_backend_t whisper_backend_init_gpu(const whisper_context_params & params) {
1257
- wsp_ggml_backend_t result = NULL;
1258
-
1259
1321
  wsp_ggml_log_set(g_state.log_callback, g_state.log_callback_user_data);
1260
1322
 
1261
- #ifdef WSP_GGML_USE_CUDA
1262
- if (params.use_gpu) {
1263
- WHISPER_LOG_INFO("%s: using CUDA backend\n", __func__);
1264
- result = wsp_ggml_backend_cuda_init(params.gpu_device);
1265
- if (!result) {
1266
- WHISPER_LOG_ERROR("%s: wsp_ggml_backend_cuda_init() failed\n", __func__);
1267
- }
1268
- }
1269
- #endif
1323
+ wsp_ggml_backend_dev_t dev = nullptr;
1270
1324
 
1271
- #ifdef WSP_GGML_USE_METAL
1325
+ int cnt = 0;
1272
1326
  if (params.use_gpu) {
1273
- WHISPER_LOG_INFO("%s: using Metal backend\n", __func__);
1274
- result = wsp_ggml_backend_metal_init();
1275
- if (!result) {
1276
- WHISPER_LOG_ERROR("%s: wsp_ggml_backend_metal_init() failed\n", __func__);
1277
- } else if (!wsp_ggml_backend_metal_supports_family(result, 7)) {
1278
- WHISPER_LOG_ERROR("%s: Metal GPU does not support family 7 - falling back to CPU\n", __func__);
1279
- wsp_ggml_backend_free(result);
1280
- result = NULL;
1281
- }
1282
- }
1283
- #endif
1327
+ for (size_t i = 0; i < wsp_ggml_backend_dev_count(); ++i) {
1328
+ wsp_ggml_backend_dev_t dev_cur = wsp_ggml_backend_dev_get(i);
1329
+ if (wsp_ggml_backend_dev_type(dev_cur) == WSP_GGML_BACKEND_DEVICE_TYPE_GPU) {
1330
+ if (cnt == 0 || cnt == params.gpu_device) {
1331
+ dev = dev_cur;
1332
+ }
1284
1333
 
1285
- #ifdef WSP_GGML_USE_SYCL
1286
- if (params.use_gpu) {
1287
- WHISPER_LOG_INFO("%s: using SYCL backend\n", __func__);
1288
- result = wsp_ggml_backend_sycl_init(params.gpu_device);
1289
- if (!result) {
1290
- WHISPER_LOG_ERROR("%s: wsp_ggml_backend_sycl_init() failed\n", __func__);
1334
+ if (++cnt > params.gpu_device) {
1335
+ break;
1336
+ }
1337
+ }
1291
1338
  }
1292
1339
  }
1293
- #endif
1294
1340
 
1295
- #ifdef WSP_GGML_USE_VULKAN
1296
- if (params.use_gpu) {
1297
- WHISPER_LOG_INFO("%s: using Vulkan backend\n", __func__);
1298
- result = wsp_ggml_backend_vk_init(params.gpu_device);
1299
- if (!result) {
1300
- WHISPER_LOG_ERROR("%s: wsp_ggml_backend_vk_init() failed\n", __func__);
1301
- }
1341
+ if (dev == nullptr) {
1342
+ WHISPER_LOG_INFO("%s: no GPU found\n", __func__);
1343
+ return nullptr;
1302
1344
  }
1303
- #endif
1304
1345
 
1305
- #ifdef WSP_GGML_USE_CANN
1306
- if (params.use_gpu) {
1307
- WHISPER_LOG_INFO("%s: using CANN backend\n", __func__);
1308
- result = wsp_ggml_backend_cann_init(params.gpu_device);
1309
- if (!result) {
1310
- WHISPER_LOG_ERROR("%s: wsp_ggml_backend_cann_init() failed\n", __func__);
1311
- }
1346
+ WHISPER_LOG_INFO("%s: using %s backend\n", __func__, wsp_ggml_backend_dev_name(dev));
1347
+ wsp_ggml_backend_t result = wsp_ggml_backend_dev_init(dev, nullptr);
1348
+ if (!result) {
1349
+ WHISPER_LOG_ERROR("%s: failed to initialize %s backend\n", __func__, wsp_ggml_backend_dev_name(dev));
1312
1350
  }
1313
- #endif
1314
-
1315
- WSP_GGML_UNUSED(params);
1316
1351
 
1317
1352
  return result;
1318
1353
  }
@@ -1326,53 +1361,132 @@ static std::vector<wsp_ggml_backend_t> whisper_backend_init(const whisper_contex
1326
1361
  result.push_back(backend_gpu);
1327
1362
  }
1328
1363
 
1329
- #ifdef WSP_GGML_USE_BLAS
1330
- {
1331
- WHISPER_LOG_INFO("%s: using BLAS backend\n", __func__);
1332
- wsp_ggml_backend_t backend_blas = wsp_ggml_backend_blas_init();
1333
- if (!backend_blas) {
1334
- WHISPER_LOG_ERROR("%s: wsp_ggml_backend_blas_init() failed\n", __func__);
1335
- } else {
1336
- result.push_back(backend_blas);
1364
+ // ACCEL backends
1365
+ for (size_t i = 0; i < wsp_ggml_backend_dev_count(); ++i) {
1366
+ wsp_ggml_backend_dev_t dev = wsp_ggml_backend_dev_get(i);
1367
+ if (wsp_ggml_backend_dev_type(dev) == WSP_GGML_BACKEND_DEVICE_TYPE_ACCEL) {
1368
+ WHISPER_LOG_INFO("%s: using %s backend\n", __func__, wsp_ggml_backend_dev_name(dev));
1369
+ wsp_ggml_backend_t backend = wsp_ggml_backend_dev_init(dev, nullptr);
1370
+ if (!backend) {
1371
+ WHISPER_LOG_ERROR("%s: failed to initialize %s backend\n", __func__, wsp_ggml_backend_dev_name(dev));
1372
+ continue;
1373
+ }
1374
+ result.push_back(backend);
1337
1375
  }
1338
1376
  }
1339
- #endif
1340
1377
 
1341
- WSP_GGML_UNUSED(params);
1342
-
1343
- result.push_back(wsp_ggml_backend_cpu_init());
1378
+ wsp_ggml_backend_t backend_cpu = wsp_ggml_backend_init_by_type(WSP_GGML_BACKEND_DEVICE_TYPE_CPU, nullptr);
1379
+ if (backend_cpu == nullptr) {
1380
+ throw std::runtime_error("failed to initialize CPU backend");
1381
+ }
1382
+ result.push_back(backend_cpu);
1344
1383
 
1345
1384
  return result;
1346
1385
  }
1347
1386
 
1348
- static wsp_ggml_backend_buffer_type_t whisper_default_buffer_type(const whisper_context_params & params) {
1349
- wsp_ggml_backend_buffer_type_t result = nullptr;
1387
+ using buft_list_t = std::vector<std::pair<wsp_ggml_backend_dev_t, wsp_ggml_backend_buffer_type_t>>;
1350
1388
 
1351
- params.use_gpu || (result = wsp_ggml_backend_cpu_buffer_type());
1389
+ static buft_list_t make_buft_list(whisper_context_params & params) {
1390
+ // Prio order: GPU -> CPU Extra -> CPU
1391
+ buft_list_t buft_list;
1352
1392
 
1353
- #ifdef WSP_GGML_USE_CUDA
1354
- result || (result = wsp_ggml_backend_cuda_buffer_type(params.gpu_device));
1355
- #endif
1393
+ // GPU
1394
+ if (params.use_gpu) {
1395
+ int cnt = 0;
1396
+ for (size_t i = 0; i < wsp_ggml_backend_dev_count(); ++i) {
1397
+ wsp_ggml_backend_dev_t dev = wsp_ggml_backend_dev_get(i);
1398
+ if (wsp_ggml_backend_dev_type(dev) == WSP_GGML_BACKEND_DEVICE_TYPE_GPU) {
1399
+ if (cnt == 0 || cnt == params.gpu_device) {
1400
+ auto * buft = wsp_ggml_backend_dev_buffer_type(dev);
1401
+ if (buft) {
1402
+ buft_list.emplace_back(dev, buft);
1403
+ }
1404
+ }
1356
1405
 
1357
- #ifdef WSP_GGML_USE_METAL
1358
- result || (result = wsp_ggml_backend_metal_buffer_type());
1359
- #endif
1406
+ if (++cnt > params.gpu_device) {
1407
+ break;
1408
+ }
1409
+ }
1410
+ }
1411
+ }
1360
1412
 
1361
- #ifdef WSP_GGML_USE_SYCL
1362
- result || (result = wsp_ggml_backend_sycl_buffer_type(params.gpu_device));
1363
- #endif
1413
+ // CPU Extra
1414
+ auto * cpu_dev = wsp_ggml_backend_dev_by_type(WSP_GGML_BACKEND_DEVICE_TYPE_CPU);
1415
+ auto * cpu_reg = wsp_ggml_backend_dev_backend_reg(cpu_dev);
1416
+ auto get_extra_bufts_fn = (wsp_ggml_backend_dev_get_extra_bufts_t)
1417
+ wsp_ggml_backend_reg_get_proc_address(cpu_reg, "wsp_ggml_backend_dev_get_extra_bufts");
1418
+ if (get_extra_bufts_fn) {
1419
+ wsp_ggml_backend_buffer_type_t * extra_bufts = get_extra_bufts_fn(cpu_dev);
1420
+ while (extra_bufts && *extra_bufts) {
1421
+ buft_list.emplace_back(cpu_dev, *extra_bufts);
1422
+ ++extra_bufts;
1423
+ }
1424
+ }
1364
1425
 
1365
- #ifdef WSP_GGML_USE_VULKAN
1366
- result || (result = wsp_ggml_backend_vk_buffer_type(params.gpu_device));
1367
- #endif
1426
+ // CPU
1427
+ buft_list.emplace_back(cpu_dev, wsp_ggml_backend_cpu_buffer_type());
1368
1428
 
1369
- #ifdef WSP_GGML_USE_CANN
1370
- result || (result == wsp_ggml_backend_cann_buffer_type(params.gpu_device));
1371
- #endif
1429
+ return buft_list;
1430
+ }
1372
1431
 
1373
- result || (result = wsp_ggml_backend_cpu_buffer_type());
1432
+ static bool weight_buft_supported(const whisper_hparams & hparams, wsp_ggml_tensor * w, wsp_ggml_op op, wsp_ggml_backend_buffer_type_t buft, wsp_ggml_backend_dev_t dev) {
1433
+ bool op_supported = true;
1374
1434
 
1375
- return result;
1435
+ if (wsp_ggml_backend_dev_type(dev) == WSP_GGML_BACKEND_DEVICE_TYPE_GPU ||
1436
+ (wsp_ggml_backend_dev_type(dev) == WSP_GGML_BACKEND_DEVICE_TYPE_CPU && buft == wsp_ggml_backend_cpu_buffer_type())) {
1437
+ // GPU and default CPU backend support all operators
1438
+ op_supported = true;
1439
+ } else {
1440
+ switch (op) {
1441
+ // The current extra_buffer_type implementations only support WSP_GGML_OP_MUL_MAT
1442
+ case WSP_GGML_OP_MUL_MAT: {
1443
+ wsp_ggml_init_params params = {
1444
+ /*.mem_size =*/ 2 * wsp_ggml_tensor_overhead(),
1445
+ /*.mem_buffer =*/ nullptr,
1446
+ /*.no_alloc =*/ true,
1447
+ };
1448
+
1449
+ wsp_ggml_context_ptr ctx_ptr { wsp_ggml_init(params) };
1450
+ if (!ctx_ptr) {
1451
+ throw std::runtime_error("failed to create ggml context");
1452
+ }
1453
+ wsp_ggml_context * ctx = ctx_ptr.get();
1454
+
1455
+ wsp_ggml_tensor * op_tensor = nullptr;
1456
+
1457
+ int64_t n_ctx = hparams.n_audio_ctx;
1458
+ wsp_ggml_tensor * b = wsp_ggml_new_tensor_4d(ctx, WSP_GGML_TYPE_F32, w->ne[0], n_ctx, w->ne[2], w->ne[3]);
1459
+ op_tensor = wsp_ggml_mul_mat(ctx, w, b);
1460
+
1461
+ // create a temporary dummy buffer for the weight so that supports_op can check the buffer type
1462
+ WSP_GGML_ASSERT(w->buffer == nullptr);
1463
+ w->buffer = wsp_ggml_backend_buft_alloc_buffer(buft, 0);
1464
+ op_supported = wsp_ggml_backend_dev_supports_op(dev, op_tensor);
1465
+ wsp_ggml_backend_buffer_free(w->buffer);
1466
+ w->buffer = nullptr;
1467
+ break;
1468
+ }
1469
+ default: {
1470
+ op_supported = false;
1471
+ break;
1472
+ }
1473
+ };
1474
+ }
1475
+
1476
+ return op_supported;
1477
+ }
1478
+
1479
+ static wsp_ggml_backend_buffer_type_t select_weight_buft(const whisper_hparams & hparams, wsp_ggml_tensor * w, wsp_ggml_op op, buft_list_t buft_list) {
1480
+ WSP_GGML_ASSERT(!buft_list.empty());
1481
+ for (const auto & p : buft_list) {
1482
+ wsp_ggml_backend_dev_t dev = p.first;
1483
+ wsp_ggml_backend_buffer_type_t buft = p.second;
1484
+ if (weight_buft_supported(hparams, w, op, buft, dev)) {
1485
+ return buft;
1486
+ }
1487
+ }
1488
+
1489
+ return nullptr;
1376
1490
  }
1377
1491
 
1378
1492
  // load the model from a ggml file
@@ -1581,31 +1695,65 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
1581
1695
  const wsp_ggml_type wtype = wctx.wtype;
1582
1696
  const wsp_ggml_type vtype = wctx.wtype == WSP_GGML_TYPE_F32 ? WSP_GGML_TYPE_F32 : WSP_GGML_TYPE_F16; // conv type
1583
1697
 
1584
- // create the ggml context
1585
- {
1586
- const auto & hparams = model.hparams;
1698
+ const auto & hparams = model.hparams;
1587
1699
 
1588
- const int n_audio_layer = hparams.n_audio_layer;
1589
- const int n_text_layer = hparams.n_text_layer;
1700
+ const int n_audio_layer = hparams.n_audio_layer;
1701
+ const int n_text_layer = hparams.n_text_layer;
1590
1702
 
1591
- const size_t n_tensors = 10 /* input */ + 15 + 15*n_audio_layer + 24*n_text_layer;
1703
+ const size_t n_tensors = 10 /* input */ + 15 + 15*n_audio_layer + 24*n_text_layer;
1592
1704
 
1593
- struct wsp_ggml_init_params params = {
1594
- /*.mem_size =*/ n_tensors*wsp_ggml_tensor_overhead(),
1595
- /*.mem_buffer =*/ nullptr,
1596
- /*.no_alloc =*/ true,
1597
- };
1705
+ std::map<wsp_ggml_backend_buffer_type_t, wsp_ggml_context *> ctx_map;
1706
+ auto get_ctx = [&](wsp_ggml_backend_buffer_type_t buft) -> wsp_ggml_context * {
1707
+ auto it = ctx_map.find(buft);
1708
+ if (it == ctx_map.end()) {
1709
+ wsp_ggml_init_params params = {
1710
+ /*.mem_size =*/ n_tensors * wsp_ggml_tensor_overhead(),
1711
+ /*.mem_buffer =*/ nullptr,
1712
+ /*.no_alloc =*/ true,
1713
+ };
1598
1714
 
1599
- model.ctx = wsp_ggml_init(params);
1600
- if (!model.ctx) {
1601
- WHISPER_LOG_ERROR("%s: wsp_ggml_init() failed\n", __func__);
1602
- return false;
1715
+ wsp_ggml_context * ctx = wsp_ggml_init(params);
1716
+ if (!ctx) {
1717
+ throw std::runtime_error("failed to create ggml context");
1718
+ }
1719
+
1720
+ ctx_map[buft] = ctx;
1721
+ model.ctxs.emplace_back(ctx);
1722
+
1723
+ return ctx;
1603
1724
  }
1604
- }
1725
+
1726
+ return it->second;
1727
+ };
1728
+
1729
+ // Create a list of available bufts, in priority order
1730
+ buft_list_t buft_list = make_buft_list(wctx.params);
1731
+
1732
+ auto create_tensor = [&](asr_tensor type, asr_system system, wsp_ggml_tensor * meta, int layer = 0) -> wsp_ggml_tensor * {
1733
+ wsp_ggml_op op = ASR_TENSOR_INFO.at(type);
1734
+ wsp_ggml_backend_buffer_type_t buft = select_weight_buft(hparams, meta, op, buft_list);
1735
+ if (!buft) {
1736
+ throw std::runtime_error(format("failed to find a compatible buffer type for tensor %s", ASR_TENSOR_NAMES.at(system).at(type)));
1737
+ }
1738
+
1739
+ wsp_ggml_context * ctx = get_ctx(buft);
1740
+ wsp_ggml_tensor * tensor = wsp_ggml_dup_tensor(ctx, meta);
1741
+
1742
+ model.tensors[format(ASR_TENSOR_NAMES.at(system).at(type), layer)] = tensor;
1743
+
1744
+ return tensor;
1745
+ };
1746
+
1605
1747
 
1606
1748
  // prepare tensors for the weights
1607
1749
  {
1608
- auto & ctx = model.ctx;
1750
+ wsp_ggml_init_params params = {
1751
+ /*.mem_size =*/ n_tensors * wsp_ggml_tensor_overhead(),
1752
+ /*.mem_buffer =*/ nullptr,
1753
+ /*.no_alloc =*/ true,
1754
+ };
1755
+
1756
+ wsp_ggml_context * ctx = wsp_ggml_init(params);
1609
1757
 
1610
1758
  const auto & hparams = model.hparams;
1611
1759
 
@@ -1625,189 +1773,108 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
1625
1773
  model.layers_decoder.resize(n_text_layer);
1626
1774
 
1627
1775
  // encoder
1628
- {
1629
- model.e_pe = wsp_ggml_new_tensor_2d(ctx, WSP_GGML_TYPE_F32, n_audio_state, n_audio_ctx);
1630
-
1631
- model.e_conv_1_w = wsp_ggml_new_tensor_3d(ctx, vtype, 3, n_mels, n_audio_state);
1632
- model.e_conv_1_b = wsp_ggml_new_tensor_2d(ctx, WSP_GGML_TYPE_F32, 1, n_audio_state);
1633
-
1634
- model.e_conv_2_w = wsp_ggml_new_tensor_3d(ctx, vtype, 3, n_audio_state, n_audio_state);
1635
- model.e_conv_2_b = wsp_ggml_new_tensor_2d(ctx, WSP_GGML_TYPE_F32, 1, n_audio_state);
1776
+ model.e_pe = create_tensor(ASR_TENSOR_ENC_POS_EMBD, ASR_SYSTEM_ENCODER, wsp_ggml_new_tensor_2d(ctx, WSP_GGML_TYPE_F32, n_audio_state, n_audio_ctx));
1636
1777
 
1637
- model.e_ln_w = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_audio_state);
1638
- model.e_ln_b = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_audio_state);
1778
+ model.e_conv_1_w = create_tensor(ASR_TENSOR_CONV1_WEIGHT, ASR_SYSTEM_ENCODER, wsp_ggml_new_tensor_3d(ctx, vtype, 3, n_mels, n_audio_state));
1779
+ model.e_conv_1_b = create_tensor(ASR_TENSOR_CONV1_BIAS, ASR_SYSTEM_ENCODER, wsp_ggml_new_tensor_2d(ctx, WSP_GGML_TYPE_F32, 1, n_audio_state));
1639
1780
 
1640
- // map by name
1641
- model.tensors["encoder.positional_embedding"] = model.e_pe;
1781
+ model.e_conv_2_w = create_tensor(ASR_TENSOR_CONV2_WEIGHT, ASR_SYSTEM_ENCODER, wsp_ggml_new_tensor_3d(ctx, vtype, 3, n_audio_state, n_audio_state));
1782
+ model.e_conv_2_b = create_tensor(ASR_TENSOR_CONV2_BIAS, ASR_SYSTEM_ENCODER, wsp_ggml_new_tensor_2d(ctx, WSP_GGML_TYPE_F32, 1, n_audio_state));
1642
1783
 
1643
- model.tensors["encoder.conv1.weight"] = model.e_conv_1_w;
1644
- model.tensors["encoder.conv1.bias"] = model.e_conv_1_b;
1784
+ model.e_ln_w = create_tensor(ASR_TENSOR_LN_WEIGHT, ASR_SYSTEM_ENCODER, wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_audio_state));
1785
+ model.e_ln_b = create_tensor(ASR_TENSOR_LN_POST_BIAS, ASR_SYSTEM_ENCODER, wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_audio_state));
1645
1786
 
1646
- model.tensors["encoder.conv2.weight"] = model.e_conv_2_w;
1647
- model.tensors["encoder.conv2.bias"] = model.e_conv_2_b;
1787
+ for (int i = 0; i < n_audio_layer; ++i) {
1788
+ auto & layer = model.layers_encoder[i];
1648
1789
 
1649
- model.tensors["encoder.ln_post.weight"] = model.e_ln_w;
1650
- model.tensors["encoder.ln_post.bias"] = model.e_ln_b;
1790
+ layer.mlp_ln_w = create_tensor(ASR_TENSOR_MLP_LN_WEIGHT, ASR_SYSTEM_ENCODER, wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_audio_state), i);
1791
+ layer.mlp_ln_b = create_tensor(ASR_TENSOR_MLP_LN_BIAS, ASR_SYSTEM_ENCODER, wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_audio_state), i);
1651
1792
 
1652
- for (int i = 0; i < n_audio_layer; ++i) {
1653
- auto & layer = model.layers_encoder[i];
1793
+ layer.mlp_0_w = create_tensor(ASR_TENSOR_MLP_0_WEIGHT, ASR_SYSTEM_ENCODER, wsp_ggml_new_tensor_2d(ctx, wtype, n_audio_state, 4*n_audio_state), i);
1794
+ layer.mlp_0_b = create_tensor(ASR_TENSOR_MLP_0_BIAS, ASR_SYSTEM_ENCODER, wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, 4*n_audio_state), i);
1654
1795
 
1655
- layer.mlp_ln_w = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_audio_state);
1656
- layer.mlp_ln_b = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_audio_state);
1796
+ layer.mlp_1_w = create_tensor(ASR_TENSOR_MLP_2_WEIGHT, ASR_SYSTEM_ENCODER, wsp_ggml_new_tensor_2d(ctx, wtype, 4*n_audio_state, n_audio_state), i);
1797
+ layer.mlp_1_b = create_tensor(ASR_TENSOR_MLP_2_BIAS, ASR_SYSTEM_ENCODER, wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_audio_state), i);
1657
1798
 
1658
- layer.mlp_0_w = wsp_ggml_new_tensor_2d(ctx, wtype, n_audio_state, 4*n_audio_state);
1659
- layer.mlp_0_b = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, 4*n_audio_state);
1799
+ layer.attn_ln_0_w = create_tensor(ASR_TENSOR_ATTN_LN_WEIGHT, ASR_SYSTEM_ENCODER, wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_audio_state), i);
1800
+ layer.attn_ln_0_b = create_tensor(ASR_TENSOR_ATTN_LN_BIAS, ASR_SYSTEM_ENCODER, wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_audio_state), i);
1660
1801
 
1661
- layer.mlp_1_w = wsp_ggml_new_tensor_2d(ctx, wtype, 4*n_audio_state, n_audio_state);
1662
- layer.mlp_1_b = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_audio_state);
1802
+ layer.attn_q_w = create_tensor(ASR_TENSOR_ATTN_QUERY_WEIGHT, ASR_SYSTEM_ENCODER, wsp_ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state), i);
1803
+ layer.attn_q_b = create_tensor(ASR_TENSOR_ATTN_QUERY_BIAS, ASR_SYSTEM_ENCODER, wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_audio_state), i);
1663
1804
 
1664
- layer.attn_ln_0_w = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_audio_state);
1665
- layer.attn_ln_0_b = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_audio_state);
1805
+ layer.attn_k_w = create_tensor(ASR_TENSOR_ATTN_KEY_WEIGHT, ASR_SYSTEM_ENCODER, wsp_ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state), i);
1666
1806
 
1667
- layer.attn_q_w = wsp_ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state);
1668
- layer.attn_q_b = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_audio_state);
1807
+ layer.attn_v_w = create_tensor(ASR_TENSOR_ATTN_VALUE_WEIGHT, ASR_SYSTEM_ENCODER, wsp_ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state), i);
1808
+ layer.attn_v_b = create_tensor(ASR_TENSOR_ATTN_VALUE_BIAS, ASR_SYSTEM_ENCODER, wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_audio_state), i);
1669
1809
 
1670
- layer.attn_k_w = wsp_ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state);
1671
-
1672
- layer.attn_v_w = wsp_ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state);
1673
- layer.attn_v_b = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_audio_state);
1674
-
1675
- layer.attn_ln_1_w = wsp_ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state);
1676
- layer.attn_ln_1_b = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_audio_state);
1677
-
1678
- // map by name
1679
- model.tensors["encoder.blocks." + std::to_string(i) + ".mlp_ln.weight"] = layer.mlp_ln_w;
1680
- model.tensors["encoder.blocks." + std::to_string(i) + ".mlp_ln.bias"] = layer.mlp_ln_b;
1681
-
1682
- model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.0.weight"] = layer.mlp_0_w;
1683
- model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.0.bias"] = layer.mlp_0_b;
1684
-
1685
- model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.2.weight"] = layer.mlp_1_w;
1686
- model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.2.bias"] = layer.mlp_1_b;
1687
-
1688
- model.tensors["encoder.blocks." + std::to_string(i) + ".attn_ln.weight"] = layer.attn_ln_0_w;
1689
- model.tensors["encoder.blocks." + std::to_string(i) + ".attn_ln.bias"] = layer.attn_ln_0_b;
1690
-
1691
- model.tensors["encoder.blocks." + std::to_string(i) + ".attn.query.weight"] = layer.attn_q_w;
1692
- model.tensors["encoder.blocks." + std::to_string(i) + ".attn.query.bias"] = layer.attn_q_b;
1693
-
1694
- model.tensors["encoder.blocks." + std::to_string(i) + ".attn.key.weight"] = layer.attn_k_w;
1695
-
1696
- model.tensors["encoder.blocks." + std::to_string(i) + ".attn.value.weight"] = layer.attn_v_w;
1697
- model.tensors["encoder.blocks." + std::to_string(i) + ".attn.value.bias"] = layer.attn_v_b;
1698
-
1699
- model.tensors["encoder.blocks." + std::to_string(i) + ".attn.out.weight"] = layer.attn_ln_1_w;
1700
- model.tensors["encoder.blocks." + std::to_string(i) + ".attn.out.bias"] = layer.attn_ln_1_b;
1701
- }
1810
+ layer.attn_ln_1_w = create_tensor(ASR_TENSOR_ATTN_OUT_WEIGHT, ASR_SYSTEM_ENCODER, wsp_ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state), i);
1811
+ layer.attn_ln_1_b = create_tensor(ASR_TENSOR_ATTN_OUT_BIAS, ASR_SYSTEM_ENCODER, wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_audio_state), i);
1702
1812
  }
1703
1813
 
1704
1814
  // decoder
1705
- {
1706
- model.d_pe = wsp_ggml_new_tensor_2d(ctx, WSP_GGML_TYPE_F32, n_text_state, n_text_ctx);
1707
-
1708
- model.d_te = wsp_ggml_new_tensor_2d(ctx, wtype, n_text_state, n_vocab);
1709
-
1710
- model.d_ln_w = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_text_state);
1711
- model.d_ln_b = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_text_state);
1712
-
1713
- // map by name
1714
- model.tensors["decoder.positional_embedding"] = model.d_pe;
1715
-
1716
- model.tensors["decoder.token_embedding.weight"] = model.d_te;
1815
+ model.d_pe = create_tensor(ASR_TENSOR_DEC_POS_EMBD, ASR_SYSTEM_DECODER, wsp_ggml_new_tensor_2d(ctx, WSP_GGML_TYPE_F32, n_text_state, n_text_ctx));
1717
1816
 
1718
- model.tensors["decoder.ln.weight"] = model.d_ln_w;
1719
- model.tensors["decoder.ln.bias"] = model.d_ln_b;
1817
+ model.d_te = create_tensor(ASR_TENSOR_DEC_TOKEN_EMBD_WEIGHT, ASR_SYSTEM_DECODER, wsp_ggml_new_tensor_2d(ctx, wtype, n_text_state, n_vocab));
1720
1818
 
1721
- for (int i = 0; i < n_text_layer; ++i) {
1722
- auto & layer = model.layers_decoder[i];
1819
+ model.d_ln_w = create_tensor(ASR_TENSOR_LN_WEIGHT, ASR_SYSTEM_DECODER, wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_text_state));
1820
+ model.d_ln_b = create_tensor(ASR_TENSOR_LN_BIAS, ASR_SYSTEM_DECODER, wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_text_state));
1723
1821
 
1724
- layer.mlp_ln_w = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_text_state);
1725
- layer.mlp_ln_b = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_text_state);
1822
+ for (int i = 0; i < n_text_layer; ++i) {
1823
+ auto & layer = model.layers_decoder[i];
1726
1824
 
1727
- layer.mlp_0_w = wsp_ggml_new_tensor_2d(ctx, wtype, n_text_state, 4*n_text_state);
1728
- layer.mlp_0_b = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, 4*n_text_state);
1825
+ layer.mlp_ln_w = create_tensor(ASR_TENSOR_MLP_LN_WEIGHT, ASR_SYSTEM_DECODER, wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_text_state), i);
1826
+ layer.mlp_ln_b = create_tensor(ASR_TENSOR_MLP_LN_BIAS, ASR_SYSTEM_DECODER, wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_text_state), i);
1729
1827
 
1730
- layer.mlp_1_w = wsp_ggml_new_tensor_2d(ctx, wtype, 4*n_text_state, n_text_state);
1731
- layer.mlp_1_b = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_text_state);
1828
+ layer.mlp_0_w = create_tensor(ASR_TENSOR_MLP_0_WEIGHT, ASR_SYSTEM_DECODER, wsp_ggml_new_tensor_2d(ctx, wtype, n_text_state, 4*n_text_state), i);
1829
+ layer.mlp_0_b = create_tensor(ASR_TENSOR_MLP_0_BIAS, ASR_SYSTEM_DECODER, wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, 4*n_text_state), i);
1732
1830
 
1733
- layer.attn_ln_0_w = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_text_state);
1734
- layer.attn_ln_0_b = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_text_state);
1831
+ layer.mlp_1_w = create_tensor(ASR_TENSOR_MLP_2_WEIGHT, ASR_SYSTEM_DECODER, wsp_ggml_new_tensor_2d(ctx, wtype, 4*n_text_state, n_text_state), i);
1832
+ layer.mlp_1_b = create_tensor(ASR_TENSOR_MLP_2_BIAS, ASR_SYSTEM_DECODER, wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_text_state), i);
1735
1833
 
1736
- layer.attn_q_w = wsp_ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
1737
- layer.attn_q_b = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_text_state);
1834
+ layer.attn_ln_0_w = create_tensor(ASR_TENSOR_ATTN_LN_WEIGHT, ASR_SYSTEM_DECODER, wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_text_state), i);
1835
+ layer.attn_ln_0_b = create_tensor(ASR_TENSOR_ATTN_LN_BIAS, ASR_SYSTEM_DECODER, wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_text_state), i);
1738
1836
 
1739
- layer.attn_k_w = wsp_ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
1837
+ layer.attn_q_w = create_tensor(ASR_TENSOR_ATTN_QUERY_WEIGHT, ASR_SYSTEM_DECODER, wsp_ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state), i);
1838
+ layer.attn_q_b = create_tensor(ASR_TENSOR_ATTN_QUERY_BIAS, ASR_SYSTEM_DECODER, wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_text_state), i);
1740
1839
 
1741
- layer.attn_v_w = wsp_ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
1742
- layer.attn_v_b = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_text_state);
1840
+ layer.attn_k_w = create_tensor(ASR_TENSOR_ATTN_KEY_WEIGHT, ASR_SYSTEM_DECODER, wsp_ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state), i);
1743
1841
 
1744
- layer.attn_ln_1_w = wsp_ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
1745
- layer.attn_ln_1_b = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_text_state);
1842
+ layer.attn_v_w = create_tensor(ASR_TENSOR_ATTN_VALUE_WEIGHT, ASR_SYSTEM_DECODER, wsp_ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state), i);
1843
+ layer.attn_v_b = create_tensor(ASR_TENSOR_ATTN_VALUE_BIAS, ASR_SYSTEM_DECODER, wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_text_state), i);
1746
1844
 
1747
- layer.cross_attn_ln_0_w = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_text_state);
1748
- layer.cross_attn_ln_0_b = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_text_state);
1845
+ layer.attn_ln_1_w = create_tensor(ASR_TENSOR_ATTN_OUT_WEIGHT, ASR_SYSTEM_DECODER, wsp_ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state), i);
1846
+ layer.attn_ln_1_b = create_tensor(ASR_TENSOR_ATTN_OUT_BIAS, ASR_SYSTEM_DECODER, wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_text_state), i);
1749
1847
 
1750
- layer.cross_attn_q_w = wsp_ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
1751
- layer.cross_attn_q_b = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_text_state);
1848
+ layer.cross_attn_ln_0_w = create_tensor(ASR_TENSOR_ATTN_LN_WEIGHT, ASR_SYSTEM_CROSS, wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_text_state), i);
1849
+ layer.cross_attn_ln_0_b = create_tensor(ASR_TENSOR_ATTN_LN_BIAS, ASR_SYSTEM_CROSS, wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_text_state), i);
1752
1850
 
1753
- layer.cross_attn_k_w = wsp_ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
1851
+ layer.cross_attn_q_w = create_tensor(ASR_TENSOR_ATTN_QUERY_WEIGHT, ASR_SYSTEM_CROSS, wsp_ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state), i);
1852
+ layer.cross_attn_q_b = create_tensor(ASR_TENSOR_ATTN_QUERY_BIAS, ASR_SYSTEM_CROSS, wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_text_state), i);
1754
1853
 
1755
- layer.cross_attn_v_w = wsp_ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
1756
- layer.cross_attn_v_b = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_text_state);
1854
+ layer.cross_attn_k_w = create_tensor(ASR_TENSOR_ATTN_KEY_WEIGHT, ASR_SYSTEM_CROSS, wsp_ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state), i);
1757
1855
 
1758
- layer.cross_attn_ln_1_w = wsp_ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
1759
- layer.cross_attn_ln_1_b = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_text_state);
1856
+ layer.cross_attn_v_w = create_tensor(ASR_TENSOR_ATTN_VALUE_WEIGHT, ASR_SYSTEM_CROSS, wsp_ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state), i);
1857
+ layer.cross_attn_v_b = create_tensor(ASR_TENSOR_ATTN_VALUE_BIAS, ASR_SYSTEM_CROSS, wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_text_state), i);
1760
1858
 
1761
- // map by name
1762
- model.tensors["decoder.blocks." + std::to_string(i) + ".mlp_ln.weight"] = layer.mlp_ln_w;
1763
- model.tensors["decoder.blocks." + std::to_string(i) + ".mlp_ln.bias"] = layer.mlp_ln_b;
1764
-
1765
- model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.0.weight"] = layer.mlp_0_w;
1766
- model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.0.bias"] = layer.mlp_0_b;
1767
-
1768
- model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.2.weight"] = layer.mlp_1_w;
1769
- model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.2.bias"] = layer.mlp_1_b;
1770
-
1771
- model.tensors["decoder.blocks." + std::to_string(i) + ".attn_ln.weight"] = layer.attn_ln_0_w;
1772
- model.tensors["decoder.blocks." + std::to_string(i) + ".attn_ln.bias"] = layer.attn_ln_0_b;
1773
-
1774
- model.tensors["decoder.blocks." + std::to_string(i) + ".attn.query.weight"] = layer.attn_q_w;
1775
- model.tensors["decoder.blocks." + std::to_string(i) + ".attn.query.bias"] = layer.attn_q_b;
1776
-
1777
- model.tensors["decoder.blocks." + std::to_string(i) + ".attn.key.weight"] = layer.attn_k_w;
1778
-
1779
- model.tensors["decoder.blocks." + std::to_string(i) + ".attn.value.weight"] = layer.attn_v_w;
1780
- model.tensors["decoder.blocks." + std::to_string(i) + ".attn.value.bias"] = layer.attn_v_b;
1781
-
1782
- model.tensors["decoder.blocks." + std::to_string(i) + ".attn.out.weight"] = layer.attn_ln_1_w;
1783
- model.tensors["decoder.blocks." + std::to_string(i) + ".attn.out.bias"] = layer.attn_ln_1_b;
1784
-
1785
- model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn_ln.weight"] = layer.cross_attn_ln_0_w;
1786
- model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn_ln.bias"] = layer.cross_attn_ln_0_b;
1787
-
1788
- model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.query.weight"] = layer.cross_attn_q_w;
1789
- model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.query.bias"] = layer.cross_attn_q_b;
1790
-
1791
- model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.key.weight"] = layer.cross_attn_k_w;
1792
-
1793
- model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.value.weight"] = layer.cross_attn_v_w;
1794
- model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.value.bias"] = layer.cross_attn_v_b;
1795
-
1796
- model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.out.weight"] = layer.cross_attn_ln_1_w;
1797
- model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.out.bias"] = layer.cross_attn_ln_1_b;
1798
- }
1859
+ layer.cross_attn_ln_1_w = create_tensor(ASR_TENSOR_ATTN_OUT_WEIGHT, ASR_SYSTEM_CROSS, wsp_ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state), i);
1860
+ layer.cross_attn_ln_1_b = create_tensor(ASR_TENSOR_ATTN_OUT_BIAS, ASR_SYSTEM_CROSS, wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, n_text_state), i);
1799
1861
  }
1862
+
1863
+ wsp_ggml_free(ctx);
1800
1864
  }
1801
1865
 
1802
1866
  // allocate tensors in the backend buffers
1803
- model.buffer = wsp_ggml_backend_alloc_ctx_tensors_from_buft(model.ctx, whisper_default_buffer_type(wctx.params));
1804
- if (!model.buffer) {
1805
- WHISPER_LOG_ERROR("%s: failed to allocate memory for the model\n", __func__);
1806
- return false;
1807
- }
1867
+ for (auto & p : ctx_map) {
1868
+ wsp_ggml_backend_buffer_type_t buft = p.first;
1869
+ wsp_ggml_context * ctx = p.second;
1870
+ wsp_ggml_backend_buffer_t buf = wsp_ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
1871
+ if (buf) {
1872
+ model.buffers.emplace_back(buf);
1808
1873
 
1809
- size_t size_main = wsp_ggml_backend_buffer_get_size(model.buffer);
1810
- WHISPER_LOG_INFO("%s: %8s total size = %8.2f MB\n", __func__, wsp_ggml_backend_buffer_name(model.buffer), size_main / 1e6);
1874
+ size_t size_main = wsp_ggml_backend_buffer_get_size(buf);
1875
+ WHISPER_LOG_INFO("%s: %12s total size = %8.2f MB\n", __func__, wsp_ggml_backend_buffer_name(buf), size_main / 1e6);
1876
+ }
1877
+ }
1811
1878
 
1812
1879
  // load weights
1813
1880
  {
@@ -1870,11 +1937,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
1870
1937
  return false;
1871
1938
  }
1872
1939
 
1873
- //wsp_ggml_backend_t backend = wctx.backend;
1874
-
1875
- //printf("%s: [%5.5s] %s\n", __func__, wsp_ggml_backend_name(backend), name.c_str());
1876
-
1877
- if (wsp_ggml_backend_buffer_is_host(model.buffer)) {
1940
+ if (wsp_ggml_backend_buffer_is_host(tensor->buffer)) {
1878
1941
  // for the CPU and Metal backend, we can read directly into the tensor
1879
1942
  loader->read(loader->context, tensor->data, wsp_ggml_nbytes(tensor));
1880
1943
  BYTESWAP_TENSOR(tensor);
@@ -1887,7 +1950,6 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
1887
1950
  wsp_ggml_backend_tensor_set(tensor, read_buf.data(), 0, wsp_ggml_nbytes(tensor));
1888
1951
  }
1889
1952
 
1890
- //printf("%48s - [%5d, %5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ne[2], wsp_ggml_type_name((wsp_ggml_type) ttype), wsp_ggml_nbytes(tensor)/1e6);
1891
1953
  total_size += wsp_ggml_nbytes(tensor);
1892
1954
  model.n_loaded++;
1893
1955
  }
@@ -1902,7 +1964,9 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
1902
1964
  }
1903
1965
  }
1904
1966
 
1905
- wsp_ggml_backend_buffer_set_usage(model.buffer, WSP_GGML_BACKEND_BUFFER_USAGE_WEIGHTS);
1967
+ for (auto & buf : model.buffers) {
1968
+ wsp_ggml_backend_buffer_set_usage(buf, WSP_GGML_BACKEND_BUFFER_USAGE_WEIGHTS);
1969
+ }
1906
1970
 
1907
1971
  wctx.t_load_us = wsp_ggml_time_us() - t_start_us;
1908
1972
 
@@ -3164,7 +3228,7 @@ static bool log_mel_spectrogram(
3164
3228
  std::vector<std::thread> workers(n_threads - 1);
3165
3229
  for (int iw = 0; iw < n_threads - 1; ++iw) {
3166
3230
  workers[iw] = std::thread(
3167
- log_mel_spectrogram_worker_thread, iw + 1, hann, samples_padded,
3231
+ log_mel_spectrogram_worker_thread, iw + 1, hann, std::cref(samples_padded),
3168
3232
  n_samples + stage_2_pad, frame_size, frame_step, n_threads,
3169
3233
  std::cref(filters), std::ref(mel));
3170
3234
  }
@@ -3670,8 +3734,7 @@ struct whisper_context * whisper_init_with_params_no_state(struct whisper_model_
3670
3734
  WHISPER_LOG_INFO("%s: flash attn = %d\n", __func__, params.flash_attn);
3671
3735
  WHISPER_LOG_INFO("%s: gpu_device = %d\n", __func__, params.gpu_device);
3672
3736
  WHISPER_LOG_INFO("%s: dtw = %d\n", __func__, params.dtw_token_timestamps);
3673
-
3674
- // TODO: temporary call to force backend registry initialization
3737
+ WHISPER_LOG_INFO("%s: devices = %zu\n", __func__, wsp_ggml_backend_dev_count());
3675
3738
  WHISPER_LOG_INFO("%s: backends = %zu\n", __func__, wsp_ggml_backend_reg_count());
3676
3739
 
3677
3740
  whisper_context * ctx = new whisper_context;
@@ -3792,15 +3855,24 @@ void whisper_free_state(struct whisper_state * state) {
3792
3855
  // [EXPERIMENTAL] Token-level timestamps with DTW
3793
3856
  aheads_masks_free(state->aheads_masks);
3794
3857
 
3858
+ if (state->vad_context != nullptr) {
3859
+ whisper_vad_free(state->vad_context);
3860
+ state->vad_context = nullptr;
3861
+ }
3862
+
3795
3863
  delete state;
3796
3864
  }
3797
3865
  }
3798
3866
 
3799
3867
  void whisper_free(struct whisper_context * ctx) {
3800
3868
  if (ctx) {
3801
- wsp_ggml_free(ctx->model.ctx);
3869
+ for (wsp_ggml_context * context : ctx->model.ctxs) {
3870
+ wsp_ggml_free(context);
3871
+ }
3802
3872
 
3803
- wsp_ggml_backend_buffer_free(ctx->model.buffer);
3873
+ for (wsp_ggml_backend_buffer_t buf : ctx->model.buffers) {
3874
+ wsp_ggml_backend_buffer_free(buf);
3875
+ }
3804
3876
 
3805
3877
  whisper_free_state(ctx->state);
3806
3878
 
@@ -4194,47 +4266,37 @@ struct whisper_timings * whisper_get_timings(struct whisper_context * ctx) {
4194
4266
  if (ctx->state == nullptr) {
4195
4267
  return nullptr;
4196
4268
  }
4197
- return new whisper_timings {
4198
- .load_us = ctx->t_load_us,
4199
- .t_start_us = ctx->t_start_us,
4200
- .fail_p = ctx->state->n_fail_p,
4201
- .fail_h = ctx->state->n_fail_h,
4202
- .t_mel_us = ctx->state->t_mel_us,
4203
- .n_sample = ctx->state->n_sample,
4204
- .n_encode = ctx->state->n_encode,
4205
- .n_decode = ctx->state->n_decode,
4206
- .n_batchd = ctx->state->n_batchd,
4207
- .n_prompt = ctx->state->n_prompt,
4208
- .t_sample_us = ctx->state->t_sample_us,
4209
- .t_encode_us = ctx->state->t_encode_us,
4210
- .t_decode_us = ctx->state->t_decode_us,
4211
- .t_batchd_us = ctx->state->t_batchd_us,
4212
- .t_prompt_us = ctx->state->t_prompt_us,
4213
- };
4269
+ whisper_timings * timings = new whisper_timings;
4270
+ timings->sample_ms = 1e-3f * ctx->state->t_sample_us / std::max(1, ctx->state->n_sample);
4271
+ timings->encode_ms = 1e-3f * ctx->state->t_encode_us / std::max(1, ctx->state->n_encode);
4272
+ timings->decode_ms = 1e-3f * ctx->state->t_decode_us / std::max(1, ctx->state->n_decode);
4273
+ timings->batchd_ms = 1e-3f * ctx->state->t_batchd_us / std::max(1, ctx->state->n_batchd);
4274
+ timings->prompt_ms = 1e-3f * ctx->state->t_prompt_us / std::max(1, ctx->state->n_prompt);
4275
+ return timings;
4214
4276
  }
4215
4277
 
4216
4278
  void whisper_print_timings(struct whisper_context * ctx) {
4217
4279
  const int64_t t_end_us = wsp_ggml_time_us();
4218
- const struct whisper_timings * timings = whisper_get_timings(ctx);
4219
4280
 
4220
4281
  WHISPER_LOG_INFO("\n");
4221
- WHISPER_LOG_INFO("%s: load time = %8.2f ms\n", __func__, timings->load_us / 1000.0f);
4282
+ WHISPER_LOG_INFO("%s: load time = %8.2f ms\n", __func__, ctx->t_load_us / 1000.0f);
4222
4283
  if (ctx->state != nullptr) {
4284
+
4223
4285
  const int32_t n_sample = std::max(1, ctx->state->n_sample);
4224
4286
  const int32_t n_encode = std::max(1, ctx->state->n_encode);
4225
4287
  const int32_t n_decode = std::max(1, ctx->state->n_decode);
4226
4288
  const int32_t n_batchd = std::max(1, ctx->state->n_batchd);
4227
4289
  const int32_t n_prompt = std::max(1, ctx->state->n_prompt);
4228
4290
 
4229
- WHISPER_LOG_INFO("%s: fallbacks = %3d p / %3d h\n", __func__, timings->fail_p, timings->fail_h);
4230
- WHISPER_LOG_INFO("%s: mel time = %8.2f ms\n", __func__, timings->t_mel_us/1000.0f);
4231
- WHISPER_LOG_INFO("%s: sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * timings->t_sample_us, n_sample, 1e-3f * timings->t_sample_us / n_sample);
4232
- WHISPER_LOG_INFO("%s: encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * timings->t_encode_us, n_encode, 1e-3f * timings->t_encode_us / n_encode);
4233
- WHISPER_LOG_INFO("%s: decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * timings->t_decode_us, n_decode, 1e-3f * timings->t_decode_us / n_decode);
4234
- WHISPER_LOG_INFO("%s: batchd time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * timings->t_batchd_us, n_batchd, 1e-3f * timings->t_batchd_us / n_batchd);
4235
- WHISPER_LOG_INFO("%s: prompt time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * timings->t_prompt_us, n_prompt, 1e-3f * timings->t_prompt_us / n_prompt);
4291
+ WHISPER_LOG_INFO("%s: fallbacks = %3d p / %3d h\n", __func__, ctx->state->n_fail_p, ctx->state->n_fail_h);
4292
+ WHISPER_LOG_INFO("%s: mel time = %8.2f ms\n", __func__, ctx->state->t_mel_us / 1000.0f);
4293
+ WHISPER_LOG_INFO("%s: sample time = %8.2f ms / %5d runs ( %8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_sample_us, n_sample, 1e-3f * ctx->state->t_sample_us / n_sample);
4294
+ WHISPER_LOG_INFO("%s: encode time = %8.2f ms / %5d runs ( %8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_encode_us, n_encode, 1e-3f * ctx->state->t_encode_us / n_encode);
4295
+ WHISPER_LOG_INFO("%s: decode time = %8.2f ms / %5d runs ( %8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_decode_us, n_decode, 1e-3f * ctx->state->t_decode_us / n_decode);
4296
+ WHISPER_LOG_INFO("%s: batchd time = %8.2f ms / %5d runs ( %8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_batchd_us, n_batchd, 1e-3f * ctx->state->t_batchd_us / n_batchd);
4297
+ WHISPER_LOG_INFO("%s: prompt time = %8.2f ms / %5d runs ( %8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_prompt_us, n_prompt, 1e-3f * ctx->state->t_prompt_us / n_prompt);
4236
4298
  }
4237
- WHISPER_LOG_INFO("%s: total time = %8.2f ms\n", __func__, (t_end_us - timings->t_start_us)/1000.0f);
4299
+ WHISPER_LOG_INFO("%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f);
4238
4300
  }
4239
4301
 
4240
4302
  void whisper_reset_timings(struct whisper_context * ctx) {
@@ -4274,64 +4336,1186 @@ const char * whisper_print_system_info(void) {
4274
4336
  static std::string s;
4275
4337
 
4276
4338
  s = "";
4277
- s += "AVX = " + std::to_string(wsp_ggml_cpu_has_avx()) + " | ";
4278
- s += "AVX2 = " + std::to_string(wsp_ggml_cpu_has_avx2()) + " | ";
4279
- s += "AVX512 = " + std::to_string(wsp_ggml_cpu_has_avx512()) + " | ";
4280
- s += "FMA = " + std::to_string(wsp_ggml_cpu_has_fma()) + " | ";
4281
- s += "NEON = " + std::to_string(wsp_ggml_cpu_has_neon()) + " | ";
4282
- s += "ARM_FMA = " + std::to_string(wsp_ggml_cpu_has_arm_fma()) + " | ";
4283
- s += "METAL = " + std::to_string(wsp_ggml_cpu_has_metal()) + " | ";
4284
- s += "F16C = " + std::to_string(wsp_ggml_cpu_has_f16c()) + " | ";
4285
- s += "FP16_VA = " + std::to_string(wsp_ggml_cpu_has_fp16_va()) + " | ";
4286
- s += "WASM_SIMD = " + std::to_string(wsp_ggml_cpu_has_wasm_simd()) + " | ";
4287
- s += "BLAS = " + std::to_string(wsp_ggml_cpu_has_blas()) + " | ";
4288
- s += "SSE3 = " + std::to_string(wsp_ggml_cpu_has_sse3()) + " | ";
4289
- s += "SSSE3 = " + std::to_string(wsp_ggml_cpu_has_ssse3()) + " | ";
4290
- s += "VSX = " + std::to_string(wsp_ggml_cpu_has_vsx()) + " | ";
4291
- s += "CUDA = " + std::to_string(wsp_ggml_cpu_has_cuda()) + " | ";
4339
+ s += "WHISPER : ";
4292
4340
  s += "COREML = " + std::to_string(whisper_has_coreml()) + " | ";
4293
4341
  s += "OPENVINO = " + std::to_string(whisper_has_openvino()) + " | ";
4294
- s += "CANN = " + std::to_string(wsp_ggml_cpu_has_cann()) ;
4342
+
4343
+ for (size_t i = 0; i < wsp_ggml_backend_reg_count(); i++) {
4344
+ auto * reg = wsp_ggml_backend_reg_get(i);
4345
+ auto * get_features_fn = (wsp_ggml_backend_get_features_t) wsp_ggml_backend_reg_get_proc_address(reg, "wsp_ggml_backend_get_features");
4346
+ if (get_features_fn) {
4347
+ wsp_ggml_backend_feature * features = get_features_fn(reg);
4348
+ s += wsp_ggml_backend_reg_name(reg);
4349
+ s += " : ";
4350
+ for (; features->name; features++) {
4351
+ s += features->name;
4352
+ s += " = ";
4353
+ s += features->value;
4354
+ s += " | ";
4355
+ }
4356
+ }
4357
+ }
4295
4358
  return s.c_str();
4296
4359
  }
4297
4360
 
4298
4361
  //////////////////////////////////
4299
- // Grammar - ported from llama.cpp
4362
+ // Voice Activity Detection (VAD)
4300
4363
  //////////////////////////////////
4301
4364
 
4302
- // Decodes a UTF-8 string which may end in an incomplete sequence. Adds a terminating 0 for use as
4303
- // pointer. If an invalid sequence is encountered, returns `whisper_partial_utf8.n_remain == -1`.
4304
- static std::pair<std::vector<uint32_t>, whisper_partial_utf8> decode_utf8(
4305
- const char * src,
4306
- whisper_partial_utf8 partial_start) {
4307
- static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4 };
4308
- const char * pos = src;
4309
- std::vector<uint32_t> code_points;
4310
- uint32_t value = partial_start.value;
4311
- int n_remain = partial_start.n_remain;
4365
+ struct whisper_vad_hparams {
4366
+ int32_t n_encoder_layers;
4367
+ int32_t * encoder_in_channels;
4368
+ int32_t * encoder_out_channels;
4369
+ int32_t * kernel_sizes;
4370
+ int32_t lstm_input_size;
4371
+ int32_t lstm_hidden_size;
4372
+ int32_t final_conv_in;
4373
+ int32_t final_conv_out;
4374
+ };
4312
4375
 
4313
- // continue previous decode, if applicable
4314
- while (*pos != 0 && n_remain > 0) {
4315
- uint8_t next_byte = static_cast<uint8_t>(*pos);
4316
- if ((next_byte >> 6) != 2) {
4317
- // invalid sequence, abort
4318
- code_points.push_back(0);
4319
- return std::make_pair(std::move(code_points), whisper_partial_utf8{ 0, -1 });
4320
- }
4321
- value = (value << 6) + (next_byte & 0x3F);
4322
- ++pos;
4323
- --n_remain;
4324
- }
4376
+ struct whisper_vad_model {
4377
+ std::string type;
4378
+ std::string version;
4379
+ whisper_vad_hparams hparams;
4325
4380
 
4326
- if (partial_start.n_remain > 0 && n_remain == 0) {
4327
- code_points.push_back(value);
4328
- }
4381
+ struct wsp_ggml_tensor * stft_forward_basis; // [256, 1, 258]
4329
4382
 
4330
- // decode any subsequent utf-8 sequences, which may end in an incomplete one
4331
- while (*pos != 0) {
4332
- uint8_t first_byte = static_cast<uint8_t>(*pos);
4333
- uint8_t highbits = first_byte >> 4;
4334
- n_remain = lookup[highbits] - 1;
4383
+ // Encoder tensors - 4 convolutional layers
4384
+ struct wsp_ggml_tensor * encoder_0_weight; // [3, 129, 128]
4385
+ struct wsp_ggml_tensor * encoder_0_bias; // [128]
4386
+
4387
+ // Second encoder layer
4388
+ struct wsp_ggml_tensor * encoder_1_weight; // [3, 128, 64]
4389
+ struct wsp_ggml_tensor * encoder_1_bias; // [64]
4390
+
4391
+ // Third encoder layer
4392
+ struct wsp_ggml_tensor * encoder_2_weight; // [3, 64, 64]
4393
+ struct wsp_ggml_tensor * encoder_2_bias; // [64]
4394
+
4395
+ // Fourth encoder layer
4396
+ struct wsp_ggml_tensor * encoder_3_weight; // [3, 64, 128]
4397
+ struct wsp_ggml_tensor * encoder_3_bias; // [128]
4398
+
4399
+ // LSTM decoder tensors
4400
+ struct wsp_ggml_tensor * lstm_ih_weight; // [128, 512] input-to-hidden
4401
+ struct wsp_ggml_tensor * lstm_ih_bias; // [512]
4402
+ struct wsp_ggml_tensor * lstm_hh_weight; // [128, 512] hidden-to-hidden
4403
+ struct wsp_ggml_tensor * lstm_hh_bias; // [512]
4404
+
4405
+ // Final conv layer
4406
+ struct wsp_ggml_tensor * final_conv_weight; // [128]
4407
+ struct wsp_ggml_tensor * final_conv_bias; // [1]
4408
+
4409
+ // ggml contexts
4410
+ std::vector<wsp_ggml_context *> ctxs;
4411
+
4412
+ // buffer for the model tensors
4413
+ std::vector<wsp_ggml_backend_buffer_t> buffers;
4414
+
4415
+ // tensors
4416
+ int n_loaded;
4417
+ std::map<std::string, struct wsp_ggml_tensor *> tensors;
4418
+ };
4419
+
4420
+ struct whisper_vad_segment {
4421
+ int64_t start;
4422
+ int64_t end;
4423
+ };
4424
+
4425
+ struct whisper_vad_segments {
4426
+ std::vector<whisper_vad_segment> data;
4427
+ };
4428
+
4429
+ struct whisper_vad_context {
4430
+ int64_t t_vad_us = 0;
4431
+
4432
+ int n_window;
4433
+ int n_context;
4434
+ int n_threads;
4435
+
4436
+ std::vector<wsp_ggml_backend_t> backends;
4437
+ wsp_ggml_backend_buffer_t buffer = nullptr;
4438
+ whisper_context_params params;
4439
+ std::vector<uint8_t> ctx_buf;
4440
+ whisper_sched sched;
4441
+
4442
+ whisper_vad_model model;
4443
+ std::string path_model;
4444
+ struct wsp_ggml_tensor * h_state;
4445
+ struct wsp_ggml_tensor * c_state;
4446
+ std::vector<float> probs;
4447
+ };
4448
+
4449
+ struct whisper_vad_context_params whisper_vad_default_context_params(void) {
4450
+ whisper_vad_context_params result = {
4451
+ /*.n_thread = */ 4,
4452
+ /*.use_gpu = */ false,
4453
+ /*.gpu_device = */ 0,
4454
+ };
4455
+ return result;
4456
+ }
4457
+
4458
+ struct whisper_vad_params whisper_vad_default_params(void) {
4459
+ whisper_vad_params result = {
4460
+ /* threshold = */ 0.5f,
4461
+ /* min_speech_duration_ms = */ 250,
4462
+ /* min_silence_duration_ms = */ 100,
4463
+ /* max_speech_duration_s = */ FLT_MAX,
4464
+ /* speech_pad_ms = */ 30,
4465
+ /* samples_overlap = */ 0.1,
4466
+ };
4467
+ return result;
4468
+ }
4469
+
4470
+ // Time conversion utility functions for whisper VAD
4471
+ static int cs_to_samples(int64_t cs) {
4472
+ return (int)((cs / 100.0) * WHISPER_SAMPLE_RATE + 0.5);
4473
+ }
4474
+
4475
+ static int64_t samples_to_cs(int samples) {
4476
+ return (int64_t)((samples / (double)WHISPER_SAMPLE_RATE) * 100.0 + 0.5);
4477
+ }
4478
+
4479
+ static bool weight_buft_supported(const whisper_vad_hparams & hparams, wsp_ggml_tensor * w, wsp_ggml_op op, wsp_ggml_backend_buffer_type_t buft, wsp_ggml_backend_dev_t dev) {
4480
+ bool op_supported = true;
4481
+
4482
+ if (wsp_ggml_backend_dev_type(dev) == WSP_GGML_BACKEND_DEVICE_TYPE_GPU ||
4483
+ (wsp_ggml_backend_dev_type(dev) == WSP_GGML_BACKEND_DEVICE_TYPE_CPU && buft == wsp_ggml_backend_cpu_buffer_type())) {
4484
+ // GPU and default CPU backend support all operators
4485
+ op_supported = true;
4486
+ } else {
4487
+ switch (op) {
4488
+ // The current extra_buffer_type implementations only support WSP_GGML_OP_MUL_MAT
4489
+ case WSP_GGML_OP_MUL_MAT: {
4490
+ wsp_ggml_init_params params = {
4491
+ /*.mem_size =*/ 2 * wsp_ggml_tensor_overhead(),
4492
+ /*.mem_buffer =*/ nullptr,
4493
+ /*.no_alloc =*/ true,
4494
+ };
4495
+
4496
+ wsp_ggml_context_ptr ctx_ptr { wsp_ggml_init(params) };
4497
+ if (!ctx_ptr) {
4498
+ throw std::runtime_error("failed to create ggml context");
4499
+ }
4500
+ wsp_ggml_context * ctx = ctx_ptr.get();
4501
+
4502
+ wsp_ggml_tensor * op_tensor = nullptr;
4503
+
4504
+ int64_t n_ctx = hparams.lstm_hidden_size;
4505
+ wsp_ggml_tensor * b = wsp_ggml_new_tensor_4d(ctx, WSP_GGML_TYPE_F32, w->ne[0], n_ctx, w->ne[2], w->ne[3]);
4506
+ op_tensor = wsp_ggml_mul_mat(ctx, w, b);
4507
+
4508
+ // create a temporary dummy buffer for the weight so that supports_op can check the buffer type
4509
+ WSP_GGML_ASSERT(w->buffer == nullptr);
4510
+ w->buffer = wsp_ggml_backend_buft_alloc_buffer(buft, 0);
4511
+ op_supported = wsp_ggml_backend_dev_supports_op(dev, op_tensor);
4512
+ wsp_ggml_backend_buffer_free(w->buffer);
4513
+ w->buffer = nullptr;
4514
+ break;
4515
+ }
4516
+ default: {
4517
+ op_supported = false;
4518
+ break;
4519
+ }
4520
+ };
4521
+ }
4522
+ return op_supported;
4523
+ }
4524
+
4525
+ static wsp_ggml_backend_buffer_type_t select_weight_buft(const whisper_vad_hparams & hparams, wsp_ggml_tensor * w, wsp_ggml_op op, buft_list_t buft_list) {
4526
+ WSP_GGML_ASSERT(!buft_list.empty());
4527
+ for (const auto & p : buft_list) {
4528
+ wsp_ggml_backend_dev_t dev = p.first;
4529
+ wsp_ggml_backend_buffer_type_t buft = p.second;
4530
+ if (weight_buft_supported(hparams, w, op, buft, dev)) {
4531
+ return buft;
4532
+ }
4533
+ }
4534
+
4535
+ return nullptr;
4536
+ }
4537
+
4538
+ static wsp_ggml_tensor * whisper_vad_build_stft_layer(wsp_ggml_context * ctx0,
4539
+ const whisper_vad_model & model, wsp_ggml_tensor * cur) {
4540
+ // Apply reflective padding to the input tensor
4541
+ wsp_ggml_tensor * padded = wsp_ggml_pad_reflect_1d(ctx0, cur, 64, 64);
4542
+
4543
+ struct wsp_ggml_tensor * stft = wsp_ggml_conv_1d(ctx0, model.stft_forward_basis, padded, model.hparams.lstm_input_size, 0, 1);
4544
+
4545
+ // Calculate cutoff for real/imaginary parts
4546
+ int cutoff = model.stft_forward_basis->ne[2] / 2;
4547
+
4548
+ // Extract real part (first half of the STFT output).
4549
+ struct wsp_ggml_tensor * real_part = wsp_ggml_view_2d(ctx0, stft, 4, cutoff, stft->nb[1], 0);
4550
+ // Extract imaginary part (second half of the STFT output).
4551
+ struct wsp_ggml_tensor * img_part = wsp_ggml_view_2d(ctx0, stft, 4, cutoff, stft->nb[1], cutoff * stft->nb[1]);
4552
+
4553
+ // Calculate magnitude: sqrt(real^2 + imag^2)
4554
+ struct wsp_ggml_tensor * real_squared = wsp_ggml_mul(ctx0, real_part, real_part);
4555
+ struct wsp_ggml_tensor * img_squared = wsp_ggml_mul(ctx0, img_part, img_part);
4556
+ struct wsp_ggml_tensor * sum_squares = wsp_ggml_add(ctx0, real_squared, img_squared);
4557
+ struct wsp_ggml_tensor * magnitude = wsp_ggml_sqrt(ctx0, sum_squares);
4558
+ return magnitude;
4559
+ }
4560
+
4561
+ static wsp_ggml_tensor * whisper_vad_build_encoder_layer(wsp_ggml_context * ctx0,
4562
+ const whisper_vad_model & model, wsp_ggml_tensor * cur) {
4563
+ // First Conv1D: expands to 128 channels.
4564
+ cur = wsp_ggml_conv_1d(ctx0, model.encoder_0_weight, cur, 1, 1, 1);
4565
+ cur = wsp_ggml_add(ctx0, cur, wsp_ggml_reshape_3d(ctx0, model.encoder_0_bias, 1, 128, 1));
4566
+ cur = wsp_ggml_relu(ctx0, cur);
4567
+
4568
+ // Second Conv1D: reduces to 64 channels.
4569
+ cur = wsp_ggml_conv_1d(ctx0, model.encoder_1_weight, cur, 2, 1, 1);
4570
+ cur = wsp_ggml_add(ctx0, cur, wsp_ggml_reshape_3d(ctx0, model.encoder_1_bias, 1, 64, 1));
4571
+ cur = wsp_ggml_relu(ctx0, cur);
4572
+
4573
+ // Third Conv1D: maintains 64 channels
4574
+ cur = wsp_ggml_conv_1d(ctx0, model.encoder_2_weight, cur, 2, 1, 1);
4575
+ cur = wsp_ggml_add(ctx0, cur, wsp_ggml_reshape_3d(ctx0, model.encoder_2_bias, 1, 64, 1));
4576
+ cur = wsp_ggml_relu(ctx0, cur);
4577
+
4578
+ // Fourth Conv1D: expands to 128 channels
4579
+ cur = wsp_ggml_conv_1d(ctx0, model.encoder_3_weight, cur, 1, 1, 1);
4580
+ cur = wsp_ggml_add(ctx0, cur, wsp_ggml_reshape_3d(ctx0, model.encoder_3_bias, 1, 128, 1));
4581
+ cur = wsp_ggml_relu(ctx0, cur);
4582
+
4583
+ return cur;
4584
+ }
4585
+
4586
+ static wsp_ggml_tensor * whisper_vad_build_lstm_layer(wsp_ggml_context * ctx0,
4587
+ const whisper_vad_context & vctx, wsp_ggml_tensor * cur, wsp_ggml_cgraph * gf) {
4588
+ const whisper_vad_model & model = vctx.model;
4589
+ const int hdim = model.hparams.lstm_hidden_size;
4590
+
4591
+ struct wsp_ggml_tensor * x_t = wsp_ggml_transpose(ctx0, cur);
4592
+
4593
+ // Create operations using the input-to-hidden weights.
4594
+ struct wsp_ggml_tensor * inp_gate = wsp_ggml_mul_mat(ctx0, model.lstm_ih_weight, x_t);
4595
+ inp_gate = wsp_ggml_add(ctx0, inp_gate, model.lstm_ih_bias);
4596
+
4597
+ // Create operations using the hidden-to-hidden weights.
4598
+ struct wsp_ggml_tensor * hid_gate = wsp_ggml_mul_mat(ctx0, model.lstm_hh_weight, vctx.h_state);
4599
+ hid_gate = wsp_ggml_add(ctx0, hid_gate, model.lstm_hh_bias);
4600
+
4601
+ // Create add operation to get preactivations for all gates.
4602
+ struct wsp_ggml_tensor * out_gate = wsp_ggml_add(ctx0, inp_gate, hid_gate);
4603
+
4604
+ const size_t hdim_size = wsp_ggml_row_size(out_gate->type, hdim);
4605
+
4606
+ // Create sigmoid for input gate (using the first 128 bytes from the preactivations).
4607
+ struct wsp_ggml_tensor * i_t = wsp_ggml_sigmoid(ctx0, wsp_ggml_view_1d(ctx0, out_gate, hdim, 0 * hdim_size));
4608
+
4609
+ // Create sigmoid for the forget gate (using the second 128 bytes from the preactivations).
4610
+ struct wsp_ggml_tensor * f_t = wsp_ggml_sigmoid(ctx0, wsp_ggml_view_1d(ctx0, out_gate, hdim, 1 * hdim_size));
4611
+
4612
+ // Create sigmoid for the cell gate (using the third 128 bytes from the preactivations).
4613
+ struct wsp_ggml_tensor * g_t = wsp_ggml_tanh(ctx0, wsp_ggml_view_1d(ctx0, out_gate, hdim, 2 * hdim_size));
4614
+
4615
+ // Create sigmoid for the output gate (using the fourth 128 bytes from the preactivations).
4616
+ struct wsp_ggml_tensor * o_t = wsp_ggml_sigmoid(ctx0, wsp_ggml_view_1d(ctx0, out_gate, hdim, 3 * hdim_size));
4617
+
4618
+ // Update cell state
4619
+ struct wsp_ggml_tensor * c_out = wsp_ggml_add(ctx0,
4620
+ wsp_ggml_mul(ctx0, f_t, vctx.c_state),
4621
+ wsp_ggml_mul(ctx0, i_t, g_t));
4622
+ wsp_ggml_build_forward_expand(gf, wsp_ggml_cpy(ctx0, c_out, vctx.c_state));
4623
+
4624
+ // Update hidden state
4625
+ struct wsp_ggml_tensor * out = wsp_ggml_mul(ctx0, o_t, wsp_ggml_tanh(ctx0, c_out));
4626
+ wsp_ggml_build_forward_expand(gf, wsp_ggml_cpy(ctx0, out, vctx.h_state));
4627
+
4628
+ return out;
4629
+ }
4630
+
4631
+ static struct wsp_ggml_cgraph * whisper_vad_build_graph(whisper_vad_context & vctx) {
4632
+ const auto & model = vctx.model;
4633
+
4634
+ struct wsp_ggml_init_params params = {
4635
+ /*.mem_size =*/ vctx.sched.meta.size(),
4636
+ /*.mem_buffer =*/ vctx.sched.meta.data(),
4637
+ /*.no_alloc =*/ true,
4638
+ };
4639
+
4640
+ struct wsp_ggml_context * ctx0 = wsp_ggml_init(params);
4641
+
4642
+ wsp_ggml_cgraph * gf = wsp_ggml_new_graph(ctx0);
4643
+
4644
+ struct wsp_ggml_tensor * frame = wsp_ggml_new_tensor_2d(ctx0, WSP_GGML_TYPE_F32, vctx.n_window, 1);
4645
+ wsp_ggml_set_name(frame, "frame");
4646
+ wsp_ggml_set_input(frame);
4647
+
4648
+ struct wsp_ggml_tensor * cur = nullptr;
4649
+ {
4650
+ cur = whisper_vad_build_stft_layer(ctx0, model, frame);
4651
+
4652
+ cur = whisper_vad_build_encoder_layer(ctx0, model, cur);
4653
+
4654
+ // Extract the first element of the first dimension
4655
+ // (equivalent to pytorch's [:, :, 0])
4656
+ cur = wsp_ggml_view_2d(ctx0, cur, 1, 128, cur->nb[1], 0);
4657
+
4658
+ cur = whisper_vad_build_lstm_layer(ctx0, vctx, cur, gf);
4659
+ cur = wsp_ggml_relu(ctx0, cur);
4660
+ cur = wsp_ggml_conv_1d(ctx0, model.final_conv_weight, cur, 1, 0, 1);
4661
+ cur = wsp_ggml_add(ctx0, cur, model.final_conv_bias);
4662
+ cur = wsp_ggml_sigmoid(ctx0, cur);
4663
+ wsp_ggml_set_name(cur, "prob");
4664
+ wsp_ggml_set_output(cur);
4665
+ }
4666
+
4667
+ wsp_ggml_build_forward_expand(gf, cur);
4668
+
4669
+ wsp_ggml_free(ctx0);
4670
+
4671
+ return gf;
4672
+ }
4673
+
4674
+ static bool whisper_vad_init_context(whisper_vad_context * vctx) {
4675
+
4676
+ auto whisper_context_params = whisper_context_default_params();
4677
+ // TODO: GPU VAD is forced disabled until the performance is improved
4678
+ //whisper_context_params.use_gpu = vctx->params.use_gpu;
4679
+ whisper_context_params.use_gpu = false;
4680
+ whisper_context_params.gpu_device = vctx->params.gpu_device;
4681
+
4682
+ vctx->backends = whisper_backend_init(whisper_context_params);
4683
+ if (vctx->backends.empty()) {
4684
+ WHISPER_LOG_ERROR("%s: whisper_backend_init() failed\n", __func__);
4685
+ return false;
4686
+ }
4687
+
4688
+ const int32_t lstm_hidden_size = vctx->model.hparams.lstm_hidden_size;
4689
+
4690
+ vctx->ctx_buf.resize(2u*wsp_ggml_tensor_overhead());
4691
+
4692
+ struct wsp_ggml_init_params params = {
4693
+ /*.mem_size =*/ vctx->ctx_buf.size(),
4694
+ /*.mem_buffer =*/ vctx->ctx_buf.data(),
4695
+ /*.no_alloc =*/ true,
4696
+ };
4697
+
4698
+ wsp_ggml_context * ctx = wsp_ggml_init(params);
4699
+ if (!ctx) {
4700
+ WHISPER_LOG_ERROR("%s: failed to init LSTM state ggml context\n", __func__);
4701
+ return false;
4702
+ }
4703
+
4704
+ // LSTM Hidden state
4705
+ vctx->h_state = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, lstm_hidden_size);
4706
+ wsp_ggml_set_name(vctx->h_state, "h_state");
4707
+
4708
+ // LSTM Cell state
4709
+ vctx->c_state = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, lstm_hidden_size);
4710
+ wsp_ggml_set_name(vctx->c_state, "c_state");
4711
+
4712
+ vctx->buffer = wsp_ggml_backend_alloc_ctx_tensors(ctx, vctx->backends[0]);
4713
+ if (!vctx->buffer) {
4714
+ WHISPER_LOG_ERROR("%s: failed to allocate memory for the VAD state\n", __func__);
4715
+ return false;
4716
+ }
4717
+
4718
+ {
4719
+ bool ok = whisper_sched_graph_init(vctx->sched, vctx->backends,
4720
+ [&]() {
4721
+ return whisper_vad_build_graph(*vctx);
4722
+ });
4723
+
4724
+ if (!ok) {
4725
+ WHISPER_LOG_ERROR("%s: failed to init VAD allocator\n", __func__);
4726
+ return false;
4727
+ }
4728
+
4729
+ WHISPER_LOG_INFO("%s: compute buffer (VAD) = %7.2f MB\n", __func__, whisper_sched_size(vctx->sched) / 1e6);
4730
+ }
4731
+
4732
+ return true;
4733
+ }
4734
+
4735
+ struct whisper_vad_context * whisper_vad_init_from_file_with_params(
4736
+ const char * path_model,
4737
+ struct whisper_vad_context_params params) {
4738
+ WHISPER_LOG_INFO("%s: loading VAD model from '%s'\n", __func__, path_model);
4739
+ #ifdef _MSC_VER
4740
+ std::wstring_convert<std::codecvt_utf8<wchar_t>> converter;
4741
+ std::wstring path_model_wide = converter.from_bytes(path_model);
4742
+ auto fin = std::ifstream(path_model_wide, std::ios::binary);
4743
+ #else
4744
+ auto fin = std::ifstream(path_model, std::ios::binary);
4745
+ #endif
4746
+ if (!fin) {
4747
+ WHISPER_LOG_ERROR("%s: failed to open VAD model '%s'\n", __func__, path_model);
4748
+ return nullptr;
4749
+ }
4750
+
4751
+ whisper_model_loader loader = {};
4752
+ loader.context = &fin;
4753
+
4754
+ loader.read = [](void * ctx, void * output, size_t read_size) {
4755
+ std::ifstream * fin = (std::ifstream*)ctx;
4756
+ fin->read((char *)output, read_size);
4757
+ return read_size;
4758
+ };
4759
+
4760
+ loader.eof = [](void * ctx) {
4761
+ std::ifstream * fin = (std::ifstream*)ctx;
4762
+ return fin->eof();
4763
+ };
4764
+
4765
+ loader.close = [](void * ctx) {
4766
+ std::ifstream * fin = (std::ifstream*)ctx;
4767
+ fin->close();
4768
+ };
4769
+
4770
+ auto ctx = whisper_vad_init_with_params(&loader, params);
4771
+ if (!ctx) {
4772
+ whisper_vad_free(ctx);
4773
+ return nullptr;
4774
+ }
4775
+ ctx->path_model = path_model;
4776
+ return ctx;
4777
+ }
4778
+
4779
+ struct whisper_vad_context * whisper_vad_init_with_params(
4780
+ struct whisper_model_loader * loader,
4781
+ struct whisper_vad_context_params params) {
4782
+ // Read the VAD model
4783
+ {
4784
+ uint32_t magic;
4785
+ read_safe(loader, magic);
4786
+ if (magic != WSP_GGML_FILE_MAGIC) {
4787
+ WHISPER_LOG_ERROR("%s: invalid model data (bad magic)\n", __func__);
4788
+ return nullptr;
4789
+ }
4790
+ }
4791
+
4792
+ whisper_vad_context * vctx = new whisper_vad_context;
4793
+ vctx->n_threads = params.n_threads;
4794
+ vctx->params.use_gpu = params.use_gpu;
4795
+ vctx->params.gpu_device = params.gpu_device;
4796
+
4797
+ auto & model = vctx->model;
4798
+ auto & hparams = model.hparams;
4799
+
4800
+ // load model context params.
4801
+ {
4802
+ int32_t str_len;
4803
+ read_safe(loader, str_len);
4804
+ std::vector<char> buffer(str_len + 1, 0);
4805
+ loader->read(loader->context, buffer.data(), str_len);
4806
+ std::string model_type(buffer.data(), str_len);
4807
+ model.type = model_type;
4808
+ WHISPER_LOG_INFO("%s: model type: %s\n", __func__, model.type.c_str());
4809
+
4810
+ int32_t major, minor, patch;
4811
+ read_safe(loader, major);
4812
+ read_safe(loader, minor);
4813
+ read_safe(loader, patch);
4814
+ std::string version_str = std::to_string(major) + "." +
4815
+ std::to_string(minor) + "." +
4816
+ std::to_string(patch);
4817
+ model.version = version_str;
4818
+ WHISPER_LOG_INFO("%s: model version: %s\n", __func__, model.version.c_str());
4819
+
4820
+ read_safe(loader, vctx->n_window);
4821
+ read_safe(loader, vctx->n_context);
4822
+ }
4823
+
4824
+ // load model hyper params (hparams).
4825
+ {
4826
+ read_safe(loader, hparams.n_encoder_layers);
4827
+
4828
+ hparams.encoder_in_channels = new int32_t[hparams.n_encoder_layers];
4829
+ hparams.encoder_out_channels = new int32_t[hparams.n_encoder_layers];
4830
+ hparams.kernel_sizes = new int32_t[hparams.n_encoder_layers];
4831
+
4832
+ for (int32_t i = 0; i < hparams.n_encoder_layers; i++) {
4833
+ read_safe(loader, hparams.encoder_in_channels[i]);
4834
+ read_safe(loader, hparams.encoder_out_channels[i]);
4835
+ read_safe(loader, hparams.kernel_sizes[i]);
4836
+ }
4837
+
4838
+ read_safe(loader, hparams.lstm_input_size);
4839
+ read_safe(loader, hparams.lstm_hidden_size);
4840
+ read_safe(loader, hparams.final_conv_in);
4841
+ read_safe(loader, hparams.final_conv_out);
4842
+
4843
+ WHISPER_LOG_INFO("%s: n_encoder_layers = %d\n", __func__, hparams.n_encoder_layers);
4844
+ for (int32_t i = 0; i < hparams.n_encoder_layers; i++) {
4845
+ WHISPER_LOG_INFO("%s: encoder_in_channels[%d] = %d\n", __func__, i, hparams.encoder_in_channels[i]);
4846
+ }
4847
+ for (int32_t i = 0; i < hparams.n_encoder_layers; i++) {
4848
+ WHISPER_LOG_INFO("%s: encoder_out_channels[%d] = %d\n", __func__, i, hparams.encoder_out_channels[i]);
4849
+ }
4850
+ WHISPER_LOG_INFO("%s: lstm_input_size = %d\n", __func__, hparams.lstm_input_size);
4851
+ WHISPER_LOG_INFO("%s: lstm_hidden_size = %d\n", __func__, hparams.lstm_hidden_size);
4852
+ WHISPER_LOG_INFO("%s: final_conv_in = %d\n", __func__, hparams.final_conv_in);
4853
+ WHISPER_LOG_INFO("%s: final_conv_out = %d\n", __func__, hparams.final_conv_out);
4854
+ }
4855
+
4856
+ // 1 STFT tensor, 4*2 encoder tensors, 4 LSTM tensors, 2 final output tensors
4857
+ const size_t n_tensors = hparams.n_encoder_layers * 2 + 4 + 2 + 1;
4858
+
4859
+ std::map<wsp_ggml_backend_buffer_type_t, wsp_ggml_context *> ctx_map;
4860
+ auto get_ctx = [&](wsp_ggml_backend_buffer_type_t buft) -> wsp_ggml_context * {
4861
+ auto it = ctx_map.find(buft);
4862
+ if (it == ctx_map.end()) {
4863
+ wsp_ggml_init_params params = {
4864
+ /*.mem_size =*/ n_tensors * wsp_ggml_tensor_overhead(),
4865
+ /*.mem_buffer =*/ nullptr,
4866
+ /*.no_alloc =*/ true,
4867
+ };
4868
+
4869
+ wsp_ggml_context * ctx = wsp_ggml_init(params);
4870
+ if (!ctx) {
4871
+ throw std::runtime_error("failed to create ggml context");
4872
+ }
4873
+
4874
+ ctx_map[buft] = ctx;
4875
+ model.ctxs.emplace_back(ctx);
4876
+
4877
+ return ctx;
4878
+ }
4879
+
4880
+ return it->second;
4881
+ };
4882
+
4883
+ whisper_context_params wparams = whisper_context_default_params();
4884
+ wparams.use_gpu = params.use_gpu;
4885
+ wparams.gpu_device = params.gpu_device;
4886
+ buft_list_t buft_list = make_buft_list(wparams);
4887
+
4888
+ auto create_tensor = [&](vad_tensor type, wsp_ggml_tensor * meta) -> wsp_ggml_tensor * {
4889
+ wsp_ggml_op op = VAD_TENSOR_OPS.at(type);
4890
+ wsp_ggml_backend_buffer_type_t buft = select_weight_buft(hparams, meta, op, buft_list);
4891
+ if (!buft) {
4892
+ throw std::runtime_error(format("failed to find a compatible buffer type for tensor %s", VAD_TENSOR_NAMES.at(type)));
4893
+ }
4894
+ wsp_ggml_context * ctx = get_ctx(buft);
4895
+ wsp_ggml_tensor * tensor = wsp_ggml_dup_tensor(ctx, meta);
4896
+ model.tensors[VAD_TENSOR_NAMES.at(type)] = tensor;
4897
+
4898
+ return tensor;
4899
+ };
4900
+
4901
+ // create tensors
4902
+ {
4903
+ wsp_ggml_init_params params = {
4904
+ /*.mem_size =*/ n_tensors * wsp_ggml_tensor_overhead(),
4905
+ /*.mem_buffer =*/ nullptr,
4906
+ /*.no_alloc =*/ true,
4907
+ };
4908
+
4909
+ wsp_ggml_context * ctx = wsp_ggml_init(params);
4910
+ const auto & hparams = model.hparams;
4911
+
4912
+ // SFTF precomputed basis matrix
4913
+ model.stft_forward_basis = create_tensor(VAD_TENSOR_STFT_BASIS,
4914
+ wsp_ggml_new_tensor_3d(ctx, WSP_GGML_TYPE_F16, 256, 1, 258));
4915
+
4916
+ model.encoder_0_weight = create_tensor(VAD_TENSOR_ENC_0_WEIGHT,
4917
+ wsp_ggml_new_tensor_3d(
4918
+ ctx,
4919
+ WSP_GGML_TYPE_F16,
4920
+ hparams.kernel_sizes[0],
4921
+ hparams.encoder_in_channels[0],
4922
+ hparams.encoder_out_channels[0]
4923
+ ));
4924
+ model.encoder_0_bias = create_tensor(VAD_TENSOR_ENC_0_BIAS,
4925
+ wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, hparams.encoder_out_channels[0]));
4926
+
4927
+ model.encoder_1_weight = create_tensor(VAD_TENSOR_ENC_1_WEIGHT,
4928
+ wsp_ggml_new_tensor_3d(
4929
+ ctx,
4930
+ WSP_GGML_TYPE_F16,
4931
+ hparams.kernel_sizes[1],
4932
+ hparams.encoder_in_channels[1],
4933
+ hparams.encoder_out_channels[1]
4934
+ ));
4935
+ model.encoder_1_bias = create_tensor(VAD_TENSOR_ENC_1_BIAS,
4936
+ wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, hparams.encoder_out_channels[1]));
4937
+
4938
+ model.encoder_2_weight = create_tensor(VAD_TENSOR_ENC_2_WEIGHT,
4939
+ wsp_ggml_new_tensor_3d(
4940
+ ctx,
4941
+ WSP_GGML_TYPE_F16,
4942
+ hparams.kernel_sizes[2],
4943
+ hparams.encoder_in_channels[2],
4944
+ hparams.encoder_out_channels[2]
4945
+ ));
4946
+ model.encoder_2_bias = create_tensor(VAD_TENSOR_ENC_2_BIAS,
4947
+ wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, hparams.encoder_out_channels[2]));
4948
+
4949
+ model.encoder_3_weight = create_tensor(VAD_TENSOR_ENC_3_WEIGHT,
4950
+ wsp_ggml_new_tensor_3d(
4951
+ ctx,
4952
+ WSP_GGML_TYPE_F16,
4953
+ hparams.kernel_sizes[3],
4954
+ hparams.encoder_in_channels[3],
4955
+ hparams.encoder_out_channels[3]
4956
+ ));
4957
+ model.encoder_3_bias = create_tensor(VAD_TENSOR_ENC_3_BIAS,
4958
+ wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, hparams.encoder_out_channels[3]));
4959
+
4960
+ // Hidden State dimension (input gate, forget gate, cell gate, output gate)
4961
+ const int hstate_dim = hparams.lstm_hidden_size * 4;
4962
+
4963
+ // LSTM weights - input to hidden
4964
+ model.lstm_ih_weight = create_tensor(
4965
+ VAD_TENSOR_LSTM_WEIGHT_IH,
4966
+ wsp_ggml_new_tensor_2d(ctx, WSP_GGML_TYPE_F32, hparams.lstm_hidden_size, hstate_dim)
4967
+ );
4968
+ model.lstm_ih_bias = create_tensor(
4969
+ VAD_TENSOR_LSTM_BIAS_IH,
4970
+ wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, hstate_dim)
4971
+ );
4972
+
4973
+ // LSTM weights - hidden to hidden
4974
+ model.lstm_hh_weight = create_tensor(
4975
+ VAD_TENSOR_LSTM_WEIGHT_HH,
4976
+ wsp_ggml_new_tensor_2d(ctx, WSP_GGML_TYPE_F32, hparams.lstm_hidden_size, hstate_dim)
4977
+ );
4978
+ model.lstm_hh_bias = create_tensor(
4979
+ VAD_TENSOR_LSTM_BIAS_HH,
4980
+ wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, hstate_dim)
4981
+ );
4982
+
4983
+ // Final conv layer weight
4984
+ model.final_conv_weight = create_tensor(
4985
+ VAD_TENSOR_FINAL_CONV_WEIGHT,
4986
+ wsp_ggml_new_tensor_2d(ctx, WSP_GGML_TYPE_F16, hparams.final_conv_in, 1)
4987
+ );
4988
+ model.final_conv_bias = create_tensor(
4989
+ VAD_TENSOR_FINAL_CONV_BIAS,
4990
+ wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, 1)
4991
+ );
4992
+
4993
+ wsp_ggml_free(ctx);
4994
+ }
4995
+
4996
+ // allocate tensors in the backend buffers
4997
+ for (auto & p : ctx_map) {
4998
+ wsp_ggml_backend_buffer_type_t buft = p.first;
4999
+ wsp_ggml_context * ctx = p.second;
5000
+ wsp_ggml_backend_buffer_t buf = wsp_ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
5001
+ if (buf) {
5002
+ model.buffers.emplace_back(buf);
5003
+
5004
+ size_t size_main = wsp_ggml_backend_buffer_get_size(buf);
5005
+ WHISPER_LOG_INFO("%s: %12s total size = %8.2f MB\n", __func__, wsp_ggml_backend_buffer_name(buf), size_main / 1e6);
5006
+ }
5007
+ }
5008
+
5009
+ // load weights
5010
+ {
5011
+ size_t total_size = 0;
5012
+ model.n_loaded = 0;
5013
+ std::vector<char> read_buf;
5014
+
5015
+ while (true) {
5016
+ int32_t n_dims;
5017
+ int32_t length;
5018
+ int32_t ttype;
5019
+
5020
+ read_safe(loader, n_dims);
5021
+ read_safe(loader, length);
5022
+ read_safe(loader, ttype);
5023
+
5024
+ if (loader->eof(loader->context)) {
5025
+ break;
5026
+ }
5027
+
5028
+ int32_t nelements = 1;
5029
+ int32_t ne[4] = { 1, 1, 1, 1 };
5030
+ for (int i = 0; i < n_dims; ++i) {
5031
+ read_safe(loader, ne[i]);
5032
+ nelements *= ne[i];
5033
+ }
5034
+
5035
+ std::string name;
5036
+ std::vector<char> tmp(length);
5037
+ loader->read(loader->context, &tmp[0], tmp.size());
5038
+ name.assign(&tmp[0], tmp.size());
5039
+
5040
+ if (model.tensors.find(name) == model.tensors.end()) {
5041
+ WHISPER_LOG_ERROR("%s: unknown tensor '%s' in model file\n", __func__, name.data());
5042
+ return nullptr;
5043
+ }
5044
+
5045
+ auto tensor = model.tensors[name.data()];
5046
+
5047
+ if (wsp_ggml_nelements(tensor) != nelements) {
5048
+ WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
5049
+ WHISPER_LOG_ERROR("%s: shape: [%d, %d, %d], expected: [%d, %d, %d]\n",
5050
+ __func__, ne[0], ne[1], ne[2], (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2]);
5051
+ return nullptr;
5052
+ }
5053
+
5054
+ if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1] || tensor->ne[2] != ne[2]) {
5055
+ WHISPER_LOG_ERROR("%s: tensor '%s' has wrong shape in model file: got [%d, %d, %d], expected [%d, %d, %d]\n",
5056
+ __func__, name.data(), (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2], ne[0], ne[1], ne[2]);
5057
+ return nullptr;
5058
+ }
5059
+
5060
+ const size_t bpe = wsp_ggml_type_size(wsp_ggml_type(ttype));
5061
+
5062
+ if ((nelements*bpe)/wsp_ggml_blck_size(tensor->type) != wsp_ggml_nbytes(tensor)) {
5063
+ WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
5064
+ __func__, name.data(), wsp_ggml_nbytes(tensor), nelements*bpe);
5065
+ return nullptr;
5066
+ }
5067
+
5068
+ if (wsp_ggml_backend_buffer_is_host(tensor->buffer)) {
5069
+ // for the CPU and Metal backend, we can read directly into the tensor
5070
+ loader->read(loader->context, tensor->data, wsp_ggml_nbytes(tensor));
5071
+ BYTESWAP_TENSOR(tensor);
5072
+ } else {
5073
+ // read into a temporary buffer first, then copy to device memory
5074
+ read_buf.resize(wsp_ggml_nbytes(tensor));
5075
+
5076
+ loader->read(loader->context, read_buf.data(), read_buf.size());
5077
+
5078
+ wsp_ggml_backend_tensor_set(tensor, read_buf.data(), 0, wsp_ggml_nbytes(tensor));
5079
+ }
5080
+
5081
+ total_size += wsp_ggml_nbytes(tensor);
5082
+ model.n_loaded++;
5083
+ }
5084
+
5085
+ WHISPER_LOG_INFO("%s: model size = %7.2f MB\n", __func__, total_size/1e6);
5086
+
5087
+ if (model.n_loaded == 0) {
5088
+ WHISPER_LOG_WARN("%s: WARN no tensors loaded from model file - assuming empty model for testing\n", __func__);
5089
+ } else if (model.n_loaded != (int) model.tensors.size()) {
5090
+ WHISPER_LOG_ERROR("%s: ERROR not all tensors loaded from model file - expected %zu, got %d\n", __func__, model.tensors.size(), model.n_loaded);
5091
+ return nullptr;
5092
+ }
5093
+
5094
+ }
5095
+
5096
+ if (!whisper_vad_init_context(vctx)) {
5097
+ whisper_vad_free(vctx);
5098
+ return nullptr;
5099
+ }
5100
+
5101
+ return vctx;
5102
+ }
5103
+
5104
+ bool whisper_vad_detect_speech(
5105
+ struct whisper_vad_context * vctx,
5106
+ const float * samples,
5107
+ int n_samples) {
5108
+ int n_chunks = n_samples / vctx->n_window;
5109
+ if (n_samples % vctx->n_window != 0) {
5110
+ n_chunks += 1; // Add one more chunk for remaining samples.
5111
+ }
5112
+
5113
+ WHISPER_LOG_INFO("%s: detecting speech in %d samples\n", __func__, n_samples);
5114
+ WHISPER_LOG_INFO("%s: n_chunks: %d\n", __func__, n_chunks);
5115
+
5116
+ // Reset LSTM hidden/cell states
5117
+ wsp_ggml_backend_buffer_clear(vctx->buffer, 0);
5118
+
5119
+ vctx->probs.resize(n_chunks);
5120
+ WHISPER_LOG_INFO("%s: props size: %u\n", __func__, n_chunks);
5121
+
5122
+ std::vector<float> window(vctx->n_window, 0.0f);
5123
+
5124
+ auto & sched = vctx->sched.sched;
5125
+
5126
+ wsp_ggml_cgraph * gf = whisper_vad_build_graph(*vctx);
5127
+
5128
+ if (!wsp_ggml_backend_sched_alloc_graph(sched, gf)) {
5129
+ WHISPER_LOG_ERROR("%s: failed to allocate the compute buffer\n", __func__);
5130
+ return false;
5131
+ }
5132
+
5133
+ struct wsp_ggml_tensor * frame = wsp_ggml_graph_get_tensor(gf, "frame");
5134
+ struct wsp_ggml_tensor * prob = wsp_ggml_graph_get_tensor(gf, "prob");
5135
+
5136
+ // we are going to reuse the graph multiple times for each chunk
5137
+ const int64_t t_start_vad_us = wsp_ggml_time_us();
5138
+
5139
+ for (int i = 0; i < n_chunks; i++) {
5140
+ const int idx_start = i * vctx->n_window;
5141
+ const int idx_end = std::min(idx_start + vctx->n_window, n_samples);
5142
+
5143
+ const int chunk_len = idx_end - idx_start;
5144
+
5145
+ if (chunk_len < vctx->n_window) {
5146
+ WHISPER_LOG_INFO("%s: chunk_len: %d < n_window: %d\n", __func__, chunk_len, vctx->n_window);
5147
+ std::vector<float> partial_chunk(vctx->n_window, 0.0f);
5148
+ std::copy(samples + idx_start, samples + idx_end, partial_chunk.begin());
5149
+
5150
+ // Copy the zero-padded chunk to the window.
5151
+ const int samples_to_copy_max = vctx->n_window;
5152
+ const int samples_to_copy_cur = std::min(samples_to_copy_max, (int)partial_chunk.size());
5153
+ std::copy(partial_chunk.begin(), partial_chunk.begin() + samples_to_copy_cur, window.begin());
5154
+ if (samples_to_copy_cur < samples_to_copy_max) {
5155
+ std::fill(window.begin() + samples_to_copy_cur, window.end(), 0.0f);
5156
+ }
5157
+ } else {
5158
+ // Copy current frame samples to the window.
5159
+ const int samples_to_copy = std::min(idx_end - idx_start, vctx->n_window);
5160
+ std::copy(samples + idx_start, samples + idx_start + samples_to_copy, window.begin());
5161
+ }
5162
+
5163
+ // Set the frame tensor data with the samples.
5164
+ wsp_ggml_backend_tensor_set(frame, window.data(), 0, wsp_ggml_nelements(frame) * sizeof(float));
5165
+
5166
+ // do not reset the scheduler - we will reuse the graph in the next chunk
5167
+ if (!wsp_ggml_graph_compute_helper(sched, gf, vctx->n_threads, false)) {
5168
+ WHISPER_LOG_ERROR("%s: failed to compute VAD graph\n", __func__);
5169
+ break;
5170
+ }
5171
+
5172
+ // Get the probability for this chunk.
5173
+ wsp_ggml_backend_tensor_get(prob, &vctx->probs[i], 0, sizeof(float));
5174
+
5175
+ //WHISPER_LOG_DEBUG("chunk %d: p = %7.3f\n", i, probs[i]);
5176
+ }
5177
+
5178
+ vctx->t_vad_us += wsp_ggml_time_us() - t_start_vad_us;
5179
+ WHISPER_LOG_INFO("%s: vad time = %.2f ms processing %d samples\n", __func__, 1e-3f * vctx->t_vad_us, n_samples);
5180
+
5181
+ wsp_ggml_backend_sched_reset(sched);
5182
+
5183
+ return true;
5184
+ }
5185
+
5186
+ int whisper_vad_segments_n_segments(struct whisper_vad_segments * segments) {
5187
+ return segments->data.size();
5188
+ }
5189
+
5190
+ float whisper_vad_segments_get_segment_t0(struct whisper_vad_segments * segments, int i_segment) {
5191
+ return segments->data[i_segment].start;
5192
+ }
5193
+
5194
+ float whisper_vad_segments_get_segment_t1(struct whisper_vad_segments * segments, int i_segment) {
5195
+ return segments->data[i_segment].end;
5196
+ }
5197
+
5198
+ int whisper_vad_n_probs(struct whisper_vad_context * vctx) {
5199
+ return vctx->probs.size();
5200
+ }
5201
+
5202
+ float * whisper_vad_probs(struct whisper_vad_context * vctx) {
5203
+ return vctx->probs.data();
5204
+ }
5205
+
5206
+ struct whisper_vad_segments * whisper_vad_segments_from_probs(
5207
+ struct whisper_vad_context * vctx,
5208
+ whisper_vad_params params) {
5209
+ WHISPER_LOG_INFO("%s: detecting speech timestamps using %d probabilities\n", __func__, whisper_vad_n_probs(vctx));
5210
+
5211
+ int n_probs = whisper_vad_n_probs(vctx);
5212
+ float * probs = whisper_vad_probs(vctx);
5213
+ float threshold = params.threshold;
5214
+ int min_speech_duration_ms = params.min_speech_duration_ms;
5215
+ int min_silence_duration_ms = params.min_silence_duration_ms;
5216
+ float max_speech_duration_s = params.max_speech_duration_s;
5217
+ int speech_pad_ms = params.speech_pad_ms;
5218
+ int n_window = vctx->n_window;
5219
+ int sample_rate = WHISPER_SAMPLE_RATE;
5220
+ int min_silence_samples = sample_rate * min_silence_duration_ms / 1000;
5221
+ int audio_length_samples = n_probs * n_window;
5222
+
5223
+ // Min number of samples to be considered valid speech.
5224
+ int min_speech_samples = sample_rate * min_speech_duration_ms / 1000;
5225
+ int speech_pad_samples = sample_rate * speech_pad_ms / 1000;
5226
+
5227
+ // Max number of samples that a speech segment can contain before it is
5228
+ // split into multiple segments.
5229
+ int max_speech_samples;
5230
+ if (max_speech_duration_s > 100000.0f) {
5231
+ max_speech_samples = INT_MAX / 2;
5232
+ } else {
5233
+ int64_t temp = (int64_t)sample_rate * (int64_t)(max_speech_duration_s) - n_window - 2 * speech_pad_samples;
5234
+ max_speech_samples = (temp > INT_MAX) ? INT_MAX / 2 : (int)temp;
5235
+ if (max_speech_samples < 0) {
5236
+ max_speech_samples = INT_MAX / 2;
5237
+ }
5238
+ }
5239
+ // Detect silence period that exceeds this value, then that location (sample)
5240
+ // is marked as a potential place where the segment could be split if
5241
+ // max_speech_samples is reached. The value 98 was taken from the original
5242
+ // silaro-vad python implementation:
5243
+ //https://github.com/snakers4/silero-vad/blob/0dd45f0bcd7271463c234f3bae5ad25181f9df8b/src/silero_vad/utils_vad.py#L291
5244
+ int min_silence_samples_at_max_speech = sample_rate * 98 / 1000;
5245
+
5246
+ // Calculate lower threshold for detecting end of speech segments.
5247
+ float neg_threshold = threshold - 0.15f;
5248
+ if (neg_threshold < 0.01f) {
5249
+ neg_threshold = 0.01f;
5250
+ }
5251
+
5252
+ struct speech_segment_t {
5253
+ int start;
5254
+ int end;
5255
+ };
5256
+
5257
+ std::vector<speech_segment_t> speeches;
5258
+ speeches.reserve(256);
5259
+
5260
+ bool is_speech_segment = false;
5261
+ int temp_end = 0;
5262
+ int prev_end = 0;
5263
+ int next_start = 0;
5264
+ int curr_speech_start = 0;
5265
+ bool has_curr_speech = false;
5266
+
5267
+ for (int i = 0; i < n_probs; i++) {
5268
+ float curr_prob = probs[i];
5269
+ int curr_sample = n_window * i;
5270
+
5271
+ // Reset temp_end when we get back to speech
5272
+ if ((curr_prob >= threshold) && temp_end) {
5273
+ temp_end = 0;
5274
+ if (next_start < prev_end) {
5275
+ next_start = curr_sample;
5276
+ }
5277
+ }
5278
+
5279
+ // Start a new speech segment when probability exceeds threshold and not already in speech
5280
+ if ((curr_prob >= threshold) && !is_speech_segment) {
5281
+ is_speech_segment = true;
5282
+ curr_speech_start = curr_sample;
5283
+ has_curr_speech = true;
5284
+ continue;
5285
+ }
5286
+
5287
+ // Handle maximum speech duration
5288
+ if (is_speech_segment && (curr_sample - curr_speech_start) > max_speech_samples) {
5289
+ if (prev_end) {
5290
+ speeches.push_back({ curr_speech_start, prev_end });
5291
+ has_curr_speech = true;
5292
+
5293
+ if (next_start < prev_end) { // Previously reached silence and is still not speech
5294
+ is_speech_segment = false;
5295
+ has_curr_speech = false;
5296
+ } else {
5297
+ curr_speech_start = next_start;
5298
+ }
5299
+ prev_end = next_start = temp_end = 0;
5300
+ } else {
5301
+ speeches.push_back({ curr_speech_start, curr_sample });
5302
+
5303
+ prev_end = next_start = temp_end = 0;
5304
+ is_speech_segment = false;
5305
+ has_curr_speech = false;
5306
+ continue;
5307
+ }
5308
+ }
5309
+
5310
+ // Handle silence after speech
5311
+ if ((curr_prob < neg_threshold) && is_speech_segment) {
5312
+ if (!temp_end) {
5313
+ temp_end = curr_sample;
5314
+ }
5315
+
5316
+ // Track potential segment ends for max_speech handling
5317
+ if ((curr_sample - temp_end) > min_silence_samples_at_max_speech) {
5318
+ prev_end = temp_end;
5319
+ }
5320
+
5321
+ // Check if silence is long enough to end the segment
5322
+ if ((curr_sample - temp_end) < min_silence_samples) {
5323
+ continue;
5324
+ } else {
5325
+ // End the segment if it's long enough
5326
+ if ((temp_end - curr_speech_start) > min_speech_samples) {
5327
+ speeches.push_back({ curr_speech_start, temp_end });
5328
+ }
5329
+
5330
+ prev_end = next_start = temp_end = 0;
5331
+ is_speech_segment = false;
5332
+ has_curr_speech = false;
5333
+ continue;
5334
+ }
5335
+ }
5336
+ }
5337
+
5338
+ // Handle the case if we're still in a speech segment at the end
5339
+ if (has_curr_speech && (audio_length_samples - curr_speech_start) > min_speech_samples) {
5340
+ speeches.push_back({ curr_speech_start, audio_length_samples });
5341
+ }
5342
+
5343
+ // Merge adjacent segments with small gaps in between (post-processing)
5344
+ if (speeches.size() > 1) {
5345
+ int merged_count = 0;
5346
+ for (int i = 0; i < (int) speeches.size() - 1; i++) {
5347
+ // Define maximum gap allowed for merging (e.g., 200ms converted to samples)
5348
+ int max_merge_gap_samples = sample_rate * 200 / 1000;
5349
+
5350
+ // If the gap between this segment and the next is small enough
5351
+ if (speeches[i+1].start - speeches[i].end < max_merge_gap_samples) {
5352
+ // Merge by extending current segment to the end of next segment
5353
+ speeches[i].end = speeches[i+1].end;
5354
+ speeches.erase(speeches.begin() + i + 1);
5355
+
5356
+ i--;
5357
+ merged_count++;
5358
+ }
5359
+ }
5360
+ WHISPER_LOG_INFO("%s: Merged %d adjacent segments, now have %d segments\n",
5361
+ __func__, merged_count, (int) speeches.size());
5362
+ }
5363
+
5364
+ // Double-check for minimum speech duration
5365
+ for (int i = 0; i < (int) speeches.size(); i++) {
5366
+ if (speeches[i].end - speeches[i].start < min_speech_samples) {
5367
+ WHISPER_LOG_INFO("%s: Removing segment %d (too short: %d samples)\n",
5368
+ __func__, i, speeches[i].end - speeches[i].start);
5369
+
5370
+ speeches.erase(speeches.begin() + i);
5371
+ i--;
5372
+ }
5373
+ }
5374
+
5375
+ WHISPER_LOG_INFO("%s: Final speech segments after filtering: %d\n", __func__, (int) speeches.size());
5376
+
5377
+ // Allocate final segments
5378
+ std::vector<whisper_vad_segment> segments;
5379
+ if (speeches.size() > 0) {
5380
+ try {
5381
+ segments.resize(speeches.size());
5382
+ } catch (const std::bad_alloc &) {
5383
+ WHISPER_LOG_ERROR("%s: failed to allocate memory for final segments\n", __func__);
5384
+ return nullptr;
5385
+ }
5386
+ }
5387
+
5388
+ // Apply padding to segments and copy to final segments
5389
+ for (int i = 0; i < (int) speeches.size(); i++) {
5390
+ // Apply padding to the start of the first segment
5391
+ if (i == 0) {
5392
+ speeches[i].start =
5393
+ (speeches[i].start > speech_pad_samples) ?
5394
+ (speeches[i].start - speech_pad_samples) : 0;
5395
+ }
5396
+
5397
+ // Handle spacing between segments
5398
+ if (i < (int) speeches.size() - 1) {
5399
+ int silence_duration = speeches[i+1].start - speeches[i].end;
5400
+
5401
+ if (silence_duration < 2 * speech_pad_samples) {
5402
+ // If segments are close, split the difference
5403
+ speeches[i].end += silence_duration / 2;
5404
+ speeches[i+1].start =
5405
+ (speeches[i+1].start > silence_duration / 2) ?
5406
+ (speeches[i+1].start - silence_duration / 2) : 0;
5407
+ } else {
5408
+ // Otherwise, apply full padding to both
5409
+ speeches[i].end =
5410
+ (speeches[i].end + speech_pad_samples < audio_length_samples) ?
5411
+ (speeches[i].end + speech_pad_samples) : audio_length_samples;
5412
+ speeches[i+1].start =
5413
+ (speeches[i+1].start > speech_pad_samples) ?
5414
+ (speeches[i+1].start - speech_pad_samples) : 0;
5415
+ }
5416
+ } else {
5417
+ // Apply padding to the end of the last segment
5418
+ speeches[i].end =
5419
+ (speeches[i].end + speech_pad_samples < audio_length_samples) ?
5420
+ (speeches[i].end + speech_pad_samples) : audio_length_samples;
5421
+ }
5422
+
5423
+ // Convert from samples to centiseconds
5424
+ segments[i].start = samples_to_cs(speeches[i].start);
5425
+ segments[i].end = samples_to_cs(speeches[i].end);
5426
+
5427
+ WHISPER_LOG_INFO("%s: VAD segment %d: start = %.2f, end = %.2f (duration: %.2f)\n",
5428
+ __func__, i, segments[i].start/100.0, segments[i].end/100.0, (segments[i].end - segments[i].start)/100.0);
5429
+ }
5430
+
5431
+ whisper_vad_segments * vad_segments = new whisper_vad_segments;
5432
+ if (vad_segments == NULL) {
5433
+ WHISPER_LOG_ERROR("%s: failed to allocate memory for whisper_vad_segments\n", __func__);
5434
+ return nullptr;
5435
+ }
5436
+
5437
+ vad_segments->data = std::move(segments);
5438
+
5439
+ return vad_segments;
5440
+ }
5441
+
5442
+ struct whisper_vad_segments * whisper_vad_segments_from_samples(
5443
+ whisper_vad_context * vctx,
5444
+ whisper_vad_params params,
5445
+ const float * samples,
5446
+ int n_samples) {
5447
+ WHISPER_LOG_INFO("%s: detecting speech timestamps in %d samples\n", __func__, n_samples);
5448
+ if (!whisper_vad_detect_speech(vctx, samples, n_samples)) {
5449
+ WHISPER_LOG_ERROR("%s: failed to detect speech\n", __func__);
5450
+ return nullptr;
5451
+ }
5452
+ return whisper_vad_segments_from_probs(vctx, params);
5453
+ }
5454
+
5455
+ void whisper_vad_free(whisper_vad_context * ctx) {
5456
+ if (ctx) {
5457
+ for (wsp_ggml_context * context : ctx->model.ctxs) {
5458
+ wsp_ggml_free(context);
5459
+ }
5460
+
5461
+ for (wsp_ggml_backend_buffer_t buf : ctx->model.buffers) {
5462
+ wsp_ggml_backend_buffer_free(buf);
5463
+ }
5464
+
5465
+ wsp_ggml_backend_sched_free(ctx->sched.sched);
5466
+
5467
+ for (auto & backend : ctx->backends) {
5468
+ wsp_ggml_backend_free(backend);
5469
+ }
5470
+
5471
+
5472
+ delete ctx;
5473
+ }
5474
+ }
5475
+
5476
+ void whisper_vad_free_segments(whisper_vad_segments * segments) {
5477
+ if (segments) {
5478
+ delete segments;
5479
+ }
5480
+ }
5481
+
5482
+ //////////////////////////////////
5483
+ // Grammar - ported from llama.cpp
5484
+ //////////////////////////////////
5485
+
5486
+ // Decodes a UTF-8 string which may end in an incomplete sequence. Adds a terminating 0 for use as
5487
+ // pointer. If an invalid sequence is encountered, returns `whisper_partial_utf8.n_remain == -1`.
5488
+ static std::pair<std::vector<uint32_t>, whisper_partial_utf8> decode_utf8(
5489
+ const char * src,
5490
+ whisper_partial_utf8 partial_start) {
5491
+ static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4 };
5492
+ const char * pos = src;
5493
+ std::vector<uint32_t> code_points;
5494
+ uint32_t value = partial_start.value;
5495
+ int n_remain = partial_start.n_remain;
5496
+
5497
+ // continue previous decode, if applicable
5498
+ while (*pos != 0 && n_remain > 0) {
5499
+ uint8_t next_byte = static_cast<uint8_t>(*pos);
5500
+ if ((next_byte >> 6) != 2) {
5501
+ // invalid sequence, abort
5502
+ code_points.push_back(0);
5503
+ return std::make_pair(std::move(code_points), whisper_partial_utf8{ 0, -1 });
5504
+ }
5505
+ value = (value << 6) + (next_byte & 0x3F);
5506
+ ++pos;
5507
+ --n_remain;
5508
+ }
5509
+
5510
+ if (partial_start.n_remain > 0 && n_remain == 0) {
5511
+ code_points.push_back(value);
5512
+ }
5513
+
5514
+ // decode any subsequent utf-8 sequences, which may end in an incomplete one
5515
+ while (*pos != 0) {
5516
+ uint8_t first_byte = static_cast<uint8_t>(*pos);
5517
+ uint8_t highbits = first_byte >> 4;
5518
+ n_remain = lookup[highbits] - 1;
4335
5519
 
4336
5520
  if (n_remain < 0) {
4337
5521
  // invalid sequence, abort
@@ -4450,7 +5634,7 @@ static void whisper_grammar_advance_stack(
4450
5634
  std::vector<std::vector<const whisper_grammar_element *>> & new_stacks) {
4451
5635
 
4452
5636
  if (stack.empty()) {
4453
- new_stacks.push_back(stack);
5637
+ new_stacks.emplace_back();
4454
5638
  return;
4455
5639
  }
4456
5640
 
@@ -4771,7 +5955,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
4771
5955
  /*.detect_language =*/ false,
4772
5956
 
4773
5957
  /*.suppress_blank =*/ true,
4774
- /*.suppress_non_speech_tokens =*/ false,
5958
+ /*.suppress_nst =*/ false,
4775
5959
 
4776
5960
  /*.temperature =*/ 0.0f,
4777
5961
  /*.max_initial_ts =*/ 1.0f,
@@ -4811,6 +5995,11 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
4811
5995
  /*.n_grammar_rules =*/ 0,
4812
5996
  /*.i_start_rule =*/ 0,
4813
5997
  /*.grammar_penalty =*/ 100.0f,
5998
+
5999
+ /*.vad =*/ false,
6000
+ /*.vad_model_path =*/ nullptr,
6001
+
6002
+ /* vad_params =*/ whisper_vad_default_params(),
4814
6003
  };
4815
6004
 
4816
6005
  switch (strategy) {
@@ -4921,6 +6110,42 @@ static const std::vector<std::string> non_speech_tokens = {
4921
6110
  "♪♪♪","♩", "♪", "♫", "♬", "♭", "♮", "♯"
4922
6111
  };
4923
6112
 
6113
+ static void whisper_compute_logprobs(
6114
+ const std::vector<float> & logits,
6115
+ const int n_logits,
6116
+ std::vector<float> & logprobs) {
6117
+ const float logit_max = *std::max_element(logits.begin(), logits.end());
6118
+ float logsumexp = 0.0f;
6119
+ for (int i = 0; i < n_logits; ++i) {
6120
+ if (logits[i] > -INFINITY) {
6121
+ logsumexp += expf(logits[i] - logit_max);
6122
+ }
6123
+ }
6124
+ logsumexp = logf(logsumexp) + logit_max;
6125
+
6126
+ for (int i = 0; i < n_logits; ++i) {
6127
+ if (logits[i] > -INFINITY) {
6128
+ logprobs[i] = logits[i] - logsumexp;
6129
+ } else {
6130
+ logprobs[i] = -INFINITY;
6131
+ }
6132
+ }
6133
+ }
6134
+
6135
+ static void whisper_compute_probs(
6136
+ const std::vector<float> & logits,
6137
+ const int n_logits,
6138
+ const std::vector<float> & logprobs,
6139
+ std::vector<float> & probs) {
6140
+ for (int i = 0; i < n_logits; ++i) {
6141
+ if (logits[i] == -INFINITY) {
6142
+ probs[i] = 0.0f;
6143
+ } else {
6144
+ probs[i] = expf(logprobs[i]);
6145
+ }
6146
+ }
6147
+ }
6148
+
4924
6149
  // process the logits for the selected decoder
4925
6150
  // - applies logit filters
4926
6151
  // - computes logprobs and probs
@@ -4982,7 +6207,7 @@ static void whisper_process_logits(
4982
6207
 
4983
6208
  // suppress sot and nosp tokens
4984
6209
  logits[vocab.token_sot] = -INFINITY;
4985
- logits[vocab.token_nosp] = -INFINITY; // TODO: ignore this token for now
6210
+ logits[vocab.token_nosp] = -INFINITY;
4986
6211
 
4987
6212
  // [TDRZ] when tinydiarize is disabled, suppress solm token
4988
6213
  if (params.tdrz_enable == false) {
@@ -5019,7 +6244,7 @@ static void whisper_process_logits(
5019
6244
 
5020
6245
  // suppress non-speech tokens
5021
6246
  // ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253
5022
- if (params.suppress_non_speech_tokens) {
6247
+ if (params.suppress_nst) {
5023
6248
  for (const std::string & token : non_speech_tokens) {
5024
6249
  const std::string suppress_tokens[] = {token, " " + token};
5025
6250
  for (const std::string & suppress_token : suppress_tokens) {
@@ -5081,24 +6306,7 @@ static void whisper_process_logits(
5081
6306
  }
5082
6307
 
5083
6308
  // populate the logprobs array (log_softmax)
5084
- {
5085
- const float logit_max = *std::max_element(logits.begin(), logits.end());
5086
- float logsumexp = 0.0f;
5087
- for (int i = 0; i < n_logits; ++i) {
5088
- if (logits[i] > -INFINITY) {
5089
- logsumexp += expf(logits[i] - logit_max);
5090
- }
5091
- }
5092
- logsumexp = logf(logsumexp) + logit_max;
5093
-
5094
- for (int i = 0; i < n_logits; ++i) {
5095
- if (logits[i] > -INFINITY) {
5096
- logprobs[i] = logits[i] - logsumexp;
5097
- } else {
5098
- logprobs[i] = -INFINITY;
5099
- }
5100
- }
5101
- }
6309
+ whisper_compute_logprobs(logits, n_logits, logprobs);
5102
6310
 
5103
6311
  // if sum of probability over timestamps is above any other token, sample timestamp
5104
6312
  // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L431-L437
@@ -5156,15 +6364,7 @@ static void whisper_process_logits(
5156
6364
  }
5157
6365
 
5158
6366
  // compute probs
5159
- {
5160
- for (int i = 0; i < n_logits; ++i) {
5161
- if (logits[i] == -INFINITY) {
5162
- probs[i] = 0.0f;
5163
- } else {
5164
- probs[i] = expf(logprobs[i]);
5165
- }
5166
- }
5167
- }
6367
+ whisper_compute_probs(logits, n_logits, logprobs, probs);
5168
6368
 
5169
6369
  #if 0
5170
6370
  // print first 100 logits - token string : logit
@@ -5416,6 +6616,186 @@ static void whisper_sequence_score(
5416
6616
  }
5417
6617
  }
5418
6618
 
6619
+ static bool whisper_vad(
6620
+ struct whisper_context * ctx,
6621
+ struct whisper_state * state,
6622
+ struct whisper_full_params params,
6623
+ const float * samples,
6624
+ int n_samples,
6625
+ std::vector<float> & filtered_samples) {
6626
+ WHISPER_LOG_INFO("%s: VAD is enabled, processing speech segments only\n", __func__);
6627
+ int filtered_n_samples = 0;
6628
+
6629
+ // Clear any existing mapping table
6630
+ state->vad_mapping_table.clear();
6631
+ state->has_vad_segments = false;
6632
+
6633
+ if (state->vad_context == nullptr) {
6634
+ struct whisper_vad_context_params vad_ctx_params = whisper_vad_default_context_params();
6635
+ struct whisper_vad_context * vctx = whisper_vad_init_from_file_with_params(params.vad_model_path, vad_ctx_params);
6636
+ if (vctx == nullptr) {
6637
+ WHISPER_LOG_ERROR("%s: failed to initialize VAD context\n", __func__);
6638
+ return false;
6639
+ }
6640
+ state->vad_context = vctx;
6641
+ }
6642
+ auto vctx = state->vad_context;
6643
+
6644
+ const whisper_vad_params & vad_params = params.vad_params;
6645
+
6646
+ whisper_vad_segments * vad_segments = whisper_vad_segments_from_samples(vctx, vad_params, samples, n_samples);
6647
+
6648
+ if (vad_segments->data.size() > 0) {
6649
+ state->has_vad_segments = true;
6650
+ ctx->state->vad_segments.clear();
6651
+ ctx->state->vad_segments.reserve(vad_segments->data.size());
6652
+
6653
+ // Initialize the time mapping table
6654
+ state->vad_mapping_table.clear();
6655
+ state->vad_mapping_table.reserve(vad_segments->data.size() * 4);
6656
+
6657
+ WHISPER_LOG_INFO("%s: detected %d speech segments\n", __func__, (int)vad_segments->data.size());
6658
+ float overlap_seconds = vad_params.samples_overlap;
6659
+ int overlap_samples = overlap_seconds * WHISPER_SAMPLE_RATE;
6660
+
6661
+ for (int i = 0; i < (int)vad_segments->data.size(); i++) {
6662
+ int segment_start_samples = cs_to_samples(vad_segments->data[i].start);
6663
+ int segment_end_samples = cs_to_samples(vad_segments->data[i].end);
6664
+
6665
+ if (i < (int)vad_segments->data.size() - 1) {
6666
+ segment_end_samples += overlap_samples;
6667
+ }
6668
+ segment_end_samples = std::min(segment_end_samples, n_samples - 1);
6669
+ filtered_n_samples += (segment_end_samples - segment_start_samples);
6670
+
6671
+ WHISPER_LOG_INFO("%s: Including segment %d: %.2f - %.2f (duration: %.2f)\n",
6672
+ __func__, i, vad_segments->data[i].start/100.0,
6673
+ (vad_segments->data[i].end/100.0 + (i < (int)vad_segments->data.size() - 1 ? overlap_seconds : 0)),
6674
+ (vad_segments->data[i].end - vad_segments->data[i].start)/100.0 +
6675
+ (i < (int)vad_segments->data.size() - 1 ? overlap_seconds : 0));
6676
+ }
6677
+
6678
+ int silence_samples = 0.1 * WHISPER_SAMPLE_RATE;
6679
+ int total_silence_samples = (vad_segments->data.size() > 1) ? (vad_segments->data.size() - 1) * silence_samples : 0;
6680
+ int total_samples_needed = filtered_n_samples + total_silence_samples;
6681
+
6682
+ WHISPER_LOG_INFO("%s: total duration of speech segments: %.2f seconds\n",
6683
+ __func__, (float)filtered_n_samples / WHISPER_SAMPLE_RATE);
6684
+
6685
+ try {
6686
+ filtered_samples.resize(total_samples_needed);
6687
+ } catch (const std::bad_alloc & /* e */) {
6688
+ WHISPER_LOG_ERROR("%s: failed to allocate memory for filtered samples\n", __func__);
6689
+ whisper_vad_free_segments(vad_segments);
6690
+ whisper_vad_free(vctx);
6691
+ return false;
6692
+ }
6693
+
6694
+ int offset = 0;
6695
+ for (int i = 0; i < (int)vad_segments->data.size(); i++) {
6696
+ int segment_start_samples = cs_to_samples(vad_segments->data[i].start);
6697
+ int segment_end_samples = cs_to_samples(vad_segments->data[i].end);
6698
+
6699
+ if (i < (int)vad_segments->data.size() - 1) {
6700
+ segment_end_samples += overlap_samples;
6701
+ }
6702
+
6703
+ segment_start_samples = std::min(segment_start_samples, n_samples - 1);
6704
+ segment_end_samples = std::min(segment_end_samples, n_samples);
6705
+ int segment_length = segment_end_samples - segment_start_samples;
6706
+ if (segment_length > 0) {
6707
+ whisper_state::vad_segment_info segment;
6708
+
6709
+ segment.orig_start = vad_segments->data[i].start;
6710
+ segment.orig_end = vad_segments->data[i].end;
6711
+
6712
+ segment.vad_start = samples_to_cs(offset);
6713
+ segment.vad_end = samples_to_cs(offset + segment_length);
6714
+
6715
+ // Add segment boundaries to mapping table
6716
+ vad_time_mapping start_mapping = {segment.vad_start, segment.orig_start};
6717
+ vad_time_mapping end_mapping = {segment.vad_end, segment.orig_end};
6718
+
6719
+ state->vad_mapping_table.push_back(start_mapping);
6720
+ state->vad_mapping_table.push_back(end_mapping);
6721
+
6722
+ // Add intermediate points for longer segments to improve interpolation accuracy
6723
+ const int64_t min_segment_length = 100; // 1 second
6724
+ const int64_t point_interval = 20; // Add a point every 200ms
6725
+
6726
+ if (segment.vad_end - segment.vad_start > min_segment_length) {
6727
+ int64_t segment_duration = segment.vad_end - segment.vad_start;
6728
+ int num_points = (int)(segment_duration / point_interval) - 1;
6729
+
6730
+ for (int j = 1; j <= num_points; j++) {
6731
+ int64_t vad_time = segment.vad_start + j * point_interval;
6732
+
6733
+ if (vad_time >= segment.vad_end) continue;
6734
+
6735
+ int64_t vad_elapsed = vad_time - segment.vad_start;
6736
+ int64_t vad_total = segment.vad_end - segment.vad_start;
6737
+ int64_t orig_total = segment.orig_end - segment.orig_start;
6738
+ int64_t orig_time = segment.orig_start + (vad_elapsed * orig_total) / vad_total;
6739
+
6740
+ vad_time_mapping intermediate_mapping = {vad_time, orig_time};
6741
+ state->vad_mapping_table.push_back(intermediate_mapping);
6742
+ }
6743
+ }
6744
+
6745
+ WHISPER_LOG_INFO("%s: vad_segment_info: orig_start: %.2f, orig_end: %.2f, vad_start: %.2f, vad_end: %.2f\n",
6746
+ __func__, segment.orig_start/100.0, segment.orig_end/100.0, segment.vad_start/100.0, segment.vad_end/100.0);
6747
+ ctx->state->vad_segments.push_back(segment);
6748
+
6749
+ // Copy this speech segment
6750
+ memcpy(filtered_samples.data() + offset, samples + segment_start_samples, segment_length * sizeof(float));
6751
+ offset += segment_length;
6752
+
6753
+ // Add silence after this segment (except after the last segment)
6754
+ if (i < (int)vad_segments->data.size() - 1) {
6755
+ // Calculate the start and end time of the silence gap in processed audio
6756
+ int64_t silence_start_vad = samples_to_cs(offset);
6757
+ int64_t silence_end_vad = samples_to_cs(offset + silence_samples);
6758
+ // Calculate the corresponding original times
6759
+ int64_t orig_silence_start = segment.orig_end;
6760
+ int64_t orig_silence_end = vad_segments->data[i+1].start;
6761
+
6762
+ // Add mapping points for silence boundaries
6763
+ state->vad_mapping_table.push_back({silence_start_vad, orig_silence_start});
6764
+ state->vad_mapping_table.push_back({silence_end_vad, orig_silence_end});
6765
+
6766
+ // Fill with zeros (silence)
6767
+ memset(filtered_samples.data() + offset, 0, silence_samples * sizeof(float));
6768
+ offset += silence_samples;
6769
+ }
6770
+ }
6771
+ }
6772
+
6773
+ // Sort the mapping table by processed time
6774
+ std::sort(state->vad_mapping_table.begin(), state->vad_mapping_table.end(),
6775
+ [](const vad_time_mapping& a, const vad_time_mapping& b) {
6776
+ return a.processed_time < b.processed_time;
6777
+ });
6778
+
6779
+ // Remove any duplicate processed times to ensure monotonicity which is
6780
+ // needed for binary search and interpolation later.
6781
+ if (!state->vad_mapping_table.empty()) {
6782
+ auto last = std::unique(state->vad_mapping_table.begin(), state->vad_mapping_table.end(),
6783
+ [](const vad_time_mapping& a, const vad_time_mapping& b) {
6784
+ return a.processed_time == b.processed_time;
6785
+ });
6786
+ state->vad_mapping_table.erase(last, state->vad_mapping_table.end());
6787
+ }
6788
+
6789
+ WHISPER_LOG_INFO("%s: Created time mapping table with %d points\n", __func__, (int)state->vad_mapping_table.size());
6790
+
6791
+ filtered_n_samples = offset;
6792
+ WHISPER_LOG_INFO("%s: Reduced audio from %d to %d samples (%.1f%% reduction)\n",
6793
+ __func__, n_samples, filtered_n_samples, 100.0f * (1.0f - (float)filtered_n_samples / n_samples));
6794
+ }
6795
+
6796
+ return true;
6797
+ }
6798
+
5419
6799
  int whisper_full_with_state(
5420
6800
  struct whisper_context * ctx,
5421
6801
  struct whisper_state * state,
@@ -5465,11 +6845,13 @@ int whisper_full_with_state(
5465
6845
  const int seek_start = params.offset_ms/10;
5466
6846
  const int seek_end = params.duration_ms == 0 ? whisper_n_len_from_state(state) : seek_start + params.duration_ms/10;
5467
6847
 
5468
- // if length of spectrogram is less than 1.0s (100 frames), then return
5469
- // basically don't process anything that is less than 1.0s
5470
- // see issue #39: https://github.com/ggerganov/whisper.cpp/issues/39
5471
- if (seek_end < seek_start + 100) {
5472
- WHISPER_LOG_WARN("%s: input is too short - %d ms < 1000 ms. consider padding the input audio with silence\n", __func__, (seek_end - seek_start)*10);
6848
+ // if length of spectrogram is less than 100ms (10 frames), then return
6849
+ // basically don't process anything that is less than 100ms
6850
+ // ref: https://github.com/ggml-org/whisper.cpp/issues/2065
6851
+ const int delta_min = 10;
6852
+
6853
+ if (seek_end < seek_start + delta_min) {
6854
+ WHISPER_LOG_WARN("%s: input is too short - %d ms < 100 ms. consider padding the input audio with silence\n", __func__, (seek_end - seek_start)*10);
5473
6855
  return 0;
5474
6856
  }
5475
6857
 
@@ -5516,7 +6898,7 @@ int whisper_full_with_state(
5516
6898
  decoder.logprobs.resize(ctx->vocab.n_vocab);
5517
6899
  decoder.logits_id.reserve(ctx->model.hparams.n_vocab);
5518
6900
 
5519
- decoder.rng = std::mt19937(0);
6901
+ decoder.rng = std::mt19937(j);
5520
6902
  }
5521
6903
 
5522
6904
  // the accumulated text context so far
@@ -5613,8 +6995,8 @@ int whisper_full_with_state(
5613
6995
  ctx, state, progress_cur, params.progress_callback_user_data);
5614
6996
  }
5615
6997
 
5616
- // if only 1 second left, then stop
5617
- if (seek + 100 >= seek_end) {
6998
+ // if only 100ms left, then stop
6999
+ if (seek + delta_min >= seek_end) {
5618
7000
  break;
5619
7001
  }
5620
7002
 
@@ -5743,6 +7125,18 @@ int whisper_full_with_state(
5743
7125
  return -8;
5744
7126
  }
5745
7127
 
7128
+ // Calculate no_speech probability after first decode.
7129
+ // This has to be done before any logit filtering. Hence we cannot use the probs from the whisper_process_logits.
7130
+ {
7131
+ const int n_logits = ctx->vocab.id_to_token.size();
7132
+ std::vector<float> logprobs(n_logits);
7133
+ std::vector<float> probs(n_logits);
7134
+
7135
+ whisper_compute_logprobs(state->logits, n_logits, logprobs);
7136
+ whisper_compute_probs(state->logits, n_logits, logprobs, probs);
7137
+ state->no_speech_prob = probs[whisper_token_nosp(ctx)];
7138
+ }
7139
+
5746
7140
  {
5747
7141
  const int64_t t_start_sample_us = wsp_ggml_time_us();
5748
7142
 
@@ -5949,10 +7343,10 @@ int whisper_full_with_state(
5949
7343
  // end of segment
5950
7344
  if (token.id == whisper_token_eot(ctx) || // end of text token
5951
7345
  (params.max_tokens > 0 && i >= params.max_tokens) || // max tokens per segment reached
5952
- (has_ts && seek + seek_delta + 100 >= seek_end) // end of audio reached
7346
+ (has_ts && seek + seek_delta + delta_min >= seek_end) // end of audio reached (100ms)
5953
7347
  ) {
5954
7348
  if (result_len == 0 && !params.no_timestamps) {
5955
- if (seek + seek_delta + 100 >= seek_end) {
7349
+ if (seek + seek_delta + delta_min >= seek_end) {
5956
7350
  result_len = i + 1;
5957
7351
  } else {
5958
7352
  WHISPER_LOG_DEBUG("%s: decoder %d failed (result_len = 0)\n", __func__, j);
@@ -6134,8 +7528,9 @@ int whisper_full_with_state(
6134
7528
  if (it != (int) temperatures.size() - 1) {
6135
7529
  const auto & decoder = state->decoders[best_decoder_id];
6136
7530
 
6137
- if (decoder.failed || decoder.sequence.avg_logprobs < params.logprob_thold) {
6138
- WHISPER_LOG_DEBUG("%s: failed due to avg_logprobs %8.5f < %8.5f\n", __func__, decoder.sequence.avg_logprobs, params.logprob_thold);
7531
+ if (decoder.failed ||
7532
+ (decoder.sequence.avg_logprobs < params.logprob_thold && state->no_speech_prob < params.no_speech_thold)) {
7533
+ WHISPER_LOG_DEBUG("%s: failed due to avg_logprobs %8.5f < %8.5f and no_speech_prob %8.5f < %8.5f\n", __func__, decoder.sequence.avg_logprobs, params.logprob_thold, state->no_speech_prob, params.no_speech_thold);
6139
7534
  success = false;
6140
7535
  state->n_fail_p++;
6141
7536
  }
@@ -6156,7 +7551,7 @@ int whisper_full_with_state(
6156
7551
  {
6157
7552
  const auto & best_decoder = state->decoders[best_decoder_id];
6158
7553
 
6159
- const auto seek_delta = best_decoder.seek_delta;
7554
+ auto seek_delta = best_decoder.seek_delta;
6160
7555
  const auto result_len = best_decoder.sequence.result_len;
6161
7556
 
6162
7557
  const auto & tokens_cur = best_decoder.sequence.tokens;
@@ -6164,6 +7559,9 @@ int whisper_full_with_state(
6164
7559
  // [EXPERIMENTAL] Token-level timestamps with DTW
6165
7560
  const auto n_segments_before = state->result_all.size();
6166
7561
 
7562
+ const bool is_no_speech = (state->no_speech_prob > params.no_speech_thold &&
7563
+ best_decoder.sequence.avg_logprobs < params.logprob_thold);
7564
+
6167
7565
  //WHISPER_LOG_DEBUG("prompt_init.size() = %d, prompt.size() = %d, result_len = %d, seek_delta = %d\n", prompt_init.size(), prompt.size(), result_len, seek_delta);
6168
7566
 
6169
7567
  // update prompt_past
@@ -6172,11 +7570,11 @@ int whisper_full_with_state(
6172
7570
  prompt_past.insert(prompt_past.end(), prompt.begin() + 1, prompt.end() - prompt_init.size());
6173
7571
  }
6174
7572
 
6175
- for (int i = 0; i < result_len; ++i) {
7573
+ for (int i = 0; i < result_len && !is_no_speech; ++i) {
6176
7574
  prompt_past.push_back(tokens_cur[i].id);
6177
7575
  }
6178
7576
 
6179
- if (!tokens_cur.empty() && ctx->model.n_loaded > 0) {
7577
+ if (!tokens_cur.empty() && ctx->model.n_loaded > 0 && !is_no_speech) {
6180
7578
  int i0 = 0;
6181
7579
  auto t0 = seek + 2*(tokens_cur.front().tid - whisper_token_beg(ctx));
6182
7580
 
@@ -6215,7 +7613,7 @@ int whisper_full_with_state(
6215
7613
 
6216
7614
  //printf("tt0 = %d, tt1 = %d, text = %s, token = %s, token_id = %d, tid = %d\n", tt0, tt1, text.c_str(), ctx->vocab.id_to_token[tokens_cur[i].id].c_str(), tokens_cur[i].id, tokens_cur[i].tid);
6217
7615
 
6218
- result_all.push_back({ tt0, tt1, text, {}, speaker_turn_next });
7616
+ result_all.push_back({ tt0, tt1, text, state->no_speech_prob, {}, speaker_turn_next });
6219
7617
  for (int j = i0; j <= i; j++) {
6220
7618
  result_all.back().tokens.push_back(tokens_cur[j]);
6221
7619
  }
@@ -6260,7 +7658,7 @@ int whisper_full_with_state(
6260
7658
  }
6261
7659
  }
6262
7660
 
6263
- result_all.push_back({ tt0, tt1, text, {} , speaker_turn_next });
7661
+ result_all.push_back({ tt0, tt1, text, state->no_speech_prob, {}, speaker_turn_next });
6264
7662
  for (int j = i0; j < (int) tokens_cur.size(); j++) {
6265
7663
  result_all.back().tokens.push_back(tokens_cur[j]);
6266
7664
  }
@@ -6297,6 +7695,15 @@ int whisper_full_with_state(
6297
7695
  }
6298
7696
  }
6299
7697
 
7698
+ // ref: https://github.com/ggml-org/whisper.cpp/pull/2629
7699
+ const bool single_timestamp_ending = tokens_cur.size() > 1 &&
7700
+ tokens_cur[tokens_cur.size() - 2].id < whisper_token_beg(ctx) &&
7701
+ tokens_cur[tokens_cur.size() - 1].id > whisper_token_beg(ctx);
7702
+ if (single_timestamp_ending) {
7703
+ WHISPER_LOG_DEBUG("single timestamp ending - skip entire chunk\n");
7704
+ seek_delta = std::min(seek_end - seek, WHISPER_CHUNK_SIZE * 100);
7705
+ }
7706
+
6300
7707
  // update audio window
6301
7708
  seek += seek_delta;
6302
7709
 
@@ -6312,6 +7719,21 @@ int whisper_full(
6312
7719
  struct whisper_full_params params,
6313
7720
  const float * samples,
6314
7721
  int n_samples) {
7722
+
7723
+ std::vector<float> vad_samples;
7724
+ if (params.vad) {
7725
+ WHISPER_LOG_INFO("%s: VAD is enabled, processing speech segments only\n", __func__);
7726
+ if (!whisper_vad(ctx, ctx->state, params, samples, n_samples, vad_samples)) {
7727
+ WHISPER_LOG_ERROR("%s: failed to compute VAD\n", __func__);
7728
+ return -1;
7729
+ }
7730
+ if (vad_samples.empty()) {
7731
+ ctx->state->result_all.clear();
7732
+ return 0;
7733
+ }
7734
+ samples = vad_samples.data();
7735
+ n_samples = vad_samples.size();
7736
+ }
6315
7737
  return whisper_full_with_state(ctx, ctx->state, params, samples, n_samples);
6316
7738
  }
6317
7739
 
@@ -6321,9 +7743,24 @@ int whisper_full_parallel(
6321
7743
  const float * samples,
6322
7744
  int n_samples,
6323
7745
  int n_processors) {
7746
+
6324
7747
  if (n_processors == 1) {
6325
7748
  return whisper_full(ctx, params, samples, n_samples);
6326
7749
  }
7750
+
7751
+ std::vector<float> vad_samples;
7752
+ if (params.vad) {
7753
+ WHISPER_LOG_INFO("%s: VAD is enabled, processing speech segments only\n", __func__);
7754
+ if (!whisper_vad(ctx, ctx->state, params, samples, n_samples, vad_samples)) {
7755
+ WHISPER_LOG_ERROR("%s: failed to compute VAD\n", __func__);
7756
+ return -1;
7757
+ }
7758
+ if (vad_samples.empty()) {
7759
+ return 0;
7760
+ }
7761
+ samples = vad_samples.data();
7762
+ n_samples = vad_samples.size();
7763
+ }
6327
7764
  int ret = 0;
6328
7765
 
6329
7766
  // prepare separate states for each thread
@@ -6446,20 +7883,93 @@ int whisper_full_lang_id(struct whisper_context * ctx) {
6446
7883
  return ctx->state->lang_id;
6447
7884
  }
6448
7885
 
6449
- int64_t whisper_full_get_segment_t0_from_state(struct whisper_state * state, int i_segment) {
6450
- return state->result_all[i_segment].t0;
7886
+ static int64_t map_processed_to_original_time(int64_t processed_time, const std::vector<vad_time_mapping> & mapping_table) {
7887
+ if (mapping_table.empty()) {
7888
+ return processed_time;
7889
+ }
7890
+
7891
+ if (processed_time <= mapping_table.front().processed_time) {
7892
+ return mapping_table.front().original_time; // Before first mapping point
7893
+ }
7894
+
7895
+ if (processed_time >= mapping_table.back().processed_time) {
7896
+ return mapping_table.back().original_time; // After last mapping point
7897
+ }
7898
+
7899
+ // Binary search over the time map that finds the first entry that has a
7900
+ // processed time greater than or equal to the current processed time.
7901
+ auto upper = std::lower_bound(mapping_table.begin(), mapping_table.end(), processed_time,
7902
+ [](const vad_time_mapping & entry, int64_t time) {
7903
+ return entry.processed_time < time;
7904
+ }
7905
+ );
7906
+
7907
+ // If exact match found
7908
+ if (upper->processed_time == processed_time) {
7909
+ return upper->original_time;
7910
+ }
7911
+
7912
+ // Need to interpolate between two points
7913
+ auto lower = upper - 1;
7914
+
7915
+ int64_t processed_diff = upper->processed_time - lower->processed_time;
7916
+ int64_t original_diff = upper->original_time - lower->original_time;
7917
+ int64_t offset = processed_time - lower->processed_time;
7918
+
7919
+ if (processed_diff == 0) {
7920
+ return lower->original_time;
7921
+ }
7922
+
7923
+ // Perform linear interpolation
7924
+ return lower->original_time + (offset * original_diff) / processed_diff;
6451
7925
  }
6452
7926
 
6453
- int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment) {
6454
- return ctx->state->result_all[i_segment].t0;
7927
+ // Function to get the starting timestamp of a segment
7928
+ int64_t whisper_full_get_segment_t0_from_state(struct whisper_state * state, int i_segment) {
7929
+ // If VAD wasn't used, return the original timestamp
7930
+ if (!state->has_vad_segments || state->vad_mapping_table.empty()) {
7931
+ return state->result_all[i_segment].t0;
7932
+ }
7933
+
7934
+ // Get the processed timestamp
7935
+ int64_t t0 = state->result_all[i_segment].t0;
7936
+
7937
+ // Map to original time using the mapping table
7938
+ return map_processed_to_original_time(t0, state->vad_mapping_table);
6455
7939
  }
6456
7940
 
7941
+ // Function to get the ending timestamp of a segment
6457
7942
  int64_t whisper_full_get_segment_t1_from_state(struct whisper_state * state, int i_segment) {
6458
- return state->result_all[i_segment].t1;
7943
+ // If VAD wasn't used, return the original timestamp
7944
+ if (!state->has_vad_segments || state->vad_mapping_table.empty()) {
7945
+ return state->result_all[i_segment].t1;
7946
+ }
7947
+
7948
+ // Get the processed timestamp
7949
+ int64_t t1 = state->result_all[i_segment].t1;
7950
+
7951
+ // Map to original time using the mapping table
7952
+ int64_t orig_t1 = map_processed_to_original_time(t1, state->vad_mapping_table);
7953
+
7954
+ // Get the corresponding t0 for this segment
7955
+ int64_t orig_t0 = whisper_full_get_segment_t0_from_state(state, i_segment);
7956
+
7957
+ // Ensure minimum duration to prevent zero-length segments
7958
+ const int64_t min_duration = 10; // 10ms minimum
7959
+ if (orig_t1 - orig_t0 < min_duration) {
7960
+ orig_t1 = orig_t0 + min_duration;
7961
+ }
7962
+
7963
+ return orig_t1;
7964
+ }
7965
+
7966
+
7967
+ int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment) {
7968
+ return whisper_full_get_segment_t0_from_state(ctx->state, i_segment);
6459
7969
  }
6460
7970
 
6461
7971
  int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment) {
6462
- return ctx->state->result_all[i_segment].t1;
7972
+ return whisper_full_get_segment_t1_from_state(ctx->state, i_segment);
6463
7973
  }
6464
7974
 
6465
7975
  bool whisper_full_get_segment_speaker_turn_next_from_state(struct whisper_state * state, int i_segment) {
@@ -6518,6 +8028,14 @@ float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int
6518
8028
  return ctx->state->result_all[i_segment].tokens[i_token].p;
6519
8029
  }
6520
8030
 
8031
+ float whisper_full_get_segment_no_speech_prob(struct whisper_context * ctx, int i_segment) {
8032
+ return ctx->state->result_all[i_segment].no_speech_prob;
8033
+ }
8034
+
8035
+ float whisper_full_get_segment_no_speech_prob_from_state(struct whisper_state * state, int i_segment) {
8036
+ return state->result_all[i_segment].no_speech_prob;
8037
+ }
8038
+
6521
8039
  // =================================================================================================
6522
8040
 
6523
8041
  //
@@ -6698,7 +8216,6 @@ WHISPER_API const char * whisper_bench_wsp_ggml_mul_mat_str(int n_threads) {
6698
8216
  // c: N*N*sizeof(float)
6699
8217
  // when F16 is used, there is an extra work buffer of size N*N*sizeof(float)
6700
8218
  std::vector<uint8_t> buf(3llu*N_max*N_max*sizeof(float) + 3*wsp_ggml_tensor_overhead() + wsp_ggml_graph_overhead());
6701
- std::vector<uint8_t> work;
6702
8219
 
6703
8220
  // put a bunch of random data in the buffer
6704
8221
  for (size_t i = 0; i < buf.size(); i++) buf[i] = i;
@@ -6755,12 +8272,12 @@ WHISPER_API const char * whisper_bench_wsp_ggml_mul_mat_str(int n_threads) {
6755
8272
  double tsum = 0.0;
6756
8273
 
6757
8274
  // heat-up
6758
- wsp_ggml_graph_compute_helper(gf, work, n_threads, nullptr, nullptr);
8275
+ wsp_ggml_graph_compute_helper(gf, n_threads, nullptr, nullptr);
6759
8276
 
6760
8277
  for (int i = 0; i < n_max; ++i) {
6761
8278
  const int64_t t0 = wsp_ggml_time_us();
6762
8279
 
6763
- wsp_ggml_graph_compute_helper(gf, work, n_threads, nullptr, nullptr);
8280
+ wsp_ggml_graph_compute_helper(gf, n_threads, nullptr, nullptr);
6764
8281
 
6765
8282
  const int64_t t1 = wsp_ggml_time_us();
6766
8283
 
@@ -6813,10 +8330,6 @@ WHISPER_API const char * whisper_bench_wsp_ggml_mul_mat_str(int n_threads) {
6813
8330
  // token-level timestamps
6814
8331
  //
6815
8332
 
6816
- static int timestamp_to_sample(int64_t t, int n_samples) {
6817
- return std::max(0, std::min((int) n_samples - 1, (int) ((t*WHISPER_SAMPLE_RATE)/100)));
6818
- }
6819
-
6820
8333
  static int64_t sample_to_timestamp(int i_sample) {
6821
8334
  return (100ll*i_sample)/WHISPER_SAMPLE_RATE;
6822
8335
  }
@@ -6866,6 +8379,18 @@ static std::vector<float> get_signal_energy(const float * signal, int n_samples,
6866
8379
  return result;
6867
8380
  }
6868
8381
 
8382
+ static int timestamp_to_sample(int64_t t, int64_t segment_t0, int n_samples) {
8383
+ // Convert absolute timestamp to segment-relative timestamp
8384
+ int64_t relative_t = t - segment_t0;
8385
+ int sample = (int)((relative_t * WHISPER_SAMPLE_RATE) / 100);
8386
+ return std::max(0, std::min(n_samples - 1, sample));
8387
+ }
8388
+
8389
+ static int64_t sample_to_timestamp(int i_sample, int64_t segment_t0) {
8390
+ int64_t relative_timestamp = (100ll * i_sample) / WHISPER_SAMPLE_RATE;
8391
+ return relative_timestamp + segment_t0;
8392
+ }
8393
+
6869
8394
  static void whisper_exp_compute_token_level_timestamps(
6870
8395
  struct whisper_context & ctx,
6871
8396
  struct whisper_state & state,
@@ -6921,12 +8446,6 @@ static void whisper_exp_compute_token_level_timestamps(
6921
8446
 
6922
8447
  const int64_t tt = t_beg + 2*(token.tid - whisper_token_beg(&ctx));
6923
8448
 
6924
- tokens[j].id = token.id;
6925
- tokens[j].tid = token.tid;
6926
- tokens[j].p = token.p;
6927
- tokens[j].pt = token.pt;
6928
- tokens[j].ptsum = token.ptsum;
6929
-
6930
8449
  tokens[j].vlen = voice_length(whisper_token_to_str(&ctx, token.id));
6931
8450
 
6932
8451
  if (token.pt > thold_pt && token.ptsum > thold_ptsum && token.tid > tid_last && tt <= t1) {
@@ -7012,8 +8531,8 @@ static void whisper_exp_compute_token_level_timestamps(
7012
8531
  continue;
7013
8532
  }
7014
8533
 
7015
- int s0 = timestamp_to_sample(tokens[j].t0, n_samples);
7016
- int s1 = timestamp_to_sample(tokens[j].t1, n_samples);
8534
+ int s0 = timestamp_to_sample(tokens[j].t0, segment.t0, n_samples);
8535
+ int s1 = timestamp_to_sample(tokens[j].t1, segment.t0, n_samples);
7017
8536
 
7018
8537
  const int ss0 = std::max(s0 - hw, 0);
7019
8538
  const int ss1 = std::min(s1 + hw, n_samples);
@@ -7034,7 +8553,7 @@ static void whisper_exp_compute_token_level_timestamps(
7034
8553
  while (k > 0 && state.energy[k] > thold) {
7035
8554
  k--;
7036
8555
  }
7037
- tokens[j].t0 = sample_to_timestamp(k);
8556
+ tokens[j].t0 = sample_to_timestamp(k, segment.t0);
7038
8557
  if (tokens[j].t0 < tokens[j - 1].t1) {
7039
8558
  tokens[j].t0 = tokens[j - 1].t1;
7040
8559
  } else {
@@ -7045,7 +8564,7 @@ static void whisper_exp_compute_token_level_timestamps(
7045
8564
  k++;
7046
8565
  }
7047
8566
  s0 = k;
7048
- tokens[j].t0 = sample_to_timestamp(k);
8567
+ tokens[j].t0 = sample_to_timestamp(k, segment.t0);
7049
8568
  }
7050
8569
  }
7051
8570
 
@@ -7055,7 +8574,7 @@ static void whisper_exp_compute_token_level_timestamps(
7055
8574
  while (k < n_samples - 1 && state.energy[k] > thold) {
7056
8575
  k++;
7057
8576
  }
7058
- tokens[j].t1 = sample_to_timestamp(k);
8577
+ tokens[j].t1 = sample_to_timestamp(k, segment.t0);
7059
8578
  if (j < n - 1 && tokens[j].t1 > tokens[j + 1].t0) {
7060
8579
  tokens[j].t1 = tokens[j + 1].t0;
7061
8580
  } else {
@@ -7066,7 +8585,7 @@ static void whisper_exp_compute_token_level_timestamps(
7066
8585
  k--;
7067
8586
  }
7068
8587
  s1 = k;
7069
- tokens[j].t1 = sample_to_timestamp(k);
8588
+ tokens[j].t1 = sample_to_timestamp(k, segment.t0);
7070
8589
  }
7071
8590
  }
7072
8591
  }
@@ -7137,18 +8656,18 @@ static wsp_ggml_tensor * dtw_and_backtrace(wsp_ggml_context * ctx, wsp_ggml_tens
7137
8656
  struct wsp_ggml_tensor * cost = wsp_ggml_new_tensor_2d(ctx, WSP_GGML_TYPE_F32, N + 1, M + 1);
7138
8657
  struct wsp_ggml_tensor * trace = wsp_ggml_new_tensor_2d(ctx, WSP_GGML_TYPE_I32, N + 1, M + 1);
7139
8658
 
7140
- cost = wsp_ggml_set_f32(cost, INFINITY);
7141
- trace = wsp_ggml_set_f32(trace, -1);
7142
- wsp_ggml_set_f32_nd(cost, 0, 0, 0, 0, 0.0);
8659
+ cost = whisper_set_f32(cost, INFINITY);
8660
+ trace = whisper_set_i32(trace, -1);
8661
+ whisper_set_f32_nd(cost, 0, 0, 0, 0, 0.0);
7143
8662
 
7144
8663
  // dtw
7145
8664
  // supposedly can be optmized by computing diagonals in parallel ?
7146
8665
  // Not sure it is worth it since x will be GENERATED_TOKENS*1500 size at most.
7147
8666
  for (int64_t j = 1; j < M + 1; ++j) {
7148
8667
  for (int64_t i = 1; i < N + 1; ++i) {
7149
- float c0 = wsp_ggml_get_f32_nd(cost, i - 1, j - 1, 0, 0);
7150
- float c1 = wsp_ggml_get_f32_nd(cost, i - 1, j, 0, 0);
7151
- float c2 = wsp_ggml_get_f32_nd(cost, i, j - 1, 0, 0);
8668
+ float c0 = whisper_get_f32_nd(cost, i - 1, j - 1, 0, 0);
8669
+ float c1 = whisper_get_f32_nd(cost, i - 1, j, 0, 0);
8670
+ float c2 = whisper_get_f32_nd(cost, i, j - 1, 0, 0);
7152
8671
 
7153
8672
  float c;
7154
8673
  int32_t t;
@@ -7163,9 +8682,9 @@ static wsp_ggml_tensor * dtw_and_backtrace(wsp_ggml_context * ctx, wsp_ggml_tens
7163
8682
  t = 2;
7164
8683
  }
7165
8684
 
7166
- c = wsp_ggml_get_f32_nd(x, i - 1, j - 1, 0, 0) + c;
7167
- wsp_ggml_set_f32_nd(cost, i, j, 0, 0, c);
7168
- wsp_ggml_set_i32_nd(trace, i, j, 0, 0, t);
8685
+ c = whisper_get_f32_nd(x, i - 1, j - 1, 0, 0) + c;
8686
+ whisper_set_f32_nd(cost, i, j, 0, 0, c);
8687
+ whisper_set_i32_nd(trace, i, j, 0, 0, t);
7169
8688
  }
7170
8689
  }
7171
8690
 
@@ -7174,19 +8693,19 @@ static wsp_ggml_tensor * dtw_and_backtrace(wsp_ggml_context * ctx, wsp_ggml_tens
7174
8693
  struct wsp_ggml_tensor * bt = wsp_ggml_new_tensor_2d(ctx, WSP_GGML_TYPE_I32, BT_MAX_ROWS, 2);
7175
8694
  // trace[0, :] = 2;
7176
8695
  for (int64_t i = 0; i < M + 1; ++i)
7177
- wsp_ggml_set_i32_nd(trace, 0, i, 0, 0, 2);
8696
+ whisper_set_i32_nd(trace, 0, i, 0, 0, 2);
7178
8697
  //trace[:, 0] = 1;
7179
8698
  for (int64_t i = 0; i < N + 1; ++i)
7180
- wsp_ggml_set_i32_nd(trace, i, 0, 0, 0, 1);
8699
+ whisper_set_i32_nd(trace, i, 0, 0, 0, 1);
7181
8700
  int bt_row_idx = BT_MAX_ROWS - 1;
7182
8701
  int64_t i = N;
7183
8702
  int64_t j = M;
7184
8703
  while (i > 0 || j > 0) {
7185
- wsp_ggml_set_i32_nd(bt, bt_row_idx, 0, 0, 0, i - 1);
7186
- wsp_ggml_set_i32_nd(bt, bt_row_idx, 1, 0, 0, j - 1);
8704
+ whisper_set_i32_nd(bt, bt_row_idx, 0, 0, 0, i - 1);
8705
+ whisper_set_i32_nd(bt, bt_row_idx, 1, 0, 0, j - 1);
7187
8706
  --bt_row_idx;
7188
8707
 
7189
- int32_t t = wsp_ggml_get_i32_nd(trace, i, j, 0, 0);
8708
+ int32_t t = whisper_get_i32_nd(trace, i, j, 0, 0);
7190
8709
  if (t == 0) {
7191
8710
  --i;
7192
8711
  --j;
@@ -7207,8 +8726,8 @@ static wsp_ggml_tensor * dtw_and_backtrace(wsp_ggml_context * ctx, wsp_ggml_tens
7207
8726
  wsp_ggml_tensor * r = wsp_ggml_new_tensor_2d(ctx, WSP_GGML_TYPE_I32, 2, result_n_cols);
7208
8727
  for (int64_t i = 0; i < 2; ++i) {
7209
8728
  for (int64_t j = 0; j < result_n_cols; ++j) {
7210
- int32_t v = wsp_ggml_get_i32_nd(bt, j+bt_row_idx+1, i, 0, 0);
7211
- wsp_ggml_set_i32_nd(r, i, j, 0, 0, v);
8729
+ int32_t v = whisper_get_i32_nd(bt, j+bt_row_idx+1, i, 0, 0);
8730
+ whisper_set_i32_nd(r, i, j, 0, 0, v);
7212
8731
  }
7213
8732
  }
7214
8733
 
@@ -7243,11 +8762,11 @@ static void median_filter(struct wsp_ggml_tensor * dst , const struct wsp_ggml_t
7243
8762
  idx = 2*(a->ne[2] - 1) - idx;
7244
8763
  }
7245
8764
 
7246
- filter.push_back(wsp_ggml_get_f32_nd(a, i, j, idx, 0));
8765
+ filter.push_back(whisper_get_f32_nd(a, i, j, idx, 0));
7247
8766
  }
7248
8767
  std::sort(filter.begin(), filter.end());
7249
8768
  const float v = filter[filter.size()/2];
7250
- wsp_ggml_set_f32_nd(dst, i, j, k, 0, v);
8769
+ whisper_set_f32_nd(dst, i, j, k, 0, v);
7251
8770
  filter.clear();
7252
8771
  }
7253
8772
  }
@@ -7369,7 +8888,9 @@ static void whisper_exp_compute_token_level_timestamps_dtw(
7369
8888
  // Compute
7370
8889
  struct wsp_ggml_cgraph * gf = wsp_ggml_new_graph(gctx);
7371
8890
  wsp_ggml_build_forward_expand(gf, w);
7372
- wsp_ggml_graph_compute_with_ctx(gctx, gf, n_threads);
8891
+
8892
+ wsp_ggml_backend_ptr backend { wsp_ggml_backend_init_by_type(WSP_GGML_BACKEND_DEVICE_TYPE_CPU, nullptr) };
8893
+ wsp_ggml_backend_graph_compute(backend.get(), gf);
7373
8894
 
7374
8895
  wsp_ggml_tensor * alignment = dtw_and_backtrace(gctx, w);
7375
8896
 
@@ -7378,9 +8899,9 @@ static void whisper_exp_compute_token_level_timestamps_dtw(
7378
8899
  auto seg_i = state->result_all.begin() + i_segment;
7379
8900
  auto tok_i = seg_i->tokens.begin();
7380
8901
  for (int i = 0; i < alignment->ne[1]; ++i) {
7381
- int32_t v = wsp_ggml_get_i32_nd(alignment, 0, i, 0, 0);
8902
+ int32_t v = whisper_get_i32_nd(alignment, 0, i, 0, 0);
7382
8903
  if (v != last_v) {
7383
- int32_t time_index = wsp_ggml_get_i32_nd(alignment, 1, i, 0, 0);
8904
+ int32_t time_index = whisper_get_i32_nd(alignment, 1, i, 0, 0);
7384
8905
  int64_t timestamp = (time_index * 2) + seek; // Each index on DTW result = 20mS audio
7385
8906
  last_v = v;
7386
8907
 
@@ -7418,6 +8939,7 @@ static void whisper_exp_compute_token_level_timestamps_dtw(
7418
8939
  void whisper_log_set(wsp_ggml_log_callback log_callback, void * user_data) {
7419
8940
  g_state.log_callback = log_callback ? log_callback : whisper_log_callback_default;
7420
8941
  g_state.log_callback_user_data = user_data;
8942
+ wsp_ggml_log_set(g_state.log_callback, g_state.log_callback_user_data);
7421
8943
  }
7422
8944
 
7423
8945
  WSP_GGML_ATTRIBUTE_FORMAT(2, 3)
@@ -7441,6 +8963,11 @@ static void whisper_log_internal(wsp_ggml_log_level level, const char * format,
7441
8963
  static void whisper_log_callback_default(wsp_ggml_log_level level, const char * text, void * user_data) {
7442
8964
  (void) level;
7443
8965
  (void) user_data;
8966
+ #ifndef WHISPER_DEBUG
8967
+ if (level == WSP_GGML_LOG_LEVEL_DEBUG) {
8968
+ return;
8969
+ }
8970
+ #endif
7444
8971
  fputs(text, stderr);
7445
8972
  fflush(stderr);
7446
8973
  }