whisper.rn 0.4.0-rc.7 → 0.4.0-rc.9

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. package/android/src/main/CMakeLists.txt +2 -1
  2. package/android/src/main/java/com/rnwhisper/AudioUtils.java +27 -12
  3. package/android/src/main/java/com/rnwhisper/RNWhisper.java +75 -34
  4. package/android/src/main/java/com/rnwhisper/WhisperContext.java +20 -3
  5. package/android/src/main/jni.cpp +29 -1
  6. package/android/src/newarch/java/com/rnwhisper/RNWhisperModule.java +10 -0
  7. package/android/src/oldarch/java/com/rnwhisper/RNWhisperModule.java +10 -0
  8. package/cpp/coreml/whisper-encoder.mm +1 -1
  9. package/cpp/ggml-aarch64.c +3209 -0
  10. package/cpp/ggml-aarch64.h +39 -0
  11. package/cpp/ggml-alloc.c +732 -494
  12. package/cpp/ggml-alloc.h +47 -63
  13. package/cpp/ggml-backend-impl.h +162 -47
  14. package/cpp/ggml-backend.cpp +2635 -0
  15. package/cpp/ggml-backend.h +216 -71
  16. package/cpp/ggml-common.h +1853 -0
  17. package/cpp/ggml-cpu-impl.h +614 -0
  18. package/cpp/ggml-impl.h +144 -178
  19. package/cpp/ggml-metal.h +14 -60
  20. package/cpp/ggml-metal.m +3437 -2097
  21. package/cpp/ggml-quants.c +12559 -4189
  22. package/cpp/ggml-quants.h +135 -212
  23. package/cpp/ggml-whisper.metallib +0 -0
  24. package/cpp/ggml.c +9029 -5219
  25. package/cpp/ggml.h +673 -338
  26. package/cpp/rn-whisper.cpp +91 -0
  27. package/cpp/rn-whisper.h +2 -0
  28. package/cpp/whisper.cpp +1476 -675
  29. package/cpp/whisper.h +84 -28
  30. package/ios/RNWhisper.mm +124 -37
  31. package/ios/RNWhisperAudioUtils.h +1 -0
  32. package/ios/RNWhisperAudioUtils.m +20 -13
  33. package/ios/RNWhisperContext.h +3 -2
  34. package/ios/RNWhisperContext.mm +41 -8
  35. package/jest/mock.js +9 -1
  36. package/lib/commonjs/NativeRNWhisper.js.map +1 -1
  37. package/lib/commonjs/index.js +48 -19
  38. package/lib/commonjs/index.js.map +1 -1
  39. package/lib/commonjs/version.json +1 -1
  40. package/lib/module/NativeRNWhisper.js.map +1 -1
  41. package/lib/module/index.js +48 -19
  42. package/lib/module/index.js.map +1 -1
  43. package/lib/module/version.json +1 -1
  44. package/lib/typescript/NativeRNWhisper.d.ts +6 -3
  45. package/lib/typescript/NativeRNWhisper.d.ts.map +1 -1
  46. package/lib/typescript/index.d.ts +25 -3
  47. package/lib/typescript/index.d.ts.map +1 -1
  48. package/package.json +6 -5
  49. package/src/NativeRNWhisper.ts +12 -3
  50. package/src/index.ts +63 -24
  51. package/src/version.json +1 -1
  52. package/whisper-rn.podspec +9 -2
  53. package/cpp/ggml-backend.c +0 -1357
  54. package/cpp/ggml-metal-whisper.metal +0 -4908
