whisper.rn 0.4.0-rc.3 → 0.4.0-rc.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. package/README.md +6 -6
  2. package/android/build.gradle +4 -0
  3. package/android/src/main/CMakeLists.txt +7 -0
  4. package/android/src/main/java/com/rnwhisper/AudioUtils.java +0 -80
  5. package/android/src/main/java/com/rnwhisper/RNWhisper.java +6 -1
  6. package/android/src/main/java/com/rnwhisper/WhisperContext.java +53 -135
  7. package/android/src/main/jni-utils.h +76 -0
  8. package/android/src/main/jni.cpp +188 -109
  9. package/cpp/README.md +1 -1
  10. package/cpp/coreml/whisper-encoder-impl.h +1 -1
  11. package/cpp/coreml/whisper-encoder.h +4 -0
  12. package/cpp/coreml/whisper-encoder.mm +4 -2
  13. package/cpp/ggml-alloc.c +451 -282
  14. package/cpp/ggml-alloc.h +74 -8
  15. package/cpp/ggml-backend-impl.h +112 -0
  16. package/cpp/ggml-backend.c +1357 -0
  17. package/cpp/ggml-backend.h +181 -0
  18. package/cpp/ggml-impl.h +243 -0
  19. package/cpp/{ggml-metal.metal → ggml-metal-whisper.metal} +1556 -329
  20. package/cpp/ggml-metal.h +28 -1
  21. package/cpp/ggml-metal.m +1128 -308
  22. package/cpp/ggml-quants.c +7382 -0
  23. package/cpp/ggml-quants.h +224 -0
  24. package/cpp/ggml.c +3848 -5245
  25. package/cpp/ggml.h +353 -155
  26. package/cpp/rn-audioutils.cpp +68 -0
  27. package/cpp/rn-audioutils.h +14 -0
  28. package/cpp/rn-whisper-log.h +11 -0
  29. package/cpp/rn-whisper.cpp +141 -59
  30. package/cpp/rn-whisper.h +47 -15
  31. package/cpp/whisper.cpp +1750 -964
  32. package/cpp/whisper.h +97 -15
  33. package/ios/RNWhisper.mm +15 -9
  34. package/ios/RNWhisper.xcodeproj/project.xcworkspace/contents.xcworkspacedata +4 -0
  35. package/ios/RNWhisper.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist +8 -0
  36. package/ios/RNWhisper.xcodeproj/project.xcworkspace/xcuserdata/jhen.xcuserdatad/UserInterfaceState.xcuserstate +0 -0
  37. package/ios/RNWhisper.xcodeproj/xcuserdata/jhen.xcuserdatad/xcschemes/xcschememanagement.plist +19 -0
  38. package/ios/RNWhisperAudioUtils.h +0 -2
  39. package/ios/RNWhisperAudioUtils.m +0 -56
  40. package/ios/RNWhisperContext.h +8 -12
  41. package/ios/RNWhisperContext.mm +132 -138
  42. package/jest/mock.js +1 -1
  43. package/lib/commonjs/NativeRNWhisper.js.map +1 -1
  44. package/lib/commonjs/index.js +28 -9
  45. package/lib/commonjs/index.js.map +1 -1
  46. package/lib/commonjs/version.json +1 -1
  47. package/lib/module/NativeRNWhisper.js.map +1 -1
  48. package/lib/module/index.js +28 -9
  49. package/lib/module/index.js.map +1 -1
  50. package/lib/module/version.json +1 -1
  51. package/lib/typescript/NativeRNWhisper.d.ts +7 -1
  52. package/lib/typescript/NativeRNWhisper.d.ts.map +1 -1
  53. package/lib/typescript/index.d.ts +7 -2
  54. package/lib/typescript/index.d.ts.map +1 -1
  55. package/package.json +6 -5
  56. package/src/NativeRNWhisper.ts +8 -1
  57. package/src/index.ts +29 -17
  58. package/src/version.json +1 -1
  59. package/whisper-rn.podspec +1 -2
package/README.md CHANGED
@@ -25,19 +25,19 @@ npm install whisper.rn
25
25
 
26
26
  Please re-run `npx pod-install` again.
27
27
 
