whisper.rn 0.3.6 → 0.3.8

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. package/README.md +28 -0
  2. package/android/src/main/java/com/rnwhisper/AudioUtils.java +119 -0
  3. package/android/src/main/java/com/rnwhisper/WhisperContext.java +74 -39
  4. package/android/src/main/jni.cpp +45 -12
  5. package/android/src/newarch/java/com/rnwhisper/RNWhisperModule.java +26 -0
  6. package/cpp/rn-whisper.cpp +51 -0
  7. package/cpp/rn-whisper.h +2 -1
  8. package/ios/RNWhisper.mm +81 -22
  9. package/ios/RNWhisper.xcodeproj/project.pbxproj +27 -3
  10. package/ios/RNWhisper.xcodeproj/project.xcworkspace/xcuserdata/jhen.xcuserdatad/UserInterfaceState.xcuserstate +0 -0
  11. package/ios/RNWhisper.xcodeproj/xcuserdata/jhen.xcuserdatad/xcschemes/xcschememanagement.plist +5 -0
  12. package/ios/RNWhisperAudioSessionUtils.h +13 -0
  13. package/ios/RNWhisperAudioSessionUtils.m +85 -0
  14. package/ios/RNWhisperAudioUtils.h +9 -0
  15. package/ios/RNWhisperAudioUtils.m +83 -0
  16. package/ios/RNWhisperContext.h +1 -0
  17. package/ios/RNWhisperContext.mm +101 -28
  18. package/lib/commonjs/AudioSessionIos.js +91 -0
  19. package/lib/commonjs/AudioSessionIos.js.map +1 -0
  20. package/lib/commonjs/NativeRNWhisper.js.map +1 -1
  21. package/lib/commonjs/index.js +82 -14
  22. package/lib/commonjs/index.js.map +1 -1
  23. package/lib/module/AudioSessionIos.js +83 -0
  24. package/lib/module/AudioSessionIos.js.map +1 -0
  25. package/lib/module/NativeRNWhisper.js.map +1 -1
  26. package/lib/module/index.js +77 -14
  27. package/lib/module/index.js.map +1 -1
  28. package/lib/typescript/AudioSessionIos.d.ts +54 -0
  29. package/lib/typescript/AudioSessionIos.d.ts.map +1 -0
  30. package/lib/typescript/NativeRNWhisper.d.ts +8 -0
  31. package/lib/typescript/NativeRNWhisper.d.ts.map +1 -1
  32. package/lib/typescript/index.d.ts +62 -4
  33. package/lib/typescript/index.d.ts.map +1 -1
  34. package/package.json +1 -1
  35. package/src/AudioSessionIos.ts +90 -0
  36. package/src/NativeRNWhisper.ts +11 -1
  37. package/src/index.ts +178 -28
package/README.md CHANGED
@@ -99,6 +99,34 @@ subscribe(evt => {
99
99
  })