@@ -9,8 +9,9 @@ set(
9
9
  SOURCE_FILES
10
10
  ${RNWHISPER_LIB_DIR}/ggml.c
11
11
  ${RNWHISPER_LIB_DIR}/ggml-alloc.c
12
- ${RNWHISPER_LIB_DIR}/ggml-backend.c
12
+ ${RNWHISPER_LIB_DIR}/ggml-backend.cpp
13
13
  ${RNWHISPER_LIB_DIR}/ggml-quants.c
14
+ ${RNWHISPER_LIB_DIR}/ggml-aarch64.c
14
15
  ${RNWHISPER_LIB_DIR}/whisper.cpp
15
16
  ${RNWHISPER_LIB_DIR}/rn-audioutils.cpp
16
17
  ${RNWHISPER_LIB_DIR}/rn-whisper.cpp
@@ -2,8 +2,6 @@ package com.rnwhisper;
2
2
 
3
3
  import android.util.Log;
4
4
 
5
- import java.io.IOException;
6
- import java.io.FileReader;
7
5
  import java.io.ByteArrayOutputStream;
8
6
  import java.io.File;
9
7
  import java.io.IOException;
@@ -11,23 +9,22 @@ import java.io.InputStream;
11
9
  import java.nio.ByteBuffer;
12
10
  import java.nio.ByteOrder;
13
11
  import java.nio.ShortBuffer;
12
+ import java.util.Base64;
13
+
14
+ import java.util.Arrays;
14
15
 
15
16
  public class AudioUtils {
16
17
  private static final String NAME = "RNWhisperAudioUtils";
17
18
 
18
- public static float[] decodeWaveFile(InputStream inputStream) throws IOException {
19
- ByteArrayOutputStream baos = new ByteArrayOutputStream();
20
- byte[] buffer = new byte[1024];
21
- int bytesRead;
22
- while ((bytesRead = inputStream.read(buffer)) != -1) {
23
- baos.write(buffer, 0, bytesRead);
24
- }
25
- ByteBuffer byteBuffer = ByteBuffer.wrap(baos.toByteArray());
19
+ private static float[] bufferToFloatArray(byte[] buffer, Boolean cutHeader) {
20
+ ByteBuffer byteBuffer = ByteBuffer.wrap(buffer);
26
21
  byteBuffer.order(ByteOrder.LITTLE_ENDIAN);
27
- byteBuffer.position(44);
28
22
  ShortBuffer shortBuffer = byteBuffer.asShortBuffer();
29
23
  short[] shortArray = new short[shortBuffer.limit()];
30
24
  shortBuffer.get(shortArray);
25
+ if (cutHeader) {
26
+ shortArray = Arrays.copyOfRange(shortArray, 44, shortArray.length);
27
+ }
31
28
  float[] floatArray = new float[shortArray.length];
32
29
  for (int i = 0; i < shortArray.length; i++) {
33
30
  floatArray[i] = ((float) shortArray[i]) / 32767.0f;
@@ -36,4 +33,22 @@ public class AudioUtils {
36
33
  }
37
34
  return floatArray;
38
35
  }
39
- }
36
+
37
+ public static float[] decodeWaveFile(InputStream inputStream) throws IOException {
38
+ ByteArrayOutputStream baos = new ByteArrayOutputStream();
39
+ byte[] buffer = new byte[1024];
40
+ int bytesRead;
41
+ while ((bytesRead = inputStream.read(buffer)) != -1) {
42
+ baos.write(buffer, 0, bytesRead);
43
+ }
44
+ return bufferToFloatArray(baos.toByteArray(), true);
45
+ }
46
+
47
+ public static float[] decodeWaveData(String dataBase64) throws IOException {
48
+ return bufferToFloatArray(Base64.getDecoder().decode(dataBase64), true);
49
+ }
50
+
51
+ public static float[] decodePcmData(String dataBase64) {
52
+ return bufferToFloatArray(Base64.getDecoder().decode(dataBase64), false);
53
+ }
54
+ }
@@ -19,6 +19,7 @@ import java.util.HashMap;
19
19
  import java.util.Random;
20
20
  import java.io.File;
21
21
  import java.io.FileInputStream;
22
+ import java.io.InputStream;
22
23
  import java.io.PushbackInputStream;
23
24
 
24
25
  public class RNWhisper implements LifecycleEventListener {
@@ -119,44 +120,16 @@ public class RNWhisper implements LifecycleEventListener {
119
120
  tasks.put(task, "initContext");
120
121
  }
121
122
 
122
- public void transcribeFile(double id, double jobId, String filePath, ReadableMap options, Promise promise) {
123
- final WhisperContext context = contexts.get((int) id);
124
- if (context == null) {
125
- promise.reject("Context not found");
126
- return;
127
- }
128
- if (context.isCapturing()) {
129
- promise.reject("The context is in realtime transcribe mode");
130
- return;
131
- }
132
- if (context.isTranscribing()) {
133
- promise.reject("Context is already transcribing");
134
- return;
135
- }
123
+ private AsyncTask transcribe(WhisperContext context, double jobId, final float[] audioData, final ReadableMap options, Promise promise) {
136
124
  AsyncTask task = new AsyncTask<Void, Void, WritableMap>() {
137
125
  private Exception exception;
138
126
 
139
127
  @Override
140
128
  protected WritableMap doInBackground(Void... voids) {
141
129
  try {
142
- String waveFilePath = filePath;
143
-
144
- if (filePath.startsWith("http://") || filePath.startsWith("https://")) {
145
- waveFilePath = downloader.downloadFile(filePath);
146
- }
147
-
148
- int resId = getResourceIdentifier(waveFilePath);
149
- if (resId > 0) {
150
- return context.transcribeInputStream(
151
- (int) jobId,
152
- reactContext.getResources().openRawResource(resId),
153
- options
154
- );
155
- }
156
-
157
- return context.transcribeInputStream(
130
+ return context.transcribe(
158
131
  (int) jobId,
159
- new FileInputStream(new File(waveFilePath)),
132
+ audioData,
160
133
  options
161
134
  );
162
135
  } catch (Exception e) {
@@ -175,7 +148,66 @@ public class RNWhisper implements LifecycleEventListener {
175
148
  tasks.remove(this);
176
149
  }
177
150
  }.executeOnExecutor(AsyncTask.THREAD_POOL_EXECUTOR);
178
- tasks.put(task, "transcribeFile-" + id);
151
+ return task;
152
+ }
153
+
154
+ public void transcribeFile(double id, double jobId, String filePathOrBase64, ReadableMap options, Promise promise) {
155
+ final WhisperContext context = contexts.get((int) id);
156
+ if (context == null) {
157
+ promise.reject("Context not found");
158
+ return;
159
+ }
160
+ if (context.isCapturing()) {
161
+ promise.reject("The context is in realtime transcribe mode");
162
+ return;
163
+ }
164
+ if (context.isTranscribing()) {
165
+ promise.reject("Context is already transcribing");
166
+ return;
167
+ }
168
+
169
+ String waveFilePath = filePathOrBase64;
170
+ try {
171
+ if (filePathOrBase64.startsWith("http://") || filePathOrBase64.startsWith("https://")) {
172
+ waveFilePath = downloader.downloadFile(filePathOrBase64);
173
+ }
174
+
175
+ float[] audioData;
176
+ int resId = getResourceIdentifier(waveFilePath);
177
+ if (resId > 0) {
178
+ audioData = AudioUtils.decodeWaveFile(reactContext.getResources().openRawResource(resId));
179
+ } else if (filePathOrBase64.startsWith("data:audio/wav;base64,")) {
180
+ audioData = AudioUtils.decodeWaveData(filePathOrBase64);
181
+ } else {
182
+ audioData = AudioUtils.decodeWaveFile(new FileInputStream(new File(waveFilePath)));
183
+ }
184
+
185
+ AsyncTask task = transcribe(context, jobId, audioData, options, promise);
186
+ tasks.put(task, "transcribeFile-" + id);
187
+ } catch (Exception e) {
188
+ promise.reject(e);
189
+ }
190
+ }
191
+
192
+ public void transcribeData(double id, double jobId, String dataBase64, ReadableMap options, Promise promise) {
193
+ final WhisperContext context = contexts.get((int) id);
194
+ if (context == null) {
195
+ promise.reject("Context not found");
196
+ return;
197
+ }
198
+ if (context.isCapturing()) {
199
+ promise.reject("The context is in realtime transcribe mode");
200
+ return;
201
+ }
202
+ if (context.isTranscribing()) {
203
+ promise.reject("Context is already transcribing");
204
+ return;
205
+ }
206
+
207
+ float[] audioData = AudioUtils.decodePcmData(dataBase64);
208
+ AsyncTask task = transcribe(context, jobId, audioData, options, promise);
209
+
210
+ tasks.put(task, "transcribeData-" + id);
179
211
  }
180
212
 
181
213
  public void startRealtimeTranscribe(double id, double jobId, ReadableMap options, Promise promise) {
@@ -211,7 +243,7 @@ public class RNWhisper implements LifecycleEventListener {
211
243
  context.stopTranscribe((int) jobId);
212
244
  AsyncTask completionTask = null;
213
245
  for (AsyncTask task : tasks.keySet()) {
214
- if (tasks.get(task).equals("transcribeFile-" + id)) {
246
+ if (tasks.get(task).equals("transcribeFile-" + id) || tasks.get(task).equals("transcribeData-" + id)) {
215
247
  task.get();
216
248
  break;
217
249
  }
@@ -235,6 +267,15 @@ public class RNWhisper implements LifecycleEventListener {
235
267
  tasks.put(task, "abortTranscribe-" + id);
236
268
  }
237
269
 
270
+ public void bench(double id, double nThreads, Promise promise) {
271
+ final WhisperContext context = contexts.get((int) id);
272
+ if (context == null) {
273
+ promise.reject("Context not found");
274
+ return;
275
+ }
276
+ promise.resolve(context.bench((int) nThreads));
277
+ }
278
+
238
279
  public void releaseContext(double id, Promise promise) {
239
280
  final int contextId = (int) id;
240
281
  AsyncTask task = new AsyncTask<Void, Void, Void>() {
@@ -250,7 +291,7 @@ public class RNWhisper implements LifecycleEventListener {
250
291
  context.stopCurrentTranscribe();
251
292
  AsyncTask completionTask = null;
252
293
  for (AsyncTask task : tasks.keySet()) {
253
- if (tasks.get(task).equals("transcribeFile-" + contextId)) {
294
+ if (tasks.get(task).equals("transcribeFile-" + contextId) || tasks.get(task).equals("transcribeData-" + contextId)) {
254
295
  task.get();
255
296
  break;
256
297
  }
@@ -53,6 +53,7 @@ public class WhisperContext {
53
53
  private boolean isCapturing = false;
54
54
  private boolean isStoppedByAction = false;
55
55
  private boolean isTranscribing = false;
56
+ private boolean isTdrzEnable = false;
56
57
  private Thread rootFullHandler = null;
57
58
  private Thread fullHandler = null;
58
59
 
@@ -73,6 +74,7 @@ public class WhisperContext {
73
74
  isCapturing = false;
74
75
  isStoppedByAction = false;
75
76
  isTranscribing = false;
77
+ isTdrzEnable = false;
76
78
  rootFullHandler = null;
77
79
  fullHandler = null;
78
80
  }
@@ -113,6 +115,8 @@ public class WhisperContext {
113
115
  double realtimeAudioMinSec = options.hasKey("realtimeAudioMinSec") ? options.getDouble("realtimeAudioMinSec") : 0;
114
116
  final double audioMinSec = realtimeAudioMinSec > 0.5 && realtimeAudioMinSec <= audioSliceSec ? realtimeAudioMinSec : 1;
115
117
 
118
+ this.isTdrzEnable = options.hasKey("tdrzEnable") && options.getBoolean("tdrzEnable");
119
+
116
120
  createRealtimeTranscribeJob(jobId, context, options);
117
121
 
118
122
  sliceNSamples = new ArrayList<Integer>();
@@ -328,15 +332,15 @@ public class WhisperContext {
328
332
  }
329
333
  }
330
334
 
331
- public WritableMap transcribeInputStream(int jobId, InputStream inputStream, ReadableMap options) throws IOException, Exception {
335
+ public WritableMap transcribe(int jobId, float[] audioData, ReadableMap options) throws IOException, Exception {
332
336
  if (isCapturing || isTranscribing) {
333
337
  throw new Exception("Context is already in capturing or transcribing");
334
338
  }
335
339
  rewind();
336
-
337
340
  this.jobId = jobId;
341
+ this.isTdrzEnable = options.hasKey("tdrzEnable") && options.getBoolean("tdrzEnable");
342
+
338
343
  isTranscribing = true;
339
- float[] audioData = AudioUtils.decodeWaveFile(inputStream);
340
344
 
341
345
  boolean hasProgressCallback = options.hasKey("onProgress") && options.getBoolean("onProgress");
342
346
  boolean hasNewSegmentsCallback = options.hasKey("onNewSegments") && options.getBoolean("onNewSegments");
@@ -368,8 +372,15 @@ public class WhisperContext {
368
372
 
369
373
  WritableMap data = Arguments.createMap();
370
374
  WritableArray segments = Arguments.createArray();
375
+
371
376
  for (int i = 0; i < count; i++) {
372
377
  String text = getTextSegment(context, i);
378
+
379
+ // If tdrzEnable is enabled and speaker turn is detected
380
+ if (this.isTdrzEnable && getTextSegmentSpeakerTurnNext(context, i)) {
381
+ text += " [SPEAKER_TURN]";
382
+ }
383
+
373
384
  builder.append(text);
374
385
 
375
386
  WritableMap segment = Arguments.createMap();
@@ -411,6 +422,10 @@ public class WhisperContext {
411
422
  stopTranscribe(this.jobId);
412
423
  }
413
424
 
425
+ public String bench(int n_threads) {
426
+ return bench(context, n_threads);
427
+ }
428
+
414
429
  public void release() {
415
430
  stopCurrentTranscribe();
416
431
  freeContext(context);
@@ -499,6 +514,7 @@ public class WhisperContext {
499
514
  protected static native String getTextSegment(long context, int index);
500
515
  protected static native int getTextSegmentT0(long context, int index);
501
516
  protected static native int getTextSegmentT1(long context, int index);
517
+ protected static native boolean getTextSegmentSpeakerTurnNext(long context, int index);
502
518
 
503
519
  protected static native void createRealtimeTranscribeJob(
504
520
  int job_id,
@@ -514,4 +530,5 @@ public class WhisperContext {
514
530
  int slice_index,
515
531
  int n_samples
516
532
  );
533
+ protected static native String bench(long context, int n_threads);
517
534
  }
@@ -155,6 +155,8 @@ Java_com_rnwhisper_WhisperContext_initContext(
155
155
  JNIEnv *env, jobject thiz, jstring model_path_str) {
156
156
  UNUSED(thiz);
157
157
  struct whisper_context_params cparams;
158
+ cparams.dtw_token_timestamps = false;
159
+
158
160
  struct whisper_context *context = nullptr;
159
161
  const char *model_path_chars = env->GetStringUTFChars(model_path_str, nullptr);
160
162
  context = whisper_init_from_file_with_params(model_path_chars, cparams);
@@ -171,6 +173,8 @@ Java_com_rnwhisper_WhisperContext_initContextWithAsset(
171
173
  ) {
172
174
  UNUSED(thiz);
173
175
  struct whisper_context_params cparams;
176
+ cparams.dtw_token_timestamps = false;
177
+
174
178
  struct whisper_context *context = nullptr;
175
179
  const char *model_path_chars = env->GetStringUTFChars(model_path_str, nullptr);
176
180
  context = whisper_init_from_asset(env, asset_manager, model_path_chars, cparams);
@@ -186,6 +190,8 @@ Java_com_rnwhisper_WhisperContext_initContextWithInputStream(
186
190
  ) {
187
191
  UNUSED(thiz);
188
192
  struct whisper_context_params cparams;
193
+ cparams.dtw_token_timestamps = false;
194
+
189
195
  struct whisper_context *context = nullptr;
190
196
  context = whisper_init_from_input_stream(env, input_stream, cparams);
191
197
  return reinterpret_cast<jlong>(context);
@@ -206,8 +212,8 @@ struct whisper_full_params createFullParams(JNIEnv *env, jobject options) {
206
212
  int n_threads = readablemap::getInt(env, options, "maxThreads", default_n_threads);
207
213
  params.n_threads = n_threads > 0 ? n_threads : default_n_threads;
208
214
  params.translate = readablemap::getBool(env, options, "translate", false);
209
- params.speed_up = readablemap::getBool(env, options, "speedUp", false);
210
215
  params.token_timestamps = readablemap::getBool(env, options, "tokenTimestamps", false);
216
+ params.tdrz_enable = readablemap::getBool(env, options, "tdrzEnable", false);
211
217
  params.offset_ms = 0;
212
218
  params.no_context = true;
213
219
  params.single_segment = false;
@@ -493,4 +499,26 @@ Java_com_rnwhisper_WhisperContext_freeContext(
493
499
  whisper_free(context);
494
500
  }
495
501
 
502
+ JNIEXPORT jboolean JNICALL
503
+ Java_com_rnwhisper_WhisperContext_getTextSegmentSpeakerTurnNext(
504
+ JNIEnv *env, jobject thiz, jlong context_ptr, jint index) {
505
+ UNUSED(env);
506
+ UNUSED(thiz);
507
+ struct whisper_context *context = reinterpret_cast<struct whisper_context *>(context_ptr);
508
+ return whisper_full_get_segment_speaker_turn_next(context, index);
509
+ }
510
+
511
+ JNIEXPORT jstring JNICALL
512
+ Java_com_rnwhisper_WhisperContext_bench(
513
+ JNIEnv *env,
514
+ jobject thiz,
515
+ jlong context_ptr,
516
+ jint n_threads
517
+ ) {
518
+ UNUSED(thiz);
519
+ struct whisper_context *context = reinterpret_cast<struct whisper_context *>(context_ptr);
520
+ std::string result = rnwhisper::bench(context, n_threads);
521
+ return env->NewStringUTF(result.c_str());
522
+ }
523
+
496
524
  } // extern "C"
@@ -47,6 +47,11 @@ public class RNWhisperModule extends NativeRNWhisperSpec {
47
47
  rnwhisper.transcribeFile(id, jobId, filePath, options, promise);
48
48
  }
49
49
 
50
+ @ReactMethod
51
+ public void transcribeData(double id, double jobId, String dataBase64, ReadableMap options, Promise promise) {
52
+ rnwhisper.transcribeData(id, jobId, dataBase64, options, promise);
53
+ }
54
+
50
55
  @ReactMethod
51
56
  public void startRealtimeTranscribe(double id, double jobId, ReadableMap options, Promise promise) {
52
57
  rnwhisper.startRealtimeTranscribe(id, jobId, options, promise);
@@ -57,6 +62,11 @@ public class RNWhisperModule extends NativeRNWhisperSpec {
57
62
  rnwhisper.abortTranscribe(contextId, jobId, promise);
58
63
  }
59
64
 
65
+ @ReactMethod
66
+ public void bench(double id, double nThreads, Promise promise) {
67
+ rnwhisper.bench(id, nThreads, promise);
68
+ }
69
+
60
70
  @ReactMethod
61
71
  public void releaseContext(double id, Promise promise) {
62
72
  rnwhisper.releaseContext(id, promise);
@@ -47,6 +47,11 @@ public class RNWhisperModule extends ReactContextBaseJavaModule {
47
47
  rnwhisper.transcribeFile(id, jobId, filePath, options, promise);
48
48
  }
49
49
 
50
+ @ReactMethod
51
+ public void transcribeData(double id, double jobId, String dataBase64, ReadableMap options, Promise promise) {
52
+ rnwhisper.transcribeData(id, jobId, dataBase64, options, promise);
53
+ }
54
+
50
55
  @ReactMethod
51
56
  public void startRealtimeTranscribe(double id, double jobId, ReadableMap options, Promise promise) {
52
57
  rnwhisper.startRealtimeTranscribe(id, jobId, options, promise);
@@ -57,6 +62,11 @@ public class RNWhisperModule extends ReactContextBaseJavaModule {
57
62
  rnwhisper.abortTranscribe(contextId, jobId, promise);
58
63
  }
59
64
 
65
+ @ReactMethod
66
+ public void bench(double id, double nThreads, Promise promise) {
67
+ rnwhisper.bench(id, nThreads, promise);
68
+ }
69
+
60
70
  @ReactMethod
61
71
  public void releaseContext(double id, Promise promise) {
62
72
  rnwhisper.releaseContext(id, promise);
@@ -24,7 +24,7 @@ struct whisper_coreml_context * whisper_coreml_init(const char * path_model) {
24
24
 
25
25
  // select which device to run the Core ML model on
26
26
  MLModelConfiguration *config = [[MLModelConfiguration alloc] init];
27
- //config.computeUnits = MLComputeUnitsCPUAndGPU;
27
+ // config.computeUnits = MLComputeUnitsCPUAndGPU;
28
28
  //config.computeUnits = MLComputeUnitsCPUAndNeuralEngine;
29
29
  config.computeUnits = MLComputeUnitsAll;
30
30