28
- #### Android
29
-
30
28
  If you want to use `medium` or `large` model, the [Extended Virtual Addressing](https://developer.apple.com/documentation/bundleresources/entitlements/com_apple_developer_kernel_extended-virtual-addressing) capability is recommended to enable on iOS project.
31
29
 
32
- For Android, it's recommended to use `ndkVersion = "24.0.8215888"` (or above) in your root project build configuration for Apple Silicon Macs. Otherwise please follow this trobleshooting [issue](./TROUBLESHOOTING.md#android-got-build-error-unknown-host-cpu-architecture-arm64-on-apple-silicon-macs).
30
+ #### Android
33
31
 
34
- Don't forget to add proguard rule if it's enabled in project (android/app/proguard-rules.pro):
32
+ Add proguard rule if it's enabled in project (android/app/proguard-rules.pro):
35
33
 
36
34
  ```proguard
37
35
  # whisper.rn
38
36
  -keep class com.rnwhisper.** { *; }
39
37
  ```
40
38
 
39
+ For build, it's recommended to use `ndkVersion = "24.0.8215888"` (or above) in your root project build configuration for Apple Silicon Macs. Otherwise please follow this trobleshooting [issue](./TROUBLESHOOTING.md#android-got-build-error-unknown-host-cpu-architecture-arm64-on-apple-silicon-macs).
40
+
41
41
  #### Expo
42
42
 
43
43
  You will need to prebuild the project before using it. See [Expo guide](https://docs.expo.io/guides/using-libraries/#using-a-library-in-a-expo-project) for more details.
@@ -91,7 +91,7 @@ subscribe(evt => {
91
91
  console.log(
92
92
  `Realtime transcribing: ${isCapturing ? 'ON' : 'OFF'}\n` +
93
93
  // The inference text result from audio record:
94
- `Result: ${data.result}\n\n` +
94
+ `Result: ${data.result}\n\n` +
95
95
  `Process time: ${processTime}ms\n` +
96
96
  `Recording time: ${recordingTime}ms`,
97
97
  )
@@ -220,7 +220,7 @@ In real world, we recommended to split the asset imports into another platform s
220
220
 
221
221
  The example app provide a simple UI for testing the functions.
222
222
 
223
- Used Whisper model: `tiny.en` in https://huggingface.co/ggerganov/whisper.cpp
223
+ Used Whisper model: `tiny.en` in https://huggingface.co/ggerganov/whisper.cpp
224
224
  Sample file: `jfk.wav` in https://github.com/ggerganov/whisper.cpp/tree/master/samples
225
225
 
226
226
  Please follow the [Development Workflow section of contributing guide](./CONTRIBUTING.md#development-workflow) to run the example app.
@@ -36,6 +36,10 @@ def reactNativeArchitectures() {
36
36
  }
37
37
 
38
38
  android {
39
+ def agpVersion = com.android.Version.ANDROID_GRADLE_PLUGIN_VERSION
40
+ if (agpVersion.tokenize('.')[0].toInteger() >= 7) {
41
+ namespace "com.rnwhisper"
42
+ }
39
43
  ndkVersion getExtOrDefault("ndkVersion")
40
44
  compileSdkVersion getExtOrIntegerDefault("compileSdkVersion")
41
45
 
@@ -9,7 +9,10 @@ set(
9
9
  SOURCE_FILES
10
10
  ${RNWHISPER_LIB_DIR}/ggml.c
11
11
  ${RNWHISPER_LIB_DIR}/ggml-alloc.c
12
+ ${RNWHISPER_LIB_DIR}/ggml-backend.c
13
+ ${RNWHISPER_LIB_DIR}/ggml-quants.c
12
14
  ${RNWHISPER_LIB_DIR}/whisper.cpp
15
+ ${RNWHISPER_LIB_DIR}/rn-audioutils.cpp
13
16
  ${RNWHISPER_LIB_DIR}/rn-whisper.cpp
14
17
  ${CMAKE_SOURCE_DIR}/jni.cpp
15
18
  )
@@ -31,6 +34,10 @@ function(build_library target_name)
31
34
  target_compile_options(${target_name} PRIVATE -mfpu=neon-vfpv4)
32
35
  endif ()
33
36
 
37
+ if (${CMAKE_BUILD_TYPE} STREQUAL "Debug")
38
+ target_compile_options(${target_name} PRIVATE -DRNWHISPER_ANDROID_ENABLE_LOGGING)
39
+ endif ()
40
+
34
41
  # NOTE: If you want to debug the native code, you can uncomment if and endif
35
42
  # if (NOT ${CMAKE_BUILD_TYPE} STREQUAL "Debug")
36
43
 
@@ -2,14 +2,10 @@ package com.rnwhisper;
2
2
 
3
3
  import android.util.Log;
4
4
 
5
- import java.util.ArrayList;
6
- import java.lang.StringBuilder;
7
5
  import java.io.IOException;
8
6
  import java.io.FileReader;
9
7
  import java.io.ByteArrayOutputStream;
10
8
  import java.io.File;
11
- import java.io.FileOutputStream;
12
- import java.io.DataOutputStream;
13
9
  import java.io.IOException;
14
10
  import java.io.InputStream;
15
11
  import java.nio.ByteBuffer;
@@ -19,82 +15,6 @@ import java.nio.ShortBuffer;
19
15
  public class AudioUtils {
20
16
  private static final String NAME = "RNWhisperAudioUtils";
21
17
 
22
- private static final int SAMPLE_RATE = 16000;
23
-
24
- private static byte[] shortToByte(short[] shortInts) {
25
- int j = 0;
26
- int length = shortInts.length;
27
- byte[] byteData = new byte[length * 2];
28
- for (int i = 0; i < length; i++) {
29
- byteData[j++] = (byte) (shortInts[i] >>> 8);
30
- byteData[j++] = (byte) (shortInts[i] >>> 0);
31
- }
32
- return byteData;
33
- }
34
-
35
- public static byte[] concatShortBuffers(ArrayList<short[]> buffers) {
36
- int totalLength = 0;
37
- for (int i = 0; i < buffers.size(); i++) {
38
- totalLength += buffers.get(i).length;
39
- }
40
- byte[] result = new byte[totalLength * 2];
41
- int offset = 0;
42
- for (int i = 0; i < buffers.size(); i++) {
43
- byte[] bytes = shortToByte(buffers.get(i));
44
- System.arraycopy(bytes, 0, result, offset, bytes.length);
45
- offset += bytes.length;
46
- }
47
-
48
- return result;
49
- }
50
-
51
- private static byte[] removeTrailingZeros(byte[] audioData) {
52
- int i = audioData.length - 1;
53
- while (i >= 0 && audioData[i] == 0) {
54
- --i;
55
- }
56
- byte[] newData = new byte[i + 1];
57
- System.arraycopy(audioData, 0, newData, 0, i + 1);
58
- return newData;
59
- }
60
-
61
- public static void saveWavFile(byte[] rawData, String audioOutputFile) throws IOException {
62
- Log.d(NAME, "call saveWavFile");
63
- rawData = removeTrailingZeros(rawData);
64
- DataOutputStream output = null;
65
- try {
66
- output = new DataOutputStream(new FileOutputStream(audioOutputFile));
67
- // WAVE header
68
- // see http://ccrma.stanford.edu/courses/422/projects/WaveFormat/
69
- output.writeBytes("RIFF"); // chunk id
70
- output.writeInt(Integer.reverseBytes(36 + rawData.length)); // chunk size
71
- output.writeBytes("WAVE"); // format
72
- output.writeBytes("fmt "); // subchunk 1 id
73
- output.writeInt(Integer.reverseBytes(16)); // subchunk 1 size
74
- output.writeShort(Short.reverseBytes((short) 1)); // audio format (1 = PCM)
75
- output.writeShort(Short.reverseBytes((short) 1)); // number of channels
76
- output.writeInt(Integer.reverseBytes(SAMPLE_RATE)); // sample rate
77
- output.writeInt(Integer.reverseBytes(SAMPLE_RATE * 2)); // byte rate
78
- output.writeShort(Short.reverseBytes((short) 2)); // block align
79
- output.writeShort(Short.reverseBytes((short) 16)); // bits per sample
80
- output.writeBytes("data"); // subchunk 2 id
81
- output.writeInt(Integer.reverseBytes(rawData.length)); // subchunk 2 size
82
- // Audio data (conversion big endian -> little endian)
83
- short[] shorts = new short[rawData.length / 2];
84
- ByteBuffer.wrap(rawData).order(ByteOrder.LITTLE_ENDIAN).asShortBuffer().get(shorts);
85
- ByteBuffer bytes = ByteBuffer.allocate(shorts.length * 2);
86
- for (short s : shorts) {
87
- bytes.putShort(s);
88
- }
89
- Log.d(NAME, "writing audio file: " + audioOutputFile);
90
- output.write(bytes.array());
91
- } finally {
92
- if (output != null) {
93
- output.close();
94
- }
95
- }
96
- }
97
-
98
18
  public static float[] decodeWaveFile(InputStream inputStream) throws IOException {
99
19
  ByteArrayOutputStream baos = new ByteArrayOutputStream();
100
20
  byte[] buffer = new byte[1024];
@@ -13,6 +13,7 @@ import com.facebook.react.bridge.ReactMethod;
13
13
  import com.facebook.react.bridge.LifecycleEventListener;
14
14
  import com.facebook.react.bridge.ReadableMap;
15
15
  import com.facebook.react.bridge.WritableMap;
16
+ import com.facebook.react.bridge.Arguments;
16
17
 
17
18
  import java.util.HashMap;
18
19
  import java.util.Random;
@@ -107,7 +108,11 @@ public class RNWhisper implements LifecycleEventListener {
107
108
  promise.reject(exception);
108
109
  return;
109
110
  }
110
- promise.resolve(id);
111
+ WritableMap result = Arguments.createMap();
112
+ result.putInt("contextId", id);
113
+ result.putBoolean("gpu", false);
114
+ result.putString("reasonNoGPU", "Currently not supported");
115
+ promise.resolve(result);
111
116
  tasks.remove(this);
112
117
  }
113
118
  }.execute();
@@ -42,7 +42,6 @@ public class WhisperContext {
42
42
  private AudioRecord recorder = null;
43
43
  private int bufferSize;
44
44
  private int nSamplesTranscribing = 0;
45
- private ArrayList<short[]> shortBufferSlices;
46
45
  // Remember number of samples in each slice
47
46
  private ArrayList<Integer> sliceNSamples;
48
47
  // Current buffer slice index
@@ -66,7 +65,6 @@ public class WhisperContext {
66
65
  }
67
66
 
68
67
  private void rewind() {
69
- shortBufferSlices = null;
70
68
  sliceNSamples = null;
71
69
  sliceIndex = 0;
72
70
  transcribeSliceIndex = 0;
@@ -79,41 +77,14 @@ public class WhisperContext {
79
77
  fullHandler = null;
80
78
  }
81
79
 
82
- private boolean vad(ReadableMap options, short[] shortBuffer, int nSamples, int n) {
83
- boolean isSpeech = true;
84
- if (!isTranscribing && options.hasKey("useVad") && options.getBoolean("useVad")) {
85
- int vadMs = options.hasKey("vadMs") ? options.getInt("vadMs") : 2000;
86
- if (vadMs < 2000) vadMs = 2000;
87
- int sampleSize = (int) (SAMPLE_RATE * vadMs / 1000);
88
- if (nSamples + n > sampleSize) {
89
- int start = nSamples + n - sampleSize;
90
- float[] audioData = new float[sampleSize];
91
- for (int i = 0; i < sampleSize; i++) {
92
- audioData[i] = shortBuffer[i + start] / 32768.0f;
93
- }
94
- float vadThold = options.hasKey("vadThold") ? (float) options.getDouble("vadThold") : 0.6f;
95
- float vadFreqThold = options.hasKey("vadFreqThold") ? (float) options.getDouble("vadFreqThold") : 0.6f;
96
- isSpeech = vadSimple(audioData, sampleSize, vadThold, vadFreqThold);
97
- } else {
98
- isSpeech = false;
99
- }
100
- }
101
- return isSpeech;
80
+ private boolean vad(int sliceIndex, int nSamples, int n) {
81
+ if (isTranscribing) return true;
82
+ return vadSimple(jobId, sliceIndex, nSamples, n);
102
83
  }
103
84
 
104
- private void finishRealtimeTranscribe(ReadableMap options, WritableMap result) {
105
- String audioOutputPath = options.hasKey("audioOutputPath") ? options.getString("audioOutputPath") : null;
106
- if (audioOutputPath != null) {
107
- // TODO: Append in real time so we don't need to keep all slices & also reduce memory usage
108
- Log.d(NAME, "Begin saving wav file to " + audioOutputPath);
109
- try {
110
- AudioUtils.saveWavFile(AudioUtils.concatShortBuffers(shortBufferSlices), audioOutputPath);
111
- } catch (IOException e) {
112
- Log.e(NAME, "Error saving wav file: " + e.getMessage());
113
- }
114
- }
115
-
85
+ private void finishRealtimeTranscribe(WritableMap result) {
116
86
  emitTranscribeEvent("@RNWhisper_onRealtimeTranscribeEnd", Arguments.createMap());
87
+ finishRealtimeTranscribeJob(jobId, context, sliceNSamples.stream().mapToInt(i -> i).toArray());
117
88
  }
118
89
 
119
90
  public int startRealtimeTranscribe(int jobId, ReadableMap options) {
@@ -135,16 +106,12 @@ public class WhisperContext {
135
106
 
136
107
  int realtimeAudioSec = options.hasKey("realtimeAudioSec") ? options.getInt("realtimeAudioSec") : 0;
137
108
  final int audioSec = realtimeAudioSec > 0 ? realtimeAudioSec : DEFAULT_MAX_AUDIO_SEC;
138
-
139
109
  int realtimeAudioSliceSec = options.hasKey("realtimeAudioSliceSec") ? options.getInt("realtimeAudioSliceSec") : 0;
140
110
  final int audioSliceSec = realtimeAudioSliceSec > 0 && realtimeAudioSliceSec < audioSec ? realtimeAudioSliceSec : audioSec;
141
-
142
111
  isUseSlices = audioSliceSec < audioSec;
143
112
 
144
- String audioOutputPath = options.hasKey("audioOutputPath") ? options.getString("audioOutputPath") : null;
113
+ createRealtimeTranscribeJob(jobId, context, options);
145
114
 
146
- shortBufferSlices = new ArrayList<short[]>();
147
- shortBufferSlices.add(new short[audioSliceSec * SAMPLE_RATE]);
148
115
  sliceNSamples = new ArrayList<Integer>();
149
116
  sliceNSamples.add(0);
150
117
 
@@ -175,37 +142,29 @@ public class WhisperContext {
175
142
  nSamples == nSamplesTranscribing &&
176
143
  sliceIndex == transcribeSliceIndex
177
144
  ) {
178
- finishRealtimeTranscribe(options, Arguments.createMap());
145
+ finishRealtimeTranscribe(Arguments.createMap());
179
146
  } else if (!isTranscribing) {
180
- short[] shortBuffer = shortBufferSlices.get(sliceIndex);
181
- boolean isSpeech = vad(options, shortBuffer, nSamples, 0);
182
- if (!isSpeech) {
183
- finishRealtimeTranscribe(options, Arguments.createMap());
147
+ if (!vad(sliceIndex, nSamples, 0)) {
148
+ finishRealtimeTranscribe(Arguments.createMap());
184
149
  break;
185
150
  }
186
151
  isTranscribing = true;
187
- fullTranscribeSamples(options, true);
152
+ fullTranscribeSamples(true);
188
153
  }
189
154
  break;
190
155
  }
191
156
 
192
157
  // Append to buffer
193
- short[] shortBuffer = shortBufferSlices.get(sliceIndex);
194
158
  if (nSamples + n > audioSliceSec * SAMPLE_RATE) {
195
159
  Log.d(NAME, "next slice");
196
160
 
197
161
  sliceIndex++;
198
162
  nSamples = 0;
199
- shortBuffer = new short[audioSliceSec * SAMPLE_RATE];
200
- shortBufferSlices.add(shortBuffer);
201
163
  sliceNSamples.add(0);
202
164
  }
165
+ putPcmData(jobId, buffer, sliceIndex, nSamples, n);
203
166
 
204
- for (int i = 0; i < n; i++) {
205
- shortBuffer[nSamples + i] = buffer[i];
206
- }
207
-
208
- boolean isSpeech = vad(options, shortBuffer, nSamples, n);
167
+ boolean isSpeech = vad(sliceIndex, nSamples, n);
209
168
 
210
169
  nSamples += n;
211
170
  sliceNSamples.set(sliceIndex, nSamples);
@@ -217,7 +176,7 @@ public class WhisperContext {
217
176
  fullHandler = new Thread(new Runnable() {
218
177
  @Override
219
178
  public void run() {
220
- fullTranscribeSamples(options, false);
179
+ fullTranscribeSamples(false);
221
180
  }
222
181
  });
223
182
  fullHandler.start();
@@ -228,7 +187,7 @@ public class WhisperContext {
228
187
  }
229
188
 
230
189
  if (!isTranscribing) {
231
- finishRealtimeTranscribe(options, Arguments.createMap());
190
+ finishRealtimeTranscribe(Arguments.createMap());
232
191
  }
233
192
  if (fullHandler != null) {
234
193
  fullHandler.join(); // Wait for full transcribe to finish
@@ -246,26 +205,16 @@ public class WhisperContext {
246
205
  return state;
247
206
  }
248
207
 
249
- private void fullTranscribeSamples(ReadableMap options, boolean skipCapturingCheck) {
208
+ private void fullTranscribeSamples(boolean skipCapturingCheck) {
250
209
  int nSamplesOfIndex = sliceNSamples.get(transcribeSliceIndex);
251
210
 
252
211
  if (!isCapturing && !skipCapturingCheck) return;
253
212
 
254
- short[] shortBuffer = shortBufferSlices.get(transcribeSliceIndex);
255
- int nSamples = sliceNSamples.get(transcribeSliceIndex);
256
-
257
213
  nSamplesTranscribing = nSamplesOfIndex;
258
-
259
- // convert I16 to F32
260
- float[] nSamplesBuffer32 = new float[nSamplesTranscribing];
261
- for (int i = 0; i < nSamplesTranscribing; i++) {
262
- nSamplesBuffer32[i] = shortBuffer[i] / 32768.0f;
263
- }
264
-
265
214
  Log.d(NAME, "Start transcribing realtime: " + nSamplesTranscribing);
266
215
 
267
216
  int timeStart = (int) System.currentTimeMillis();
268
- int code = full(jobId, options, nSamplesBuffer32, nSamplesTranscribing);
217
+ int code = fullWithJob(jobId, context, transcribeSliceIndex, nSamplesTranscribing);
269
218
  int timeEnd = (int) System.currentTimeMillis();
270
219
  int timeRecording = (int) (nSamplesTranscribing / SAMPLE_RATE * 1000);
271
220
 
@@ -278,7 +227,7 @@ public class WhisperContext {
278
227
 
279
228
  if (code == 0) {
280
229
  payload.putMap("data", getTextSegments(0, getTextSegmentCount(context)));
281
- } else {
230
+ } else if (code != -999) { // Not aborted
282
231
  payload.putString("error", "Transcribe failed with code " + code);
283
232
  }
284
233
 
@@ -297,12 +246,12 @@ public class WhisperContext {
297
246
  nSamplesTranscribing = 0;
298
247
  }
299
248
 
300
- boolean continueNeeded = !isCapturing && nSamplesTranscribing != nSamplesOfIndex;
249
+ boolean continueNeeded = !isCapturing && nSamplesTranscribing != nSamplesOfIndex && code != -999;
301
250
 
302
251
  if (isStopped && !continueNeeded) {
303
252
  payload.putBoolean("isCapturing", false);
304
253
  payload.putBoolean("isStoppedByAction", isStoppedByAction);
305
- finishRealtimeTranscribe(options, payload);
254
+ finishRealtimeTranscribe(payload);
306
255
  } else if (code == 0) {
307
256
  payload.putBoolean("isCapturing", true);
308
257
  emitTranscribeEvent("@RNWhisper_onRealtimeTranscribe", payload);
@@ -313,7 +262,7 @@ public class WhisperContext {
313
262
 
314
263
  if (continueNeeded) {
315
264
  // If no more capturing, continue transcribing until all slices are transcribed
316
- fullTranscribeSamples(options, true);
265
+ fullTranscribeSamples(true);
317
266
  } else if (isStopped) {
318
267
  // No next, cleanup
319
268
  rewind();
@@ -383,62 +332,30 @@ public class WhisperContext {
383
332
  this.jobId = jobId;
384
333
  isTranscribing = true;
385
334
  float[] audioData = AudioUtils.decodeWaveFile(inputStream);
386
- int code = full(jobId, options, audioData, audioData.length);
387
- isTranscribing = false;
388
- this.jobId = -1;
389
- if (code != 0) {
390
- throw new Exception("Failed to transcribe the file. Code: " + code);
391
- }
392
- WritableMap result = getTextSegments(0, getTextSegmentCount(context));
393
- result.putBoolean("isAborted", isStoppedByAction);
394
- return result;
395
- }
396
335
 
397
- private int full(int jobId, ReadableMap options, float[] audioData, int audioDataLen) {
398
336
  boolean hasProgressCallback = options.hasKey("onProgress") && options.getBoolean("onProgress");
399
337
  boolean hasNewSegmentsCallback = options.hasKey("onNewSegments") && options.getBoolean("onNewSegments");
400
- return fullTranscribe(
338
+ int code = fullWithNewJob(
401
339
  jobId,
402
340
  context,
403
341
  // float[] audio_data,
404
342
  audioData,
405
343
  // jint audio_data_len,
406
- audioDataLen,
407
- // jint n_threads,
408
- options.hasKey("maxThreads") ? options.getInt("maxThreads") : -1,
409
- // jint max_context,
410
- options.hasKey("maxContext") ? options.getInt("maxContext") : -1,
411
-
412
- // jint word_thold,
413
- options.hasKey("wordThold") ? options.getInt("wordThold") : -1,
414
- // jint max_len,
415
- options.hasKey("maxLen") ? options.getInt("maxLen") : -1,
416
- // jboolean token_timestamps,
417
- options.hasKey("tokenTimestamps") ? options.getBoolean("tokenTimestamps") : false,
418
-
419
- // jint offset,
420
- options.hasKey("offset") ? options.getInt("offset") : -1,
421
- // jint duration,
422
- options.hasKey("duration") ? options.getInt("duration") : -1,
423
- // jfloat temperature,
424
- options.hasKey("temperature") ? (float) options.getDouble("temperature") : -1.0f,
425
- // jfloat temperature_inc,
426
- options.hasKey("temperatureInc") ? (float) options.getDouble("temperatureInc") : -1.0f,
427
- // jint beam_size,
428
- options.hasKey("beamSize") ? options.getInt("beamSize") : -1,
429
- // jint best_of,
430
- options.hasKey("bestOf") ? options.getInt("bestOf") : -1,
431
- // jboolean speed_up,
432
- options.hasKey("speedUp") ? options.getBoolean("speedUp") : false,
433
- // jboolean translate,
434
- options.hasKey("translate") ? options.getBoolean("translate") : false,
435
- // jstring language,
436
- options.hasKey("language") ? options.getString("language") : "auto",
437
- // jstring prompt
438
- options.hasKey("prompt") ? options.getString("prompt") : null,
344
+ audioData.length,
345
+ // ReadableMap options,
346
+ options,
439
347
  // Callback callback
440
348
  hasProgressCallback || hasNewSegmentsCallback ? new Callback(this, hasProgressCallback, hasNewSegmentsCallback) : null
441
349
  );
350
+
351
+ isTranscribing = false;
352
+ this.jobId = -1;
353
+ if (code != 0 && code != 999) {
354
+ throw new Exception("Failed to transcribe the file. Code: " + code);
355
+ }
356
+ WritableMap result = getTextSegments(0, getTextSegmentCount(context));
357
+ result.putBoolean("isAborted", isStoppedByAction);
358
+ return result;
442
359
  }
443
360
 
444
361
  private WritableMap getTextSegments(int start, int count) {
@@ -557,31 +474,18 @@ public class WhisperContext {
557
474
  }
558
475
  }
559
476
 
560
-
477
+ // JNI methods
561
478
  protected static native long initContext(String modelPath);
562
479
  protected static native long initContextWithAsset(AssetManager assetManager, String modelPath);
563
480
  protected static native long initContextWithInputStream(PushbackInputStream inputStream);
564
- protected static native boolean vadSimple(float[] audio_data, int audio_data_len, float vad_thold, float vad_freq_thold);
565
- protected static native int fullTranscribe(
481
+ protected static native void freeContext(long contextPtr);
482
+
483
+ protected static native int fullWithNewJob(
566
484
  int job_id,
567
485
  long context,
568
486
  float[] audio_data,
569
487
  int audio_data_len,
570
- int n_threads,
571
- int max_context,
572
- int word_thold,
573
- int max_len,
574
- boolean token_timestamps,
575
- int offset,
576
- int duration,
577
- float temperature,
578
- float temperature_inc,
579
- int beam_size,
580
- int best_of,
581
- boolean speed_up,
582
- boolean translate,
583
- String language,
584
- String prompt,
488
+ ReadableMap options,
585
489
  Callback Callback
586
490
  );
587
491
  protected static native void abortTranscribe(int jobId);
@@ -590,5 +494,19 @@ public class WhisperContext {
590
494
  protected static native String getTextSegment(long context, int index);
591
495
  protected static native int getTextSegmentT0(long context, int index);
592
496
  protected static native int getTextSegmentT1(long context, int index);
593
- protected static native void freeContext(long contextPtr);
497
+
498
+ protected static native void createRealtimeTranscribeJob(
499
+ int job_id,
500
+ long context,
501
+ ReadableMap options
502
+ );
503
+ protected static native void finishRealtimeTranscribeJob(int job_id, long context, int[] sliceNSamples);
504
+ protected static native boolean vadSimple(int job_id, int slice_index, int n_samples, int n);
505
+ protected static native void putPcmData(int job_id, short[] buffer, int slice_index, int n_samples, int n);
506
+ protected static native int fullWithJob(
507
+ int job_id,
508
+ long context,
509
+ int slice_index,
510
+ int n_samples
511
+ );
594
512
  }
@@ -0,0 +1,76 @@
1
+ #include <jni.h>
2
+
3
+ // ReadableMap utils
4
+
5
+ namespace readablemap {
6
+
7
+ bool hasKey(JNIEnv *env, jobject readableMap, const char *key) {
8
+ jclass mapClass = env->GetObjectClass(readableMap);
9
+ jmethodID hasKeyMethod = env->GetMethodID(mapClass, "hasKey", "(Ljava/lang/String;)Z");
10
+ jstring jKey = env->NewStringUTF(key);
11
+ jboolean result = env->CallBooleanMethod(readableMap, hasKeyMethod, jKey);
12
+ env->DeleteLocalRef(jKey);
13
+ return result;
14
+ }
15
+
16
+ int getInt(JNIEnv *env, jobject readableMap, const char *key, jint defaultValue) {
17
+ if (!hasKey(env, readableMap, key)) {
18
+ return defaultValue;
19
+ }
20
+ jclass mapClass = env->GetObjectClass(readableMap);
21
+ jmethodID getIntMethod = env->GetMethodID(mapClass, "getInt", "(Ljava/lang/String;)I");
22
+ jstring jKey = env->NewStringUTF(key);
23
+ jint result = env->CallIntMethod(readableMap, getIntMethod, jKey);
24
+ env->DeleteLocalRef(jKey);
25
+ return result;
26
+ }
27
+
28
+ bool getBool(JNIEnv *env, jobject readableMap, const char *key, jboolean defaultValue) {
29
+ if (!hasKey(env, readableMap, key)) {
30
+ return defaultValue;
31
+ }
32
+ jclass mapClass = env->GetObjectClass(readableMap);
33
+ jmethodID getBoolMethod = env->GetMethodID(mapClass, "getBoolean", "(Ljava/lang/String;)Z");
34
+ jstring jKey = env->NewStringUTF(key);
35
+ jboolean result = env->CallBooleanMethod(readableMap, getBoolMethod, jKey);
36
+ env->DeleteLocalRef(jKey);
37
+ return result;
38
+ }
39
+
40
+ long getLong(JNIEnv *env, jobject readableMap, const char *key, jlong defaultValue) {
41
+ if (!hasKey(env, readableMap, key)) {
42
+ return defaultValue;
43
+ }
44
+ jclass mapClass = env->GetObjectClass(readableMap);
45
+ jmethodID getLongMethod = env->GetMethodID(mapClass, "getLong", "(Ljava/lang/String;)J");
46
+ jstring jKey = env->NewStringUTF(key);
47
+ jlong result = env->CallLongMethod(readableMap, getLongMethod, jKey);
48
+ env->DeleteLocalRef(jKey);
49
+ return result;
50
+ }
51
+
52
+ float getFloat(JNIEnv *env, jobject readableMap, const char *key, jfloat defaultValue) {
53
+ if (!hasKey(env, readableMap, key)) {
54
+ return defaultValue;
55
+ }
56
+ jclass mapClass = env->GetObjectClass(readableMap);
57
+ jmethodID getFloatMethod = env->GetMethodID(mapClass, "getDouble", "(Ljava/lang/String;)D");
58
+ jstring jKey = env->NewStringUTF(key);
59
+ jfloat result = env->CallDoubleMethod(readableMap, getFloatMethod, jKey);
60
+ env->DeleteLocalRef(jKey);
61
+ return result;
62
+ }
63
+
64
+ jstring getString(JNIEnv *env, jobject readableMap, const char *key, jstring defaultValue) {
65
+ if (!hasKey(env, readableMap, key)) {
66
+ return defaultValue;
67
+ }
68
+ jclass mapClass = env->GetObjectClass(readableMap);
69
+ jmethodID getStringMethod = env->GetMethodID(mapClass, "getString", "(Ljava/lang/String;)Ljava/lang/String;");
70
+ jstring jKey = env->NewStringUTF(key);
71
+ jstring result = (jstring) env->CallObjectMethod(readableMap, getStringMethod, jKey);
72
+ env->DeleteLocalRef(jKey);
73
+ return result;
74
+ }
75
+
76
+ }