100
100
  ```
101
101
 
102
+ In iOS, You may need to change the Audio Session so that it can be used with other audio playback, or to optimize the quality of the recording. So we have provided AudioSession utilities for you:
103
+
104
+ Option 1 - Use options in transcribeRealtime:
105
+ ```js
106
+ import { AudioSessionIos } from 'whisper.rn'
107
+
108
+ const { stop, subscribe } = await whisperContext.transcribeRealtime({
109
+ audioSessionOnStartIos: {
110
+ category: AudioSessionIos.Category.PlayAndRecord,
111
+ options: [AudioSessionIos.CategoryOption.MixWithOthers],
112
+ mode: AudioSessionIos.Mode.Default,
113
+ },
114
+ audioSessionOnStopIos: 'restore', // Or an AudioSessionSettingIos
115
+ })
116
+ ```
117
+
118
+ Option 2 - Manage the Audio Session in anywhere:
119
+ ```js
120
+ import { AudioSessionIos } from 'whisper.rn'
121
+
122
+ await AudioSessionIos.setCategory(
123
+ AudioSessionIos.Category.PlayAndRecord, [AudioSessionIos.CategoryOption.MixWithOthers],
124
+ )
125
+ await AudioSessionIos.setMode(AudioSessionIos.Mode.Default)
126
+ await AudioSessionIos.setActive(true)
127
+ // Then you can start do recording
128
+ ```
129
+
102
130
  In Android, you may need to request the microphone permission by [`PermissionAndroid`](https://reactnative.dev/docs/permissionsandroid).
103
131
 
104
132
  Please visit the [Documentation](docs/) for more details.
@@ -0,0 +1,119 @@
1
+ package com.rnwhisper;
2
+
3
+ import android.util.Log;
4
+
5
+ import java.util.ArrayList;
6
+ import java.lang.StringBuilder;
7
+ import java.io.IOException;
8
+ import java.io.FileReader;
9
+ import java.io.ByteArrayOutputStream;
10
+ import java.io.File;
11
+ import java.io.FileOutputStream;
12
+ import java.io.DataOutputStream;
13
+ import java.io.IOException;
14
+ import java.io.InputStream;
15
+ import java.nio.ByteBuffer;
16
+ import java.nio.ByteOrder;
17
+ import java.nio.ShortBuffer;
18
+
19
+ public class AudioUtils {
20
+ private static final String NAME = "RNWhisperAudioUtils";
21
+
22
+ private static final int SAMPLE_RATE = 16000;
23
+
24
+ private static byte[] shortToByte(short[] shortInts) {
25
+ int j = 0;
26
+ int length = shortInts.length;
27
+ byte[] byteData = new byte[length * 2];
28
+ for (int i = 0; i < length; i++) {
29
+ byteData[j++] = (byte) (shortInts[i] >>> 8);
30
+ byteData[j++] = (byte) (shortInts[i] >>> 0);
31
+ }
32
+ return byteData;
33
+ }
34
+
35
+ public static byte[] concatShortBuffers(ArrayList<short[]> buffers) {
36
+ int totalLength = 0;
37
+ for (int i = 0; i < buffers.size(); i++) {
38
+ totalLength += buffers.get(i).length;
39
+ }
40
+ byte[] result = new byte[totalLength * 2];
41
+ int offset = 0;
42
+ for (int i = 0; i < buffers.size(); i++) {
43
+ byte[] bytes = shortToByte(buffers.get(i));
44
+ System.arraycopy(bytes, 0, result, offset, bytes.length);
45
+ offset += bytes.length;
46
+ }
47
+
48
+ return result;
49
+ }
50
+
51
+ private static byte[] removeTrailingZeros(byte[] audioData) {
52
+ int i = audioData.length - 1;
53
+ while (i >= 0 && audioData[i] == 0) {
54
+ --i;
55
+ }
56
+ byte[] newData = new byte[i + 1];
57
+ System.arraycopy(audioData, 0, newData, 0, i + 1);
58
+ return newData;
59
+ }
60
+
61
+ public static void saveWavFile(byte[] rawData, String audioOutputFile) throws IOException {
62
+ Log.d(NAME, "call saveWavFile");
63
+ rawData = removeTrailingZeros(rawData);
64
+ DataOutputStream output = null;
65
+ try {
66
+ output = new DataOutputStream(new FileOutputStream(audioOutputFile));
67
+ // WAVE header
68
+ // see http://ccrma.stanford.edu/courses/422/projects/WaveFormat/
69
+ output.writeBytes("RIFF"); // chunk id
70
+ output.writeInt(Integer.reverseBytes(36 + rawData.length)); // chunk size
71
+ output.writeBytes("WAVE"); // format
72
+ output.writeBytes("fmt "); // subchunk 1 id
73
+ output.writeInt(Integer.reverseBytes(16)); // subchunk 1 size
74
+ output.writeShort(Short.reverseBytes((short) 1)); // audio format (1 = PCM)
75
+ output.writeShort(Short.reverseBytes((short) 1)); // number of channels
76
+ output.writeInt(Integer.reverseBytes(SAMPLE_RATE)); // sample rate
77
+ output.writeInt(Integer.reverseBytes(SAMPLE_RATE * 2)); // byte rate
78
+ output.writeShort(Short.reverseBytes((short) 2)); // block align
79
+ output.writeShort(Short.reverseBytes((short) 16)); // bits per sample
80
+ output.writeBytes("data"); // subchunk 2 id
81
+ output.writeInt(Integer.reverseBytes(rawData.length)); // subchunk 2 size
82
+ // Audio data (conversion big endian -> little endian)
83
+ short[] shorts = new short[rawData.length / 2];
84
+ ByteBuffer.wrap(rawData).order(ByteOrder.LITTLE_ENDIAN).asShortBuffer().get(shorts);
85
+ ByteBuffer bytes = ByteBuffer.allocate(shorts.length * 2);
86
+ for (short s : shorts) {
87
+ bytes.putShort(s);
88
+ }
89
+ Log.d(NAME, "writing audio file: " + audioOutputFile);
90
+ output.write(bytes.array());
91
+ } finally {
92
+ if (output != null) {
93
+ output.close();
94
+ }
95
+ }
96
+ }
97
+
98
+ public static float[] decodeWaveFile(InputStream inputStream) throws IOException {
99
+ ByteArrayOutputStream baos = new ByteArrayOutputStream();
100
+ byte[] buffer = new byte[1024];
101
+ int bytesRead;
102
+ while ((bytesRead = inputStream.read(buffer)) != -1) {
103
+ baos.write(buffer, 0, bytesRead);
104
+ }
105
+ ByteBuffer byteBuffer = ByteBuffer.wrap(baos.toByteArray());
106
+ byteBuffer.order(ByteOrder.LITTLE_ENDIAN);
107
+ byteBuffer.position(44);
108
+ ShortBuffer shortBuffer = byteBuffer.asShortBuffer();
109
+ short[] shortArray = new short[shortBuffer.limit()];
110
+ shortBuffer.get(shortArray);
111
+ float[] floatArray = new float[shortArray.length];
112
+ for (int i = 0; i < shortArray.length; i++) {
113
+ floatArray[i] = ((float) shortArray[i]) / 32767.0f;
114
+ floatArray[i] = Math.max(floatArray[i], -1f);
115
+ floatArray[i] = Math.min(floatArray[i], 1f);
116
+ }
117
+ return floatArray;
118
+ }
119
+ }
@@ -14,22 +14,15 @@ import android.media.AudioFormat;
14
14
  import android.media.AudioRecord;
15
15
  import android.media.MediaRecorder.AudioSource;
16
16
 
17
- import java.util.Random;
18
17
  import java.util.ArrayList;
19
18
  import java.lang.StringBuilder;
20
- import java.io.File;
21
19
  import java.io.BufferedReader;
22
20
  import java.io.IOException;
23
21
  import java.io.FileReader;
24
- import java.io.ByteArrayOutputStream;
25
22
  import java.io.File;
26
- import java.io.FileInputStream;
27
23
  import java.io.IOException;
28
24
  import java.io.InputStream;
29
25
  import java.io.PushbackInputStream;
30
- import java.nio.ByteBuffer;
31
- import java.nio.ByteOrder;
32
- import java.nio.ShortBuffer;
33
26
 
34
27
  public class WhisperContext {
35
28
  public static final String NAME = "RNWhisperContext";
@@ -86,6 +79,27 @@ public class WhisperContext {
86
79
  fullHandler = null;
87
80
  }
88
81
 
82
+ private boolean vad(ReadableMap options, short[] shortBuffer, int nSamples, int n) {
83
+ boolean isSpeech = true;
84
+ if (!isTranscribing && options.hasKey("useVad") && options.getBoolean("useVad")) {
85
+ int vadSec = options.hasKey("vadMs") ? options.getInt("vadMs") / 1000 : 2;
86
+ int sampleSize = vadSec * SAMPLE_RATE;
87
+ if (nSamples + n > sampleSize) {
88
+ int start = nSamples + n - sampleSize;
89
+ float[] audioData = new float[sampleSize];
90
+ for (int i = 0; i < sampleSize; i++) {
91
+ audioData[i] = shortBuffer[i + start] / 32768.0f;
92
+ }
93
+ float vadThold = options.hasKey("vadThold") ? (float) options.getDouble("vadThold") : 0.6f;
94
+ float vadFreqThold = options.hasKey("vadFreqThold") ? (float) options.getDouble("vadFreqThold") : 0.6f;
95
+ isSpeech = vadSimple(audioData, sampleSize, vadThold, vadFreqThold);
96
+ } else {
97
+ isSpeech = false;
98
+ }
99
+ }
100
+ return isSpeech;
101
+ }
102
+
89
103
  public int startRealtimeTranscribe(int jobId, ReadableMap options) {
90
104
  if (isCapturing || isTranscribing) {
91
105
  return -100;
@@ -111,6 +125,8 @@ public class WhisperContext {
111
125
 
112
126
  isUseSlices = audioSliceSec < audioSec;
113
127
 
128
+ String audioOutputPath = options.hasKey("audioOutputPath") ? options.getString("audioOutputPath") : null;
129
+
114
130
  shortBufferSlices = new ArrayList<short[]>();
115
131
  shortBufferSlices.add(new short[audioSliceSec * SAMPLE_RATE]);
116
132
  sliceNSamples = new ArrayList<Integer>();
@@ -145,6 +161,12 @@ public class WhisperContext {
145
161
  ) {
146
162
  emitTranscribeEvent("@RNWhisper_onRealtimeTranscribeEnd", Arguments.createMap());
147
163
  } else if (!isTranscribing) {
164
+ short[] shortBuffer = shortBufferSlices.get(sliceIndex);
165
+ boolean isSpeech = vad(options, shortBuffer, nSamples, 0);
166
+ if (!isSpeech) {
167
+ emitTranscribeEvent("@RNWhisper_onRealtimeTranscribeEnd", Arguments.createMap());
168
+ break;
169
+ }
148
170
  isTranscribing = true;
149
171
  fullTranscribeSamples(options, true);
150
172
  }
@@ -166,9 +188,14 @@ public class WhisperContext {
166
188
  for (int i = 0; i < n; i++) {
167
189
  shortBuffer[nSamples + i] = buffer[i];
168
190
  }
191
+
192
+ boolean isSpeech = vad(options, shortBuffer, nSamples, n);
193
+
169
194
  nSamples += n;
170
195
  sliceNSamples.set(sliceIndex, nSamples);
171
196
 
197
+ if (!isSpeech) continue;
198
+
172
199
  if (!isTranscribing && nSamples > SAMPLE_RATE / 2) {
173
200
  isTranscribing = true;
174
201
  fullHandler = new Thread(new Runnable() {
@@ -183,6 +210,9 @@ public class WhisperContext {
183
210
  Log.e(NAME, "Error transcribing realtime: " + e.getMessage());
184
211
  }
185
212
  }
213
+ // TODO: Append in real time so we don't need to keep all slices & also reduce memory usage
214
+ Log.d(NAME, "Begin saving wav file to " + audioOutputPath);
215
+ AudioUtils.saveWavFile(AudioUtils.concatShortBuffers(shortBufferSlices), audioOutputPath);
186
216
  if (!isTranscribing) {
187
217
  emitTranscribeEvent("@RNWhisper_onRealtimeTranscribeEnd", Arguments.createMap());
188
218
  }
@@ -233,7 +263,7 @@ public class WhisperContext {
233
263
  payload.putInt("sliceIndex", transcribeSliceIndex);
234
264
 
235
265
  if (code == 0) {
236
- payload.putMap("data", getTextSegments());
266
+ payload.putMap("data", getTextSegments(0, getTextSegmentCount(context)));
237
267
  } else {
238
268
  payload.putString("error", "Transcribe failed with code " + code);
239
269
  }
@@ -293,16 +323,41 @@ public class WhisperContext {
293
323
  eventEmitter.emit("@RNWhisper_onTranscribeProgress", event);
294
324
  }
295
325
 
296
- private static class ProgressCallback {
326
+ private void emitNewSegments(WritableMap result) {
327
+ WritableMap event = Arguments.createMap();
328
+ event.putInt("contextId", WhisperContext.this.id);
329
+ event.putInt("jobId", jobId);
330
+ event.putMap("result", result);
331
+ eventEmitter.emit("@RNWhisper_onTranscribeNewSegments", event);
332
+ }
333
+
334
+ private static class Callback {
297
335
  WhisperContext context;
336
+ boolean emitProgressNeeded = false;
337
+ boolean emitNewSegmentsNeeded = false;
338
+ int totalNNew = 0;
298
339
 
299
- public ProgressCallback(WhisperContext context) {
340
+ public Callback(WhisperContext context, boolean emitProgressNeeded, boolean emitNewSegmentsNeeded) {
300
341
  this.context = context;
342
+ this.emitProgressNeeded = emitProgressNeeded;
343
+ this.emitNewSegmentsNeeded = emitNewSegmentsNeeded;
301
344
  }
302
345
 
303
346
  void onProgress(int progress) {
347
+ if (!emitProgressNeeded) return;
304
348
  context.emitProgress(progress);
305
349
  }
350
+
351
+ void onNewSegments(int nNew) {
352
+ Log.d(NAME, "onNewSegments: " + nNew);
353
+ totalNNew += nNew;
354
+ if (!emitNewSegmentsNeeded) return;
355
+
356
+ WritableMap result = context.getTextSegments(totalNNew - nNew, totalNNew);
357
+ result.putInt("nNew", nNew);
358
+ result.putInt("totalNNew", totalNNew);
359
+ context.emitNewSegments(result);
360
+ }
306
361
  }
307
362
 
308
363
  public WritableMap transcribeInputStream(int jobId, InputStream inputStream, ReadableMap options) throws IOException, Exception {
@@ -313,19 +368,21 @@ public class WhisperContext {
313
368
 
314
369
  this.jobId = jobId;
315
370
  isTranscribing = true;
316
- float[] audioData = decodeWaveFile(inputStream);
371
+ float[] audioData = AudioUtils.decodeWaveFile(inputStream);
317
372
  int code = full(jobId, options, audioData, audioData.length);
318
373
  isTranscribing = false;
319
374
  this.jobId = -1;
320
375
  if (code != 0) {
321
376
  throw new Exception("Failed to transcribe the file. Code: " + code);
322
377
  }
323
- WritableMap result = getTextSegments();
378
+ WritableMap result = getTextSegments(0, getTextSegmentCount(context));
324
379
  result.putBoolean("isAborted", isStoppedByAction);
325
380
  return result;
326
381
  }
327
382
 
328
383
  private int full(int jobId, ReadableMap options, float[] audioData, int audioDataLen) {
384
+ boolean hasProgressCallback = options.hasKey("onProgress") && options.getBoolean("onProgress");
385
+ boolean hasNewSegmentsCallback = options.hasKey("onNewSegments") && options.getBoolean("onNewSegments");
329
386
  return fullTranscribe(
330
387
  jobId,
331
388
  context,
@@ -365,13 +422,12 @@ public class WhisperContext {
365
422
  options.hasKey("language") ? options.getString("language") : "auto",
366
423
  // jstring prompt
367
424
  options.hasKey("prompt") ? options.getString("prompt") : null,
368
- // ProgressCallback progressCallback
369
- options.hasKey("onProgress") && options.getBoolean("onProgress") ? new ProgressCallback(this) : null
425
+ // Callback callback
426
+ hasProgressCallback || hasNewSegmentsCallback ? new Callback(this, hasProgressCallback, hasNewSegmentsCallback) : null
370
427
  );
371
428
  }
372
429
 
373
- private WritableMap getTextSegments() {
374
- Integer count = getTextSegmentCount(context);
430
+ private WritableMap getTextSegments(int start, int count) {
375
431
  StringBuilder builder = new StringBuilder();
376
432
 
377
433
  WritableMap data = Arguments.createMap();
@@ -424,28 +480,6 @@ public class WhisperContext {
424
480
  freeContext(context);
425
481
  }
426
482
 
427
- public static float[] decodeWaveFile(InputStream inputStream) throws IOException {
428
- ByteArrayOutputStream baos = new ByteArrayOutputStream();
429
- byte[] buffer = new byte[1024];
430
- int bytesRead;
431
- while ((bytesRead = inputStream.read(buffer)) != -1) {
432
- baos.write(buffer, 0, bytesRead);
433
- }
434
- ByteBuffer byteBuffer = ByteBuffer.wrap(baos.toByteArray());
435
- byteBuffer.order(ByteOrder.LITTLE_ENDIAN);
436
- byteBuffer.position(44);
437
- ShortBuffer shortBuffer = byteBuffer.asShortBuffer();
438
- short[] shortArray = new short[shortBuffer.limit()];
439
- shortBuffer.get(shortArray);
440
- float[] floatArray = new float[shortArray.length];
441
- for (int i = 0; i < shortArray.length; i++) {
442
- floatArray[i] = ((float) shortArray[i]) / 32767.0f;
443
- floatArray[i] = Math.max(floatArray[i], -1f);
444
- floatArray[i] = Math.min(floatArray[i], 1f);
445
- }
446
- return floatArray;
447
- }
448
-
449
483
  static {
450
484
  Log.d(NAME, "Primary ABI: " + Build.SUPPORTED_ABIS[0]);
451
485
  boolean loadVfpv4 = false;
@@ -513,6 +547,7 @@ public class WhisperContext {
513
547
  protected static native long initContext(String modelPath);
514
548
  protected static native long initContextWithAsset(AssetManager assetManager, String modelPath);
515
549
  protected static native long initContextWithInputStream(PushbackInputStream inputStream);
550
+ protected static native boolean vadSimple(float[] audio_data, int audio_data_len, float vad_thold, float vad_freq_thold);
516
551
  protected static native int fullTranscribe(
517
552
  int job_id,
518
553
  long context,
@@ -533,7 +568,7 @@ public class WhisperContext {
533
568
  boolean translate,
534
569
  String language,
535
570
  String prompt,
536
- ProgressCallback progressCallback
571
+ Callback Callback
537
572
  );
538
573
  protected static native void abortTranscribe(int jobId);
539
574
  protected static native void abortAllTranscribe();
@@ -6,6 +6,7 @@
6
6
  #include <sys/sysinfo.h>
7
7
  #include <string>
8
8
  #include <thread>
9
+ #include <vector>
9
10
  #include "whisper.h"
10
11
  #include "rn-whisper.h"
11
12
  #include "ggml.h"
@@ -184,9 +185,30 @@ Java_com_rnwhisper_WhisperContext_initContextWithInputStream(
184
185
  return reinterpret_cast<jlong>(context);
185
186
  }
186
187
 
187
- struct progress_callback_context {
188
+ JNIEXPORT jboolean JNICALL
189
+ Java_com_rnwhisper_WhisperContext_vadSimple(
190
+ JNIEnv *env,
191
+ jobject thiz,
192
+ jfloatArray audio_data,
193
+ jint audio_data_len,
194
+ jfloat vad_thold,
195
+ jfloat vad_freq_thold
196
+ ) {
197
+ UNUSED(thiz);
198
+
199
+ std::vector<float> samples(audio_data_len);
200
+ jfloat *audio_data_arr = env->GetFloatArrayElements(audio_data, nullptr);
201
+ for (int i = 0; i < audio_data_len; i++) {
202
+ samples[i] = audio_data_arr[i];
203
+ }
204
+ bool is_speech = rn_whisper_vad_simple(samples, WHISPER_SAMPLE_RATE, 1000, vad_thold, vad_freq_thold, false);
205
+ env->ReleaseFloatArrayElements(audio_data, audio_data_arr, JNI_ABORT);
206
+ return is_speech;
207
+ }
208
+
209
+ struct callback_context {
188
210
  JNIEnv *env;
189
- jobject progress_callback_instance;
211
+ jobject callback_instance;
190
212
  };
191
213
 
192
214
  JNIEXPORT jint JNICALL
@@ -212,7 +234,7 @@ Java_com_rnwhisper_WhisperContext_fullTranscribe(
212
234
  jboolean translate,
213
235
  jstring language,
214
236
  jstring prompt,
215
- jobject progress_callback_instance
237
+ jobject callback_instance
216
238
  ) {
217
239
  UNUSED(thiz);
218
240
  struct whisper_context *context = reinterpret_cast<struct whisper_context *>(context_ptr);
@@ -280,19 +302,30 @@ Java_com_rnwhisper_WhisperContext_fullTranscribe(
280
302
  };
281
303
  params.encoder_begin_callback_user_data = rn_whisper_assign_abort_map(job_id);
282
304
 
283
- if (progress_callback_instance != nullptr) {
305
+ if (callback_instance != nullptr) {
306
+ callback_context *cb_ctx = new callback_context;
307
+ cb_ctx->env = env;
308
+ cb_ctx->callback_instance = env->NewGlobalRef(callback_instance);
309
+
284
310
  params.progress_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, int progress, void * user_data) {
285
- progress_callback_context *cb_ctx = (progress_callback_context *)user_data;
311
+ callback_context *cb_ctx = (callback_context *)user_data;
286
312
  JNIEnv *env = cb_ctx->env;
287
- jobject progress_callback_instance = cb_ctx->progress_callback_instance;
288
- jclass progress_callback_class = env->GetObjectClass(progress_callback_instance);
289
- jmethodID onProgress = env->GetMethodID(progress_callback_class, "onProgress", "(I)V");
290
- env->CallVoidMethod(progress_callback_instance, onProgress, progress);
313
+ jobject callback_instance = cb_ctx->callback_instance;
314
+ jclass callback_class = env->GetObjectClass(callback_instance);
315
+ jmethodID onProgress = env->GetMethodID(callback_class, "onProgress", "(I)V");
316
+ env->CallVoidMethod(callback_instance, onProgress, progress);
291
317
  };
292
- progress_callback_context *cb_ctx = new progress_callback_context;
293
- cb_ctx->env = env;
294
- cb_ctx->progress_callback_instance = env->NewGlobalRef(progress_callback_instance);
295
318
  params.progress_callback_user_data = cb_ctx;
319
+
320
+ params.new_segment_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, int n_new, void * user_data) {
321
+ callback_context *cb_ctx = (callback_context *)user_data;
322
+ JNIEnv *env = cb_ctx->env;
323
+ jobject callback_instance = cb_ctx->callback_instance;
324
+ jclass callback_class = env->GetObjectClass(callback_instance);
325
+ jmethodID onNewSegments = env->GetMethodID(callback_class, "onNewSegments", "(I)V");
326
+ env->CallVoidMethod(callback_instance, onNewSegments, n_new);
327
+ };
328
+ params.new_segment_callback_user_data = cb_ctx;
296
329
  }
297
330
 
298
331
  LOGI("About to reset timings");
@@ -6,6 +6,7 @@ import com.facebook.react.bridge.Promise;
6
6
  import com.facebook.react.bridge.ReactApplicationContext;
7
7
  import com.facebook.react.bridge.ReactMethod;
8
8
  import com.facebook.react.bridge.ReadableMap;
9
+ import com.facebook.react.bridge.ReadableArray;
9
10
  import com.facebook.react.module.annotations.ReactModule;
10
11
 
11
12
  import java.util.HashMap;
@@ -65,4 +66,29 @@ public class RNWhisperModule extends NativeRNWhisperSpec {
65
66
  public void releaseAllContexts(Promise promise) {
66
67
  rnwhisper.releaseAllContexts(promise);
67
68
  }
69
+
70
+ /*
71
+ * iOS Specific methods, left here for make the turbo module happy:
72
+ */
73
+
74
+ @ReactMethod
75
+ public void getAudioSessionCurrentCategory(Promise promise) {
76
+ promise.resolve(null);
77
+ }
78
+ @ReactMethod
79
+ public void getAudioSessionCurrentMode(Promise promise) {
80
+ promise.resolve(null);
81
+ }
82
+ @ReactMethod
83
+ public void setAudioSessionCategory(String category, ReadableArray options, Promise promise) {
84
+ promise.resolve(null);
85
+ }
86
+ @ReactMethod
87
+ public void setAudioSessionMode(String mode, Promise promise) {
88
+ promise.resolve(null);
89
+ }
90
+ @ReactMethod
91
+ public void setAudioSessionActive(boolean active, Promise promise) {
92
+ promise.resolve(null);
93
+ }
68
94
  }
@@ -38,4 +38,55 @@ void rn_whisper_abort_all_transcribe() {
38
38
  }
39
39
  }
40
40
 
41
+ void high_pass_filter(std::vector<float> & data, float cutoff, float sample_rate) {
42
+ const float rc = 1.0f / (2.0f * M_PI * cutoff);
43
+ const float dt = 1.0f / sample_rate;
44
+ const float alpha = dt / (rc + dt);
45
+
46
+ float y = data[0];
47
+
48
+ for (size_t i = 1; i < data.size(); i++) {
49
+ y = alpha * (y + data[i] - data[i - 1]);
50
+ data[i] = y;
51
+ }
52
+ }
53
+
54
+ bool rn_whisper_vad_simple(std::vector<float> & pcmf32, int sample_rate, int last_ms, float vad_thold, float freq_thold, bool verbose) {
55
+ const int n_samples = pcmf32.size();
56
+ const int n_samples_last = (sample_rate * last_ms) / 1000;
57
+
58
+ if (n_samples_last >= n_samples) {
59
+ // not enough samples - assume no speech
60
+ return false;
61
+ }
62
+
63
+ if (freq_thold > 0.0f) {
64
+ high_pass_filter(pcmf32, freq_thold, sample_rate);
65
+ }
66
+
67
+ float energy_all = 0.0f;
68
+ float energy_last = 0.0f;
69
+
70
+ for (int i = 0; i < n_samples; i++) {
71
+ energy_all += fabsf(pcmf32[i]);
72
+
73
+ if (i >= n_samples - n_samples_last) {
74
+ energy_last += fabsf(pcmf32[i]);
75
+ }
76
+ }
77
+
78
+ energy_all /= n_samples;
79
+ energy_last /= n_samples_last;
80
+
81
+ if (verbose) {
82
+ fprintf(stderr, "%s: energy_all: %f, energy_last: %f, vad_thold: %f, freq_thold: %f\n", __func__, energy_all, energy_last, vad_thold, freq_thold);
83
+ }
84
+
85
+ if (energy_last > vad_thold*energy_all) {
86
+ return false;
87
+ }
88
+
89
+ return true;
90
+ }
91
+
41
92
  }
package/cpp/rn-whisper.h CHANGED
@@ -10,7 +10,8 @@ void rn_whisper_remove_abort_map(int job_id);
10
10
  void rn_whisper_abort_transcribe(int job_id);
11
11
  bool rn_whisper_transcribe_is_aborted(int job_id);
12
12
  void rn_whisper_abort_all_transcribe();
13
+ bool rn_whisper_vad_simple(std::vector<float> & pcmf32, int sample_rate, int last_ms, float vad_thold, float freq_thold, bool verbose);
13
14
 
14
15
  #ifdef __cplusplus
15
16
  }
16
- #endif
17
+ #endif