whisper.rn 0.4.0-rc.1 → 0.4.0-rc.11

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. package/README.md +6 -6
  2. package/android/build.gradle +4 -0
  3. package/android/src/main/CMakeLists.txt +21 -1
  4. package/android/src/main/java/com/rnwhisper/AudioUtils.java +27 -92
  5. package/android/src/main/java/com/rnwhisper/RNWhisper.java +86 -40
  6. package/android/src/main/java/com/rnwhisper/WhisperContext.java +85 -131
  7. package/android/src/main/jni-utils.h +76 -0
  8. package/android/src/main/jni.cpp +226 -109
  9. package/android/src/newarch/java/com/rnwhisper/RNWhisperModule.java +10 -0
  10. package/android/src/oldarch/java/com/rnwhisper/RNWhisperModule.java +10 -0
  11. package/cpp/coreml/whisper-encoder-impl.h +1 -1
  12. package/cpp/coreml/whisper-encoder.h +4 -0
  13. package/cpp/coreml/whisper-encoder.mm +5 -3
  14. package/cpp/ggml-alloc.c +797 -400
  15. package/cpp/ggml-alloc.h +60 -10
  16. package/cpp/ggml-backend-impl.h +255 -0
  17. package/cpp/ggml-backend-reg.cpp +582 -0
  18. package/cpp/ggml-backend.cpp +2002 -0
  19. package/cpp/ggml-backend.h +354 -0
  20. package/cpp/ggml-common.h +1851 -0
  21. package/cpp/ggml-cpp.h +39 -0
  22. package/cpp/ggml-cpu-aarch64.cpp +4247 -0
  23. package/cpp/ggml-cpu-aarch64.h +8 -0
  24. package/cpp/ggml-cpu-impl.h +531 -0
  25. package/cpp/ggml-cpu-quants.c +12245 -0
  26. package/cpp/ggml-cpu-quants.h +63 -0
  27. package/cpp/ggml-cpu-traits.cpp +36 -0
  28. package/cpp/ggml-cpu-traits.h +38 -0
  29. package/cpp/ggml-cpu.c +14792 -0
  30. package/cpp/ggml-cpu.cpp +653 -0
  31. package/cpp/ggml-cpu.h +137 -0
  32. package/cpp/ggml-impl.h +567 -0
  33. package/cpp/ggml-metal-impl.h +288 -0
  34. package/cpp/ggml-metal.h +24 -43
  35. package/cpp/ggml-metal.m +4867 -1080
  36. package/cpp/ggml-opt.cpp +854 -0
  37. package/cpp/ggml-opt.h +216 -0
  38. package/cpp/ggml-quants.c +5238 -0
  39. package/cpp/ggml-quants.h +100 -0
  40. package/cpp/ggml-threading.cpp +12 -0
  41. package/cpp/ggml-threading.h +14 -0
  42. package/cpp/ggml-whisper.metallib +0 -0
  43. package/cpp/ggml.c +5106 -19431
  44. package/cpp/ggml.h +847 -669
  45. package/cpp/gguf.cpp +1329 -0
  46. package/cpp/gguf.h +202 -0
  47. package/cpp/rn-audioutils.cpp +68 -0
  48. package/cpp/rn-audioutils.h +14 -0
  49. package/cpp/rn-whisper-log.h +11 -0
  50. package/cpp/rn-whisper.cpp +221 -52
  51. package/cpp/rn-whisper.h +50 -15
  52. package/cpp/whisper.cpp +3174 -1533
  53. package/cpp/whisper.h +176 -44
  54. package/ios/RNWhisper.mm +139 -46
  55. package/ios/RNWhisperAudioUtils.h +1 -2
  56. package/ios/RNWhisperAudioUtils.m +18 -67
  57. package/ios/RNWhisperContext.h +11 -8
  58. package/ios/RNWhisperContext.mm +195 -150
  59. package/jest/mock.js +15 -2
  60. package/lib/commonjs/NativeRNWhisper.js.map +1 -1
  61. package/lib/commonjs/index.js +76 -28
  62. package/lib/commonjs/index.js.map +1 -1
  63. package/lib/commonjs/version.json +1 -1
  64. package/lib/module/NativeRNWhisper.js.map +1 -1
  65. package/lib/module/index.js +76 -28
  66. package/lib/module/index.js.map +1 -1
  67. package/lib/module/version.json +1 -1
  68. package/lib/typescript/NativeRNWhisper.d.ts +13 -4
  69. package/lib/typescript/NativeRNWhisper.d.ts.map +1 -1
  70. package/lib/typescript/index.d.ts +37 -5
  71. package/lib/typescript/index.d.ts.map +1 -1
  72. package/package.json +9 -7
  73. package/src/NativeRNWhisper.ts +20 -4
  74. package/src/index.ts +98 -42
  75. package/src/version.json +1 -1
  76. package/whisper-rn.podspec +13 -20
  77. package/cpp/README.md +0 -4
  78. package/cpp/ggml-metal.metal +0 -2353
