whisper.rn 0.4.0-rc.4 → 0.4.0-rc.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. package/README.md +6 -6
  2. package/android/build.gradle +4 -0
  3. package/android/src/main/CMakeLists.txt +5 -0
  4. package/android/src/main/java/com/rnwhisper/AudioUtils.java +0 -80
  5. package/android/src/main/java/com/rnwhisper/WhisperContext.java +51 -133
  6. package/android/src/main/jni-utils.h +76 -0
  7. package/android/src/main/jni.cpp +187 -112
  8. package/cpp/README.md +1 -1
  9. package/cpp/coreml/whisper-encoder-impl.h +1 -1
  10. package/cpp/coreml/whisper-encoder.h +4 -0
  11. package/cpp/coreml/whisper-encoder.mm +4 -2
  12. package/cpp/ggml-alloc.c +55 -19
  13. package/cpp/ggml-alloc.h +7 -0
  14. package/cpp/ggml-backend-impl.h +46 -21
  15. package/cpp/ggml-backend.c +563 -156
  16. package/cpp/ggml-backend.h +62 -17
  17. package/cpp/ggml-impl.h +1 -1
  18. package/cpp/ggml-metal-whisper.metal +1010 -253
  19. package/cpp/ggml-metal.h +7 -1
  20. package/cpp/ggml-metal.m +618 -187
  21. package/cpp/ggml-quants.c +64 -59
  22. package/cpp/ggml-quants.h +40 -40
  23. package/cpp/ggml.c +751 -1466
  24. package/cpp/ggml.h +90 -25
  25. package/cpp/rn-audioutils.cpp +68 -0
  26. package/cpp/rn-audioutils.h +14 -0
  27. package/cpp/rn-whisper-log.h +11 -0
  28. package/cpp/rn-whisper.cpp +141 -59
  29. package/cpp/rn-whisper.h +47 -15
  30. package/cpp/whisper.cpp +1635 -928
  31. package/cpp/whisper.h +55 -10
  32. package/ios/RNWhisper.mm +7 -7
  33. package/ios/RNWhisperAudioUtils.h +0 -2
  34. package/ios/RNWhisperAudioUtils.m +0 -56
  35. package/ios/RNWhisperContext.h +3 -11
  36. package/ios/RNWhisperContext.mm +62 -134
  37. package/lib/commonjs/version.json +1 -1
  38. package/lib/module/version.json +1 -1
  39. package/package.json +6 -5
  40. package/src/version.json +1 -1
package/cpp/ggml.h CHANGED
@@ -244,11 +244,10 @@
244
244
  #define WSP_GGML_ASSERT(x) \
