whisper.rn 0.4.0-rc.3 → 0.4.0-rc.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. package/README.md +6 -6
  2. package/android/build.gradle +4 -0
  3. package/android/src/main/CMakeLists.txt +7 -0
  4. package/android/src/main/java/com/rnwhisper/AudioUtils.java +0 -80
  5. package/android/src/main/java/com/rnwhisper/RNWhisper.java +6 -1
  6. package/android/src/main/java/com/rnwhisper/WhisperContext.java +53 -135
  7. package/android/src/main/jni-utils.h +76 -0
  8. package/android/src/main/jni.cpp +188 -109
  9. package/cpp/README.md +1 -1
  10. package/cpp/coreml/whisper-encoder-impl.h +1 -1
  11. package/cpp/coreml/whisper-encoder.h +4 -0
  12. package/cpp/coreml/whisper-encoder.mm +4 -2
  13. package/cpp/ggml-alloc.c +451 -282
  14. package/cpp/ggml-alloc.h +74 -8
  15. package/cpp/ggml-backend-impl.h +112 -0
  16. package/cpp/ggml-backend.c +1357 -0
  17. package/cpp/ggml-backend.h +181 -0
  18. package/cpp/ggml-impl.h +243 -0
  19. package/cpp/{ggml-metal.metal → ggml-metal-whisper.metal} +1556 -329
  20. package/cpp/ggml-metal.h +28 -1
  21. package/cpp/ggml-metal.m +1128 -308
  22. package/cpp/ggml-quants.c +7382 -0
  23. package/cpp/ggml-quants.h +224 -0
  24. package/cpp/ggml.c +3848 -5245
  25. package/cpp/ggml.h +353 -155
  26. package/cpp/rn-audioutils.cpp +68 -0
  27. package/cpp/rn-audioutils.h +14 -0
  28. package/cpp/rn-whisper-log.h +11 -0
  29. package/cpp/rn-whisper.cpp +141 -59
  30. package/cpp/rn-whisper.h +47 -15
  31. package/cpp/whisper.cpp +1750 -964
  32. package/cpp/whisper.h +97 -15
  33. package/ios/RNWhisper.mm +15 -9
  34. package/ios/RNWhisper.xcodeproj/project.xcworkspace/contents.xcworkspacedata +4 -0
  35. package/ios/RNWhisper.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist +8 -0
  36. package/ios/RNWhisper.xcodeproj/project.xcworkspace/xcuserdata/jhen.xcuserdatad/UserInterfaceState.xcuserstate +0 -0
  37. package/ios/RNWhisper.xcodeproj/xcuserdata/jhen.xcuserdatad/xcschemes/xcschememanagement.plist +19 -0
  38. package/ios/RNWhisperAudioUtils.h +0 -2
  39. package/ios/RNWhisperAudioUtils.m +0 -56
  40. package/ios/RNWhisperContext.h +8 -12
  41. package/ios/RNWhisperContext.mm +132 -138
  42. package/jest/mock.js +1 -1
  43. package/lib/commonjs/NativeRNWhisper.js.map +1 -1
  44. package/lib/commonjs/index.js +28 -9
  45. package/lib/commonjs/index.js.map +1 -1
  46. package/lib/commonjs/version.json +1 -1
  47. package/lib/module/NativeRNWhisper.js.map +1 -1
  48. package/lib/module/index.js +28 -9
  49. package/lib/module/index.js.map +1 -1
  50. package/lib/module/version.json +1 -1
  51. package/lib/typescript/NativeRNWhisper.d.ts +7 -1
  52. package/lib/typescript/NativeRNWhisper.d.ts.map +1 -1
  53. package/lib/typescript/index.d.ts +7 -2
  54. package/lib/typescript/index.d.ts.map +1 -1
  55. package/package.json +6 -5
  56. package/src/NativeRNWhisper.ts +8 -1
  57. package/src/index.ts +29 -17
  58. package/src/version.json +1 -1
  59. package/whisper-rn.podspec +1 -2
@@ -10,6 +10,7 @@
10
10
  #include "whisper.h"
11
11
  #include "rn-whisper.h"
12
12
  #include "ggml.h"
13
+ #include "jni-utils.h"
13
14
 
14
15
  #define UNUSED(x) (void)(x)
15
16
  #define TAG "JNI"