@@ -10,6 +10,7 @@
10
10
  #include "whisper.h"
11
11
  #include "rn-whisper.h"
12
12
  #include "ggml.h"
13
+ #include "jni-utils.h"
13
14
 
14
15
  #define UNUSED(x) (void)(x)
15
16
  #define TAG "JNI"
@@ -96,7 +97,8 @@ static void input_stream_close(void *ctx) {
96
97
 
97
98
  static struct whisper_context *whisper_init_from_input_stream(
98
99
  JNIEnv *env,
99
- jobject input_stream // PushbackInputStream
100
+ jobject input_stream, // PushbackInputStream
101
+ struct whisper_context_params cparams
100
102
  ) {
101
103
  input_stream_context *context = new input_stream_context;
102
104
  context->env = env;
@@ -108,7 +110,7 @@ static struct whisper_context *whisper_init_from_input_stream(
108
110
  .eof = &input_stream_is_eof,
109
111
  .close = &input_stream_close
110
112
  };
111
- return whisper_init(&loader);
113
+ return whisper_init_with_params(&loader, cparams);
112
114
  }
113
115
 
114
116
  // Load model from asset
@@ -127,7 +129,8 @@ static void asset_close(void *ctx) {
127
129
  static struct whisper_context *whisper_init_from_asset(
128
130
  JNIEnv *env,
129
131
  jobject assetManager,
130
- const char *asset_path
132
+ const char *asset_path,
133
+ struct whisper_context_params cparams
131
134
  ) {
132
135
  LOGI("Loading model from asset '%s'\n", asset_path);
133
136
  AAssetManager *asset_manager = AAssetManager_fromJava(env, assetManager);
@@ -142,7 +145,7 @@ static struct whisper_context *whisper_init_from_asset(
142
145
  .eof = &asset_is_eof,
143
146
  .close = &asset_close
144
147
  };
145
- return whisper_init(&loader);
148
+ return whisper_init_with_params(&loader, cparams);
146
149
  }
147
150
 
148
151
  extern "C" {
@@ -151,9 +154,15 @@ JNIEXPORT jlong JNICALL
151
154
  Java_com_rnwhisper_WhisperContext_initContext(
152
155
  JNIEnv *env, jobject thiz, jstring model_path_str) {
153
156
  UNUSED(thiz);
157
+ struct whisper_context_params cparams;
158
+
159
+ // TODO: Expose dtw_token_timestamps and dtw_aheads_preset
160
+ cparams.dtw_token_timestamps = false;
161
+ // cparams.dtw_aheads_preset = WHISPER_AHEADS_BASE;
162
+
154
163
  struct whisper_context *context = nullptr;
155
164
  const char *model_path_chars = env->GetStringUTFChars(model_path_str, nullptr);
156
- context = whisper_init_from_file(model_path_chars);
165
+ context = whisper_init_from_file_with_params(model_path_chars, cparams);
157
166
  env->ReleaseStringUTFChars(model_path_str, model_path_chars);
158
167
  return reinterpret_cast<jlong>(context);
159
168
  }
@@ -166,9 +175,15 @@ Java_com_rnwhisper_WhisperContext_initContextWithAsset(
166
175
  jstring model_path_str
167
176
  ) {
168
177
  UNUSED(thiz);
178
+ struct whisper_context_params cparams;
179
+
180
+ // TODO: Expose dtw_token_timestamps and dtw_aheads_preset
181
+ cparams.dtw_token_timestamps = false;
182
+ // cparams.dtw_aheads_preset = WHISPER_AHEADS_BASE;
183
+
169
184
  struct whisper_context *context = nullptr;
170
185
  const char *model_path_chars = env->GetStringUTFChars(model_path_str, nullptr);
171
- context = whisper_init_from_asset(env, asset_manager, model_path_chars);
186
+ context = whisper_init_from_asset(env, asset_manager, model_path_chars, cparams);
172
187
  env->ReleaseStringUTFChars(model_path_str, model_path_chars);
173
188
  return reinterpret_cast<jlong>(context);
174
189
  }
@@ -180,30 +195,70 @@ Java_com_rnwhisper_WhisperContext_initContextWithInputStream(
180
195
  jobject input_stream
181
196
  ) {
182
197
  UNUSED(thiz);
198
+ struct whisper_context_params cparams;
199
+
200
+ // TODO: Expose dtw_token_timestamps and dtw_aheads_preset
201
+ cparams.dtw_token_timestamps = false;
202
+ // cparams.dtw_aheads_preset = WHISPER_AHEADS_BASE;
203
+
183
204
  struct whisper_context *context = nullptr;
184
- context = whisper_init_from_input_stream(env, input_stream);
205
+ context = whisper_init_from_input_stream(env, input_stream, cparams);
185
206
  return reinterpret_cast<jlong>(context);
186
207
  }
187
208
 
188
- JNIEXPORT jboolean JNICALL
189
- Java_com_rnwhisper_WhisperContext_vadSimple(
190
- JNIEnv *env,
191
- jobject thiz,
192
- jfloatArray audio_data,
193
- jint audio_data_len,
194
- jfloat vad_thold,
195
- jfloat vad_freq_thold
196
- ) {
197
- UNUSED(thiz);
198
209
 
199
- std::vector<float> samples(audio_data_len);
200
- jfloat *audio_data_arr = env->GetFloatArrayElements(audio_data, nullptr);
201
- for (int i = 0; i < audio_data_len; i++) {
202
- samples[i] = audio_data_arr[i];
210
+ struct whisper_full_params createFullParams(JNIEnv *env, jobject options) {
211
+ struct whisper_full_params params = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
212
+
213
+ params.print_realtime = false;
214
+ params.print_progress = false;
215
+ params.print_timestamps = false;
216
+ params.print_special = false;
217
+
218
+ int max_threads = std::thread::hardware_concurrency();
219
+ // Use 2 threads by default on 4-core devices, 4 threads on more cores
220
+ int default_n_threads = max_threads == 4 ? 2 : min(4, max_threads);
221
+ int n_threads = readablemap::getInt(env, options, "maxThreads", default_n_threads);
222
+ params.n_threads = n_threads > 0 ? n_threads : default_n_threads;
223
+ params.translate = readablemap::getBool(env, options, "translate", false);
224
+ params.token_timestamps = readablemap::getBool(env, options, "tokenTimestamps", false);
225
+ params.tdrz_enable = readablemap::getBool(env, options, "tdrzEnable", false);
226
+ params.offset_ms = 0;
227
+ params.no_context = true;
228
+ params.single_segment = false;
229
+
230
+ int beam_size = readablemap::getInt(env, options, "beamSize", -1);
231
+ if (beam_size > -1) {
232
+ params.strategy = WHISPER_SAMPLING_BEAM_SEARCH;
233
+ params.beam_search.beam_size = beam_size;
203
234
  }
204
- bool is_speech = rn_whisper_vad_simple(samples, WHISPER_SAMPLE_RATE, 1000, vad_thold, vad_freq_thold, false);
205
- env->ReleaseFloatArrayElements(audio_data, audio_data_arr, JNI_ABORT);
206
- return is_speech;
235
+ int best_of = readablemap::getInt(env, options, "bestOf", -1);
236
+ if (best_of > -1) params.greedy.best_of = best_of;
237
+ int max_len = readablemap::getInt(env, options, "maxLen", -1);
238
+ if (max_len > -1) params.max_len = max_len;
239
+ int max_context = readablemap::getInt(env, options, "maxContext", -1);
240
+ if (max_context > -1) params.n_max_text_ctx = max_context;
241
+ int offset = readablemap::getInt(env, options, "offset", -1);
242
+ if (offset > -1) params.offset_ms = offset;
243
+ int duration = readablemap::getInt(env, options, "duration", -1);
244
+ if (duration > -1) params.duration_ms = duration;
245
+ int word_thold = readablemap::getInt(env, options, "wordThold", -1);
246
+ if (word_thold > -1) params.thold_pt = word_thold;
247
+ float temperature = readablemap::getFloat(env, options, "temperature", -1);
248
+ if (temperature > -1) params.temperature = temperature;
249
+ float temperature_inc = readablemap::getFloat(env, options, "temperatureInc", -1);
250
+ if (temperature_inc > -1) params.temperature_inc = temperature_inc;
251
+ jstring prompt = readablemap::getString(env, options, "prompt", nullptr);
252
+ if (prompt != nullptr) {
253
+ params.initial_prompt = env->GetStringUTFChars(prompt, nullptr);
254
+ env->DeleteLocalRef(prompt);
255
+ }
256
+ jstring language = readablemap::getString(env, options, "language", nullptr);
257
+ if (language != nullptr) {
258
+ params.language = env->GetStringUTFChars(language, nullptr);
259
+ env->DeleteLocalRef(language);
260
+ }
261
+ return params;
207
262
  }
208
263
 
209
264
  struct callback_context {
@@ -212,101 +267,23 @@ struct callback_context {
212
267
  };
213
268
 
214
269
  JNIEXPORT jint JNICALL
215
- Java_com_rnwhisper_WhisperContext_fullTranscribe(
270
+ Java_com_rnwhisper_WhisperContext_fullWithNewJob(
216
271
  JNIEnv *env,
217
272
  jobject thiz,
218
273
  jint job_id,
219
274
  jlong context_ptr,
220
275
  jfloatArray audio_data,
221
276
  jint audio_data_len,
222
- jint n_threads,
223
- jint max_context,
224
- int word_thold,
225
- int max_len,
226
- jboolean token_timestamps,
227
- jint offset,
228
- jint duration,
229
- jfloat temperature,
230
- jfloat temperature_inc,
231
- jint beam_size,
232
- jint best_of,
233
- jboolean speed_up,
234
- jboolean translate,
235
- jstring language,
236
- jstring prompt,
277
+ jobject options,
237
278
  jobject callback_instance
238
279
  ) {
239
280
  UNUSED(thiz);
240
281
  struct whisper_context *context = reinterpret_cast<struct whisper_context *>(context_ptr);
241
282
  jfloat *audio_data_arr = env->GetFloatArrayElements(audio_data, nullptr);
242
283
 
243
- int max_threads = std::thread::hardware_concurrency();
244
- // Use 2 threads by default on 4-core devices, 4 threads on more cores
245
- int default_n_threads = max_threads == 4 ? 2 : min(4, max_threads);
246
-
247
284
  LOGI("About to create params");
248
285
 
249
- struct whisper_full_params params = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
250
-
251
- if (beam_size > -1) {
252
- params.strategy = WHISPER_SAMPLING_BEAM_SEARCH;
253
- params.beam_search.beam_size = beam_size;
254
- }
255
-
256
- params.print_realtime = false;
257
- params.print_progress = false;
258
- params.print_timestamps = false;
259
- params.print_special = false;
260
- params.translate = translate;
261
- const char *language_chars = env->GetStringUTFChars(language, nullptr);
262
- params.language = language_chars;
263
- params.n_threads = n_threads > 0 ? n_threads : default_n_threads;
264
- params.speed_up = speed_up;
265
- params.offset_ms = 0;
266
- params.no_context = true;
267
- params.single_segment = false;
268
-
269
- if (max_len > -1) {
270
- params.max_len = max_len;
271
- }
272
- params.token_timestamps = token_timestamps;
273
-
274
- if (best_of > -1) {
275
- params.greedy.best_of = best_of;
276
- }
277
- if (max_context > -1) {
278
- params.n_max_text_ctx = max_context;
279
- }
280
- if (offset > -1) {
281
- params.offset_ms = offset;
282
- }
283
- if (duration > -1) {
284
- params.duration_ms = duration;
285
- }
286
- if (word_thold > -1) {
287
- params.thold_pt = word_thold;
288
- }
289
- if (temperature > -1) {
290
- params.temperature = temperature;
291
- }
292
- if (temperature_inc > -1) {
293
- params.temperature_inc = temperature_inc;
294
- }
295
- if (prompt != nullptr) {
296
- params.initial_prompt = env->GetStringUTFChars(prompt, nullptr);
297
- }
298
-
299
- // abort handlers
300
- params.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) {
301
- bool is_aborted = *(bool*)user_data;
302
- return !is_aborted;
303
- };
304
- params.encoder_begin_callback_user_data = rn_whisper_assign_abort_map(job_id);
305
- params.abort_callback = [](void * user_data) {
306
- bool is_aborted = *(bool*)user_data;
307
- return is_aborted;
308
- };
309
- params.abort_callback_user_data = rn_whisper_assign_abort_map(job_id);
286
+ whisper_full_params params = createFullParams(env, options);
310
287
 
311
288
  if (callback_instance != nullptr) {
312
289
  callback_context *cb_ctx = new callback_context;
@@ -334,6 +311,8 @@ Java_com_rnwhisper_WhisperContext_fullTranscribe(
334
311
  params.new_segment_callback_user_data = cb_ctx;
335
312
  }
336
313
 
314
+ rnwhisper::job* job = rnwhisper::job_new(job_id, params);
315
+
337
316
  LOGI("About to reset timings");
338
317
  whisper_reset_timings(context);
339
318
 
@@ -343,8 +322,123 @@ Java_com_rnwhisper_WhisperContext_fullTranscribe(
343
322
  // whisper_print_timings(context);
344
323
  }
345
324
  env->ReleaseFloatArrayElements(audio_data, audio_data_arr, JNI_ABORT);
346
- env->ReleaseStringUTFChars(language, language_chars);
347
- rn_whisper_remove_abort_map(job_id);
325
+
326
+ if (job->is_aborted()) code = -999;
327
+ rnwhisper::job_remove(job_id);
328
+ return code;
329
+ }
330
+
331
+ JNIEXPORT void JNICALL
332
+ Java_com_rnwhisper_WhisperContext_createRealtimeTranscribeJob(
333
+ JNIEnv *env,
334
+ jobject thiz,
335
+ jint job_id,
336
+ jlong context_ptr,
337
+ jobject options
338
+ ) {
339
+ whisper_full_params params = createFullParams(env, options);
340
+ rnwhisper::job* job = rnwhisper::job_new(job_id, params);
341
+ rnwhisper::vad_params vad;
342
+ vad.use_vad = readablemap::getBool(env, options, "useVad", false);
343
+ vad.vad_ms = readablemap::getInt(env, options, "vadMs", 2000);
344
+ vad.vad_thold = readablemap::getFloat(env, options, "vadThold", 0.6f);
345
+ vad.freq_thold = readablemap::getFloat(env, options, "vadFreqThold", 100.0f);
346
+
347
+ jstring audio_output_path = readablemap::getString(env, options, "audioOutputPath", nullptr);
348
+ const char* audio_output_path_str = nullptr;
349
+ if (audio_output_path != nullptr) {
350
+ audio_output_path_str = env->GetStringUTFChars(audio_output_path, nullptr);
351
+ env->DeleteLocalRef(audio_output_path);
352
+ }
353
+ job->set_realtime_params(
354
+ vad,
355
+ readablemap::getInt(env, options, "realtimeAudioSec", 0),
356
+ readablemap::getInt(env, options, "realtimeAudioSliceSec", 0),
357
+ readablemap::getFloat(env, options, "realtimeAudioMinSec", 0),
358
+ audio_output_path_str
359
+ );
360
+ }
361
+
362
+ JNIEXPORT void JNICALL
363
+ Java_com_rnwhisper_WhisperContext_finishRealtimeTranscribeJob(
364
+ JNIEnv *env,
365
+ jobject thiz,
366
+ jint job_id,
367
+ jlong context_ptr,
368
+ jintArray slice_n_samples
369
+ ) {
370
+ UNUSED(env);
371
+ UNUSED(thiz);
372
+ UNUSED(context_ptr);
373
+
374
+ rnwhisper::job *job = rnwhisper::job_get(job_id);
375
+ if (job->audio_output_path != nullptr) {
376
+ RNWHISPER_LOG_INFO("job->params.language: %s\n", job->params.language);
377
+ std::vector<int> slice_n_samples_vec;
378
+ jint *slice_n_samples_arr = env->GetIntArrayElements(slice_n_samples, nullptr);
379
+ slice_n_samples_vec = std::vector<int>(slice_n_samples_arr, slice_n_samples_arr + env->GetArrayLength(slice_n_samples));
380
+ env->ReleaseIntArrayElements(slice_n_samples, slice_n_samples_arr, JNI_ABORT);
381
+
382
+ // TODO: Append in real time so we don't need to keep all slices & also reduce memory usage
383
+ rnaudioutils::save_wav_file(
384
+ rnaudioutils::concat_short_buffers(job->pcm_slices, slice_n_samples_vec),
385
+ job->audio_output_path
386
+ );
387
+ }
388
+ rnwhisper::job_remove(job_id);
389
+ }
390
+
391
+ JNIEXPORT jboolean JNICALL
392
+ Java_com_rnwhisper_WhisperContext_vadSimple(
393
+ JNIEnv *env,
394
+ jobject thiz,
395
+ jint job_id,
396
+ jint slice_index,
397
+ jint n_samples,
398
+ jint n
399
+ ) {
400
+ UNUSED(thiz);
401
+ rnwhisper::job* job = rnwhisper::job_get(job_id);
402
+ return job->vad_simple(slice_index, n_samples, n);
403
+ }
404
+
405
+ JNIEXPORT void JNICALL
406
+ Java_com_rnwhisper_WhisperContext_putPcmData(
407
+ JNIEnv *env,
408
+ jobject thiz,
409
+ jint job_id,
410
+ jshortArray pcm,
411
+ jint slice_index,
412
+ jint n_samples,
413
+ jint n
414
+ ) {
415
+ UNUSED(thiz);
416
+ rnwhisper::job* job = rnwhisper::job_get(job_id);
417
+ jshort *pcm_arr = env->GetShortArrayElements(pcm, nullptr);
418
+ job->put_pcm_data(pcm_arr, slice_index, n_samples, n);
419
+ env->ReleaseShortArrayElements(pcm, pcm_arr, JNI_ABORT);
420
+ }
421
+
422
+ JNIEXPORT jint JNICALL
423
+ Java_com_rnwhisper_WhisperContext_fullWithJob(
424
+ JNIEnv *env,
425
+ jobject thiz,
426
+ jint job_id,
427
+ jlong context_ptr,
428
+ jint slice_index,
429
+ jint n_samples
430
+ ) {
431
+ UNUSED(thiz);
432
+ struct whisper_context *context = reinterpret_cast<struct whisper_context *>(context_ptr);
433
+
434
+ rnwhisper::job* job = rnwhisper::job_get(job_id);
435
+ float* pcmf32 = job->pcm_slice_to_f32(slice_index, n_samples);
436
+ int code = whisper_full(context, job->params, pcmf32, n_samples);
437
+ free(pcmf32);
438
+ if (code == 0) {
439
+ // whisper_print_timings(context);
440
+ }
441
+ if (job->is_aborted()) code = -999;
348
442
  return code;
349
443
  }
350
444
 
@@ -355,7 +449,8 @@ Java_com_rnwhisper_WhisperContext_abortTranscribe(
355
449
  jint job_id
356
450
  ) {
357
451
  UNUSED(thiz);
358
- rn_whisper_abort_transcribe(job_id);
452
+ rnwhisper::job *job = rnwhisper::job_get(job_id);
453
+ if (job) job->abort();
359
454
  }
360
455
 
361
456
  JNIEXPORT void JNICALL
@@ -364,7 +459,7 @@ Java_com_rnwhisper_WhisperContext_abortAllTranscribe(
364
459
  jobject thiz
365
460
  ) {
366
461
  UNUSED(thiz);
367
- rn_whisper_abort_all_transcribe();
462
+ rnwhisper::job_abort_all();
368
463
  }
369
464
 
370
465
  JNIEXPORT jint JNICALL
@@ -413,4 +508,26 @@ Java_com_rnwhisper_WhisperContext_freeContext(
413
508
  whisper_free(context);
414
509
  }
415
510
 
511
+ JNIEXPORT jboolean JNICALL
512
+ Java_com_rnwhisper_WhisperContext_getTextSegmentSpeakerTurnNext(
513
+ JNIEnv *env, jobject thiz, jlong context_ptr, jint index) {
514
+ UNUSED(env);
515
+ UNUSED(thiz);
516
+ struct whisper_context *context = reinterpret_cast<struct whisper_context *>(context_ptr);
517
+ return whisper_full_get_segment_speaker_turn_next(context, index);
518
+ }
519
+
520
+ JNIEXPORT jstring JNICALL
521
+ Java_com_rnwhisper_WhisperContext_bench(
522
+ JNIEnv *env,
523
+ jobject thiz,
524
+ jlong context_ptr,
525
+ jint n_threads
526
+ ) {
527
+ UNUSED(thiz);
528
+ struct whisper_context *context = reinterpret_cast<struct whisper_context *>(context_ptr);
529
+ std::string result = rnwhisper::bench(context, n_threads);
530
+ return env->NewStringUTF(result.c_str());
531
+ }
532
+
416
533
  } // extern "C"
@@ -47,6 +47,11 @@ public class RNWhisperModule extends NativeRNWhisperSpec {
47
47
  rnwhisper.transcribeFile(id, jobId, filePath, options, promise);
48
48
  }
49
49
 
50
+ @ReactMethod
51
+ public void transcribeData(double id, double jobId, String dataBase64, ReadableMap options, Promise promise) {
52
+ rnwhisper.transcribeData(id, jobId, dataBase64, options, promise);
53
+ }
54
+
50
55
  @ReactMethod
51
56
  public void startRealtimeTranscribe(double id, double jobId, ReadableMap options, Promise promise) {
52
57
  rnwhisper.startRealtimeTranscribe(id, jobId, options, promise);
@@ -57,6 +62,11 @@ public class RNWhisperModule extends NativeRNWhisperSpec {
57
62
  rnwhisper.abortTranscribe(contextId, jobId, promise);
58
63
  }
59
64
 
65
+ @ReactMethod
66
+ public void bench(double id, double nThreads, Promise promise) {
67
+ rnwhisper.bench(id, nThreads, promise);
68
+ }
69
+
60
70
  @ReactMethod
61
71
  public void releaseContext(double id, Promise promise) {
62
72
  rnwhisper.releaseContext(id, promise);
@@ -47,6 +47,11 @@ public class RNWhisperModule extends ReactContextBaseJavaModule {
47
47
  rnwhisper.transcribeFile(id, jobId, filePath, options, promise);
48
48
  }
49
49
 
50
+ @ReactMethod
51
+ public void transcribeData(double id, double jobId, String dataBase64, ReadableMap options, Promise promise) {
52
+ rnwhisper.transcribeData(id, jobId, dataBase64, options, promise);
53
+ }
54
+
50
55
  @ReactMethod
51
56
  public void startRealtimeTranscribe(double id, double jobId, ReadableMap options, Promise promise) {
52
57
  rnwhisper.startRealtimeTranscribe(id, jobId, options, promise);
@@ -57,6 +62,11 @@ public class RNWhisperModule extends ReactContextBaseJavaModule {
57
62
  rnwhisper.abortTranscribe(contextId, jobId, promise);
58
63
  }
59
64
 
65
+ @ReactMethod
66
+ public void bench(double id, double nThreads, Promise promise) {
67
+ rnwhisper.bench(id, nThreads, promise);
68
+ }
69
+
60
70
  @ReactMethod
61
71
  public void releaseContext(double id, Promise promise) {
62
72
  rnwhisper.releaseContext(id, promise);
@@ -123,7 +123,7 @@ API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((v
123
123
 
124
124
  /**
125
125
  Make a prediction using the convenience interface
126
- @param logmel_data as 1 × 80 × 3000 3-dimensional array of floats:
126
+ @param logmel_data as 1 × n_mel × 3000 3-dimensional array of floats:
127
127
  @param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
128
128
  @return the prediction as whisper_encoder_implOutput
129
129
  */
@@ -3,6 +3,8 @@
3
3
  // Code is derived from the work of Github user @wangchou
4
4
  // ref: https://github.com/wangchou/callCoreMLFromCpp
5
5
 
6
+ #include <stdint.h>
7
+
6
8
  #if __cplusplus
7
9
  extern "C" {
8
10
  #endif
@@ -14,6 +16,8 @@ void whisper_coreml_free(struct whisper_coreml_context * ctx);
14
16
 
15
17
  void whisper_coreml_encode(
16
18
  const whisper_coreml_context * ctx,
19
+ int64_t n_ctx,
20
+ int64_t n_mel,
17
21
  float * mel,
18
22
  float * out);
19
23
 
@@ -24,7 +24,7 @@ struct whisper_coreml_context * whisper_coreml_init(const char * path_model) {
24
24
 
25
25
  // select which device to run the Core ML model on
26
26
  MLModelConfiguration *config = [[MLModelConfiguration alloc] init];
27
- //config.computeUnits = MLComputeUnitsCPUAndGPU;
27
+ // config.computeUnits = MLComputeUnitsCPUAndGPU;
28
28
  //config.computeUnits = MLComputeUnitsCPUAndNeuralEngine;
29
29
  config.computeUnits = MLComputeUnitsAll;
30
30
 
@@ -48,13 +48,15 @@ void whisper_coreml_free(struct whisper_coreml_context * ctx) {
48
48
 
49
49
  void whisper_coreml_encode(
50
50
  const whisper_coreml_context * ctx,
51
+ int64_t n_ctx,
52
+ int64_t n_mel,
51
53
  float * mel,
52
54
  float * out) {
53
55
  MLMultiArray * inMultiArray = [
54
56
  [MLMultiArray alloc] initWithDataPointer: mel
55
- shape: @[@1, @80, @3000]
57
+ shape: @[@1, @(n_mel), @(n_ctx)]
56
58
  dataType: MLMultiArrayDataTypeFloat32
57
- strides: @[@(240000), @(3000), @1]
59
+ strides: @[@(n_ctx*n_mel), @(n_ctx), @1]
58
60
  deallocator: nil
59
61
  error: nil
60
62
  ];