whisper.rn 0.5.0 → 0.5.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (99) hide show
  1. package/android/build.gradle +2 -1
  2. package/android/gradle.properties +1 -1
  3. package/cpp/ggml-alloc.c +264 -126
  4. package/cpp/ggml-backend-impl.h +4 -1
  5. package/cpp/ggml-backend-reg.cpp +13 -5
  6. package/cpp/ggml-backend.cpp +207 -17
  7. package/cpp/ggml-backend.h +17 -1
  8. package/cpp/ggml-cpu/amx/amx.cpp +4 -2
  9. package/cpp/ggml-cpu/arch/x86/repack.cpp +2 -2
  10. package/cpp/ggml-cpu/arch-fallback.h +0 -4
  11. package/cpp/ggml-cpu/common.h +14 -0
  12. package/cpp/ggml-cpu/ggml-cpu-impl.h +13 -6
  13. package/cpp/ggml-cpu/ggml-cpu.c +48 -41
  14. package/cpp/ggml-cpu/ggml-cpu.cpp +14 -4
  15. package/cpp/ggml-cpu/ops.cpp +518 -767
  16. package/cpp/ggml-cpu/ops.h +2 -0
  17. package/cpp/ggml-cpu/simd-mappings.h +88 -59
  18. package/cpp/ggml-cpu/vec.cpp +161 -20
  19. package/cpp/ggml-cpu/vec.h +400 -51
  20. package/cpp/ggml-cpu.h +1 -1
  21. package/cpp/ggml-impl.h +43 -10
  22. package/cpp/ggml-metal/ggml-metal-common.cpp +446 -0
  23. package/cpp/ggml-metal/ggml-metal-common.h +52 -0
  24. package/cpp/ggml-metal/ggml-metal-context.h +33 -0
  25. package/cpp/ggml-metal/ggml-metal-context.m +600 -0
  26. package/cpp/ggml-metal/ggml-metal-device.cpp +1376 -0
  27. package/cpp/ggml-metal/ggml-metal-device.h +226 -0
  28. package/cpp/ggml-metal/ggml-metal-device.m +1312 -0
  29. package/cpp/ggml-metal/ggml-metal-impl.h +722 -0
  30. package/cpp/ggml-metal/ggml-metal-ops.cpp +3158 -0
  31. package/cpp/ggml-metal/ggml-metal-ops.h +82 -0
  32. package/cpp/ggml-metal/ggml-metal.cpp +718 -0
  33. package/cpp/ggml-metal/ggml-whisper-sim.metallib +0 -0
  34. package/cpp/ggml-metal/ggml-whisper.metallib +0 -0
  35. package/cpp/ggml-metal-impl.h +40 -40
  36. package/cpp/ggml-metal.h +1 -6
  37. package/cpp/ggml-quants.c +1 -0
  38. package/cpp/ggml.c +175 -13
  39. package/cpp/ggml.h +84 -5
  40. package/cpp/jsi/RNWhisperJSI.cpp +2 -0
  41. package/cpp/jsi/ThreadPool.h +3 -3
  42. package/cpp/whisper.cpp +85 -70
  43. package/cpp/whisper.h +1 -0
  44. package/ios/CMakeLists.txt +6 -1
  45. package/ios/RNWhisperVadContext.mm +14 -13
  46. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +4 -1
  47. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend.h +17 -1
  48. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-cpu.h +1 -1
  49. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-impl.h +43 -10
  50. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-metal-impl.h +40 -40
  51. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-metal.h +1 -6
  52. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml.h +84 -5
  53. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/whisper.h +1 -0
  54. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Info.plist +0 -0
  55. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
  56. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/rnwhisper +0 -0
  57. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +4 -1
  58. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +17 -1
  59. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpu.h +1 -1
  60. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +43 -10
  61. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal-impl.h +40 -40
  62. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal.h +1 -6
  63. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +84 -5
  64. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/whisper.h +1 -0
  65. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
  66. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +1 -1
  67. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
  68. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
  69. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +4 -1
  70. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend.h +17 -1
  71. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-cpu.h +1 -1
  72. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-impl.h +43 -10
  73. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-metal-impl.h +40 -40
  74. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-metal.h +1 -6
  75. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml.h +84 -5
  76. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/whisper.h +1 -0
  77. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Info.plist +0 -0
  78. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
  79. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/rnwhisper +0 -0
  80. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +4 -1
  81. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +17 -1
  82. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpu.h +1 -1
  83. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +43 -10
  84. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal-impl.h +40 -40
  85. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal.h +1 -6
  86. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +84 -5
  87. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/whisper.h +1 -0
  88. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
  89. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +1 -1
  90. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
  91. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
  92. package/lib/commonjs/version.json +1 -1
  93. package/lib/module/version.json +1 -1
  94. package/package.json +1 -1
  95. package/src/version.json +1 -1
  96. package/whisper-rn.podspec +8 -9
  97. package/cpp/ggml-metal.m +0 -6779
  98. package/cpp/ggml-whisper-sim.metallib +0 -0
  99. package/cpp/ggml-whisper.metallib +0 -0
package/cpp/ggml.h CHANGED
@@ -244,6 +244,13 @@
244
244
  #define WSP_GGML_MROPE_SECTIONS 4
