whisper.rn 0.3.0-rc.5 → 0.3.0-rc.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/README.md CHANGED
@@ -8,9 +8,12 @@ React Native binding of [whisper.cpp](https://github.com/ggerganov/whisper.cpp).
8
8
 
9
9
  [whisper.cpp](https://github.com/ggerganov/whisper.cpp): High-performance inference of [OpenAI's Whisper](https://github.com/openai/whisper) automatic speech recognition (ASR) model
10
10
 
11
- <img src="https://user-images.githubusercontent.com/3001525/225511664-8b2ba3ec-864d-4f55-bcb0-447aef168a32.jpeg" width="500" />
11
+ ## Screenshots
12
12
 
13
- > Run example with release mode on iPhone 13 Pro Max
13
+ | <img src="https://github.com/mybigday/whisper.rn/assets/3001525/2fea7b2d-c911-44fb-9afc-8efc7b594446" width="300" /> | <img src="https://github.com/mybigday/whisper.rn/assets/3001525/a5005a6c-44f7-4db9-95e8-0fd951a2e147" width="300" /> |
14
+ | :------------------------------------------: | :------------------------------------------: |
15
+ | iOS: Tested on iPhone 13 Pro Max | Android: Tested on Pixel 6 |
16
+ | (tiny.en, Core ML enabled) | (tiny.en, armv8.2-a+fp16) |
14
17
 
15
18
  ## Installation
16
19
 
@@ -48,7 +51,6 @@ import { initWhisper } from 'whisper.rn'
48
51
 
49
52
  const whisperContext = await initWhisper({
50
53
  filePath: 'file://.../ggml-tiny.en.bin',
51
- isBundleAsset: false, // Set to true if you want to load the model from bundle resources, the filePath will be like `ggml-tiny.en.bin`
52
54
  })
53
55
 
54
56
  const sampleFilePath = 'file://.../sample.wav'
@@ -81,6 +83,44 @@ In Android, you may need to request the microphone permission by [`PermissionAnd
81
83
 
82
84
  Please visit the [Documentation](docs/) for more details.
83
85
 
86
+ ## Usage with assets
87
+
88
+ You can also use the model file / audio file from assets:
89
+
90
+ ```js
91
+ import { initWhisper } from 'whisper.rn'
92
+
93
+ const whisperContext = await initWhisper({
94
+ filePath: require('../assets/ggml-tiny.en.bin'),
95
+ })
96
+
97
+ const { stop, promise } =
98
+ whisperContext.transcribe(require('../assets/sample.wav'), options)
99
+
100
+ // ...
101
+ ```
102
+
103
+ This requires editing the `metro.config.js` to support assets:
104
+
105
+ ```js
106
+ // ...
107
+ const defaultAssetExts = require('metro-config/src/defaults/defaults').assetExts
108
+
109
+ module.exports = {
110
+ // ...
111
+ resolver: {
112
+ // ...
113
+ assetExts: [
114
+ ...defaultAssetExts,
115
+ 'bin', // whisper.rn: ggml model binary
116
+ 'mil', // whisper.rn: CoreML model asset
117
+ ]
118
+ },
119
+ }
120
+ ```
121
+
122
+ Please note that it will significantly increase the size of the app in release mode.
123
+
84
124
  ## Core ML support
85
125
 
86
126
  __*Platform: iOS 15.0+, tvOS 15.0+*__
@@ -91,25 +131,44 @@ The `.mlmodelc` model files is load depend on the ggml model file path. For exam
91
131
 
92
132
  Currently there is no official way to get the Core ML models by URL, you will need to convert Core ML models by yourself. Please see [Core ML Support](https://github.com/ggerganov/whisper.cpp#core-ml-support) of whisper.cpp for more details.
93
133
 
94
- During the `.mlmodelc` is a directory, you will need to download 5 files:
134
+ During the `.mlmodelc` is a directory, you will need to download 5 files (3 required):
95
135
 
96
136
  ```json5
97
137
  [
98
138
  'model.mil',
99
- 'metadata.json',
100
139
  'coremldata.bin',
101
140
  'weights/weight.bin',
102
- 'analytics/coremldata.bin',
141
+ // Not required:
142
+ // 'metadata.json', 'analytics/coremldata.bin',
103
143
  ]
104
144
  ```
105
145
 
106
- Or just add them to your app's bundle resources, like the example app does, but this would increase the app size significantly.
146
+ Or just use `require` to bundle that in your app, like the example app does, but this would increase the app size significantly.
147
+
148
+ ```js
149
+ const whisperContext = await initWhisper({
150
+ filePath: require('../assets/ggml-tiny.en.bin')
151
+ coreMLModelAsset:
152
+ Platform.OS === 'ios'
153
+ ? {
154
+ filename: 'ggml-tiny.en-encoder.mlmodelc',
155
+ assets: [
156
+ require('../assets/ggml-tiny.en-encoder.mlmodelc/weights/weight.bin'),
157
+ require('../assets/ggml-tiny.en-encoder.mlmodelc/model.mil'),
158
+ require('../assets/ggml-tiny.en-encoder.mlmodelc/coremldata.bin'),
159
+ ],
160
+ }
161
+ : undefined,
162
+ })
163
+ ```
164
+
165
+ In real world, we recommended to split the asset imports into another platform specific file (e.g. `context-opts.ios.js`) to avoid these unused files in the bundle for Android.
107
166
 
108
167
  ## Run with example
109
168
 
110
- The example app is using [react-native-fs](https://github.com/itinance/react-native-fs) to download the model file and audio file.
169
+ The example app provide a simple UI for testing the functions.
111
170
 
112
- Model: `base.en` in https://huggingface.co/datasets/ggerganov/whisper.cpp
171
+ Used Whisper model: `tiny.en` in https://huggingface.co/datasets/ggerganov/whisper.cpp
113
172
  Sample file: `jfk.wav` in https://github.com/ggerganov/whisper.cpp/tree/master/samples
114
173
 
115
174
  For test better performance on transcribe, you can run the app in Release mode.
@@ -130,6 +189,10 @@ jest.mock('whisper.rn', () => require('whisper.rn/jest/mock'))
130
189
 
131
190
  See the [contributing guide](CONTRIBUTING.md) to learn how to contribute to the repository and the development workflow.
132
191
 
192
+ ## Troubleshooting
193
+
194
+ See the [troubleshooting](TROUBLESHOOTING.md) if you encounter any problem while using `whisper.rn`.
195
+
133
196
  ## License
134
197
 
135
198
  MIT
@@ -0,0 +1,83 @@
1
+ package com.rnwhisper;
2
+
3
+ import android.content.Context;
4
+
5
+ import java.io.BufferedInputStream;
6
+ import java.io.FileOutputStream;
7
+ import java.io.File;
8
+ import java.io.InputStream;
9
+ import java.io.OutputStream;
10
+ import java.net.URL;
11
+ import java.net.URLConnection;
12
+
13
+ /**
14
+ * NOTE: This is simple downloader,
15
+ * the main purpose is supported load assets on RN Debug mode,
16
+ * so it's a very crude implementation.
17
+ *
18
+ * If you want to use file download in production to load model / audio files,
19
+ * I would recommend using react-native-fs or expo-file-system to manage the files.
20
+ */
21
+ public class Downloader {
22
+ private static Context context;
23
+
24
+ public Downloader(Context context) {
25
+ this.context = context;
26
+ }
27
+
28
+ private String getDir() {
29
+ String dir = context.getCacheDir().getAbsolutePath() + "/rnwhisper_debug_assets/";
30
+ File file = new File(dir);
31
+ if (!file.exists()) {
32
+ file.mkdirs();
33
+ }
34
+ return dir;
35
+ }
36
+
37
+ private boolean fileExists(String filename) {
38
+ File file = new File(getDir() + filename);
39
+ return file.exists();
40
+ }
41
+
42
+ public String downloadFile(String urlPath) throws Exception {
43
+ String filename = urlPath.substring(urlPath.lastIndexOf('/') + 1);
44
+ if (filename.contains("?")) {
45
+ filename = filename.substring(0, filename.indexOf("?"));
46
+ }
47
+ String filepath = getDir() + filename;
48
+ if (fileExists(filename)) {
49
+ return filepath;
50
+ }
51
+ try {
52
+ URL url = new URL(urlPath);
53
+ URLConnection connection = url.openConnection();
54
+ connection.connect();
55
+ InputStream input = new BufferedInputStream(url.openStream());
56
+ OutputStream output = new FileOutputStream(filepath);
57
+ byte data[] = new byte[1024];
58
+ int count;
59
+ while ((count = input.read(data)) != -1) {
60
+ output.write(data, 0, count);
61
+ }
62
+ output.flush();
63
+ output.close();
64
+ input.close();
65
+ } catch (Exception e) {
66
+ throw e;
67
+ }
68
+ return filepath;
69
+ }
70
+
71
+ private void deleteFile(File fileOrDir) {
72
+ if (fileOrDir.isDirectory()) {
73
+ for (File child : fileOrDir.listFiles()) {
74
+ deleteFile(child);
75
+ }
76
+ }
77
+ fileOrDir.delete();
78
+ }
79
+
80
+ public void clearCache() {
81
+ deleteFile(new File(getDir()));
82
+ }
83
+ }
@@ -26,6 +26,7 @@ import java.io.File;
26
26
  import java.io.FileInputStream;
27
27
  import java.io.IOException;
28
28
  import java.io.InputStream;
29
+ import java.io.PushbackInputStream;
29
30
  import java.nio.ByteBuffer;
30
31
  import java.nio.ByteOrder;
31
32
  import java.nio.ShortBuffer;
@@ -281,10 +282,10 @@ public class WhisperContext {
281
282
  eventEmitter.emit(eventName, event);
282
283
  }
283
284
 
284
- public WritableMap transcribeFile(int jobId, String filePath, ReadableMap options) throws IOException, Exception {
285
+ public WritableMap transcribeInputStream(int jobId, InputStream inputStream, ReadableMap options) throws IOException, Exception {
285
286
  this.jobId = jobId;
286
287
  isTranscribing = true;
287
- float[] audioData = decodeWaveFile(new File(filePath));
288
+ float[] audioData = decodeWaveFile(inputStream);
288
289
  int code = full(jobId, options, audioData, audioData.length);
289
290
  isTranscribing = false;
290
291
  this.jobId = -1;
@@ -383,14 +384,12 @@ public class WhisperContext {
383
384
  freeContext(context);
384
385
  }
385
386
 
386
- public static float[] decodeWaveFile(File file) throws IOException {
387
+ public static float[] decodeWaveFile(InputStream inputStream) throws IOException {
387
388
  ByteArrayOutputStream baos = new ByteArrayOutputStream();
388
- try (InputStream inputStream = new FileInputStream(file)) {
389
- byte[] buffer = new byte[1024];
390
- int bytesRead;
391
- while ((bytesRead = inputStream.read(buffer)) != -1) {
392
- baos.write(buffer, 0, bytesRead);
393
- }
389
+ byte[] buffer = new byte[1024];
390
+ int bytesRead;
391
+ while ((bytesRead = inputStream.read(buffer)) != -1) {
392
+ baos.write(buffer, 0, bytesRead);
394
393
  }
395
394
  ByteBuffer byteBuffer = ByteBuffer.wrap(baos.toByteArray());
396
395
  byteBuffer.order(ByteOrder.LITTLE_ENDIAN);
@@ -472,6 +471,7 @@ public class WhisperContext {
472
471
 
473
472
  protected static native long initContext(String modelPath);
474
473
  protected static native long initContextWithAsset(AssetManager assetManager, String modelPath);
474
+ protected static native long initContextWithInputStream(PushbackInputStream inputStream);
475
475
  protected static native int fullTranscribe(
476
476
  int job_id,
477
477
  long context,
@@ -20,6 +20,97 @@ static inline int min(int a, int b) {
20
20
  return (a < b) ? a : b;
21
21
  }
22
22
 
23
+ // Load model from input stream (used for drawable / raw resources)
24
+ struct input_stream_context {
25
+ JNIEnv *env;
26
+ jobject input_stream;
27
+ };
28
+
29
+ static size_t input_stream_read(void *ctx, void *output, size_t read_size) {
30
+ input_stream_context *context = (input_stream_context *)ctx;
31
+ JNIEnv *env = context->env;
32
+ jobject input_stream = context->input_stream;
33
+ jclass input_stream_class = env->GetObjectClass(input_stream);
34
+
35
+ jbyteArray buffer = env->NewByteArray(read_size);
36
+ jint bytes_read = env->CallIntMethod(
37
+ input_stream,
38
+ env->GetMethodID(input_stream_class, "read", "([B)I"),
39
+ buffer
40
+ );
41
+
42
+ if (bytes_read > 0) {
43
+ env->GetByteArrayRegion(buffer, 0, bytes_read, (jbyte *) output);
44
+ }
45
+
46
+ env->DeleteLocalRef(buffer);
47
+
48
+ return bytes_read;
49
+ }
50
+
51
+ static bool input_stream_is_eof(void *ctx) {
52
+ input_stream_context *context = (input_stream_context *)ctx;
53
+ JNIEnv *env = context->env;
54
+ jobject input_stream = context->input_stream;
55
+
56
+ jclass input_stream_class = env->GetObjectClass(input_stream);
57
+
58
+ jbyteArray buffer = env->NewByteArray(1);
59
+ jint bytes_read = env->CallIntMethod(
60
+ input_stream,
61
+ env->GetMethodID(input_stream_class, "read", "([B)I"),
62
+ buffer
63
+ );
64
+
65
+ bool is_eof = (bytes_read == -1);
66
+ if (!is_eof) {
67
+ // If we successfully read a byte, "unread" it by pushing it back into the stream.
68
+ env->CallVoidMethod(
69
+ input_stream,
70
+ env->GetMethodID(input_stream_class, "unread", "([BII)V"),
71
+ buffer,
72
+ 0,
73
+ 1
74
+ );
75
+ }
76
+
77
+ env->DeleteLocalRef(buffer);
78
+
79
+ return is_eof;
80
+ }
81
+
82
+ static void input_stream_close(void *ctx) {
83
+ input_stream_context *context = (input_stream_context *)ctx;
84
+ JNIEnv *env = context->env;
85
+ jobject input_stream = context->input_stream;
86
+ jclass input_stream_class = env->GetObjectClass(input_stream);
87
+
88
+ env->CallVoidMethod(
89
+ input_stream,
90
+ env->GetMethodID(input_stream_class, "close", "()V")
91
+ );
92
+
93
+ env->DeleteGlobalRef(input_stream);
94
+ }
95
+
96
+ static struct whisper_context *whisper_init_from_input_stream(
97
+ JNIEnv *env,
98
+ jobject input_stream // PushbackInputStream
99
+ ) {
100
+ input_stream_context *context = new input_stream_context;
101
+ context->env = env;
102
+ context->input_stream = env->NewGlobalRef(input_stream);
103
+
104
+ whisper_model_loader loader = {
105
+ .context = context,
106
+ .read = &input_stream_read,
107
+ .eof = &input_stream_is_eof,
108
+ .close = &input_stream_close
109
+ };
110
+ return whisper_init(&loader);
111
+ }
112
+
113
+ // Load model from asset
23
114
  static size_t asset_read(void *ctx, void *output, size_t read_size) {
24
115
  return AAsset_read((AAsset *) ctx, output, read_size);
25
116
  }
@@ -81,6 +172,17 @@ Java_com_rnwhisper_WhisperContext_initContextWithAsset(
81
172
  return reinterpret_cast<jlong>(context);
82
173
  }
83
174
 
175
+ JNIEXPORT jlong JNICALL
176
+ Java_com_rnwhisper_WhisperContext_initContextWithInputStream(
177
+ JNIEnv *env,
178
+ jobject thiz,
179
+ jobject input_stream
180
+ ) {
181
+ UNUSED(thiz);
182
+ struct whisper_context *context = nullptr;
183
+ context = whisper_init_from_input_stream(env, input_stream);
184
+ return reinterpret_cast<jlong>(context);
185
+ }
84
186
 
85
187
  JNIEXPORT jint JNICALL
86
188
  Java_com_rnwhisper_WhisperContext_fullTranscribe(
@@ -17,17 +17,22 @@ import com.facebook.react.module.annotations.ReactModule;
17
17
 
18
18
  import java.util.HashMap;
19
19
  import java.util.Random;
20
+ import java.io.File;
21
+ import java.io.FileInputStream;
22
+ import java.io.PushbackInputStream;
20
23
 
21
24
  @ReactModule(name = RNWhisperModule.NAME)
22
25
  public class RNWhisperModule extends NativeRNWhisperSpec implements LifecycleEventListener {
23
26
  public static final String NAME = "RNWhisper";
24
27
 
25
28
  private ReactApplicationContext reactContext;
29
+ private Downloader downloader;
26
30
 
27
31
  public RNWhisperModule(ReactApplicationContext reactContext) {
28
32
  super(reactContext);
29
33
  reactContext.addLifecycleEventListener(this);
30
34
  this.reactContext = reactContext;
35
+ this.downloader = new Downloader(reactContext);
31
36
  }
32
37
 
33
38
  @Override
@@ -49,19 +54,48 @@ public class RNWhisperModule extends NativeRNWhisperSpec implements LifecycleEve
49
54
 
50
55
  private HashMap<Integer, WhisperContext> contexts = new HashMap<>();
51
56
 
57
+ private int getResourceIdentifier(String filePath) {
58
+ int identifier = reactContext.getResources().getIdentifier(
59
+ filePath,
60
+ "drawable",
61
+ reactContext.getPackageName()
62
+ );
63
+ if (identifier == 0) {
64
+ identifier = reactContext.getResources().getIdentifier(
65
+ filePath,
66
+ "raw",
67
+ reactContext.getPackageName()
68
+ );
69
+ }
70
+ return identifier;
71
+ }
72
+
52
73
  @ReactMethod
53
- public void initContext(final String modelPath, final boolean isBundleAsset, final Promise promise) {
74
+ public void initContext(final ReadableMap options, final Promise promise) {
54
75
  new AsyncTask<Void, Void, Integer>() {
55
76
  private Exception exception;
56
77
 
57
78
  @Override
58
79
  protected Integer doInBackground(Void... voids) {
59
80
  try {
81
+ String modelPath = options.getString("filePath");
82
+ boolean isBundleAsset = options.getBoolean("isBundleAsset");
83
+
84
+ String modelFilePath = modelPath;
85
+ if (!isBundleAsset && (modelPath.startsWith("http://") || modelPath.startsWith("https://"))) {
86
+ modelFilePath = downloader.downloadFile(modelPath);
87
+ }
88
+
60
89
  long context;
61
- if (isBundleAsset) {
62
- context = WhisperContext.initContextWithAsset(reactContext.getAssets(), modelPath);
90
+ int resId = getResourceIdentifier(modelFilePath);
91
+ if (resId > 0) {
92
+ context = WhisperContext.initContextWithInputStream(
93
+ new PushbackInputStream(reactContext.getResources().openRawResource(resId))
94
+ );
95
+ } else if (isBundleAsset) {
96
+ context = WhisperContext.initContextWithAsset(reactContext.getAssets(), modelFilePath);
63
97
  } else {
64
- context = WhisperContext.initContext(modelPath);
98
+ context = WhisperContext.initContext(modelFilePath);
65
99
  }
66
100
  if (context == 0) {
67
101
  throw new Exception("Failed to initialize context");
@@ -108,7 +142,26 @@ public class RNWhisperModule extends NativeRNWhisperSpec implements LifecycleEve
108
142
  @Override
109
143
  protected WritableMap doInBackground(Void... voids) {
110
144
  try {
111
- return context.transcribeFile((int) jobId, filePath, options);
145
+ String waveFilePath = filePath;
146
+
147
+ if (filePath.startsWith("http://") || filePath.startsWith("https://")) {
148
+ waveFilePath = downloader.downloadFile(filePath);
149
+ }
150
+
151
+ int resId = getResourceIdentifier(waveFilePath);
152
+ if (resId > 0) {
153
+ return context.transcribeInputStream(
154
+ (int) jobId,
155
+ reactContext.getResources().openRawResource(resId),
156
+ options
157
+ );
158
+ }
159
+
160
+ return context.transcribeInputStream(
161
+ (int) jobId,
162
+ new FileInputStream(new File(waveFilePath)),
163
+ options
164
+ );
112
165
  } catch (Exception e) {
113
166
  exception = e;
114
167
  return null;
@@ -228,5 +281,6 @@ public class RNWhisperModule extends NativeRNWhisperSpec implements LifecycleEve
228
281
  context.release();
229
282
  }
230
283
  contexts.clear();
284
+ downloader.clearCache();
231
285
  }
232
286
  }
@@ -18,17 +18,22 @@ import com.facebook.react.module.annotations.ReactModule;
18
18
 
19
19
  import java.util.HashMap;
20
20
  import java.util.Random;
21
+ import java.io.File;
22
+ import java.io.FileInputStream;
23
+ import java.io.PushbackInputStream;
21
24
 
22
25
  @ReactModule(name = RNWhisperModule.NAME)
23
26
  public class RNWhisperModule extends ReactContextBaseJavaModule implements LifecycleEventListener {
24
27
  public static final String NAME = "RNWhisper";
25
28
 
26
29
  private ReactApplicationContext reactContext;
30
+ private Downloader downloader;
27
31
 
28
32
  public RNWhisperModule(ReactApplicationContext reactContext) {
29
33
  super(reactContext);
30
34
  reactContext.addLifecycleEventListener(this);
31
35
  this.reactContext = reactContext;
36
+ this.downloader = new Downloader(reactContext);
32
37
  }
33
38
 
34
39
  @Override
@@ -39,19 +44,48 @@ public class RNWhisperModule extends ReactContextBaseJavaModule implements Lifec
39
44
 
40
45
  private HashMap<Integer, WhisperContext> contexts = new HashMap<>();
41
46
 
47
+ private int getResourceIdentifier(String filePath) {
48
+ int identifier = reactContext.getResources().getIdentifier(
49
+ filePath,
50
+ "drawable",
51
+ reactContext.getPackageName()
52
+ );
53
+ if (identifier == 0) {
54
+ identifier = reactContext.getResources().getIdentifier(
55
+ filePath,
56
+ "raw",
57
+ reactContext.getPackageName()
58
+ );
59
+ }
60
+ return identifier;
61
+ }
62
+
42
63
  @ReactMethod
43
- public void initContext(final String modelPath, final boolean isBundleAsset, final Promise promise) {
64
+ public void initContext(final ReadableMap options, final Promise promise) {
44
65
  new AsyncTask<Void, Void, Integer>() {
45
66
  private Exception exception;
46
67
 
47
68
  @Override
48
69
  protected Integer doInBackground(Void... voids) {
49
70
  try {
71
+ String modelPath = options.getString("filePath");
72
+ boolean isBundleAsset = options.getBoolean("isBundleAsset");
73
+
74
+ String modelFilePath = modelPath;
75
+ if (!isBundleAsset && (modelPath.startsWith("http://") || modelPath.startsWith("https://"))) {
76
+ modelFilePath = downloader.downloadFile(modelPath);
77
+ }
78
+
50
79
  long context;
51
- if (isBundleAsset) {
52
- context = WhisperContext.initContextWithAsset(reactContext.getAssets(), modelPath);
80
+ int resId = getResourceIdentifier(modelFilePath);
81
+ if (resId > 0) {
82
+ context = WhisperContext.initContextWithInputStream(
83
+ new PushbackInputStream(reactContext.getResources().openRawResource(resId))
84
+ );
85
+ } else if (isBundleAsset) {
86
+ context = WhisperContext.initContextWithAsset(reactContext.getAssets(), modelFilePath);
53
87
  } else {
54
- context = WhisperContext.initContext(modelPath);
88
+ context = WhisperContext.initContext(modelFilePath);
55
89
  }
56
90
  if (context == 0) {
57
91
  throw new Exception("Failed to initialize context");
@@ -98,7 +132,26 @@ public class RNWhisperModule extends ReactContextBaseJavaModule implements Lifec
98
132
  @Override
99
133
  protected WritableMap doInBackground(Void... voids) {
100
134
  try {
101
- return context.transcribeFile(jobId, filePath, options);
135
+ String waveFilePath = filePath;
136
+
137
+ if (filePath.startsWith("http://") || filePath.startsWith("https://")) {
138
+ waveFilePath = downloader.downloadFile(filePath);
139
+ }
140
+
141
+ int resId = getResourceIdentifier(waveFilePath);
142
+ if (resId > 0) {
143
+ return context.transcribeInputStream(
144
+ (int) jobId,
145
+ reactContext.getResources().openRawResource(resId),
146
+ options
147
+ );
148
+ }
149
+
150
+ return context.transcribeInputStream(
151
+ (int) jobId,
152
+ new FileInputStream(new File(waveFilePath)),
153
+ options
154
+ );
102
155
  } catch (Exception e) {
103
156
  exception = e;
104
157
  return null;
@@ -217,5 +270,6 @@ public class RNWhisperModule extends ReactContextBaseJavaModule implements Lifec
217
270
  context.release();
218
271
  }
219
272
  contexts.clear();
273
+ downloader.clearCache();
220
274
  }
221
275
  }
package/ios/RNWhisper.mm CHANGED
@@ -1,5 +1,6 @@
1
1
  #import "RNWhisper.h"
2
2
  #import "RNWhisperContext.h"
3
+ #import "RNWhisperDownloader.h"
3
4
  #include <stdlib.h>
4
5
  #include <string>
5
6
 
@@ -15,7 +16,7 @@ RCT_EXPORT_MODULE()
15
16
 
16
17
  + (BOOL)requiresMainQueueSetup
17
18
  {
18
- return YES;
19
+ return NO;
19
20
  }
20
21
 
21
22
  - (NSDictionary *)constantsToExport
@@ -35,8 +36,7 @@ RCT_EXPORT_MODULE()
35
36
  }
36
37
 
37
38
  RCT_REMAP_METHOD(initContext,
38
- withPath:(NSString *)modelPath
39
- withBundleResource:(BOOL)isBundleAsset
39
+ withOptions:(NSDictionary *)modelOptions
40
40
  withResolver:(RCTPromiseResolveBlock)resolve
41
41
  withRejecter:(RCTPromiseRejectBlock)reject)
42
42
  {
@@ -44,7 +44,26 @@ RCT_REMAP_METHOD(initContext,
44
44
  contexts = [[NSMutableDictionary alloc] init];
45
45
  }
46
46
 
47
+ NSString *modelPath = [modelOptions objectForKey:@"filePath"];
48
+ BOOL isBundleAsset = [[modelOptions objectForKey:@"isBundleAsset"] boolValue];
49
+
50
+ // For support debug assets in development mode
51
+ BOOL downloadCoreMLAssets = [[modelOptions objectForKey:@"downloadCoreMLAssets"] boolValue];
52
+ if (downloadCoreMLAssets) {
53
+ NSArray *coreMLAssets = [modelOptions objectForKey:@"coreMLAssets"];
54
+ // Download coreMLAssets ([{ uri, filepath }])
55
+ for (NSDictionary *coreMLAsset in coreMLAssets) {
56
+ NSString *path = coreMLAsset[@"uri"];
57
+ if ([path hasPrefix:@"http://"] || [path hasPrefix:@"https://"]) {
58
+ [RNWhisperDownloader downloadFile:path toFile:coreMLAsset[@"filepath"]];
59
+ }
60
+ }
61
+ }
62
+
47
63
  NSString *path = modelPath;
64
+ if ([path hasPrefix:@"http://"] || [path hasPrefix:@"https://"]) {
65
+ path = [RNWhisperDownloader downloadFile:path toFile:nil];
66
+ }
48
67
  if (isBundleAsset) {
49
68
  path = [[NSBundle mainBundle] pathForResource:modelPath ofType:nil];
50
69
  }
@@ -84,10 +103,13 @@ RCT_REMAP_METHOD(transcribeFile,
84
103
  return;
85
104
  }
86
105
 
87
- NSURL *url = [NSURL fileURLWithPath:waveFilePath];
106
+ NSString *path = waveFilePath;
107
+ if ([path hasPrefix:@"http://"] || [path hasPrefix:@"https://"]) {
108
+ path = [RNWhisperDownloader downloadFile:path toFile:nil];
109
+ }
88
110
 
89
111
  int count = 0;
90
- float *waveFile = [self decodeWaveFile:url count:&count];
112
+ float *waveFile = [self decodeWaveFile:path count:&count];
91
113
  if (waveFile == nil) {
92
114
  reject(@"whisper_error", @"Invalid file", nil);
93
115
  return;
@@ -185,8 +207,9 @@ RCT_REMAP_METHOD(releaseAllContexts,
185
207
  resolve(nil);
186
208
  }
187
209
 
188
- - (float *)decodeWaveFile:(NSURL*)fileURL count:(int *)count {
189
- NSData *fileData = [NSData dataWithContentsOfURL:fileURL];
210
+ - (float *)decodeWaveFile:(NSString*)filePath count:(int *)count {
211
+ NSURL *url = [NSURL fileURLWithPath:filePath];
212
+ NSData *fileData = [NSData dataWithContentsOfURL:url];
190
213
  if (fileData == nil) {
191
214
  return nil;
192
215
  }
@@ -219,6 +242,8 @@ RCT_REMAP_METHOD(releaseAllContexts,
219
242
 
220
243
  [contexts removeAllObjects];
221
244
  contexts = nil;
245
+
246
+ [RNWhisperDownloader clearCache];
222
247
  }
223
248
 
224
249
  #ifdef RCT_NEW_ARCH_ENABLED