@@ -96,7 +97,8 @@ static void input_stream_close(void *ctx) {
96
97
 
97
98
  static struct whisper_context *whisper_init_from_input_stream(
98
99
  JNIEnv *env,
99
- jobject input_stream // PushbackInputStream
100
+ jobject input_stream, // PushbackInputStream
101
+ struct whisper_context_params cparams
100
102
  ) {
101
103
  input_stream_context *context = new input_stream_context;
102
104
  context->env = env;
@@ -108,7 +110,7 @@ static struct whisper_context *whisper_init_from_input_stream(
108
110
  .eof = &input_stream_is_eof,
109
111
  .close = &input_stream_close
110
112
  };
111
- return whisper_init(&loader);
113
+ return whisper_init_with_params(&loader, cparams);
112
114
  }
113
115
 
114
116
  // Load model from asset
@@ -127,7 +129,8 @@ static void asset_close(void *ctx) {
127
129
  static struct whisper_context *whisper_init_from_asset(
128
130
  JNIEnv *env,
129
131
  jobject assetManager,
130
- const char *asset_path
132
+ const char *asset_path,
133
+ struct whisper_context_params cparams
131
134
  ) {
132
135
  LOGI("Loading model from asset '%s'\n", asset_path);
133
136
  AAssetManager *asset_manager = AAssetManager_fromJava(env, assetManager);
@@ -142,7 +145,7 @@ static struct whisper_context *whisper_init_from_asset(
142
145
  .eof = &asset_is_eof,
143
146
  .close = &asset_close
144
147
  };
145
- return whisper_init(&loader);
148
+ return whisper_init_with_params(&loader, cparams);
146
149
  }
147
150
 
148
151
  extern "C" {
@@ -151,9 +154,10 @@ JNIEXPORT jlong JNICALL
151
154
  Java_com_rnwhisper_WhisperContext_initContext(
152
155
  JNIEnv *env, jobject thiz, jstring model_path_str) {
153
156
  UNUSED(thiz);
157
+ struct whisper_context_params cparams;
154
158
  struct whisper_context *context = nullptr;
155
159
  const char *model_path_chars = env->GetStringUTFChars(model_path_str, nullptr);
156
- context = whisper_init_from_file(model_path_chars);
160
+ context = whisper_init_from_file_with_params(model_path_chars, cparams);
157
161
  env->ReleaseStringUTFChars(model_path_str, model_path_chars);
158
162
  return reinterpret_cast<jlong>(context);
159
163
  }
@@ -166,9 +170,10 @@ Java_com_rnwhisper_WhisperContext_initContextWithAsset(
166
170
  jstring model_path_str
167
171
  ) {
168
172
  UNUSED(thiz);
173
+ struct whisper_context_params cparams;
169
174
  struct whisper_context *context = nullptr;
170
175
  const char *model_path_chars = env->GetStringUTFChars(model_path_str, nullptr);
171
- context = whisper_init_from_asset(env, asset_manager, model_path_chars);
176
+ context = whisper_init_from_asset(env, asset_manager, model_path_chars, cparams);
172
177
  env->ReleaseStringUTFChars(model_path_str, model_path_chars);
173
178
  return reinterpret_cast<jlong>(context);
174
179
  }
@@ -180,30 +185,65 @@ Java_com_rnwhisper_WhisperContext_initContextWithInputStream(
180
185
  jobject input_stream
181
186
  ) {
182
187
  UNUSED(thiz);
188
+ struct whisper_context_params cparams;
183
189
  struct whisper_context *context = nullptr;
184
- context = whisper_init_from_input_stream(env, input_stream);
190
+ context = whisper_init_from_input_stream(env, input_stream, cparams);
185
191
  return reinterpret_cast<jlong>(context);
186
192
  }
187
193
 
188
- JNIEXPORT jboolean JNICALL
189
- Java_com_rnwhisper_WhisperContext_vadSimple(
190
- JNIEnv *env,
191
- jobject thiz,
192
- jfloatArray audio_data,
193
- jint audio_data_len,
194
- jfloat vad_thold,
195
- jfloat vad_freq_thold
196
- ) {
197
- UNUSED(thiz);
198
194
 
199
- std::vector<float> samples(audio_data_len);
200
- jfloat *audio_data_arr = env->GetFloatArrayElements(audio_data, nullptr);
201
- for (int i = 0; i < audio_data_len; i++) {
202
- samples[i] = audio_data_arr[i];
195
+ struct whisper_full_params createFullParams(JNIEnv *env, jobject options) {
196
+ struct whisper_full_params params = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
197
+
198
+ params.print_realtime = false;
199
+ params.print_progress = false;
200
+ params.print_timestamps = false;
201
+ params.print_special = false;
202
+
203
+ int max_threads = std::thread::hardware_concurrency();
204
+ // Use 2 threads by default on 4-core devices, 4 threads on more cores
205
+ int default_n_threads = max_threads == 4 ? 2 : min(4, max_threads);
206
+ int n_threads = readablemap::getInt(env, options, "maxThreads", default_n_threads);
207
+ params.n_threads = n_threads > 0 ? n_threads : default_n_threads;
208
+ params.translate = readablemap::getBool(env, options, "translate", false);
209
+ params.speed_up = readablemap::getBool(env, options, "speedUp", false);
210
+ params.token_timestamps = readablemap::getBool(env, options, "tokenTimestamps", false);
211
+ params.offset_ms = 0;
212
+ params.no_context = true;
213
+ params.single_segment = false;
214
+
215
+ int beam_size = readablemap::getInt(env, options, "beamSize", -1);
216
+ if (beam_size > -1) {
217
+ params.strategy = WHISPER_SAMPLING_BEAM_SEARCH;
218
+ params.beam_search.beam_size = beam_size;
203
219
  }
204
- bool is_speech = rn_whisper_vad_simple(samples, WHISPER_SAMPLE_RATE, 1000, vad_thold, vad_freq_thold, false);
205
- env->ReleaseFloatArrayElements(audio_data, audio_data_arr, JNI_ABORT);
206
- return is_speech;
220
+ int best_of = readablemap::getInt(env, options, "bestOf", -1);
221
+ if (best_of > -1) params.greedy.best_of = best_of;
222
+ int max_len = readablemap::getInt(env, options, "maxLen", -1);
223
+ if (max_len > -1) params.max_len = max_len;
224
+ int max_context = readablemap::getInt(env, options, "maxContext", -1);
225
+ if (max_context > -1) params.n_max_text_ctx = max_context;
226
+ int offset = readablemap::getInt(env, options, "offset", -1);
227
+ if (offset > -1) params.offset_ms = offset;
228
+ int duration = readablemap::getInt(env, options, "duration", -1);
229
+ if (duration > -1) params.duration_ms = duration;
230
+ int word_thold = readablemap::getInt(env, options, "wordThold", -1);
231
+ if (word_thold > -1) params.thold_pt = word_thold;
232
+ float temperature = readablemap::getFloat(env, options, "temperature", -1);
233
+ if (temperature > -1) params.temperature = temperature;
234
+ float temperature_inc = readablemap::getFloat(env, options, "temperatureInc", -1);
235
+ if (temperature_inc > -1) params.temperature_inc = temperature_inc;
236
+ jstring prompt = readablemap::getString(env, options, "prompt", nullptr);
237
+ if (prompt != nullptr) {
238
+ params.initial_prompt = env->GetStringUTFChars(prompt, nullptr);
239
+ env->DeleteLocalRef(prompt);
240
+ }
241
+ jstring language = readablemap::getString(env, options, "language", nullptr);
242
+ if (language != nullptr) {
243
+ params.language = env->GetStringUTFChars(language, nullptr);
244
+ env->DeleteLocalRef(language);
245
+ }
246
+ return params;
207
247
  }
208
248
 
209
249
  struct callback_context {
@@ -212,101 +252,23 @@ struct callback_context {
212
252
  };
213
253
 
214
254
  JNIEXPORT jint JNICALL
215
- Java_com_rnwhisper_WhisperContext_fullTranscribe(
255
+ Java_com_rnwhisper_WhisperContext_fullWithNewJob(
216
256
  JNIEnv *env,
217
257
  jobject thiz,
218
258
  jint job_id,
219
259
  jlong context_ptr,
220
260
  jfloatArray audio_data,
221
261
  jint audio_data_len,
222
- jint n_threads,
223
- jint max_context,
224
- int word_thold,
225
- int max_len,
226
- jboolean token_timestamps,
227
- jint offset,
228
- jint duration,
229
- jfloat temperature,
230
- jfloat temperature_inc,
231
- jint beam_size,
232
- jint best_of,
233
- jboolean speed_up,
234
- jboolean translate,
235
- jstring language,
236
- jstring prompt,
262
+ jobject options,
237
263
  jobject callback_instance
238
264
  ) {
239
265
  UNUSED(thiz);
240
266
  struct whisper_context *context = reinterpret_cast<struct whisper_context *>(context_ptr);
241
267
  jfloat *audio_data_arr = env->GetFloatArrayElements(audio_data, nullptr);
242
268
 
243
- int max_threads = std::thread::hardware_concurrency();
244
- // Use 2 threads by default on 4-core devices, 4 threads on more cores
245
- int default_n_threads = max_threads == 4 ? 2 : min(4, max_threads);
246
-
247
269
  LOGI("About to create params");
248
270
 
249
- struct whisper_full_params params = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
250
-
251
- if (beam_size > -1) {
252
- params.strategy = WHISPER_SAMPLING_BEAM_SEARCH;
253
- params.beam_search.beam_size = beam_size;
254
- }
255
-
256
- params.print_realtime = false;
257
- params.print_progress = false;
258
- params.print_timestamps = false;
259
- params.print_special = false;
260
- params.translate = translate;
261
- const char *language_chars = env->GetStringUTFChars(language, nullptr);
262
- params.language = language_chars;
263
- params.n_threads = n_threads > 0 ? n_threads : default_n_threads;
264
- params.speed_up = speed_up;
265
- params.offset_ms = 0;
266
- params.no_context = true;
267
- params.single_segment = false;
268
-
269
- if (max_len > -1) {
270
- params.max_len = max_len;
271
- }
272
- params.token_timestamps = token_timestamps;
273
-
274
- if (best_of > -1) {
275
- params.greedy.best_of = best_of;
276
- }
277
- if (max_context > -1) {
278
- params.n_max_text_ctx = max_context;
279
- }
280
- if (offset > -1) {
281
- params.offset_ms = offset;
282
- }
283
- if (duration > -1) {
284
- params.duration_ms = duration;
285
- }
286
- if (word_thold > -1) {
287
- params.thold_pt = word_thold;
288
- }
289
- if (temperature > -1) {
290
- params.temperature = temperature;
291
- }
292
- if (temperature_inc > -1) {
293
- params.temperature_inc = temperature_inc;
294
- }
295
- if (prompt != nullptr) {
296
- params.initial_prompt = env->GetStringUTFChars(prompt, nullptr);
297
- }
298
-
299
- // abort handlers
300
- params.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) {
301
- bool is_aborted = *(bool*)user_data;
302
- return !is_aborted;
303
- };
304
- params.encoder_begin_callback_user_data = rn_whisper_assign_abort_map(job_id);
305
- params.abort_callback = [](void * user_data) {
306
- bool is_aborted = *(bool*)user_data;
307
- return is_aborted;
308
- };
309
- params.abort_callback_user_data = rn_whisper_assign_abort_map(job_id);
271
+ whisper_full_params params = createFullParams(env, options);
310
272
 
311
273
  if (callback_instance != nullptr) {
312
274
  callback_context *cb_ctx = new callback_context;
@@ -334,6 +296,8 @@ Java_com_rnwhisper_WhisperContext_fullTranscribe(
334
296
  params.new_segment_callback_user_data = cb_ctx;
335
297
  }
336
298
 
299
+ rnwhisper::job* job = rnwhisper::job_new(job_id, params);
300
+
337
301
  LOGI("About to reset timings");
338
302
  whisper_reset_timings(context);
339
303
 
@@ -343,8 +307,122 @@ Java_com_rnwhisper_WhisperContext_fullTranscribe(
343
307
  // whisper_print_timings(context);
344
308
  }
345
309
  env->ReleaseFloatArrayElements(audio_data, audio_data_arr, JNI_ABORT);
346
- env->ReleaseStringUTFChars(language, language_chars);
347
- rn_whisper_remove_abort_map(job_id);
310
+
311
+ if (job->is_aborted()) code = -999;
312
+ rnwhisper::job_remove(job_id);
313
+ return code;
314
+ }
315
+
316
+ JNIEXPORT void JNICALL
317
+ Java_com_rnwhisper_WhisperContext_createRealtimeTranscribeJob(
318
+ JNIEnv *env,
319
+ jobject thiz,
320
+ jint job_id,
321
+ jlong context_ptr,
322
+ jobject options
323
+ ) {
324
+ whisper_full_params params = createFullParams(env, options);
325
+ rnwhisper::job* job = rnwhisper::job_new(job_id, params);
326
+ rnwhisper::vad_params vad;
327
+ vad.use_vad = readablemap::getBool(env, options, "useVad", false);
328
+ vad.vad_ms = readablemap::getInt(env, options, "vadMs", 2000);
329
+ vad.vad_thold = readablemap::getFloat(env, options, "vadThold", 0.6f);
330
+ vad.freq_thold = readablemap::getFloat(env, options, "vadFreqThold", 100.0f);
331
+
332
+ jstring audio_output_path = readablemap::getString(env, options, "audioOutputPath", nullptr);
333
+ const char* audio_output_path_str = nullptr;
334
+ if (audio_output_path != nullptr) {
335
+ audio_output_path_str = env->GetStringUTFChars(audio_output_path, nullptr);
336
+ env->DeleteLocalRef(audio_output_path);
337
+ }
338
+ job->set_realtime_params(
339
+ vad,
340
+ readablemap::getInt(env, options, "realtimeAudioSec", 0),
341
+ readablemap::getInt(env, options, "realtimeAudioSliceSec", 0),
342
+ audio_output_path_str
343
+ );
344
+ }
345
+
346
+ JNIEXPORT void JNICALL
347
+ Java_com_rnwhisper_WhisperContext_finishRealtimeTranscribeJob(
348
+ JNIEnv *env,
349
+ jobject thiz,
350
+ jint job_id,
351
+ jlong context_ptr,
352
+ jintArray slice_n_samples
353
+ ) {
354
+ UNUSED(env);
355
+ UNUSED(thiz);
356
+ UNUSED(context_ptr);
357
+
358
+ rnwhisper::job *job = rnwhisper::job_get(job_id);
359
+ if (job->audio_output_path != nullptr) {
360
+ RNWHISPER_LOG_INFO("job->params.language: %s\n", job->params.language);
361
+ std::vector<int> slice_n_samples_vec;
362
+ jint *slice_n_samples_arr = env->GetIntArrayElements(slice_n_samples, nullptr);
363
+ slice_n_samples_vec = std::vector<int>(slice_n_samples_arr, slice_n_samples_arr + env->GetArrayLength(slice_n_samples));
364
+ env->ReleaseIntArrayElements(slice_n_samples, slice_n_samples_arr, JNI_ABORT);
365
+
366
+ // TODO: Append in real time so we don't need to keep all slices & also reduce memory usage
367
+ rnaudioutils::save_wav_file(
368
+ rnaudioutils::concat_short_buffers(job->pcm_slices, slice_n_samples_vec),
369
+ job->audio_output_path
370
+ );
371
+ }
372
+ rnwhisper::job_remove(job_id);
373
+ }
374
+
375
+ JNIEXPORT jboolean JNICALL
376
+ Java_com_rnwhisper_WhisperContext_vadSimple(
377
+ JNIEnv *env,
378
+ jobject thiz,
379
+ jint job_id,
380
+ jint slice_index,
381
+ jint n_samples,
382
+ jint n
383
+ ) {
384
+ UNUSED(thiz);
385
+ rnwhisper::job* job = rnwhisper::job_get(job_id);
386
+ return job->vad_simple(slice_index, n_samples, n);
387
+ }
388
+
389
+ JNIEXPORT void JNICALL
390
+ Java_com_rnwhisper_WhisperContext_putPcmData(
391
+ JNIEnv *env,
392
+ jobject thiz,
393
+ jint job_id,
394
+ jshortArray pcm,
395
+ jint slice_index,
396
+ jint n_samples,
397
+ jint n
398
+ ) {
399
+ UNUSED(thiz);
400
+ rnwhisper::job* job = rnwhisper::job_get(job_id);
401
+ jshort *pcm_arr = env->GetShortArrayElements(pcm, nullptr);
402
+ job->put_pcm_data(pcm_arr, slice_index, n_samples, n);
403
+ env->ReleaseShortArrayElements(pcm, pcm_arr, JNI_ABORT);
404
+ }
405
+
406
+ JNIEXPORT jint JNICALL
407
+ Java_com_rnwhisper_WhisperContext_fullWithJob(
408
+ JNIEnv *env,
409
+ jobject thiz,
410
+ jint job_id,
411
+ jlong context_ptr,
412
+ jint slice_index,
413
+ jint n_samples
414
+ ) {
415
+ UNUSED(thiz);
416
+ struct whisper_context *context = reinterpret_cast<struct whisper_context *>(context_ptr);
417
+
418
+ rnwhisper::job* job = rnwhisper::job_get(job_id);
419
+ float* pcmf32 = job->pcm_slice_to_f32(slice_index, n_samples);
420
+ int code = whisper_full(context, job->params, pcmf32, n_samples);
421
+ free(pcmf32);
422
+ if (code == 0) {
423
+ // whisper_print_timings(context);
424
+ }
425
+ if (job->is_aborted()) code = -999;
348
426
  return code;
349
427
  }
350
428
 
@@ -355,7 +433,8 @@ Java_com_rnwhisper_WhisperContext_abortTranscribe(
355
433
  jint job_id
356
434
  ) {
357
435
  UNUSED(thiz);
358
- rn_whisper_abort_transcribe(job_id);
436
+ rnwhisper::job *job = rnwhisper::job_get(job_id);
437
+ if (job) job->abort();
359
438
  }
360
439
 
361
440
  JNIEXPORT void JNICALL
@@ -364,7 +443,7 @@ Java_com_rnwhisper_WhisperContext_abortAllTranscribe(
364
443
  jobject thiz
365
444
  ) {
366
445
  UNUSED(thiz);
367
- rn_whisper_abort_all_transcribe();
446
+ rnwhisper::job_abort_all();
368
447
  }
369
448
 
370
449
  JNIEXPORT jint JNICALL
package/cpp/README.md CHANGED
@@ -1,4 +1,4 @@
1
1
  # Note
2
2
 
3
- - Only `rn-whisper.h` / `rn-whisper.cpp` are the specific files for this project, others are sync from [whisper.cpp](https://github.com/ggerganov/whisper.cpp).
3
+ - Only `rn-*` are the specific files for this project, others are sync from [whisper.cpp](https://github.com/ggerganov/whisper.cpp).
4
4
  - We can update the native source by using the [bootstrap](../scripts/bootstrap.sh) script.
@@ -123,7 +123,7 @@ API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((v
123
123
 
124
124
  /**
125
125
  Make a prediction using the convenience interface
126
- @param logmel_data as 1 × 80 × 3000 3-dimensional array of floats:
126
+ @param logmel_data as 1 × n_mel × 3000 3-dimensional array of floats:
127
127
  @param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
128
128
  @return the prediction as whisper_encoder_implOutput
129
129
  */
@@ -3,6 +3,8 @@
3
3
  // Code is derived from the work of Github user @wangchou
4
4
  // ref: https://github.com/wangchou/callCoreMLFromCpp
5
5
 
6
+ #include <stdint.h>
7
+
6
8
  #if __cplusplus
7
9
  extern "C" {
8
10
  #endif
@@ -14,6 +16,8 @@ void whisper_coreml_free(struct whisper_coreml_context * ctx);
14
16
 
15
17
  void whisper_coreml_encode(
16
18
  const whisper_coreml_context * ctx,
19
+ int64_t n_ctx,
20
+ int64_t n_mel,
17
21
  float * mel,
18
22
  float * out);
19
23
 
@@ -48,13 +48,15 @@ void whisper_coreml_free(struct whisper_coreml_context * ctx) {
48
48
 
49
49
  void whisper_coreml_encode(
50
50
  const whisper_coreml_context * ctx,
51
+ int64_t n_ctx,
52
+ int64_t n_mel,
51
53
  float * mel,
52
54
  float * out) {
53
55
  MLMultiArray * inMultiArray = [
54
56
  [MLMultiArray alloc] initWithDataPointer: mel
55
- shape: @[@1, @80, @3000]
57
+ shape: @[@1, @(n_mel), @(n_ctx)]
56
58
  dataType: MLMultiArrayDataTypeFloat32
57
- strides: @[@(240000), @(3000), @1]
59
+ strides: @[@(n_ctx*n_mel), @(n_ctx), @1]
58
60
  deallocator: nil
59
61
  error: nil
60
62
  ];