whisper.rn 0.4.0-rc.4 → 0.4.0-rc.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. package/README.md +6 -6
  2. package/android/build.gradle +4 -0
  3. package/android/src/main/CMakeLists.txt +5 -0
  4. package/android/src/main/java/com/rnwhisper/AudioUtils.java +0 -80
  5. package/android/src/main/java/com/rnwhisper/WhisperContext.java +57 -134
  6. package/android/src/main/jni-utils.h +76 -0
  7. package/android/src/main/jni.cpp +188 -112
  8. package/cpp/README.md +1 -1
  9. package/cpp/coreml/whisper-encoder-impl.h +1 -1
  10. package/cpp/coreml/whisper-encoder.h +4 -0
  11. package/cpp/coreml/whisper-encoder.mm +4 -2
  12. package/cpp/ggml-alloc.c +55 -19
  13. package/cpp/ggml-alloc.h +8 -1
  14. package/cpp/ggml-backend-impl.h +46 -21
  15. package/cpp/ggml-backend.c +563 -156
  16. package/cpp/ggml-backend.h +62 -17
  17. package/cpp/ggml-impl.h +1 -1
  18. package/cpp/ggml-metal-whisper.metal +2444 -359
  19. package/cpp/ggml-metal.h +7 -1
  20. package/cpp/ggml-metal.m +1105 -197
  21. package/cpp/ggml-quants.c +66 -61
  22. package/cpp/ggml-quants.h +40 -40
  23. package/cpp/ggml.c +1040 -1590
  24. package/cpp/ggml.h +109 -30
  25. package/cpp/rn-audioutils.cpp +68 -0
  26. package/cpp/rn-audioutils.h +14 -0
  27. package/cpp/rn-whisper-log.h +11 -0
  28. package/cpp/rn-whisper.cpp +143 -59
  29. package/cpp/rn-whisper.h +48 -15
  30. package/cpp/whisper.cpp +1635 -928
  31. package/cpp/whisper.h +55 -10
  32. package/ios/RNWhisper.mm +7 -7
  33. package/ios/RNWhisperAudioUtils.h +0 -2
  34. package/ios/RNWhisperAudioUtils.m +0 -56
  35. package/ios/RNWhisperContext.h +3 -11
  36. package/ios/RNWhisperContext.mm +68 -137
  37. package/lib/commonjs/index.js.map +1 -1
  38. package/lib/commonjs/version.json +1 -1
  39. package/lib/module/index.js.map +1 -1
  40. package/lib/module/version.json +1 -1
  41. package/lib/typescript/index.d.ts +5 -0
  42. package/lib/typescript/index.d.ts.map +1 -1
  43. package/package.json +6 -5
  44. package/src/index.ts +5 -0
  45. package/src/version.json +1 -1
  46. package/ios/RNWhisper.xcodeproj/project.xcworkspace/contents.xcworkspacedata +0 -4
  47. package/ios/RNWhisper.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist +0 -8
  48. package/ios/RNWhisper.xcodeproj/project.xcworkspace/xcuserdata/jhen.xcuserdatad/UserInterfaceState.xcuserstate +0 -0
  49. package/ios/RNWhisper.xcodeproj/xcuserdata/jhen.xcuserdatad/xcschemes/xcschememanagement.plist +0 -19
package/cpp/ggml.h CHANGED
@@ -215,9 +215,9 @@
215
215
  #define WSP_GGML_QNT_VERSION_FACTOR 1000 // do not change this
216
216
 
217
217
  #define WSP_GGML_MAX_DIMS 4
218
- #define WSP_GGML_MAX_PARAMS 1024
218
+ #define WSP_GGML_MAX_PARAMS 2048
219
219
  #define WSP_GGML_MAX_CONTEXTS 64
220
- #define WSP_GGML_MAX_SRC 6
220
+ #define WSP_GGML_MAX_SRC 10
221
221
  #define WSP_GGML_MAX_NAME 64
222
222
  #define WSP_GGML_MAX_OP_PARAMS 64
223
223
  #define WSP_GGML_DEFAULT_N_THREADS 4
@@ -244,11 +244,10 @@
244
244
  #define WSP_GGML_ASSERT(x) \