245
245
 
246
246
  #define WSP_GGML_UNUSED(x) (void)(x)
247
+ #ifdef __CUDACC__
248
+ template<typename... Args>
249
+ __host__ __device__ constexpr inline void wsp_ggml_unused_vars_impl(Args&&...) noexcept {}
250
+ #define WSP_GGML_UNUSED_VARS(...) wsp_ggml_unused_vars_impl(__VA_ARGS__)
251
+ #else
252
+ #define WSP_GGML_UNUSED_VARS(...) do { (void)sizeof((__VA_ARGS__, 0)); } while(0)
253
+ #endif // __CUDACC__
247
254
 
248
255
  #define WSP_GGML_PAD(x, n) (((x) + (n) - 1) & ~((n) - 1))
249
256
 
@@ -277,19 +284,19 @@
277
284
  // WSP_GGML_TENSOR_LOCALS(size_t, nb1, src1, nb);
278
285
  //
279
286
  #define WSP_GGML_TENSOR_LOCALS_1(type, prefix, pointer, array) \
280
- const type prefix##0 = (pointer)->array[0]; \
287
+ const type prefix##0 = (pointer) ? (pointer)->array[0] : 0; \
281
288
  WSP_GGML_UNUSED(prefix##0);
282
289
  #define WSP_GGML_TENSOR_LOCALS_2(type, prefix, pointer, array) \
283
290
  WSP_GGML_TENSOR_LOCALS_1 (type, prefix, pointer, array) \
284
- const type prefix##1 = (pointer)->array[1]; \
291
+ const type prefix##1 = (pointer) ? (pointer)->array[1] : 0; \
285
292
  WSP_GGML_UNUSED(prefix##1);
286
293
  #define WSP_GGML_TENSOR_LOCALS_3(type, prefix, pointer, array) \
287
294
  WSP_GGML_TENSOR_LOCALS_2 (type, prefix, pointer, array) \
288
- const type prefix##2 = (pointer)->array[2]; \
295
+ const type prefix##2 = (pointer) ? (pointer)->array[2] : 0; \
289
296
  WSP_GGML_UNUSED(prefix##2);
290
297
  #define WSP_GGML_TENSOR_LOCALS(type, prefix, pointer, array) \
291
298
  WSP_GGML_TENSOR_LOCALS_3 (type, prefix, pointer, array) \
292
- const type prefix##3 = (pointer)->array[3]; \
299
+ const type prefix##3 = (pointer) ? (pointer)->array[3] : 0; \
293
300
  WSP_GGML_UNUSED(prefix##3);
294
301
 
295
302
  #define WSP_GGML_TENSOR_UNARY_OP_LOCALS \
@@ -504,7 +511,9 @@ extern "C" {
504
511
  WSP_GGML_OP_CONV_TRANSPOSE_1D,
505
512
  WSP_GGML_OP_IM2COL,
506
513
  WSP_GGML_OP_IM2COL_BACK,
514
+ WSP_GGML_OP_IM2COL_3D,
507
515
  WSP_GGML_OP_CONV_2D,
516
+ WSP_GGML_OP_CONV_3D,
508
517
  WSP_GGML_OP_CONV_2D_DW,
509
518
  WSP_GGML_OP_CONV_TRANSPOSE_2D,
510
519
  WSP_GGML_OP_POOL_1D,
@@ -1395,6 +1404,7 @@ extern "C" {
1395
1404
  struct wsp_ggml_tensor * a,
1396
1405
  struct wsp_ggml_tensor * b);
1397
1406
 
1407
+ // note: casting from f32 to i32 will discard the fractional part
1398
1408
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_cast(
1399
1409
  struct wsp_ggml_context * ctx,
1400
1410
  struct wsp_ggml_tensor * a,
@@ -1519,7 +1529,11 @@ extern "C" {
1519
1529
  struct wsp_ggml_context * ctx,
1520
1530
  struct wsp_ggml_tensor * a);
1521
1531
 
1522
- // supports 3D: a->ne[2] == b->ne[1]
1532
+ // supports 4D a:
1533
+ // a [n_embd, ne1, ne2, ne3]
1534
+ // b I32 [n_rows, ne2, ne3, 1]
1535
+ //
1536
+ // return [n_embd, n_rows, ne2, ne3]
1523
1537
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_get_rows(
1524
1538
  struct wsp_ggml_context * ctx,
1525
1539
  struct wsp_ggml_tensor * a, // data
@@ -1862,6 +1876,41 @@ extern "C" {
1862
1876
  int d0, // dilation dimension 0
1863
1877
  int d1); // dilation dimension 1
1864
1878
 
1879
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_im2col_3d(
1880
+ struct wsp_ggml_context * ctx,
1881
+ struct wsp_ggml_tensor * a,
1882
+ struct wsp_ggml_tensor * b,
1883
+ int64_t IC,
1884
+ int s0, // stride width
1885
+ int s1, // stride height
1886
+ int s2, // stride depth
1887
+ int p0, // padding width
1888
+ int p1, // padding height
1889
+ int p2, // padding depth
1890
+ int d0, // dilation width
1891
+ int d1, // dilation height
1892
+ int d2, // dilation depth
1893
+ enum wsp_ggml_type dst_type);
1894
+
1895
+ // a: [OC*IC, KD, KH, KW]
1896
+ // b: [N*IC, ID, IH, IW]
1897
+ // result: [N*OC, OD, OH, OW]
1898
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_conv_3d(
1899
+ struct wsp_ggml_context * ctx,
1900
+ struct wsp_ggml_tensor * a,
1901
+ struct wsp_ggml_tensor * b,
1902
+ int64_t IC,
1903
+ int s0, // stride width
1904
+ int s1, // stride height
1905
+ int s2, // stride depth
1906
+ int p0, // padding width
1907
+ int p1, // padding height
1908
+ int p2, // padding depth
1909
+ int d0, // dilation width
1910
+ int d1, // dilation height
1911
+ int d2 // dilation depth
1912
+ );
1913
+
1865
1914
  // kernel size is a->ne[0] x a->ne[1]
1866
1915
  // stride is equal to kernel size
1867
1916
  // padding is zero
@@ -1933,6 +1982,23 @@ extern "C" {
1933
1982
  int d0, // dilation dimension 0
1934
1983
  int d1); // dilation dimension 1
1935
1984
 
1985
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_conv_3d_direct(
1986
+ struct wsp_ggml_context * ctx,
1987
+ struct wsp_ggml_tensor * a, // kernel [KW, KH, KD, IC * OC]
1988
+ struct wsp_ggml_tensor * b, // input [W, H, D, C * N]
1989
+ int s0, // stride
1990
+ int s1,
1991
+ int s2,
1992
+ int p0, // padding
1993
+ int p1,
1994
+ int p2,
1995
+ int d0, // dilation
1996
+ int d1,
1997
+ int d2,
1998
+ int n_channels,
1999
+ int n_batch,
2000
+ int n_channels_out);
2001
+
1936
2002
  enum wsp_ggml_op_pool {
1937
2003
  WSP_GGML_OP_POOL_MAX,
1938
2004
  WSP_GGML_OP_POOL_AVG,
@@ -2023,6 +2089,19 @@ extern "C" {
2023
2089
  int p2,
2024
2090
  int p3);
2025
2091
 
2092
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_pad_ext(
2093
+ struct wsp_ggml_context * ctx,
2094
+ struct wsp_ggml_tensor * a,
2095
+ int lp0,
2096
+ int rp0,
2097
+ int lp1,
2098
+ int rp1,
2099
+ int lp2,
2100
+ int rp2,
2101
+ int lp3,
2102
+ int rp3
2103
+ );
2104
+
2026
2105
  // pad each dimension with reflection: [a, b, c, d] -> [b, a, b, c, d, c]
2027
2106
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_pad_reflect_1d(
2028
2107
  struct wsp_ggml_context * ctx,
@@ -17,6 +17,8 @@ using namespace facebook::jsi;
17
17
 
18
18
  namespace rnwhisper_jsi {
19
19
 
20
+ using namespace facebook::jsi;
21
+
20
22
  // Consolidated logging function
21
23
  enum class LogLevel { LOG_DEBUG, LOG_INFO, LOG_ERROR };
22
24
 
@@ -18,7 +18,7 @@ public:
18
18
  ThreadPool(size_t);
19
19
  template<class F, class... Args>
20
20
  auto enqueue(F&& f, Args&&... args)
21
- -> std::future<typename std::result_of<F(Args...)>::type>;
21
+ -> std::future<std::invoke_result_t<F, Args...>>;
22
22
  ~ThreadPool();
23
23
  private:
24
24
  // need to keep track of threads so we can join them
@@ -63,9 +63,9 @@ inline ThreadPool::ThreadPool(size_t threads)
63
63
  // add new work item to the pool
64
64
  template<class F, class... Args>
65
65
  auto ThreadPool::enqueue(F&& f, Args&&... args)
66
- -> std::future<typename std::result_of<F(Args...)>::type>
66
+ -> std::future<std::invoke_result_t<F, Args...>>
67
67
  {
68
- using return_type = typename std::result_of<F(Args...)>::type;
68
+ using return_type = std::invoke_result_t<F, Args...>;
69
69
 
70
70
  auto task = std::make_shared< std::packaged_task<return_type()> >(
71
71
  std::bind(std::forward<F>(f), std::forward<Args>(args)...)
package/cpp/whisper.cpp CHANGED
@@ -21,14 +21,12 @@
21
21
  #define _USE_MATH_DEFINES
22
22
  #include <cmath>
23
23
  #include <climits>
24
- #include <codecvt>
25
24
  #include <cstdarg>
26
25
  #include <cstdio>
27
26
  #include <cstring>
28
27
  #include <fstream>
29
28
  #include <functional>
30
29
  #include <map>
31
- #include <mutex>
32
30
  #include <random>
33
31
  #include <regex>
34
32
  #include <set>
@@ -36,6 +34,10 @@
36
34
  #include <thread>
37
35
  #include <vector>
38
36
 
37
+ #ifdef _MSC_VER
38
+ #include <codecvt>
39
+ #endif
40
+
39
41
  #if defined(WHISPER_BIG_ENDIAN)
40
42
  template<typename T>
41
43
  static T byteswap(T value) {
@@ -138,6 +140,10 @@ static void whisper_log_callback_default(wsp_ggml_log_level level, const char *
138
140
  } while (0)
139
141
 
140
142
  #define WHISPER_MAX_DECODERS 8
143
+
144
+ // temperature below which we condition on past text history
145
+ static constexpr float WHISPER_HISTORY_CONDITIONING_TEMP_CUTOFF = 0.5f;
146
+
141
147
  #define WHISPER_MAX_NODES 4096
142
148
 
143
149
  static std::string format(const char * fmt, ...) {
@@ -252,45 +258,6 @@ static void whisper_set_i32_nd(struct wsp_ggml_tensor * t, int64_t i0, int64_t i
252
258
  *(int32_t *) data = v;
253
259
  }
254
260
 
255
- // faster matrix multiplications for tensors that do not have dimension 0 divisible by "pad"
256
- // the idea is to represent the original matrix multiplication:
257
- //
258
- // Z = X @ Y
259
- //
260
- // with the sum of two matrix multiplications:
261
- //
262
- // Z = (X_0 @ Y_0) + (X_1 @ Y_1)
263
- //
264
- // here X_0 and Y_0 are views of X and Y that have dimension 0 divisible by "pad"
265
- // and X_1 and Y_1 are the remaining views. X_1 and Y_1 end up being small matrices that can be processed with more
266
- // general-purpose kernels
267
- //
268
- static struct wsp_ggml_tensor * wsp_ggml_mul_mat_pad(struct wsp_ggml_context * ctx, struct wsp_ggml_tensor * x, struct wsp_ggml_tensor * y, int pad = 32) {
269
- // use padding only if dimension 0 is at least 8 times larger than the padding
270
- // else we won't get much benefit from the optimization
271
- const int n_pad_req = 8;
272
-
273
- if (x->ne[0] % pad == 0 || x->ne[0] / pad < n_pad_req) {
274
- return wsp_ggml_mul_mat(ctx, x, y);
275
- }
276
-
277
- struct wsp_ggml_tensor * x_0 = wsp_ggml_view_3d(ctx, x, (x->ne[0]/pad)*pad, x->ne[1], x->ne[2], x->nb[1], x->nb[2], 0);
278
- struct wsp_ggml_tensor * x_1 = wsp_ggml_view_3d(ctx, x, x->ne[0]%pad, x->ne[1], x->ne[2], x->nb[1], x->nb[2], x_0->ne[0]*x_0->nb[0]);
279
-
280
- struct wsp_ggml_tensor * y_0 = wsp_ggml_view_3d(ctx, y, (y->ne[0]/pad)*pad, y->ne[1], y->ne[2], y->nb[1], y->nb[2], 0);
281
- struct wsp_ggml_tensor * y_1 = wsp_ggml_view_3d(ctx, y, y->ne[0]%pad, y->ne[1], y->ne[2], y->nb[1], y->nb[2], y_0->ne[0]*y_0->nb[0]);
282
-
283
- return wsp_ggml_add(ctx,
284
- wsp_ggml_mul_mat(ctx, x_0, y_0),
285
- wsp_ggml_mul_mat(ctx, x_1, y_1));
286
- }
287
-
288
- // TODO: check if other platforms can benefit from this optimization
289
- // TODO: CUDA is currently broken - seems wsp_ggml_mul_mat does not handle views correctly
290
- #if defined(WSP_GGML_USE_METAL)
291
- #define wsp_ggml_mul_mat wsp_ggml_mul_mat_pad
292
- #endif
293
-
294
261
  // available whisper models
295
262
  enum e_model {
296
263
  MODEL_UNKNOWN,
@@ -919,7 +886,10 @@ struct whisper_state {
919
886
  std::vector<float> logits;
920
887
 
921
888
  std::vector<whisper_segment> result_all;
922
- std::vector<whisper_token> prompt_past;
889
+
890
+ // prompt history split into static prefix (prompt_past0) and dynamic rolling context (prompt_past1)
891
+ std::vector<whisper_token> prompt_past0; // static carried initial prompt (if enabled)
892
+ std::vector<whisper_token> prompt_past1; // dynamic context from decoded output
923
893
 
924
894
  int lang_id = 0; // english by default
925
895
 
@@ -3635,7 +3605,7 @@ struct whisper_context_params whisper_context_default_params() {
3635
3605
  struct whisper_context_params result = {
3636
3606
  /*.use_gpu =*/ true,
3637
3607
  /*.use_coreml =*/ false,
3638
- /*.flash_attn =*/ false,
3608
+ /*.flash_attn =*/ true,
3639
3609
  /*.gpu_device =*/ 0,
3640
3610
 
3641
3611
  /*.dtw_token_timestamps =*/ false,
@@ -4719,6 +4689,7 @@ static bool whisper_vad_init_context(whisper_vad_context * vctx) {
4719
4689
  wsp_ggml_set_name(vctx->c_state, "c_state");
4720
4690
 
4721
4691
  vctx->buffer = wsp_ggml_backend_alloc_ctx_tensors(ctx, vctx->backends[0]);
4692
+ wsp_ggml_free(ctx);
4722
4693
  if (!vctx->buffer) {
4723
4694
  WHISPER_LOG_ERROR("%s: failed to allocate memory for the VAD state\n", __func__);
4724
4695
  return false;
@@ -5463,6 +5434,9 @@ struct whisper_vad_segments * whisper_vad_segments_from_samples(
5463
5434
 
5464
5435
  void whisper_vad_free(whisper_vad_context * ctx) {
5465
5436
  if (ctx) {
5437
+ if (ctx->buffer) {
5438
+ wsp_ggml_backend_buffer_free(ctx->buffer);
5439
+ }
5466
5440
  for (wsp_ggml_context * context : ctx->model.ctxs) {
5467
5441
  wsp_ggml_free(context);
5468
5442
  }
@@ -5477,6 +5451,9 @@ void whisper_vad_free(whisper_vad_context * ctx) {
5477
5451
  wsp_ggml_backend_free(backend);
5478
5452
  }
5479
5453
 
5454
+ delete[] ctx->model.hparams.encoder_in_channels;
5455
+ delete[] ctx->model.hparams.encoder_out_channels;
5456
+ delete[] ctx->model.hparams.kernel_sizes;
5480
5457
 
5481
5458
  delete ctx;
5482
5459
  }
@@ -5956,9 +5933,10 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
5956
5933
 
5957
5934
  /* suppress_regex =*/ nullptr,
5958
5935
 
5959
- /*.initial_prompt =*/ nullptr,
5960
- /*.prompt_tokens =*/ nullptr,
5961
- /*.prompt_n_tokens =*/ 0,
5936
+ /*.initial_prompt =*/ nullptr,
5937
+ /*.carry_initial_prompt =*/ false,
5938
+ /*.prompt_tokens =*/ nullptr,
5939
+ /*.prompt_n_tokens =*/ 0,
5962
5940
 
5963
5941
  /*.language =*/ "en",
5964
5942
  /*.detect_language =*/ false,
@@ -6654,6 +6632,10 @@ static bool whisper_vad(
6654
6632
 
6655
6633
  whisper_vad_segments * vad_segments = whisper_vad_segments_from_samples(vctx, vad_params, samples, n_samples);
6656
6634
 
6635
+ if (!vad_segments) {
6636
+ return false;
6637
+ }
6638
+
6657
6639
  if (vad_segments->data.size() > 0) {
6658
6640
  state->has_vad_segments = true;
6659
6641
  ctx->state->vad_segments.clear();
@@ -6696,7 +6678,6 @@ static bool whisper_vad(
6696
6678
  } catch (const std::bad_alloc & /* e */) {
6697
6679
  WHISPER_LOG_ERROR("%s: failed to allocate memory for filtered samples\n", __func__);
6698
6680
  whisper_vad_free_segments(vad_segments);
6699
- whisper_vad_free(vctx);
6700
6681
  return false;
6701
6682
  }
6702
6683
 
@@ -6802,6 +6783,7 @@ static bool whisper_vad(
6802
6783
  __func__, n_samples, filtered_n_samples, 100.0f * (1.0f - (float)filtered_n_samples / n_samples));
6803
6784
  }
6804
6785
 
6786
+ whisper_vad_free_segments(vad_segments);
6805
6787
  return true;
6806
6788
  }
6807
6789
 
@@ -6910,17 +6892,22 @@ int whisper_full_with_state(
6910
6892
  decoder.rng = std::mt19937(j);
6911
6893
  }
6912
6894
 
6913
- // the accumulated text context so far
6914
- auto & prompt_past = state->prompt_past;
6895
+ // the accumulated text context split into static (prompt_past0) and dynamic (prompt_past1)
6896
+ auto & prompt_past0 = state->prompt_past0;
6897
+ auto & prompt_past1 = state->prompt_past1;
6915
6898
  if (params.no_context) {
6916
- prompt_past.clear();
6899
+ prompt_past0.clear();
6900
+ prompt_past1.clear();
6917
6901
  }
6918
6902
 
6903
+ // calculate the maximum context budget for prompt history
6904
+ const int max_prompt_ctx = std::min(params.n_max_text_ctx, whisper_n_text_ctx(ctx)/2);
6905
+
6919
6906
  // prepare prompt
6920
6907
  {
6921
6908
  std::vector<whisper_token> prompt_tokens;
6922
6909
 
6923
- // initial prompt
6910
+ // tokenize the initial prompt
6924
6911
  if (!params.prompt_tokens && params.initial_prompt) {
6925
6912
  prompt_tokens.resize(1024);
6926
6913
  int n_needed = whisper_tokenize(ctx, params.initial_prompt, prompt_tokens.data(), prompt_tokens.size());
@@ -6932,14 +6919,25 @@ int whisper_full_with_state(
6932
6919
  params.prompt_tokens = prompt_tokens.data();
6933
6920
  params.prompt_n_tokens = prompt_tokens.size();
6934
6921
  }
6935
-
6936
- // prepend the prompt tokens to the prompt_past
6937
6922
  if (params.prompt_tokens && params.prompt_n_tokens > 0) {
6938
- // parse tokens from the pointer
6939
- for (int i = 0; i < params.prompt_n_tokens; i++) {
6940
- prompt_past.push_back(params.prompt_tokens[i]);
6923
+ if (params.carry_initial_prompt) {
6924
+ if (prompt_past0.empty()) {
6925
+ const int max_tokens = std::max(1, max_prompt_ctx - 1);
6926
+
6927
+ if (params.prompt_n_tokens > max_tokens) {
6928
+ WHISPER_LOG_WARN("%s: initial prompt is too long (%d tokens), will use only the last %d tokens\n",
6929
+ __func__, params.prompt_n_tokens, max_tokens);
6930
+ }
6931
+
6932
+ const int n_tokens = std::min(params.prompt_n_tokens, max_tokens);
6933
+ prompt_past0.assign(params.prompt_tokens + (params.prompt_n_tokens - n_tokens), params.prompt_tokens + params.prompt_n_tokens);
6934
+ }
6935
+ } else {
6936
+ for (int i = 0; i < params.prompt_n_tokens; ++i) {
6937
+ prompt_past1.push_back(params.prompt_tokens[i]);
6938
+ }
6939
+ std::rotate(prompt_past1.begin(), prompt_past1.end() - params.prompt_n_tokens, prompt_past1.end());
6941
6940
  }
6942
- std::rotate(prompt_past.begin(), prompt_past.end() - params.prompt_n_tokens, prompt_past.end());
6943
6941
  }
6944
6942
  }
6945
6943
 
@@ -7025,7 +7023,8 @@ int whisper_full_with_state(
7025
7023
  // if there is a very short audio segment left to process, we remove any past prompt since it tends
7026
7024
  // to confuse the decoder and often make it repeat or hallucinate stuff
7027
7025
  if (seek > seek_start && seek + 500 >= seek_end) {
7028
- prompt_past.clear();
7026
+ prompt_past0.clear();
7027
+ prompt_past1.clear();
7029
7028
  }
7030
7029
 
7031
7030
  int best_decoder_id = 0;
@@ -7086,12 +7085,25 @@ int whisper_full_with_state(
7086
7085
  {
7087
7086
  prompt.clear();
7088
7087
 
7089
- // if we have already generated some text, use it as a prompt to condition the next generation
7090
- if (!prompt_past.empty() && t_cur < 0.5f && params.n_max_text_ctx > 0) {
7091
- int n_take = std::min(std::min(params.n_max_text_ctx, whisper_n_text_ctx(ctx)/2), int(prompt_past.size()));
7088
+ if (params.n_max_text_ctx > 0 && t_cur < WHISPER_HISTORY_CONDITIONING_TEMP_CUTOFF) {
7089
+ const bool can_take0 = params.carry_initial_prompt && !prompt_past0.empty();
7090
+ const bool can_take1 = !prompt_past1.empty();
7092
7091
 
7093
- prompt = { whisper_token_prev(ctx) };
7094
- prompt.insert(prompt.begin() + 1, prompt_past.end() - n_take, prompt_past.end());
7092
+ if (max_prompt_ctx > 0 && (can_take0 || can_take1)) {
7093
+ // Always start with previous token marker to connect continuity
7094
+ prompt.push_back(whisper_token_prev(ctx));
7095
+
7096
+ // Take static tokens (initial prompt) first
7097
+ int n_take0 = 0;
7098
+ if (can_take0) {
7099
+ n_take0 = prompt_past0.size();
7100
+ prompt.insert(prompt.end(), prompt_past0.end() - n_take0, prompt_past0.end());
7101
+ }
7102
+
7103
+ // Fill remaining budget with dynamic tokens (rolling context)
7104
+ const int n_take1 = std::min<int>(max_prompt_ctx - n_take0 - 1, prompt_past1.size());
7105
+ prompt.insert(prompt.end(), prompt_past1.end() - n_take1, prompt_past1.end());
7106
+ }
7095
7107
  }
7096
7108
 
7097
7109
  // init new transcription with sot, language (opt) and task tokens
@@ -7573,14 +7585,17 @@ int whisper_full_with_state(
7573
7585
 
7574
7586
  //WHISPER_LOG_DEBUG("prompt_init.size() = %d, prompt.size() = %d, result_len = %d, seek_delta = %d\n", prompt_init.size(), prompt.size(), result_len, seek_delta);
7575
7587
 
7576
- // update prompt_past
7577
- prompt_past.clear();
7578
- if (prompt.front() == whisper_token_prev(ctx)) {
7579
- prompt_past.insert(prompt_past.end(), prompt.begin() + 1, prompt.end() - prompt_init.size());
7588
+ // update prompt_past1
7589
+ prompt_past1.clear();
7590
+ if (!params.carry_initial_prompt && !prompt.empty() && prompt.front() == whisper_token_prev(ctx)) {
7591
+ prompt_past1.insert(prompt_past1.end(), prompt.begin() + 1, prompt.end() - prompt_init.size());
7580
7592
  }
7581
7593
 
7582
- for (int i = 0; i < result_len && !is_no_speech; ++i) {
7583
- prompt_past.push_back(tokens_cur[i].id);
7594
+ // Add newly decoded tokens to the rolling context
7595
+ if (!is_no_speech) {
7596
+ for (int i = 0; i < result_len; ++i) {
7597
+ prompt_past1.push_back(tokens_cur[i].id);
7598
+ }
7584
7599
  }
7585
7600
 
7586
7601
  if (!tokens_cur.empty() && ctx->model.n_loaded > 0 && !is_no_speech) {
@@ -8952,7 +8967,7 @@ void whisper_log_set(wsp_ggml_log_callback log_callback, void * user_data) {
8952
8967
  }
8953
8968
 
8954
8969
  const char * whisper_version(void) {
8955
- return "1.7.6";
8970
+ return "1.8.0";
8956
8971
  }
8957
8972
 
8958
8973
  WSP_GGML_ATTRIBUTE_FORMAT(2, 3)
package/cpp/whisper.h CHANGED
@@ -526,6 +526,7 @@ extern "C" {
526
526
  // use whisper_tokenize() to convert text to tokens
527
527
  // maximum of whisper_n_text_ctx()/2 tokens are used (typically 224)
528
528
  const char * initial_prompt;
529
+ bool carry_initial_prompt; // if true, always prepend initial_prompt to every decode window (may reduce conditioning on previous text)
529
530
  const whisper_token * prompt_tokens;
530
531
  int prompt_n_tokens;
531
532
 
@@ -55,7 +55,12 @@ add_library(rnwhisper SHARED
55
55
  ${SOURCE_DIR}/ggml-cpu/binary-ops.cpp
56
56
  ${SOURCE_DIR}/ggml-cpu/vec.cpp
57
57
  ${SOURCE_DIR}/ggml-cpu/ops.cpp
58
- ${SOURCE_DIR}/ggml-metal.m
58
+ ${SOURCE_DIR}/ggml-metal/ggml-metal.cpp
59
+ ${SOURCE_DIR}/ggml-metal/ggml-metal-common.cpp
60
+ ${SOURCE_DIR}/ggml-metal/ggml-metal-device.cpp
61
+ ${SOURCE_DIR}/ggml-metal/ggml-metal-context.m
62
+ ${SOURCE_DIR}/ggml-metal/ggml-metal-device.m
63
+ ${SOURCE_DIR}/ggml-metal/ggml-metal-ops.cpp
59
64
  ${SOURCE_DIR}/ggml-opt.cpp
60
65
  ${SOURCE_DIR}/ggml-threading.cpp
61
66
  ${SOURCE_DIR}/ggml-quants.c
@@ -20,27 +20,28 @@
20
20
 
21
21
  #ifdef WSP_GGML_USE_METAL
22
22
  if (ctx_params.use_gpu) {
23
- ctx_params.gpu_device = 0;
23
+ // TODO: GPU VAD is forced disabled until the performance is improved (ref: whisper.cpp/whisper_vad_init_context)
24
+ ctx_params.use_gpu = false;
25
+ // ctx_params.gpu_device = 0;
24
26
 
25
- id<MTLDevice> device = MTLCreateSystemDefaultDevice();
27
+ // id<MTLDevice> device = MTLCreateSystemDefaultDevice();
26
28
 
27
- // Check ggml-metal availability
28
- BOOL supportsGgmlMetal = [device supportsFamily:MTLGPUFamilyApple7];
29
- if (@available(iOS 16.0, tvOS 16.0, *)) {
30
- supportsGgmlMetal = supportsGgmlMetal && [device supportsFamily:MTLGPUFamilyMetal3];
31
- }
32
- if (!supportsGgmlMetal) {
33
- ctx_params.use_gpu = false;
34
- reasonNoMetal = @"Metal is not supported in this device";
35
- }
29
+ // // Check ggml-metal availability
30
+ // BOOL supportsGgmlMetal = [device supportsFamily:MTLGPUFamilyApple7];
31
+ // if (@available(iOS 16.0, tvOS 16.0, *)) {
32
+ // supportsGgmlMetal = supportsGgmlMetal && [device supportsFamily:MTLGPUFamilyMetal3];
33
+ // }
34
+ // if (!supportsGgmlMetal) {
35
+ // ctx_params.use_gpu = false;
36
+ // reasonNoMetal = @"Metal is not supported in this device";
37
+ // }
38
+ // device = nil;
36
39
 
37
40
  #if TARGET_OS_SIMULATOR
38
41
  // Use the backend, but no layers because not supported fully on simulator
39
42
  ctx_params.use_gpu = false;
40
43
  reasonNoMetal = @"Metal is not supported in simulator";
41
44
  #endif
42
-
43
- device = nil;
44
45
  }
45
46
  #endif // WSP_GGML_USE_METAL
46
47
 
@@ -8,7 +8,7 @@
8
8
  extern "C" {
9
9
  #endif
10
10
 
11
- #define WSP_GGML_BACKEND_API_VERSION 1
11
+ #define WSP_GGML_BACKEND_API_VERSION 2
12
12
 
13
13
  //
14
14
  // Backend buffer type
@@ -114,6 +114,9 @@ extern "C" {
114
114
  void (*event_record)(wsp_ggml_backend_t backend, wsp_ggml_backend_event_t event);
115
115
  // wait for an event on on a different stream
116
116
  void (*event_wait) (wsp_ggml_backend_t backend, wsp_ggml_backend_event_t event);
117
+
118
+ // (optional) sort/optimize the nodes in the graph
119
+ void (*graph_optimize) (wsp_ggml_backend_t backend, struct wsp_ggml_cgraph * cgraph);
117
120
  };
118
121
 
119
122
  struct wsp_ggml_backend {
@@ -132,6 +132,8 @@ extern "C" {
132
132
  WSP_GGML_BACKEND_DEVICE_TYPE_CPU,
133
133
  // GPU device using dedicated memory
134
134
  WSP_GGML_BACKEND_DEVICE_TYPE_GPU,
135
+ // integrated GPU device using host memory
136
+ WSP_GGML_BACKEND_DEVICE_TYPE_IGPU,
135
137
  // accelerator devices intended to be used together with the CPU backend (e.g. BLAS or AMX)
136
138
  WSP_GGML_BACKEND_DEVICE_TYPE_ACCEL
137
139
  };
@@ -150,11 +152,21 @@ extern "C" {
150
152
 
151
153
  // all the device properties
152
154
  struct wsp_ggml_backend_dev_props {
155
+ // device name
153
156
  const char * name;
157
+ // device description
154
158
  const char * description;
159
+ // device free memory in bytes
155
160
  size_t memory_free;
161
+ // device total memory in bytes
156
162
  size_t memory_total;
163
+ // device type
157
164
  enum wsp_ggml_backend_dev_type type;
165
+ // device id
166
+ // for PCI devices, this should be the PCI bus id formatted as "domain:bus:device.function" (e.g. "0000:01:00.0")
167
+ // if the id is unknown, this should be NULL
168
+ const char * device_id;
169
+ // device capabilities
158
170
  struct wsp_ggml_backend_dev_caps caps;
159
171
  };
160
172
 
@@ -302,11 +314,15 @@ extern "C" {
302
314
  WSP_GGML_API int wsp_ggml_backend_sched_get_n_splits(wsp_ggml_backend_sched_t sched);
303
315
  WSP_GGML_API int wsp_ggml_backend_sched_get_n_copies(wsp_ggml_backend_sched_t sched);
304
316
 
305
- WSP_GGML_API size_t wsp_ggml_backend_sched_get_buffer_size(wsp_ggml_backend_sched_t sched, wsp_ggml_backend_t backend);
317
+ WSP_GGML_API wsp_ggml_backend_buffer_type_t wsp_ggml_backend_sched_get_buffer_type(wsp_ggml_backend_sched_t sched, wsp_ggml_backend_t backend);
318
+ WSP_GGML_API size_t wsp_ggml_backend_sched_get_buffer_size(wsp_ggml_backend_sched_t sched, wsp_ggml_backend_t backend);
306
319
 
307
320
  WSP_GGML_API void wsp_ggml_backend_sched_set_tensor_backend(wsp_ggml_backend_sched_t sched, struct wsp_ggml_tensor * node, wsp_ggml_backend_t backend);
308
321
  WSP_GGML_API wsp_ggml_backend_t wsp_ggml_backend_sched_get_tensor_backend(wsp_ggml_backend_sched_t sched, struct wsp_ggml_tensor * node);
309
322
 
323
+ // Split graph without allocating it
324
+ WSP_GGML_API void wsp_ggml_backend_sched_split_graph(wsp_ggml_backend_sched_t sched, struct wsp_ggml_cgraph * graph);
325
+
310
326
  // Allocate and compute graph on the backend scheduler
311
327
  WSP_GGML_API bool wsp_ggml_backend_sched_alloc_graph(wsp_ggml_backend_sched_t sched, struct wsp_ggml_cgraph * graph); // returns success
312
328
  WSP_GGML_API enum wsp_ggml_status wsp_ggml_backend_sched_graph_compute(wsp_ggml_backend_sched_t sched, struct wsp_ggml_cgraph * graph);
@@ -101,7 +101,6 @@ extern "C" {
101
101
  WSP_GGML_BACKEND_API int wsp_ggml_cpu_has_riscv_v (void);
102
102
  WSP_GGML_BACKEND_API int wsp_ggml_cpu_has_vsx (void);
103
103
  WSP_GGML_BACKEND_API int wsp_ggml_cpu_has_vxe (void);
104
- WSP_GGML_BACKEND_API int wsp_ggml_cpu_has_nnpa (void);
105
104
  WSP_GGML_BACKEND_API int wsp_ggml_cpu_has_wasm_simd (void);
106
105
  WSP_GGML_BACKEND_API int wsp_ggml_cpu_has_llamafile (void);
107
106
 
@@ -135,6 +134,7 @@ extern "C" {
135
134
  WSP_GGML_BACKEND_API wsp_ggml_backend_reg_t wsp_ggml_backend_cpu_reg(void);
136
135
 
137
136
  WSP_GGML_BACKEND_API void wsp_ggml_cpu_fp32_to_fp32(const float *, float *, int64_t);
137
+ WSP_GGML_BACKEND_API void wsp_ggml_cpu_fp32_to_i32 (const float *, int32_t *, int64_t);
138
138
  WSP_GGML_BACKEND_API void wsp_ggml_cpu_fp32_to_fp16(const float *, wsp_ggml_fp16_t *, int64_t);
139
139
  WSP_GGML_BACKEND_API void wsp_ggml_cpu_fp16_to_fp32(const wsp_ggml_fp16_t *, float *, int64_t);
140
140
  WSP_GGML_BACKEND_API void wsp_ggml_cpu_fp32_to_bf16(const float *, wsp_ggml_bf16_t *, int64_t);