245
245
  do { \
246
246
  if (!(x)) { \
247
- fprintf(stderr, "WSP_GGML_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \
248
- fflush(stderr); \
249
247
  fflush(stdout); \
248
+ fprintf(stderr, "WSP_GGML_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \
250
249
  wsp_ggml_print_backtrace(); \
251
- exit(1); \
250
+ abort(); \
252
251
  } \
253
252
  } while (0)
254
253
 
@@ -284,6 +283,20 @@
284
283
  const type prefix##3 = (pointer)->array[3]; \
285
284
  WSP_GGML_UNUSED(prefix##3);
286
285
 
286
+ #define WSP_GGML_TENSOR_UNARY_OP_LOCALS \
287
+ WSP_GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
288
+ WSP_GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \
289
+ WSP_GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \
290
+ WSP_GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
291
+
292
+ #define WSP_GGML_TENSOR_BINARY_OP_LOCALS \
293
+ WSP_GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
294
+ WSP_GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \
295
+ WSP_GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) \
296
+ WSP_GGML_TENSOR_LOCALS(size_t, nb1, src1, nb) \
297
+ WSP_GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \
298
+ WSP_GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
299
+
287
300
  #ifdef __cplusplus
288
301
  extern "C" {
289
302
  #endif
@@ -382,6 +395,7 @@ extern "C" {
382
395
  WSP_GGML_OP_GROUP_NORM,
383
396
 
384
397
  WSP_GGML_OP_MUL_MAT,
398
+ WSP_GGML_OP_MUL_MAT_ID,
385
399
  WSP_GGML_OP_OUT_PROD,
386
400
 
387
401
  WSP_GGML_OP_SCALE,
@@ -403,18 +417,13 @@ extern "C" {
403
417
  WSP_GGML_OP_ROPE_BACK,
404
418
  WSP_GGML_OP_ALIBI,
405
419
  WSP_GGML_OP_CLAMP,
406
- WSP_GGML_OP_CONV_1D,
407
- WSP_GGML_OP_CONV_1D_STAGE_0, // internal
408
- WSP_GGML_OP_CONV_1D_STAGE_1, // internal
409
420
  WSP_GGML_OP_CONV_TRANSPOSE_1D,
410
- WSP_GGML_OP_CONV_2D,
411
- WSP_GGML_OP_CONV_2D_STAGE_0, // internal
412
- WSP_GGML_OP_CONV_2D_STAGE_1, // internal
421
+ WSP_GGML_OP_IM2COL,
413
422
  WSP_GGML_OP_CONV_TRANSPOSE_2D,
414
423
  WSP_GGML_OP_POOL_1D,
415
424
  WSP_GGML_OP_POOL_2D,
416
-
417
425
  WSP_GGML_OP_UPSCALE, // nearest interpolate
426
+ WSP_GGML_OP_ARGSORT,
418
427
 
419
428
  WSP_GGML_OP_FLASH_ATTN,
420
429
  WSP_GGML_OP_FLASH_FF,
@@ -454,7 +463,9 @@ extern "C" {
454
463
  WSP_GGML_UNARY_OP_GELU,
455
464
  WSP_GGML_UNARY_OP_GELU_QUICK,
456
465
  WSP_GGML_UNARY_OP_SILU,
457
- WSP_GGML_UNARY_OP_LEAKY
466
+ WSP_GGML_UNARY_OP_LEAKY,
467
+
468
+ WSP_GGML_UNARY_OP_COUNT,
458
469
  };
459
470
 
460
471
  enum wsp_ggml_object_type {
@@ -637,6 +648,9 @@ extern "C" {
637
648
  WSP_GGML_API const char * wsp_ggml_op_name (enum wsp_ggml_op op);
638
649
  WSP_GGML_API const char * wsp_ggml_op_symbol(enum wsp_ggml_op op);
639
650
 
651
+ WSP_GGML_API const char * wsp_ggml_unary_op_name(enum wsp_ggml_unary_op op);
652
+ WSP_GGML_API const char * wsp_ggml_op_desc(const struct wsp_ggml_tensor * t); // unary or op name
653
+
640
654
  WSP_GGML_API size_t wsp_ggml_element_size(const struct wsp_ggml_tensor * tensor);
641
655
 
642
656
  WSP_GGML_API bool wsp_ggml_is_quantized(enum wsp_ggml_type type);
@@ -1033,6 +1047,15 @@ extern "C" {
1033
1047
  struct wsp_ggml_tensor * a,
1034
1048
  struct wsp_ggml_tensor * b);
1035
1049
 
1050
+ // indirect matrix multiplication
1051
+ // wsp_ggml_mul_mat_id(ctx, as, ids, id, b) ~= wsp_ggml_mul_mat(as[ids[id]], b)
1052
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_mul_mat_id(
1053
+ struct wsp_ggml_context * ctx,
1054
+ struct wsp_ggml_tensor * as[],
1055
+ struct wsp_ggml_tensor * ids,
1056
+ int id,
1057
+ struct wsp_ggml_tensor * b);
1058
+
1036
1059
  // A: m columns, n rows,
1037
1060
  // B: p columns, n rows,
1038
1061
  // result is m columns, p rows
@@ -1288,6 +1311,14 @@ extern "C" {
1288
1311
  struct wsp_ggml_context * ctx,
1289
1312
  struct wsp_ggml_tensor * a);
1290
1313
 
1314
+ // fused soft_max(a*scale + mask)
1315
+ // mask is optional
1316
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_soft_max_ext(
1317
+ struct wsp_ggml_context * ctx,
1318
+ struct wsp_ggml_tensor * a,
1319
+ struct wsp_ggml_tensor * mask,
1320
+ float scale);
1321
+
1291
1322
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_soft_max_back(
1292
1323
  struct wsp_ggml_context * ctx,
1293
1324
  struct wsp_ggml_tensor * a,
@@ -1376,8 +1407,13 @@ extern "C" {
1376
1407
  int n_dims,
1377
1408
  int mode,
1378
1409
  int n_ctx,
1410
+ int n_orig_ctx,
1379
1411
  float freq_base,
1380
1412
  float freq_scale,
1413
+ float ext_factor,
1414
+ float attn_factor,
1415
+ float beta_fast,
1416
+ float beta_slow,
1381
1417
  float xpos_base,
1382
1418
  bool xpos_down);
1383
1419
 
@@ -1398,6 +1434,18 @@ extern "C" {
1398
1434
  float min,
1399
1435
  float max);
1400
1436
 
1437
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_im2col(
1438
+ struct wsp_ggml_context * ctx,
1439
+ struct wsp_ggml_tensor * a,
1440
+ struct wsp_ggml_tensor * b,
1441
+ int s0,
1442
+ int s1,
1443
+ int p0,
1444
+ int p1,
1445
+ int d0,
1446
+ int d1,
1447
+ bool is_2D);
1448
+
1401
1449
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_conv_1d(
1402
1450
  struct wsp_ggml_context * ctx,
1403
1451
  struct wsp_ggml_tensor * a,
@@ -1501,6 +1549,23 @@ extern "C" {
1501
1549
  struct wsp_ggml_tensor * a,
1502
1550
  int scale_factor);
1503
1551
 
1552
+ // sort rows
1553
+ enum wsp_ggml_sort_order {
1554
+ WSP_GGML_SORT_ASC,
1555
+ WSP_GGML_SORT_DESC,
1556
+ };
1557
+
1558
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_argsort(
1559
+ struct wsp_ggml_context * ctx,
1560
+ struct wsp_ggml_tensor * a,
1561
+ enum wsp_ggml_sort_order order);
1562
+
1563
+ // top k elements per row
1564
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_top_k(
1565
+ struct wsp_ggml_context * ctx,
1566
+ struct wsp_ggml_tensor * a,
1567
+ int k);
1568
+
1504
1569
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_flash_attn(
1505
1570
  struct wsp_ggml_context * ctx,
1506
1571
  struct wsp_ggml_tensor * q,
@@ -1562,7 +1627,6 @@ extern "C" {
1562
1627
  int kh);
1563
1628
 
1564
1629
  // used in sam
1565
-
1566
1630
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_add_rel_pos(
1567
1631
  struct wsp_ggml_context * ctx,
1568
1632
  struct wsp_ggml_tensor * a,
@@ -1737,7 +1801,7 @@ extern "C" {
1737
1801
  WSP_GGML_API struct wsp_ggml_cgraph * wsp_ggml_new_graph (struct wsp_ggml_context * ctx); // size = WSP_GGML_DEFAULT_GRAPH_SIZE, grads = false
1738
1802
  WSP_GGML_API struct wsp_ggml_cgraph * wsp_ggml_new_graph_custom (struct wsp_ggml_context * ctx, size_t size, bool grads);
1739
1803
  WSP_GGML_API struct wsp_ggml_cgraph * wsp_ggml_graph_dup (struct wsp_ggml_context * ctx, struct wsp_ggml_cgraph * cgraph);
1740
- WSP_GGML_API struct wsp_ggml_cgraph * wsp_ggml_graph_view (struct wsp_ggml_context * ctx, struct wsp_ggml_cgraph * cgraph, int i0, int i1);
1804
+ WSP_GGML_API struct wsp_ggml_cgraph wsp_ggml_graph_view (struct wsp_ggml_cgraph * cgraph, int i0, int i1);
1741
1805
  WSP_GGML_API void wsp_ggml_graph_cpy (struct wsp_ggml_cgraph * src, struct wsp_ggml_cgraph * dst);
1742
1806
  WSP_GGML_API void wsp_ggml_graph_reset (struct wsp_ggml_cgraph * cgraph); // zero grads
1743
1807
  WSP_GGML_API void wsp_ggml_graph_clear (struct wsp_ggml_cgraph * cgraph);
@@ -1955,20 +2019,20 @@ extern "C" {
1955
2019
  // quantization
1956
2020
  //
1957
2021
 
1958
- // TODO: these would probably get removed in favor of the more general wsp_ggml_quantize_chunk
1959
- WSP_GGML_API size_t wsp_ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * hist);
1960
- WSP_GGML_API size_t wsp_ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist);
1961
- WSP_GGML_API size_t wsp_ggml_quantize_q5_0(const float * src, void * dst, int n, int k, int64_t * hist);
1962
- WSP_GGML_API size_t wsp_ggml_quantize_q5_1(const float * src, void * dst, int n, int k, int64_t * hist);
1963
- WSP_GGML_API size_t wsp_ggml_quantize_q8_0(const float * src, void * dst, int n, int k, int64_t * hist);
2022
+ // TODO: these would probably get removed in favor of the more general wsp_ggml_wsp_quantize_chunk
2023
+ WSP_GGML_API size_t wsp_ggml_wsp_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * hist);
2024
+ WSP_GGML_API size_t wsp_ggml_wsp_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist);
2025
+ WSP_GGML_API size_t wsp_ggml_wsp_quantize_q5_0(const float * src, void * dst, int n, int k, int64_t * hist);
2026
+ WSP_GGML_API size_t wsp_ggml_wsp_quantize_q5_1(const float * src, void * dst, int n, int k, int64_t * hist);
2027
+ WSP_GGML_API size_t wsp_ggml_wsp_quantize_q8_0(const float * src, void * dst, int n, int k, int64_t * hist);
1964
2028
 
1965
- WSP_GGML_API size_t wsp_ggml_quantize_q2_K(const float * src, void * dst, int n, int k, int64_t * hist);
1966
- WSP_GGML_API size_t wsp_ggml_quantize_q3_K(const float * src, void * dst, int n, int k, int64_t * hist);
1967
- WSP_GGML_API size_t wsp_ggml_quantize_q4_K(const float * src, void * dst, int n, int k, int64_t * hist);
1968
- WSP_GGML_API size_t wsp_ggml_quantize_q5_K(const float * src, void * dst, int n, int k, int64_t * hist);
1969
- WSP_GGML_API size_t wsp_ggml_quantize_q6_K(const float * src, void * dst, int n, int k, int64_t * hist);
2029
+ WSP_GGML_API size_t wsp_ggml_wsp_quantize_q2_K(const float * src, void * dst, int n, int k, int64_t * hist);
2030
+ WSP_GGML_API size_t wsp_ggml_wsp_quantize_q3_K(const float * src, void * dst, int n, int k, int64_t * hist);
2031
+ WSP_GGML_API size_t wsp_ggml_wsp_quantize_q4_K(const float * src, void * dst, int n, int k, int64_t * hist);
2032
+ WSP_GGML_API size_t wsp_ggml_wsp_quantize_q5_K(const float * src, void * dst, int n, int k, int64_t * hist);
2033
+ WSP_GGML_API size_t wsp_ggml_wsp_quantize_q6_K(const float * src, void * dst, int n, int k, int64_t * hist);
1970
2034
 
1971
- WSP_GGML_API size_t wsp_ggml_quantize_chunk(enum wsp_ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist);
2035
+ WSP_GGML_API size_t wsp_ggml_wsp_quantize_chunk(enum wsp_ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist);
1972
2036
 
1973
2037
  //
1974
2038
  // gguf
@@ -2033,6 +2097,7 @@ extern "C" {
2033
2097
  WSP_GGML_API double wsp_gguf_get_val_f64 (const struct wsp_gguf_context * ctx, int key_id);
2034
2098
  WSP_GGML_API bool wsp_gguf_get_val_bool(const struct wsp_gguf_context * ctx, int key_id);
2035
2099
  WSP_GGML_API const char * wsp_gguf_get_val_str (const struct wsp_gguf_context * ctx, int key_id);
2100
+ WSP_GGML_API const void * wsp_gguf_get_val_data(const struct wsp_gguf_context * ctx, int key_id);
2036
2101
  WSP_GGML_API int wsp_gguf_get_arr_n (const struct wsp_gguf_context * ctx, int key_id);
2037
2102
  WSP_GGML_API const void * wsp_gguf_get_arr_data(const struct wsp_gguf_context * ctx, int key_id);
2038
2103
  WSP_GGML_API const char * wsp_gguf_get_arr_str (const struct wsp_gguf_context * ctx, int key_id, int i);
@@ -0,0 +1,68 @@
1
+ #include "rn-audioutils.h"
2
+ #include "rn-whisper-log.h"
3
+
4
+ namespace rnaudioutils {
5
+
6
+ std::vector<uint8_t> concat_short_buffers(const std::vector<short*>& buffers, const std::vector<int>& slice_n_samples) {
7
+ std::vector<uint8_t> output_data;
8
+
9
+ for (size_t i = 0; i < buffers.size(); i++) {
10
+ int size = slice_n_samples[i]; // Number of shorts
11
+ short* slice = buffers[i];
12
+
13
+ // Copy each short as two bytes
14
+ for (int j = 0; j < size; j++) {
15
+ output_data.push_back(static_cast<uint8_t>(slice[j] & 0xFF)); // Lower byte
16
+ output_data.push_back(static_cast<uint8_t>((slice[j] >> 8) & 0xFF)); // Higher byte
17
+ }
18
+ }
19
+
20
+ return output_data;
21
+ }
22
+
23
+ std::vector<uint8_t> remove_trailing_zeros(const std::vector<uint8_t>& audio_data) {
24
+ auto last = std::find_if(audio_data.rbegin(), audio_data.rend(), [](uint8_t byte) { return byte != 0; });
25
+ return std::vector<uint8_t>(audio_data.begin(), last.base());
26
+ }
27
+
28
+ void save_wav_file(const std::vector<uint8_t>& raw, const std::string& file) {
29
+ std::vector<uint8_t> data = remove_trailing_zeros(raw);
30
+
31
+ std::ofstream output(file, std::ios::binary);
32
+
33
+ if (!output.is_open()) {
34
+ RNWHISPER_LOG_ERROR("Failed to open file for writing: %s\n", file.c_str());
35
+ return;
36
+ }
37
+
38
+ // WAVE header
39
+ output.write("RIFF", 4);
40
+ int32_t chunk_size = 36 + static_cast<int32_t>(data.size());
41
+ output.write(reinterpret_cast<char*>(&chunk_size), sizeof(chunk_size));
42
+ output.write("WAVE", 4);
43
+ output.write("fmt ", 4);
44
+ int32_t sub_chunk_size = 16;
45
+ output.write(reinterpret_cast<char*>(&sub_chunk_size), sizeof(sub_chunk_size));
46
+ short audio_format = 1;
47
+ output.write(reinterpret_cast<char*>(&audio_format), sizeof(audio_format));
48
+ short num_channels = 1;
49
+ output.write(reinterpret_cast<char*>(&num_channels), sizeof(num_channels));
50
+ int32_t sample_rate = WHISPER_SAMPLE_RATE;
51
+ output.write(reinterpret_cast<char*>(&sample_rate), sizeof(sample_rate));
52
+ int32_t byte_rate = WHISPER_SAMPLE_RATE * 2;
53
+ output.write(reinterpret_cast<char*>(&byte_rate), sizeof(byte_rate));
54
+ short block_align = 2;
55
+ output.write(reinterpret_cast<char*>(&block_align), sizeof(block_align));
56
+ short bits_per_sample = 16;
57
+ output.write(reinterpret_cast<char*>(&bits_per_sample), sizeof(bits_per_sample));
58
+ output.write("data", 4);
59
+ int32_t sub_chunk2_size = static_cast<int32_t>(data.size());
60
+ output.write(reinterpret_cast<char*>(&sub_chunk2_size), sizeof(sub_chunk2_size));
61
+ output.write(reinterpret_cast<const char*>(data.data()), data.size());
62
+
63
+ output.close();
64
+
65
+ RNWHISPER_LOG_INFO("Saved audio file: %s\n", file.c_str());
66
+ }
67
+
68
+ } // namespace rnaudioutils
@@ -0,0 +1,14 @@
1
+ #include <iostream>
2
+ #include <fstream>
3
+ #include <vector>
4
+ #include <cstdint>
5
+ #include <cstring>
6
+ #include <algorithm>
7
+ #include "whisper.h"
8
+
9
+ namespace rnaudioutils {
10
+
11
+ std::vector<uint8_t> concat_short_buffers(const std::vector<short*>& buffers, const std::vector<int>& slice_n_samples);
12
+ void save_wav_file(const std::vector<uint8_t>& raw, const std::string& file);
13
+
14
+ } // namespace rnaudioutils
@@ -0,0 +1,11 @@
1
+ #if defined(__ANDROID__) && defined(RNWHISPER_ANDROID_ENABLE_LOGGING)
2
+ #include <android/log.h>
3
+ #define RNWHISPER_ANDROID_TAG "RNWHISPER_LOG_ANDROID"
4
+ #define RNWHISPER_LOG_INFO(...) __android_log_print(ANDROID_LOG_INFO , RNWHISPER_ANDROID_TAG, __VA_ARGS__)
5
+ #define RNWHISPER_LOG_WARN(...) __android_log_print(ANDROID_LOG_WARN , RNWHISPER_ANDROID_TAG, __VA_ARGS__)
6
+ #define RNWHISPER_LOG_ERROR(...) __android_log_print(ANDROID_LOG_ERROR, RNWHISPER_ANDROID_TAG, __VA_ARGS__)
7
+ #else
8
+ #define RNWHISPER_LOG_INFO(...) fprintf(stderr, __VA_ARGS__)
9
+ #define RNWHISPER_LOG_WARN(...) fprintf(stderr, __VA_ARGS__)
10
+ #define RNWHISPER_LOG_ERROR(...) fprintf(stderr, __VA_ARGS__)
11
+ #endif // __ANDROID__
@@ -2,41 +2,11 @@
2
2
  #include <string>
3
3
  #include <vector>
4
4
  #include <unordered_map>
5
- #include "whisper.h"
5
+ #include "rn-whisper.h"
6
6
 
7
- extern "C" {
7
+ #define DEFAULT_MAX_AUDIO_SEC 30;
8
8
 
9
- std::unordered_map<int, bool> abort_map;
10
-
11
- bool* rn_whisper_assign_abort_map(int job_id) {
12
- abort_map[job_id] = false;
13
- return &abort_map[job_id];
14
- }
15
-
16
- void rn_whisper_remove_abort_map(int job_id) {
17
- if (abort_map.find(job_id) != abort_map.end()) {
18
- abort_map.erase(job_id);
19
- }
20
- }
21
-
22
- void rn_whisper_abort_transcribe(int job_id) {
23
- if (abort_map.find(job_id) != abort_map.end()) {
24
- abort_map[job_id] = true;
25
- }
26
- }
27
-
28
- bool rn_whisper_transcribe_is_aborted(int job_id) {
29
- if (abort_map.find(job_id) != abort_map.end()) {
30
- return abort_map[job_id];
31
- }
32
- return false;
33
- }
34
-
35
- void rn_whisper_abort_all_transcribe() {
36
- for (auto it = abort_map.begin(); it != abort_map.end(); ++it) {
37
- it->second = true;
38
- }
39
- }
9
+ namespace rnwhisper {
40
10
 
41
11
  void high_pass_filter(std::vector<float> & data, float cutoff, float sample_rate) {
42
12
  const float rc = 1.0f / (2.0f * M_PI * cutoff);
@@ -51,42 +21,154 @@ void high_pass_filter(std::vector<float> & data, float cutoff, float sample_rate
51
21
  }
52
22
  }
53
23
 
54
- bool rn_whisper_vad_simple(std::vector<float> & pcmf32, int sample_rate, int last_ms, float vad_thold, float freq_thold, bool verbose) {
55
- const int n_samples = pcmf32.size();
56
- const int n_samples_last = (sample_rate * last_ms) / 1000;
24
+ bool vad_simple_impl(std::vector<float> & pcmf32, int sample_rate, int last_ms, float vad_thold, float freq_thold, bool verbose) {
25
+ const int n_samples = pcmf32.size();
26
+ const int n_samples_last = (sample_rate * last_ms) / 1000;
57
27
 
58
- if (n_samples_last >= n_samples) {
59
- // not enough samples - assume no speech
60
- return false;
61
- }
28
+ if (n_samples_last >= n_samples) {
29
+ // not enough samples - assume no speech
30
+ return false;
31
+ }
62
32
 
63
- if (freq_thold > 0.0f) {
64
- high_pass_filter(pcmf32, freq_thold, sample_rate);
65
- }
33
+ if (freq_thold > 0.0f) {
34
+ high_pass_filter(pcmf32, freq_thold, sample_rate);
35
+ }
36
+
37
+ float energy_all = 0.0f;
38
+ float energy_last = 0.0f;
39
+
40
+ for (int i = 0; i < n_samples; i++) {
41
+ energy_all += fabsf(pcmf32[i]);
66
42
 
67
- float energy_all = 0.0f;
68
- float energy_last = 0.0f;
43
+ if (i >= n_samples - n_samples_last) {
44
+ energy_last += fabsf(pcmf32[i]);
45
+ }
46
+ }
47
+
48
+ energy_all /= n_samples;
49
+ energy_last /= n_samples_last;
69
50
 
70
- for (int i = 0; i < n_samples; i++) {
71
- energy_all += fabsf(pcmf32[i]);
51
+ if (verbose) {
52
+ RNWHISPER_LOG_INFO("%s: energy_all: %f, energy_last: %f, vad_thold: %f, freq_thold: %f\n", __func__, energy_all, energy_last, vad_thold, freq_thold);
53
+ }
72
54
 
73
- if (i >= n_samples - n_samples_last) {
74
- energy_last += fabsf(pcmf32[i]);
55
+ if (energy_last > vad_thold*energy_all) {
56
+ return false;
75
57
  }
76
- }
77
58
 
78
- energy_all /= n_samples;
79
- energy_last /= n_samples_last;
59
+ return true;
60
+ }
80
61
 
81
- if (verbose) {
82
- fprintf(stderr, "%s: energy_all: %f, energy_last: %f, vad_thold: %f, freq_thold: %f\n", __func__, energy_all, energy_last, vad_thold, freq_thold);
83
- }
62
+ void job::set_realtime_params(
63
+ vad_params params,
64
+ int sec,
65
+ int slice_sec,
66
+ const char* output_path
67
+ ) {
68
+ vad = params;
69
+ if (vad.vad_ms < 2000) vad.vad_ms = 2000;
70
+ audio_sec = sec > 0 ? sec : DEFAULT_MAX_AUDIO_SEC;
71
+ audio_slice_sec = slice_sec > 0 && slice_sec < audio_sec ? slice_sec : audio_sec;
72
+ audio_output_path = output_path;
73
+ }
84
74
 
85
- if (energy_last > vad_thold*energy_all) {
75
+ bool job::vad_simple(int slice_index, int n_samples, int n) {
76
+ if (!vad.use_vad) return true;
77
+
78
+ short* pcm = pcm_slices[slice_index];
79
+ int sample_size = (int) (WHISPER_SAMPLE_RATE * vad.vad_ms / 1000);
80
+ if (n_samples + n > sample_size) {
81
+ int start = n_samples + n - sample_size;
82
+ std::vector<float> pcmf32(sample_size);
83
+ for (int i = 0; i < sample_size; i++) {
84
+ pcmf32[i] = (float)pcm[i + start] / 32768.0f;
85
+ }
86
+ return vad_simple_impl(pcmf32, WHISPER_SAMPLE_RATE, vad.last_ms, vad.vad_thold, vad.freq_thold, vad.verbose);
87
+ }
86
88
  return false;
87
- }
89
+ }
90
+
91
+ void job::put_pcm_data(short* data, int slice_index, int n_samples, int n) {
92
+ if (pcm_slices.size() == slice_index) {
93
+ int n_slices = (int) (WHISPER_SAMPLE_RATE * audio_slice_sec);
94
+ pcm_slices.push_back(new short[n_slices]);
95
+ }
96
+ short* pcm = pcm_slices[slice_index];
97
+ for (int i = 0; i < n; i++) {
98
+ pcm[i + n_samples] = data[i];
99
+ }
100
+ }
101
+
102
+ float* job::pcm_slice_to_f32(int slice_index, int size) {
103
+ if (pcm_slices.size() > slice_index) {
104
+ float* pcmf32 = new float[size];
105
+ for (int i = 0; i < size; i++) {
106
+ pcmf32[i] = (float)pcm_slices[slice_index][i] / 32768.0f;
107
+ }
108
+ return pcmf32;
109
+ }
110
+ return nullptr;
111
+ }
112
+
113
+ bool job::is_aborted() {
114
+ return aborted;
115
+ }
116
+
117
+ void job::abort() {
118
+ aborted = true;
119
+ }
120
+
121
+ job::~job() {
122
+ RNWHISPER_LOG_INFO("rnwhisper::job::%s: job_id: %d\n", __func__, job_id);
123
+
124
+ for (size_t i = 0; i < pcm_slices.size(); i++) {
125
+ delete[] pcm_slices[i];
126
+ }
127
+ pcm_slices.clear();
128
+ }
129
+
130
+ std::unordered_map<int, job*> job_map;
131
+
132
+ void job_abort_all() {
133
+ for (auto it = job_map.begin(); it != job_map.end(); ++it) {
134
+ it->second->abort();
135
+ }
136
+ }
137
+
138
+ job* job_new(int job_id, struct whisper_full_params params) {
139
+ job* ctx = new job();
140
+ ctx->job_id = job_id;
141
+ ctx->params = params;
142
+
143
+ job_map[job_id] = ctx;
144
+
145
+ // Abort handler
146
+ params.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) {
147
+ job *j = (job*)user_data;
148
+ return !j->is_aborted();
149
+ };
150
+ params.encoder_begin_callback_user_data = job_map[job_id];
151
+ params.abort_callback = [](void * user_data) {
152
+ job *j = (job*)user_data;
153
+ return j->is_aborted();
154
+ };
155
+ params.abort_callback_user_data = job_map[job_id];
156
+
157
+ return job_map[job_id];
158
+ }
159
+
160
+ job* job_get(int job_id) {
161
+ if (job_map.find(job_id) != job_map.end()) {
162
+ return job_map[job_id];
163
+ }
164
+ return nullptr;
165
+ }
88
166
 
89
- return true;
167
+ void job_remove(int job_id) {
168
+ if (job_map.find(job_id) != job_map.end()) {
169
+ delete job_map[job_id];
170
+ }
171
+ job_map.erase(job_id);
90
172
  }
91
173
 
92
- }
174
+ }
package/cpp/rn-whisper.h CHANGED
@@ -1,17 +1,49 @@
1
+ #ifndef RNWHISPER_H
2
+ #define RNWHISPER_H
1
3
 
2
- #ifdef __cplusplus
3
4
  #include <string>
4
- #include <whisper.h>
5
- extern "C" {
6
- #endif
7
-
8
- bool* rn_whisper_assign_abort_map(int job_id);
9
- void rn_whisper_remove_abort_map(int job_id);
10
- void rn_whisper_abort_transcribe(int job_id);
11
- bool rn_whisper_transcribe_is_aborted(int job_id);
12
- void rn_whisper_abort_all_transcribe();
13
- bool rn_whisper_vad_simple(std::vector<float> & pcmf32, int sample_rate, int last_ms, float vad_thold, float freq_thold, bool verbose);
14
-
15
- #ifdef __cplusplus
16
- }
17
- #endif
5
+ #include <vector>
6
+ #include "whisper.h"
7
+ #include "rn-whisper-log.h"
8
+ #include "rn-audioutils.h"
9
+
10
+ namespace rnwhisper {
11
+
12
+ struct vad_params {
13
+ bool use_vad = false;
14
+ float vad_thold = 0.6f;
15
+ float freq_thold = 100.0f;
16
+ int vad_ms = 2000;
17
+ int last_ms = 1000;
18
+ bool verbose = false;
19
+ };
20
+
21
+ struct job {
22
+ int job_id;
23
+ bool aborted = false;
24
+ whisper_full_params params;
25
+
26
+ ~job();
27
+ bool is_aborted();
28
+ void abort();
29
+
30
+ // Realtime transcription only:
31
+ vad_params vad;
32
+ int audio_sec = 0;
33
+ int audio_slice_sec = 0;
34
+ const char* audio_output_path = nullptr;
35
+ std::vector<short *> pcm_slices;
36
+ void set_realtime_params(vad_params vad, int sec, int slice_sec, const char* output_path);
37
+ bool vad_simple(int slice_index, int n_samples, int n);
38
+ void put_pcm_data(short* pcm, int slice_index, int n_samples, int n);
39
+ float* pcm_slice_to_f32(int slice_index, int size);
40
+ };
41
+
42
+ void job_abort_all();
43
+ job* job_new(int job_id, struct whisper_full_params params);
44
+ void job_remove(int job_id);
45
+ job* job_get(int job_id);
46
+
47
+ } // namespace rnwhisper
48
+
49
+ #endif // RNWHISPER_H