245
245
  do { \
246
246
  if (!(x)) { \
247
- fprintf(stderr, "WSP_GGML_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \
248
- fflush(stderr); \
249
247
  fflush(stdout); \
248
+ fprintf(stderr, "WSP_GGML_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \
250
249
  wsp_ggml_print_backtrace(); \
251
- exit(1); \
250
+ abort(); \
252
251
  } \
253
252
  } while (0)
254
253
 
@@ -284,6 +283,20 @@
284
283
  const type prefix##3 = (pointer)->array[3]; \
285
284
  WSP_GGML_UNUSED(prefix##3);
286
285
 
286
+ #define WSP_GGML_TENSOR_UNARY_OP_LOCALS \
287
+ WSP_GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
288
+ WSP_GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \
289
+ WSP_GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \
290
+ WSP_GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
291
+
292
+ #define WSP_GGML_TENSOR_BINARY_OP_LOCALS \
293
+ WSP_GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
294
+ WSP_GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \
295
+ WSP_GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) \
296
+ WSP_GGML_TENSOR_LOCALS(size_t, nb1, src1, nb) \
297
+ WSP_GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \
298
+ WSP_GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
299
+
287
300
  #ifdef __cplusplus
288
301
  extern "C" {
289
302
  #endif
@@ -382,6 +395,7 @@ extern "C" {
382
395
  WSP_GGML_OP_GROUP_NORM,
383
396
 
384
397
  WSP_GGML_OP_MUL_MAT,
398
+ WSP_GGML_OP_MUL_MAT_ID,
385
399
  WSP_GGML_OP_OUT_PROD,
386
400
 
387
401
  WSP_GGML_OP_SCALE,
@@ -403,18 +417,15 @@ extern "C" {
403
417
  WSP_GGML_OP_ROPE_BACK,
404
418
  WSP_GGML_OP_ALIBI,
405
419
  WSP_GGML_OP_CLAMP,
406
- WSP_GGML_OP_CONV_1D,
407
- WSP_GGML_OP_CONV_1D_STAGE_0, // internal
408
- WSP_GGML_OP_CONV_1D_STAGE_1, // internal
409
420
  WSP_GGML_OP_CONV_TRANSPOSE_1D,
410
- WSP_GGML_OP_CONV_2D,
411
- WSP_GGML_OP_CONV_2D_STAGE_0, // internal
412
- WSP_GGML_OP_CONV_2D_STAGE_1, // internal
421
+ WSP_GGML_OP_IM2COL,
413
422
  WSP_GGML_OP_CONV_TRANSPOSE_2D,
414
423
  WSP_GGML_OP_POOL_1D,
415
424
  WSP_GGML_OP_POOL_2D,
416
-
417
425
  WSP_GGML_OP_UPSCALE, // nearest interpolate
426
+ WSP_GGML_OP_PAD,
427
+ WSP_GGML_OP_ARGSORT,
428
+ WSP_GGML_OP_LEAKY_RELU,
418
429
 
419
430
  WSP_GGML_OP_FLASH_ATTN,
420
431
  WSP_GGML_OP_FLASH_FF,
@@ -454,7 +465,8 @@ extern "C" {
454
465
  WSP_GGML_UNARY_OP_GELU,
455
466
  WSP_GGML_UNARY_OP_GELU_QUICK,
456
467
  WSP_GGML_UNARY_OP_SILU,
457
- WSP_GGML_UNARY_OP_LEAKY
468
+
469
+ WSP_GGML_UNARY_OP_COUNT,
458
470
  };
459
471
 
460
472
  enum wsp_ggml_object_type {
@@ -637,6 +649,9 @@ extern "C" {
637
649
  WSP_GGML_API const char * wsp_ggml_op_name (enum wsp_ggml_op op);
638
650
  WSP_GGML_API const char * wsp_ggml_op_symbol(enum wsp_ggml_op op);
639
651
 
652
+ WSP_GGML_API const char * wsp_ggml_unary_op_name(enum wsp_ggml_unary_op op);
653
+ WSP_GGML_API const char * wsp_ggml_op_desc(const struct wsp_ggml_tensor * t); // unary or op name
654
+
640
655
  WSP_GGML_API size_t wsp_ggml_element_size(const struct wsp_ggml_tensor * tensor);
641
656
 
642
657
  WSP_GGML_API bool wsp_ggml_is_quantized(enum wsp_ggml_type type);
@@ -779,6 +794,9 @@ extern "C" {
779
794
  struct wsp_ggml_tensor * a,
780
795
  struct wsp_ggml_tensor * b);
781
796
 
797
+ // dst = a
798
+ // view(dst, nb1, nb2, nb3, offset) += b
799
+ // return dst
782
800
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_acc(
783
801
  struct wsp_ggml_context * ctx,
784
802
  struct wsp_ggml_tensor * a,
@@ -943,15 +961,14 @@ extern "C" {
943
961
  struct wsp_ggml_context * ctx,
944
962
  struct wsp_ggml_tensor * a);
945
963
 
946
- WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_leaky(
964
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_leaky_relu(
947
965
  struct wsp_ggml_context * ctx,
948
- struct wsp_ggml_tensor * a);
966
+ struct wsp_ggml_tensor * a, float negative_slope, bool inplace);
949
967
 
950
968
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_relu_inplace(
951
969
  struct wsp_ggml_context * ctx,
952
970
  struct wsp_ggml_tensor * a);
953
971
 
954
- // TODO: double-check this computation is correct
955
972
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_gelu(
956
973
  struct wsp_ggml_context * ctx,
957
974
  struct wsp_ggml_tensor * a);
@@ -1033,6 +1050,16 @@ extern "C" {
1033
1050
  struct wsp_ggml_tensor * a,
1034
1051
  struct wsp_ggml_tensor * b);
1035
1052
 
1053
+ // indirect matrix multiplication
1054
+ // wsp_ggml_mul_mat_id(ctx, as, ids, id, b) ~= wsp_ggml_mul_mat(as[ids[id]], b)
1055
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_mul_mat_id(
1056
+ struct wsp_ggml_context * ctx,
1057
+ struct wsp_ggml_tensor * const as[],
1058
+ int n_as,
1059
+ struct wsp_ggml_tensor * ids,
1060
+ int id,
1061
+ struct wsp_ggml_tensor * b);
1062
+
1036
1063
  // A: m columns, n rows,
1037
1064
  // B: p columns, n rows,
1038
1065
  // result is m columns, p rows
@@ -1240,6 +1267,7 @@ extern "C" {
1240
1267
  struct wsp_ggml_context * ctx,
1241
1268
  struct wsp_ggml_tensor * a);
1242
1269
 
1270
+ // supports 3D: a->ne[2] == b->ne[1]
1243
1271
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_get_rows(
1244
1272
  struct wsp_ggml_context * ctx,
1245
1273
  struct wsp_ggml_tensor * a,
@@ -1288,6 +1316,14 @@ extern "C" {
1288
1316
  struct wsp_ggml_context * ctx,
1289
1317
  struct wsp_ggml_tensor * a);
1290
1318
 
1319
+ // fused soft_max(a*scale + mask)
1320
+ // mask is optional
1321
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_soft_max_ext(
1322
+ struct wsp_ggml_context * ctx,
1323
+ struct wsp_ggml_tensor * a,
1324
+ struct wsp_ggml_tensor * mask,
1325
+ float scale);
1326
+
1291
1327
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_soft_max_back(
1292
1328
  struct wsp_ggml_context * ctx,
1293
1329
  struct wsp_ggml_tensor * a,
@@ -1376,8 +1412,13 @@ extern "C" {
1376
1412
  int n_dims,
1377
1413
  int mode,
1378
1414
  int n_ctx,
1415
+ int n_orig_ctx,
1379
1416
  float freq_base,
1380
1417
  float freq_scale,
1418
+ float ext_factor,
1419
+ float attn_factor,
1420
+ float beta_fast,
1421
+ float beta_slow,
1381
1422
  float xpos_base,
1382
1423
  bool xpos_down);
1383
1424
 
@@ -1398,6 +1439,18 @@ extern "C" {
1398
1439
  float min,
1399
1440
  float max);
1400
1441
 
1442
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_im2col(
1443
+ struct wsp_ggml_context * ctx,
1444
+ struct wsp_ggml_tensor * a,
1445
+ struct wsp_ggml_tensor * b,
1446
+ int s0,
1447
+ int s1,
1448
+ int p0,
1449
+ int p1,
1450
+ int d0,
1451
+ int d1,
1452
+ bool is_2D);
1453
+
1401
1454
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_conv_1d(
1402
1455
  struct wsp_ggml_context * ctx,
1403
1456
  struct wsp_ggml_tensor * a,
@@ -1501,6 +1554,32 @@ extern "C" {
1501
1554
  struct wsp_ggml_tensor * a,
1502
1555
  int scale_factor);
1503
1556
 
1557
+ // pad each dimension with zeros: [x, ..., x] -> [x, ..., x, 0, ..., 0]
1558
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_pad(
1559
+ struct wsp_ggml_context * ctx,
1560
+ struct wsp_ggml_tensor * a,
1561
+ int p0,
1562
+ int p1,
1563
+ int p2,
1564
+ int p3);
1565
+
1566
+ // sort rows
1567
+ enum wsp_ggml_sort_order {
1568
+ WSP_GGML_SORT_ASC,
1569
+ WSP_GGML_SORT_DESC,
1570
+ };
1571
+
1572
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_argsort(
1573
+ struct wsp_ggml_context * ctx,
1574
+ struct wsp_ggml_tensor * a,
1575
+ enum wsp_ggml_sort_order order);
1576
+
1577
+ // top k elements per row
1578
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_top_k(
1579
+ struct wsp_ggml_context * ctx,
1580
+ struct wsp_ggml_tensor * a,
1581
+ int k);
1582
+
1504
1583
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_flash_attn(
1505
1584
  struct wsp_ggml_context * ctx,
1506
1585
  struct wsp_ggml_tensor * q,
@@ -1562,7 +1641,6 @@ extern "C" {
1562
1641
  int kh);
1563
1642
 
1564
1643
  // used in sam
1565
-
1566
1644
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_add_rel_pos(
1567
1645
  struct wsp_ggml_context * ctx,
1568
1646
  struct wsp_ggml_tensor * a,
@@ -1737,7 +1815,7 @@ extern "C" {
1737
1815
  WSP_GGML_API struct wsp_ggml_cgraph * wsp_ggml_new_graph (struct wsp_ggml_context * ctx); // size = WSP_GGML_DEFAULT_GRAPH_SIZE, grads = false
1738
1816
  WSP_GGML_API struct wsp_ggml_cgraph * wsp_ggml_new_graph_custom (struct wsp_ggml_context * ctx, size_t size, bool grads);
1739
1817
  WSP_GGML_API struct wsp_ggml_cgraph * wsp_ggml_graph_dup (struct wsp_ggml_context * ctx, struct wsp_ggml_cgraph * cgraph);
1740
- WSP_GGML_API struct wsp_ggml_cgraph * wsp_ggml_graph_view (struct wsp_ggml_context * ctx, struct wsp_ggml_cgraph * cgraph, int i0, int i1);
1818
+ WSP_GGML_API struct wsp_ggml_cgraph wsp_ggml_graph_view (struct wsp_ggml_cgraph * cgraph, int i0, int i1);
1741
1819
  WSP_GGML_API void wsp_ggml_graph_cpy (struct wsp_ggml_cgraph * src, struct wsp_ggml_cgraph * dst);
1742
1820
  WSP_GGML_API void wsp_ggml_graph_reset (struct wsp_ggml_cgraph * cgraph); // zero grads
1743
1821
  WSP_GGML_API void wsp_ggml_graph_clear (struct wsp_ggml_cgraph * cgraph);
@@ -1955,20 +2033,20 @@ extern "C" {
1955
2033
  // quantization
1956
2034
  //
1957
2035
 
1958
- // TODO: these would probably get removed in favor of the more general wsp_ggml_quantize_chunk
1959
- WSP_GGML_API size_t wsp_ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * hist);
1960
- WSP_GGML_API size_t wsp_ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist);
1961
- WSP_GGML_API size_t wsp_ggml_quantize_q5_0(const float * src, void * dst, int n, int k, int64_t * hist);
1962
- WSP_GGML_API size_t wsp_ggml_quantize_q5_1(const float * src, void * dst, int n, int k, int64_t * hist);
1963
- WSP_GGML_API size_t wsp_ggml_quantize_q8_0(const float * src, void * dst, int n, int k, int64_t * hist);
2036
+ // TODO: these would probably get removed in favor of the more general wsp_ggml_wsp_quantize_chunk
2037
+ WSP_GGML_API size_t wsp_ggml_wsp_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * hist);
2038
+ WSP_GGML_API size_t wsp_ggml_wsp_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist);
2039
+ WSP_GGML_API size_t wsp_ggml_wsp_quantize_q5_0(const float * src, void * dst, int n, int k, int64_t * hist);
2040
+ WSP_GGML_API size_t wsp_ggml_wsp_quantize_q5_1(const float * src, void * dst, int n, int k, int64_t * hist);
2041
+ WSP_GGML_API size_t wsp_ggml_wsp_quantize_q8_0(const float * src, void * dst, int n, int k, int64_t * hist);
1964
2042
 
1965
- WSP_GGML_API size_t wsp_ggml_quantize_q2_K(const float * src, void * dst, int n, int k, int64_t * hist);
1966
- WSP_GGML_API size_t wsp_ggml_quantize_q3_K(const float * src, void * dst, int n, int k, int64_t * hist);
1967
- WSP_GGML_API size_t wsp_ggml_quantize_q4_K(const float * src, void * dst, int n, int k, int64_t * hist);
1968
- WSP_GGML_API size_t wsp_ggml_quantize_q5_K(const float * src, void * dst, int n, int k, int64_t * hist);
1969
- WSP_GGML_API size_t wsp_ggml_quantize_q6_K(const float * src, void * dst, int n, int k, int64_t * hist);
2043
+ WSP_GGML_API size_t wsp_ggml_wsp_quantize_q2_K(const float * src, void * dst, int n, int k, int64_t * hist);
2044
+ WSP_GGML_API size_t wsp_ggml_wsp_quantize_q3_K(const float * src, void * dst, int n, int k, int64_t * hist);
2045
+ WSP_GGML_API size_t wsp_ggml_wsp_quantize_q4_K(const float * src, void * dst, int n, int k, int64_t * hist);
2046
+ WSP_GGML_API size_t wsp_ggml_wsp_quantize_q5_K(const float * src, void * dst, int n, int k, int64_t * hist);
2047
+ WSP_GGML_API size_t wsp_ggml_wsp_quantize_q6_K(const float * src, void * dst, int n, int k, int64_t * hist);
1970
2048
 
1971
- WSP_GGML_API size_t wsp_ggml_quantize_chunk(enum wsp_ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist);
2049
+ WSP_GGML_API size_t wsp_ggml_wsp_quantize_chunk(enum wsp_ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist);
1972
2050
 
1973
2051
  //
1974
2052
  // gguf
@@ -2033,6 +2111,7 @@ extern "C" {
2033
2111
  WSP_GGML_API double wsp_gguf_get_val_f64 (const struct wsp_gguf_context * ctx, int key_id);
2034
2112
  WSP_GGML_API bool wsp_gguf_get_val_bool(const struct wsp_gguf_context * ctx, int key_id);
2035
2113
  WSP_GGML_API const char * wsp_gguf_get_val_str (const struct wsp_gguf_context * ctx, int key_id);
2114
+ WSP_GGML_API const void * wsp_gguf_get_val_data(const struct wsp_gguf_context * ctx, int key_id);
2036
2115
  WSP_GGML_API int wsp_gguf_get_arr_n (const struct wsp_gguf_context * ctx, int key_id);
2037
2116
  WSP_GGML_API const void * wsp_gguf_get_arr_data(const struct wsp_gguf_context * ctx, int key_id);
2038
2117
  WSP_GGML_API const char * wsp_gguf_get_arr_str (const struct wsp_gguf_context * ctx, int key_id, int i);
@@ -0,0 +1,68 @@
1
+ #include "rn-audioutils.h"
2
+ #include "rn-whisper-log.h"
3
+
4
+ namespace rnaudioutils {
5
+
6
+ std::vector<uint8_t> concat_short_buffers(const std::vector<short*>& buffers, const std::vector<int>& slice_n_samples) {
7
+ std::vector<uint8_t> output_data;
8
+
9
+ for (size_t i = 0; i < buffers.size(); i++) {
10
+ int size = slice_n_samples[i]; // Number of shorts
11
+ short* slice = buffers[i];
12
+
13
+ // Copy each short as two bytes
14
+ for (int j = 0; j < size; j++) {
15
+ output_data.push_back(static_cast<uint8_t>(slice[j] & 0xFF)); // Lower byte
16
+ output_data.push_back(static_cast<uint8_t>((slice[j] >> 8) & 0xFF)); // Higher byte
17
+ }
18
+ }
19
+
20
+ return output_data;
21
+ }
22
+
23
+ std::vector<uint8_t> remove_trailing_zeros(const std::vector<uint8_t>& audio_data) {
24
+ auto last = std::find_if(audio_data.rbegin(), audio_data.rend(), [](uint8_t byte) { return byte != 0; });
25
+ return std::vector<uint8_t>(audio_data.begin(), last.base());
26
+ }
27
+
28
+ void save_wav_file(const std::vector<uint8_t>& raw, const std::string& file) {
29
+ std::vector<uint8_t> data = remove_trailing_zeros(raw);
30
+
31
+ std::ofstream output(file, std::ios::binary);
32
+
33
+ if (!output.is_open()) {
34
+ RNWHISPER_LOG_ERROR("Failed to open file for writing: %s\n", file.c_str());
35
+ return;
36
+ }
37
+
38
+ // WAVE header
39
+ output.write("RIFF", 4);
40
+ int32_t chunk_size = 36 + static_cast<int32_t>(data.size());
41
+ output.write(reinterpret_cast<char*>(&chunk_size), sizeof(chunk_size));
42
+ output.write("WAVE", 4);
43
+ output.write("fmt ", 4);
44
+ int32_t sub_chunk_size = 16;
45
+ output.write(reinterpret_cast<char*>(&sub_chunk_size), sizeof(sub_chunk_size));
46
+ short audio_format = 1;
47
+ output.write(reinterpret_cast<char*>(&audio_format), sizeof(audio_format));
48
+ short num_channels = 1;
49
+ output.write(reinterpret_cast<char*>(&num_channels), sizeof(num_channels));
50
+ int32_t sample_rate = WHISPER_SAMPLE_RATE;
51
+ output.write(reinterpret_cast<char*>(&sample_rate), sizeof(sample_rate));
52
+ int32_t byte_rate = WHISPER_SAMPLE_RATE * 2;
53
+ output.write(reinterpret_cast<char*>(&byte_rate), sizeof(byte_rate));
54
+ short block_align = 2;
55
+ output.write(reinterpret_cast<char*>(&block_align), sizeof(block_align));
56
+ short bits_per_sample = 16;
57
+ output.write(reinterpret_cast<char*>(&bits_per_sample), sizeof(bits_per_sample));
58
+ output.write("data", 4);
59
+ int32_t sub_chunk2_size = static_cast<int32_t>(data.size());
60
+ output.write(reinterpret_cast<char*>(&sub_chunk2_size), sizeof(sub_chunk2_size));
61
+ output.write(reinterpret_cast<const char*>(data.data()), data.size());
62
+
63
+ output.close();
64
+
65
+ RNWHISPER_LOG_INFO("Saved audio file: %s\n", file.c_str());
66
+ }
67
+
68
+ } // namespace rnaudioutils
@@ -0,0 +1,14 @@
1
+ #include <iostream>
2
+ #include <fstream>
3
+ #include <vector>
4
+ #include <cstdint>
5
+ #include <cstring>
6
+ #include <algorithm>
7
+ #include "whisper.h"
8
+
9
+ namespace rnaudioutils {
10
+
11
+ std::vector<uint8_t> concat_short_buffers(const std::vector<short*>& buffers, const std::vector<int>& slice_n_samples);
12
+ void save_wav_file(const std::vector<uint8_t>& raw, const std::string& file);
13
+
14
+ } // namespace rnaudioutils
@@ -0,0 +1,11 @@
1
+ #if defined(__ANDROID__) && defined(RNWHISPER_ANDROID_ENABLE_LOGGING)
2
+ #include <android/log.h>
3
+ #define RNWHISPER_ANDROID_TAG "RNWHISPER_LOG_ANDROID"
4
+ #define RNWHISPER_LOG_INFO(...) __android_log_print(ANDROID_LOG_INFO , RNWHISPER_ANDROID_TAG, __VA_ARGS__)
5
+ #define RNWHISPER_LOG_WARN(...) __android_log_print(ANDROID_LOG_WARN , RNWHISPER_ANDROID_TAG, __VA_ARGS__)
6
+ #define RNWHISPER_LOG_ERROR(...) __android_log_print(ANDROID_LOG_ERROR, RNWHISPER_ANDROID_TAG, __VA_ARGS__)
7
+ #else
8
+ #define RNWHISPER_LOG_INFO(...) fprintf(stderr, __VA_ARGS__)
9
+ #define RNWHISPER_LOG_WARN(...) fprintf(stderr, __VA_ARGS__)
10
+ #define RNWHISPER_LOG_ERROR(...) fprintf(stderr, __VA_ARGS__)
11
+ #endif // __ANDROID__
@@ -2,41 +2,11 @@
2
2
  #include <string>
3
3
  #include <vector>
4
4
  #include <unordered_map>
5
- #include "whisper.h"
5
+ #include "rn-whisper.h"
6
6
 
7
- extern "C" {
7
+ #define DEFAULT_MAX_AUDIO_SEC 30;
8
8
 
9
- std::unordered_map<int, bool> abort_map;
10
-
11
- bool* rn_whisper_assign_abort_map(int job_id) {
12
- abort_map[job_id] = false;
13
- return &abort_map[job_id];
14
- }
15
-
16
- void rn_whisper_remove_abort_map(int job_id) {
17
- if (abort_map.find(job_id) != abort_map.end()) {
18
- abort_map.erase(job_id);
19
- }
20
- }
21
-
22
- void rn_whisper_abort_transcribe(int job_id) {
23
- if (abort_map.find(job_id) != abort_map.end()) {
24
- abort_map[job_id] = true;
25
- }
26
- }
27
-
28
- bool rn_whisper_transcribe_is_aborted(int job_id) {
29
- if (abort_map.find(job_id) != abort_map.end()) {
30
- return abort_map[job_id];
31
- }
32
- return false;
33
- }
34
-
35
- void rn_whisper_abort_all_transcribe() {
36
- for (auto it = abort_map.begin(); it != abort_map.end(); ++it) {
37
- it->second = true;
38
- }
39
- }
9
+ namespace rnwhisper {
40
10
 
41
11
  void high_pass_filter(std::vector<float> & data, float cutoff, float sample_rate) {
42
12
  const float rc = 1.0f / (2.0f * M_PI * cutoff);
@@ -51,42 +21,156 @@ void high_pass_filter(std::vector<float> & data, float cutoff, float sample_rate
51
21
  }
52
22
  }
53
23
 
54
- bool rn_whisper_vad_simple(std::vector<float> & pcmf32, int sample_rate, int last_ms, float vad_thold, float freq_thold, bool verbose) {
55
- const int n_samples = pcmf32.size();
56
- const int n_samples_last = (sample_rate * last_ms) / 1000;
24
+ bool vad_simple_impl(std::vector<float> & pcmf32, int sample_rate, int last_ms, float vad_thold, float freq_thold, bool verbose) {
25
+ const int n_samples = pcmf32.size();
26
+ const int n_samples_last = (sample_rate * last_ms) / 1000;
57
27
 
58
- if (n_samples_last >= n_samples) {
59
- // not enough samples - assume no speech
60
- return false;
61
- }
28
+ if (n_samples_last >= n_samples) {
29
+ // not enough samples - assume no speech
30
+ return false;
31
+ }
62
32
 
63
- if (freq_thold > 0.0f) {
64
- high_pass_filter(pcmf32, freq_thold, sample_rate);
65
- }
33
+ if (freq_thold > 0.0f) {
34
+ high_pass_filter(pcmf32, freq_thold, sample_rate);
35
+ }
36
+
37
+ float energy_all = 0.0f;
38
+ float energy_last = 0.0f;
39
+
40
+ for (int i = 0; i < n_samples; i++) {
41
+ energy_all += fabsf(pcmf32[i]);
66
42
 
67
- float energy_all = 0.0f;
68
- float energy_last = 0.0f;
43
+ if (i >= n_samples - n_samples_last) {
44
+ energy_last += fabsf(pcmf32[i]);
45
+ }
46
+ }
47
+
48
+ energy_all /= n_samples;
49
+ energy_last /= n_samples_last;
69
50
 
70
- for (int i = 0; i < n_samples; i++) {
71
- energy_all += fabsf(pcmf32[i]);
51
+ if (verbose) {
52
+ RNWHISPER_LOG_INFO("%s: energy_all: %f, energy_last: %f, vad_thold: %f, freq_thold: %f\n", __func__, energy_all, energy_last, vad_thold, freq_thold);
53
+ }
72
54
 
73
- if (i >= n_samples - n_samples_last) {
74
- energy_last += fabsf(pcmf32[i]);
55
+ if (energy_last > vad_thold*energy_all) {
56
+ return false;
75
57
  }
76
- }
77
58
 
78
- energy_all /= n_samples;
79
- energy_last /= n_samples_last;
59
+ return true;
60
+ }
80
61
 
81
- if (verbose) {
82
- fprintf(stderr, "%s: energy_all: %f, energy_last: %f, vad_thold: %f, freq_thold: %f\n", __func__, energy_all, energy_last, vad_thold, freq_thold);
83
- }
62
+ void job::set_realtime_params(
63
+ vad_params params,
64
+ int sec,
65
+ int slice_sec,
66
+ float min_sec,
67
+ const char* output_path
68
+ ) {
69
+ vad = params;
70
+ if (vad.vad_ms < 2000) vad.vad_ms = 2000;
71
+ audio_sec = sec > 0 ? sec : DEFAULT_MAX_AUDIO_SEC;
72
+ audio_slice_sec = slice_sec > 0 && slice_sec < audio_sec ? slice_sec : audio_sec;
73
+ audio_min_sec = min_sec >= 0.5 && min_sec <= audio_slice_sec ? min_sec : 1.0f;
74
+ audio_output_path = output_path;
75
+ }
84
76
 
85
- if (energy_last > vad_thold*energy_all) {
77
+ bool job::vad_simple(int slice_index, int n_samples, int n) {
78
+ if (!vad.use_vad) return true;
79
+
80
+ short* pcm = pcm_slices[slice_index];
81
+ int sample_size = (int) (WHISPER_SAMPLE_RATE * vad.vad_ms / 1000);
82
+ if (n_samples + n > sample_size) {
83
+ int start = n_samples + n - sample_size;
84
+ std::vector<float> pcmf32(sample_size);
85
+ for (int i = 0; i < sample_size; i++) {
86
+ pcmf32[i] = (float)pcm[i + start] / 32768.0f;
87
+ }
88
+ return vad_simple_impl(pcmf32, WHISPER_SAMPLE_RATE, vad.last_ms, vad.vad_thold, vad.freq_thold, vad.verbose);
89
+ }
86
90
  return false;
87
- }
91
+ }
92
+
93
+ void job::put_pcm_data(short* data, int slice_index, int n_samples, int n) {
94
+ if (pcm_slices.size() == slice_index) {
95
+ int n_slices = (int) (WHISPER_SAMPLE_RATE * audio_slice_sec);
96
+ pcm_slices.push_back(new short[n_slices]);
97
+ }
98
+ short* pcm = pcm_slices[slice_index];
99
+ for (int i = 0; i < n; i++) {
100
+ pcm[i + n_samples] = data[i];
101
+ }
102
+ }
103
+
104
+ float* job::pcm_slice_to_f32(int slice_index, int size) {
105
+ if (pcm_slices.size() > slice_index) {
106
+ float* pcmf32 = new float[size];
107
+ for (int i = 0; i < size; i++) {
108
+ pcmf32[i] = (float)pcm_slices[slice_index][i] / 32768.0f;
109
+ }
110
+ return pcmf32;
111
+ }
112
+ return nullptr;
113
+ }
114
+
115
+ bool job::is_aborted() {
116
+ return aborted;
117
+ }
118
+
119
+ void job::abort() {
120
+ aborted = true;
121
+ }
122
+
123
+ job::~job() {
124
+ RNWHISPER_LOG_INFO("rnwhisper::job::%s: job_id: %d\n", __func__, job_id);
125
+
126
+ for (size_t i = 0; i < pcm_slices.size(); i++) {
127
+ delete[] pcm_slices[i];
128
+ }
129
+ pcm_slices.clear();
130
+ }
131
+
132
+ std::unordered_map<int, job*> job_map;
133
+
134
+ void job_abort_all() {
135
+ for (auto it = job_map.begin(); it != job_map.end(); ++it) {
136
+ it->second->abort();
137
+ }
138
+ }
139
+
140
+ job* job_new(int job_id, struct whisper_full_params params) {
141
+ job* ctx = new job();
142
+ ctx->job_id = job_id;
143
+ ctx->params = params;
144
+
145
+ job_map[job_id] = ctx;
146
+
147
+ // Abort handler
148
+ params.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) {
149
+ job *j = (job*)user_data;
150
+ return !j->is_aborted();
151
+ };
152
+ params.encoder_begin_callback_user_data = job_map[job_id];
153
+ params.abort_callback = [](void * user_data) {
154
+ job *j = (job*)user_data;
155
+ return j->is_aborted();
156
+ };
157
+ params.abort_callback_user_data = job_map[job_id];
158
+
159
+ return job_map[job_id];
160
+ }
161
+
162
+ job* job_get(int job_id) {
163
+ if (job_map.find(job_id) != job_map.end()) {
164
+ return job_map[job_id];
165
+ }
166
+ return nullptr;
167
+ }
88
168
 
89
- return true;
169
+ void job_remove(int job_id) {
170
+ if (job_map.find(job_id) != job_map.end()) {
171
+ delete job_map[job_id];
172
+ }
173
+ job_map.erase(job_id);
90
174
  }
91
175
 
92
- }
176
+ }
package/cpp/rn-whisper.h CHANGED
@@ -1,17 +1,50 @@
1
+ #ifndef RNWHISPER_H
2
+ #define RNWHISPER_H
1
3
 
2
- #ifdef __cplusplus
3
4
  #include <string>
4
- #include <whisper.h>
5
- extern "C" {
6
- #endif
7
-
8
- bool* rn_whisper_assign_abort_map(int job_id);
9
- void rn_whisper_remove_abort_map(int job_id);
10
- void rn_whisper_abort_transcribe(int job_id);
11
- bool rn_whisper_transcribe_is_aborted(int job_id);
12
- void rn_whisper_abort_all_transcribe();
13
- bool rn_whisper_vad_simple(std::vector<float> & pcmf32, int sample_rate, int last_ms, float vad_thold, float freq_thold, bool verbose);
14
-
15
- #ifdef __cplusplus
16
- }
17
- #endif
5
+ #include <vector>
6
+ #include "whisper.h"
7
+ #include "rn-whisper-log.h"
8
+ #include "rn-audioutils.h"
9
+
10
+ namespace rnwhisper {
11
+
12
+ struct vad_params {
13
+ bool use_vad = false;
14
+ float vad_thold = 0.6f;
15
+ float freq_thold = 100.0f;
16
+ int vad_ms = 2000;
17
+ int last_ms = 1000;
18
+ bool verbose = false;
19
+ };
20
+
21
+ struct job {
22
+ int job_id;
23
+ bool aborted = false;
24
+ whisper_full_params params;
25
+
26
+ ~job();
27
+ bool is_aborted();
28
+ void abort();
29
+
30
+ // Realtime transcription only:
31
+ vad_params vad;
32
+ int audio_sec = 0;
33
+ int audio_slice_sec = 0;
34
+ float audio_min_sec = 0;
35
+ const char* audio_output_path = nullptr;
36
+ std::vector<short *> pcm_slices;
37
+ void set_realtime_params(vad_params vad, int sec, int slice_sec, float min_sec, const char* output_path);
38
+ bool vad_simple(int slice_index, int n_samples, int n);
39
+ void put_pcm_data(short* pcm, int slice_index, int n_samples, int n);
40
+ float* pcm_slice_to_f32(int slice_index, int size);
41
+ };
42
+
43
+ void job_abort_all();
44
+ job* job_new(int job_id, struct whisper_full_params params);
45
+ void job_remove(int job_id);
46
+ job* job_get(int job_id);
47
+
48
+ } // namespace rnwhisper
49
+
50
+ #endif // RNWHISPER_H