whisper.rn 0.4.0-rc.7 → 0.4.0-rc.9

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. package/android/src/main/CMakeLists.txt +2 -1
  2. package/android/src/main/java/com/rnwhisper/AudioUtils.java +27 -12
  3. package/android/src/main/java/com/rnwhisper/RNWhisper.java +75 -34
  4. package/android/src/main/java/com/rnwhisper/WhisperContext.java +20 -3
  5. package/android/src/main/jni.cpp +29 -1
  6. package/android/src/newarch/java/com/rnwhisper/RNWhisperModule.java +10 -0
  7. package/android/src/oldarch/java/com/rnwhisper/RNWhisperModule.java +10 -0
  8. package/cpp/coreml/whisper-encoder.mm +1 -1
  9. package/cpp/ggml-aarch64.c +3209 -0
  10. package/cpp/ggml-aarch64.h +39 -0
  11. package/cpp/ggml-alloc.c +732 -494
  12. package/cpp/ggml-alloc.h +47 -63
  13. package/cpp/ggml-backend-impl.h +162 -47
  14. package/cpp/ggml-backend.cpp +2635 -0
  15. package/cpp/ggml-backend.h +216 -71
  16. package/cpp/ggml-common.h +1853 -0
  17. package/cpp/ggml-cpu-impl.h +614 -0
  18. package/cpp/ggml-impl.h +144 -178
  19. package/cpp/ggml-metal.h +14 -60
  20. package/cpp/ggml-metal.m +3437 -2097
  21. package/cpp/ggml-quants.c +12559 -4189
  22. package/cpp/ggml-quants.h +135 -212
  23. package/cpp/ggml-whisper.metallib +0 -0
  24. package/cpp/ggml.c +9029 -5219
  25. package/cpp/ggml.h +673 -338
  26. package/cpp/rn-whisper.cpp +91 -0
  27. package/cpp/rn-whisper.h +2 -0
  28. package/cpp/whisper.cpp +1476 -675
  29. package/cpp/whisper.h +84 -28
  30. package/ios/RNWhisper.mm +124 -37
  31. package/ios/RNWhisperAudioUtils.h +1 -0
  32. package/ios/RNWhisperAudioUtils.m +20 -13
  33. package/ios/RNWhisperContext.h +3 -2
  34. package/ios/RNWhisperContext.mm +41 -8
  35. package/jest/mock.js +9 -1
  36. package/lib/commonjs/NativeRNWhisper.js.map +1 -1
  37. package/lib/commonjs/index.js +48 -19
  38. package/lib/commonjs/index.js.map +1 -1
  39. package/lib/commonjs/version.json +1 -1
  40. package/lib/module/NativeRNWhisper.js.map +1 -1
  41. package/lib/module/index.js +48 -19
  42. package/lib/module/index.js.map +1 -1
  43. package/lib/module/version.json +1 -1
  44. package/lib/typescript/NativeRNWhisper.d.ts +6 -3
  45. package/lib/typescript/NativeRNWhisper.d.ts.map +1 -1
  46. package/lib/typescript/index.d.ts +25 -3
  47. package/lib/typescript/index.d.ts.map +1 -1
  48. package/package.json +6 -5
  49. package/src/NativeRNWhisper.ts +12 -3
  50. package/src/index.ts +63 -24
  51. package/src/version.json +1 -1
  52. package/whisper-rn.podspec +9 -2
  53. package/cpp/ggml-backend.c +0 -1357
  54. package/cpp/ggml-metal-whisper.metal +0 -4908
package/cpp/whisper.h CHANGED
@@ -84,9 +84,48 @@ extern "C" {
84
84
  typedef int32_t whisper_token;
85
85
  typedef int32_t whisper_seq_id;
86
86
 
87
+ enum whisper_alignment_heads_preset {
88
+ WHISPER_AHEADS_NONE,
89
+ WHISPER_AHEADS_N_TOP_MOST, // All heads from the N-top-most text-layers
90
+ WHISPER_AHEADS_CUSTOM,
91
+ WHISPER_AHEADS_TINY_EN,
92
+ WHISPER_AHEADS_TINY,
93
+ WHISPER_AHEADS_BASE_EN,
94
+ WHISPER_AHEADS_BASE,
95
+ WHISPER_AHEADS_SMALL_EN,
96
+ WHISPER_AHEADS_SMALL,
97
+ WHISPER_AHEADS_MEDIUM_EN,
98
+ WHISPER_AHEADS_MEDIUM,
99
+ WHISPER_AHEADS_LARGE_V1,
100
+ WHISPER_AHEADS_LARGE_V2,
101
+ WHISPER_AHEADS_LARGE_V3,
102
+ WHISPER_AHEADS_LARGE_V3_TURBO,
103
+ };
104
+
105
+ typedef struct whisper_ahead {
106
+ int n_text_layer;
107
+ int n_head;
108
+ } whisper_ahead;
109
+
110
+ typedef struct whisper_aheads {
111
+ size_t n_heads;
112
+ const whisper_ahead * heads;
113
+ } whisper_aheads;
114
+
87
115
  struct whisper_context_params {
88
116
  bool use_gpu;
89
117
  bool use_coreml;
118
+ bool flash_attn;
119
+ int gpu_device; // CUDA device
120
+
121
+ // [EXPERIMENTAL] Token-level timestamps with DTW
122
+ bool dtw_token_timestamps;
123
+ enum whisper_alignment_heads_preset dtw_aheads_preset;
124
+
125
+ int dtw_n_top;
126
+ struct whisper_aheads dtw_aheads;
127
+
128
+ size_t dtw_mem_size; // TODO: remove
90
129
  };
91
130
 
92
131
  typedef struct whisper_token_data {
@@ -103,6 +142,11 @@ extern "C" {
103
142
  int64_t t0; // start time of the token
104
143
  int64_t t1; // end time of the token
105
144
 
145
+ // [EXPERIMENTAL] Token-level timestamps with DTW
146
+ // do not use if you haven't computed token-level timestamps with dtw
147
+ // Roughly corresponds to the moment in audio in which the token was output
148
+ int64_t t_dtw;
149
+
106
150
  float vlen; // voice length of the token
107
151
  } whisper_token_data;
108
152
 
@@ -196,6 +240,13 @@ extern "C" {
196
240
  // GPU, by caching compiled 'blobs' there.
197
241
  // Set to nullptr if not used.
198
242
  // Returns 0 on success. If OpenVINO is not enabled in build, this simply returns 1.
243
+ WHISPER_API int whisper_ctx_init_openvino_encoder_with_state(
244
+ struct whisper_context * ctx,
245
+ struct whisper_state * state,
246
+ const char * model_path,
247
+ const char * device,
248
+ const char * cache_dir);
249
+
199
250
  WHISPER_API int whisper_ctx_init_openvino_encoder(
200
251
  struct whisper_context * ctx,
201
252
  const char * model_path,
@@ -224,22 +275,6 @@ extern "C" {
224
275
  int n_samples,
225
276
  int n_threads);
226
277
 
227
- // Convert RAW PCM audio to log mel spectrogram but applies a Phase Vocoder to speed up the audio x2.
228
- // The resulting spectrogram is stored inside the default state of the provided whisper context.
229
- // Returns 0 on success
230
- WHISPER_API int whisper_pcm_to_mel_phase_vocoder(
231
- struct whisper_context * ctx,
232
- const float * samples,
233
- int n_samples,
234
- int n_threads);
235
-
236
- WHISPER_API int whisper_pcm_to_mel_phase_vocoder_with_state(
237
- struct whisper_context * ctx,
238
- struct whisper_state * state,
239
- const float * samples,
240
- int n_samples,
241
- int n_threads);
242
-
243
278
  // This can be used to set a custom log mel spectrogram inside the default state of the provided whisper context.
244
279
  // Use this instead of whisper_pcm_to_mel() if you want to provide your own log mel spectrogram.
245
280
  // n_mel must be 80
@@ -296,7 +331,7 @@ extern "C" {
296
331
  // Convert the provided text into tokens.
297
332
  // The tokens pointer must be large enough to hold the resulting tokens.
298
333
  // Returns the number of tokens on success, no more than n_max_tokens
299
- // Returns -1 on failure
334
+ // Returns a negative number on failure - the number of tokens that would have been returned
300
335
  // TODO: not sure if correct
301
336
  WHISPER_API int whisper_tokenize(
302
337
  struct whisper_context * ctx,
@@ -304,8 +339,12 @@ extern "C" {
304
339
  whisper_token * tokens,
305
340
  int n_max_tokens);
306
341
 
342
+ // Return the number of tokens in the provided text
343
+ // Equivalent to: -whisper_tokenize(ctx, text, NULL, 0)
344
+ int whisper_token_count(struct whisper_context * ctx, const char * text);
345
+
307
346
  // Largest language id (i.e. number of available languages - 1)
308
- WHISPER_API int whisper_lang_max_id();
347
+ WHISPER_API int whisper_lang_max_id(void);
309
348
 
310
349
  // Return the id of the specified language, returns -1 if not found
311
350
  // Examples:
@@ -385,6 +424,24 @@ extern "C" {
385
424
  WHISPER_API whisper_token whisper_token_transcribe(struct whisper_context * ctx);
386
425
 
387
426
  // Performance information from the default state.
427
+ struct whisper_timings {
428
+ int64_t load_us;
429
+ int64_t t_start_us;
430
+ int32_t fail_p;
431
+ int32_t fail_h;
432
+ int64_t t_mel_us;
433
+ int32_t n_sample;
434
+ int32_t n_encode;
435
+ int32_t n_decode;
436
+ int32_t n_batchd;
437
+ int32_t n_prompt;
438
+ int64_t t_sample_us;
439
+ int64_t t_encode_us;
440
+ int64_t t_decode_us;
441
+ int64_t t_batchd_us;
442
+ int64_t t_prompt_us;
443
+ };
444
+ WHISPER_API struct whisper_timings * whisper_get_timings(struct whisper_context * ctx);
388
445
  WHISPER_API void whisper_print_timings(struct whisper_context * ctx);
389
446
  WHISPER_API void whisper_reset_timings(struct whisper_context * ctx);
390
447
 
@@ -412,11 +469,6 @@ extern "C" {
412
469
  // If it returns false, the computation is aborted
413
470
  typedef bool (*whisper_encoder_begin_callback)(struct whisper_context * ctx, struct whisper_state * state, void * user_data);
414
471
 
415
- // Abort callback
416
- // If not NULL, called before ggml computation
417
- // If it returns true, the computation is aborted
418
- typedef bool (*whisper_abort_callback)(void * user_data);
419
-
420
472
  // Logits filter callback
421
473
  // Can be used to modify the logits before sampling
422
474
  // If not NULL, called after applying temperature to logits
@@ -458,15 +510,19 @@ extern "C" {
458
510
 
459
511
  // [EXPERIMENTAL] speed-up techniques
460
512
  // note: these can significantly reduce the quality of the output
461
- bool speed_up; // speed-up the audio by 2x using Phase Vocoder
462
513
  bool debug_mode; // enable debug_mode provides extra info (eg. Dump log_mel)
463
514
  int audio_ctx; // overwrite the audio context size (0 = use default)
464
515
 
465
516
  // [EXPERIMENTAL] [TDRZ] tinydiarize
466
517
  bool tdrz_enable; // enable tinydiarize speaker turn detection
467
518
 
519
+ // A regular expression that matches tokens to suppress
520
+ const char * suppress_regex;
521
+
468
522
  // tokens to provide to the whisper decoder as initial prompt
469
523
  // these are prepended to any existing text context from a previous call
524
+ // use whisper_tokenize() to convert text to tokens
525
+ // maximum of whisper_n_text_ctx()/2 tokens are used (typically 224)
470
526
  const char * initial_prompt;
471
527
  const whisper_token * prompt_tokens;
472
528
  int prompt_n_tokens;
@@ -513,7 +569,7 @@ extern "C" {
513
569
  void * encoder_begin_callback_user_data;
514
570
 
515
571
  // called each time before ggml computation starts
516
- whisper_abort_callback abort_callback;
572
+ wsp_ggml_abort_callback abort_callback;
517
573
  void * abort_callback_user_data;
518
574
 
519
575
  // called by each decoder to filter obtained logits
@@ -527,10 +583,10 @@ extern "C" {
527
583
  };
528
584
 
529
585
  // NOTE: this function allocates memory, and it is the responsibility of the caller to free the pointer - see whisper_free_context_params & whisper_free_params()
530
- WHISPER_API struct whisper_context_params * whisper_context_default_params_by_ref();
531
- WHISPER_API struct whisper_context_params whisper_context_default_params(void);
586
+ WHISPER_API struct whisper_context_params * whisper_context_default_params_by_ref(void);
587
+ WHISPER_API struct whisper_context_params whisper_context_default_params (void);
532
588
  WHISPER_API struct whisper_full_params * whisper_full_default_params_by_ref(enum whisper_sampling_strategy strategy);
533
- WHISPER_API struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy);
589
+ WHISPER_API struct whisper_full_params whisper_full_default_params (enum whisper_sampling_strategy strategy);
534
590
 
535
591
  // Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
536
592
  // Not thread safe for same context
package/ios/RNWhisper.mm CHANGED
@@ -50,6 +50,7 @@ RCT_REMAP_METHOD(initContext,
50
50
  BOOL isBundleAsset = [[modelOptions objectForKey:@"isBundleAsset"] boolValue];
51
51
  BOOL useGpu = [[modelOptions objectForKey:@"useGpu"] boolValue];
52
52
  BOOL useCoreMLIos = [[modelOptions objectForKey:@"useCoreMLIos"] boolValue];
53
+ BOOL useFlashAttn = [[modelOptions objectForKey:@"useFlashAttn"] boolValue];
53
54
 
54
55
  // For support debug assets in development mode
55
56
  BOOL downloadCoreMLAssets = [[modelOptions objectForKey:@"downloadCoreMLAssets"] boolValue];
@@ -79,6 +80,7 @@ RCT_REMAP_METHOD(initContext,
79
80
  contextId:contextId
80
81
  noCoreML:!useCoreMLIos
81
82
  noMetal:!useGpu
83
+ useFlashAttn:useFlashAttn
82
84
  ];
83
85
  if ([context getContext] == NULL) {
84
86
  reject(@"whisper_cpp_error", @"Failed to load the model", nil);
@@ -103,42 +105,17 @@ RCT_REMAP_METHOD(initContext,
103
105
  ];
104
106
  }
105
107
 
106
- RCT_REMAP_METHOD(transcribeFile,
107
- withContextId:(int)contextId
108
- withJobId:(int)jobId
109
- withWaveFile:(NSString *)waveFilePath
110
- withOptions:(NSDictionary *)options
111
- withResolver:(RCTPromiseResolveBlock)resolve
112
- withRejecter:(RCTPromiseRejectBlock)reject)
108
+ - (void)transcribeData:(RNWhisperContext *)context
109
+ withContextId:(int)contextId
110
+ withJobId:(int)jobId
111
+ withData:(float *)data
112
+ withDataCount:(int)count
113
+ withOptions:(NSDictionary *)options
114
+ withResolver:(RCTPromiseResolveBlock)resolve
115
+ withRejecter:(RCTPromiseRejectBlock)reject
113
116
  {
114
- RNWhisperContext *context = contexts[[NSNumber numberWithInt:contextId]];
115
-
116
- if (context == nil) {
117
- reject(@"whisper_error", @"Context not found", nil);
118
- return;
119
- }
120
- if ([context isCapturing]) {
121
- reject(@"whisper_error", @"The context is in realtime transcribe mode", nil);
122
- return;
123
- }
124
- if ([context isTranscribing]) {
125
- reject(@"whisper_error", @"Context is already transcribing", nil);
126
- return;
127
- }
128
-
129
- NSString *path = waveFilePath;
130
- if ([path hasPrefix:@"http://"] || [path hasPrefix:@"https://"]) {
131
- path = [RNWhisperDownloader downloadFile:path toFile:nil];
132
- }
133
-
134
- int count = 0;
135
- float *waveFile = [RNWhisperAudioUtils decodeWaveFile:path count:&count];
136
- if (waveFile == nil) {
137
- reject(@"whisper_error", @"Invalid file", nil);
138
- return;
139
- }
140
- [context transcribeFile:jobId
141
- audioData:waveFile
117
+ [context transcribeData:jobId
118
+ audioData:data
142
119
  audioDataCount:count
143
120
  options:options
144
121
  onProgress: ^(int progress) {
@@ -171,11 +148,9 @@ RCT_REMAP_METHOD(transcribeFile,
171
148
  }
172
149
  onEnd: ^(int code) {
173
150
  if (code != 0 && code != 999) {
174
- free(waveFile);
175
151
  reject(@"whisper_cpp_error", [NSString stringWithFormat:@"Failed to transcribe the file. Code: %d", code], nil);
176
152
  return;
177
153
  }
178
- free(waveFile);
179
154
  NSMutableDictionary *result = [context getTextSegments];
180
155
  result[@"isAborted"] = @([context isStoppedByAction]);
181
156
  resolve(result);
@@ -183,6 +158,99 @@ RCT_REMAP_METHOD(transcribeFile,
183
158
  ];
184
159
  }
185
160
 
161
+ RCT_REMAP_METHOD(transcribeFile,
162
+ withContextId:(int)contextId
163
+ withJobId:(int)jobId
164
+ withWaveFile:(NSString *)waveFilePathOrDataBase64
165
+ withOptions:(NSDictionary *)options
166
+ withResolver:(RCTPromiseResolveBlock)resolve
167
+ withRejecter:(RCTPromiseRejectBlock)reject)
168
+ {
169
+ RNWhisperContext *context = contexts[[NSNumber numberWithInt:contextId]];
170
+
171
+ if (context == nil) {
172
+ reject(@"whisper_error", @"Context not found", nil);
173
+ return;
174
+ }
175
+ if ([context isCapturing]) {
176
+ reject(@"whisper_error", @"The context is in realtime transcribe mode", nil);
177
+ return;
178
+ }
179
+ if ([context isTranscribing]) {
180
+ reject(@"whisper_error", @"Context is already transcribing", nil);
181
+ return;
182
+ }
183
+
184
+ float *data = nil;
185
+ int count = 0;
186
+ if ([waveFilePathOrDataBase64 hasPrefix:@"http://"] || [waveFilePathOrDataBase64 hasPrefix:@"https://"]) {
187
+ NSString *path = [RNWhisperDownloader downloadFile:waveFilePathOrDataBase64 toFile:nil];
188
+ data = [RNWhisperAudioUtils decodeWaveFile:path count:&count];
189
+ } else if ([waveFilePathOrDataBase64 hasPrefix:@"data:audio/wav;base64,"]) {
190
+ NSData *waveData = [[NSData alloc] initWithBase64EncodedString:[waveFilePathOrDataBase64 substringFromIndex:22] options:0];
191
+ data = [RNWhisperAudioUtils decodeWaveData:waveData count:&count cutHeader:YES];
192
+ } else {
193
+ data = [RNWhisperAudioUtils decodeWaveFile:waveFilePathOrDataBase64 count:&count];
194
+ }
195
+ if (data == nil) {
196
+ reject(@"whisper_error", @"Invalid file", nil);
197
+ return;
198
+ }
199
+
200
+ [self transcribeData:context
201
+ withContextId:contextId
202
+ withJobId:jobId
203
+ withData:data
204
+ withDataCount:count
205
+ withOptions:options
206
+ withResolver:resolve
207
+ withRejecter:reject
208
+ ];
209
+ }
210
+
211
+ RCT_REMAP_METHOD(transcribeData,
212
+ withContextId:(int)contextId
213
+ withJobId:(int)jobId
214
+ withData:(NSString *)dataBase64 // pcm data base64 encoded
215
+ withOptions:(NSDictionary *)options
216
+ withResolver:(RCTPromiseResolveBlock)resolve
217
+ withRejecter:(RCTPromiseRejectBlock)reject)
218
+ {
219
+ RNWhisperContext *context = contexts[[NSNumber numberWithInt:contextId]];
220
+
221
+ if (context == nil) {
222
+ reject(@"whisper_error", @"Context not found", nil);
223
+ return;
224
+ }
225
+ if ([context isCapturing]) {
226
+ reject(@"whisper_error", @"The context is in realtime transcribe mode", nil);
227
+ return;
228
+ }
229
+ if ([context isTranscribing]) {
230
+ reject(@"whisper_error", @"Context is already transcribing", nil);
231
+ return;
232
+ }
233
+
234
+ NSData *pcmData = [[NSData alloc] initWithBase64EncodedString:dataBase64 options:0];
235
+ int count = 0;
236
+ float *data = [RNWhisperAudioUtils decodeWaveData:pcmData count:&count cutHeader:NO];
237
+
238
+ if (data == nil) {
239
+ reject(@"whisper_error", @"Invalid data", nil);
240
+ return;
241
+ }
242
+
243
+ [self transcribeData:context
244
+ withContextId:contextId
245
+ withJobId:jobId
246
+ withData:data
247
+ withDataCount:count
248
+ withOptions:options
249
+ withResolver:resolve
250
+ withRejecter:reject
251
+ ];
252
+ }
253
+
186
254
  RCT_REMAP_METHOD(startRealtimeTranscribe,
187
255
  withContextId:(int)contextId
188
256
  withJobId:(int)jobId
@@ -244,6 +312,25 @@ RCT_REMAP_METHOD(abortTranscribe,
244
312
  resolve(nil);
245
313
  }
246
314
 
315
+ RCT_REMAP_METHOD(bench,
316
+ withContextId:(int)contextId
317
+ withMaxThreads:(int)maxThreads
318
+ withResolver:(RCTPromiseResolveBlock)resolve
319
+ withRejecter:(RCTPromiseRejectBlock)reject)
320
+ {
321
+ RNWhisperContext *context = contexts[[NSNumber numberWithInt:contextId]];
322
+ if (context == nil) {
323
+ reject(@"whisper_error", @"Context not found", nil);
324
+ return;
325
+ }
326
+ if ([context isTranscribing]) {
327
+ reject(@"whisper_error", @"The context is transcribing", nil);
328
+ return;
329
+ }
330
+ NSString *result = [context bench:maxThreads];
331
+ resolve(result);
332
+ }
333
+
247
334
  RCT_REMAP_METHOD(releaseContext,
248
335
  withContextId:(int)contextId
249
336
  withResolver:(RCTPromiseResolveBlock)resolve
@@ -2,6 +2,7 @@
2
2
 
3
3
  @interface RNWhisperAudioUtils : NSObject
4
4
 
5
+ + (float *)decodeWaveData:(NSData*)data count:(int *)count cutHeader:(BOOL)cutHeader;
5
6
  + (float *)decodeWaveFile:(NSString*)filePath count:(int *)count;
6
7
 
7
8
  @end
@@ -3,25 +3,32 @@
3
3
 
4
4
  @implementation RNWhisperAudioUtils
5
5
 
6
+ + (float *)decodeWaveData:(NSData*)data count:(int *)count cutHeader:(BOOL)cutHeader {
7
+ NSData *waveData = data;
8
+ if (cutHeader) {
9
+ // just cut 44 bytes from the beginning
10
+ waveData = [data subdataWithRange:NSMakeRange(44, [data length]-44)];
11
+ }
12
+ const short *shortArray = (const short *)[waveData bytes];
13
+ int shortCount = (int) ([waveData length] / sizeof(short));
14
+ float *floatArray = (float *) malloc(shortCount * sizeof(float));
15
+ for (NSInteger i = 0; i < shortCount; i++) {
16
+ float floatValue = ((float)shortArray[i]) / 32767.0;
17
+ floatValue = MAX(floatValue, -1.0);
18
+ floatValue = MIN(floatValue, 1.0);
19
+ floatArray[i] = floatValue;
20
+ }
21
+ *count = shortCount;
22
+ return floatArray;
23
+ }
24
+
6
25
  + (float *)decodeWaveFile:(NSString*)filePath count:(int *)count {
7
26
  NSURL *url = [NSURL fileURLWithPath:filePath];
8
27
  NSData *fileData = [NSData dataWithContentsOfURL:url];
9
28
  if (fileData == nil) {
10
29
  return nil;
11
30
  }
12
- NSMutableData *waveData = [[NSMutableData alloc] init];
13
- [waveData appendData:[fileData subdataWithRange:NSMakeRange(44, [fileData length]-44)]];
14
- const short *shortArray = (const short *)[waveData bytes];
15
- int shortCount = (int) ([waveData length] / sizeof(short));
16
- float *floatArray = (float *) malloc(shortCount * sizeof(float));
17
- for (NSInteger i = 0; i < shortCount; i++) {
18
- float floatValue = ((float)shortArray[i]) / 32767.0;
19
- floatValue = MAX(floatValue, -1.0);
20
- floatValue = MIN(floatValue, 1.0);
21
- floatArray[i] = floatValue;
22
- }
23
- *count = shortCount;
24
- return floatArray;
31
+ return [RNWhisperAudioUtils decodeWaveData:fileData count:count cutHeader:YES];
25
32
  }
26
33
 
27
34
  @end
@@ -42,7 +42,7 @@ typedef struct {
42
42
  bool isMetalEnabled;
43
43
  }
44
44
 
45
- + (instancetype)initWithModelPath:(NSString *)modelPath contextId:(int)contextId noCoreML:(BOOL)noCoreML noMetal:(BOOL)noMetal;
45
+ + (instancetype)initWithModelPath:(NSString *)modelPath contextId:(int)contextId noCoreML:(BOOL)noCoreML noMetal:(BOOL)noMetal useFlashAttn:(BOOL)useFlashAttn;
46
46
  - (bool)isMetalEnabled;
47
47
  - (NSString *)reasonNoMetal;
48
48
  - (struct whisper_context *)getContext;
@@ -50,7 +50,7 @@ typedef struct {
50
50
  - (OSStatus)transcribeRealtime:(int)jobId
51
51
  options:(NSDictionary *)options
52
52
  onTranscribe:(void (^)(int, NSString *, NSDictionary *))onTranscribe;
53
- - (void)transcribeFile:(int)jobId
53
+ - (void)transcribeData:(int)jobId
54
54
  audioData:(float *)audioData
55
55
  audioDataCount:(int)audioDataCount
56
56
  options:(NSDictionary *)options
@@ -63,6 +63,7 @@ typedef struct {
63
63
  - (bool)isTranscribing;
64
64
  - (bool)isStoppedByAction;
65
65
  - (NSMutableDictionary *)getTextSegments;
66
+ - (NSString *)bench:(int)maxThreads;
66
67
  - (void)invalidate;
67
68
 
68
69
  @end
@@ -10,12 +10,17 @@
10
10
  contextId:(int)contextId
11
11
  noCoreML:(BOOL)noCoreML
12
12
  noMetal:(BOOL)noMetal
13
+ useFlashAttn:(BOOL)useFlashAttn
13
14
  {
14
15
  RNWhisperContext *context = [[RNWhisperContext alloc] init];
15
16
  context->contextId = contextId;
16
17
  struct whisper_context_params cparams;
17
18
  NSString *reasonNoMetal = @"";
18
19
  cparams.use_gpu = !noMetal;
20
+ cparams.flash_attn = useFlashAttn;
21
+
22
+ // TODO: Figure out why it leads to re-init crash
23
+ cparams.dtw_token_timestamps = false;
19
24
 
20
25
  cparams.use_coreml = !noCoreML;
21
26
  #ifndef WHISPER_USE_COREML
@@ -116,6 +121,7 @@
116
121
  self->recordState.transcribeSliceIndex = 0;
117
122
  self->recordState.nSamplesTranscribing = 0;
118
123
 
124
+ self->recordState.sliceNSamples.clear();
119
125
  self->recordState.sliceNSamples.push_back(0);
120
126
 
121
127
  self->recordState.job = rnwhisper::job_new(jobId, [self createParams:options jobId:jobId]);
@@ -202,7 +208,7 @@ void AudioInputCallback(void * inUserData,
202
208
  state->sliceNSamples.push_back(0);
203
209
  }
204
210
 
205
- NSLog(@"[RNWhisper] Slice %d has %d samples", state->sliceIndex, nSamples);
211
+ NSLog(@"[RNWhisper] Slice %d has %d samples, put %d samples", state->sliceIndex, nSamples, n);
206
212
 
207
213
  state->job->put_pcm_data((short*) inBuffer->mAudioData, state->sliceIndex, nSamples, n);
208
214
 
@@ -352,9 +358,10 @@ void AudioInputCallback(void * inUserData,
352
358
  struct rnwhisper_segments_callback_data {
353
359
  void (^onNewSegments)(NSDictionary *);
354
360
  int total_n_new;
361
+ bool tdrzEnable;
355
362
  };
356
363
 
357
- - (void)transcribeFile:(int)jobId
364
+ - (void)transcribeData:(int)jobId
358
365
  audioData:(float *)audioData
359
366
  audioDataCount:(int)audioDataCount
360
367
  options:(NSDictionary *)options
@@ -385,12 +392,18 @@ struct rnwhisper_segments_callback_data {
385
392
  NSMutableArray *segments = [[NSMutableArray alloc] init];
386
393
  for (int i = data->total_n_new - n_new; i < data->total_n_new; i++) {
387
394
  const char * text_cur = whisper_full_get_segment_text(ctx, i);
388
- text = [text stringByAppendingString:[NSString stringWithUTF8String:text_cur]];
395
+ NSMutableString *mutable_ns_text = [NSMutableString stringWithUTF8String:text_cur];
396
+
397
+ if (data->tdrzEnable && whisper_full_get_segment_speaker_turn_next(ctx, i)) {
398
+ [mutable_ns_text appendString:@" [SPEAKER_TURN]"];
399
+ }
400
+
401
+ text = [text stringByAppendingString:mutable_ns_text];
389
402
 
390
403
  const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
391
404
  const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
392
405
  NSDictionary *segment = @{
393
- @"text": [NSString stringWithUTF8String:text_cur],
406
+ @"text": [NSString stringWithString:mutable_ns_text],
394
407
  @"t0": [NSNumber numberWithLongLong:t0],
395
408
  @"t1": [NSNumber numberWithLongLong:t1]
396
409
  };
@@ -408,7 +421,8 @@ struct rnwhisper_segments_callback_data {
408
421
  };
409
422
  struct rnwhisper_segments_callback_data user_data = {
410
423
  .onNewSegments = onNewSegments,
411
- .total_n_new = 0
424
+ .tdrzEnable = options[@"tdrzEnable"] && [options[@"tdrzEnable"] boolValue],
425
+ .total_n_new = 0,
412
426
  };
413
427
  params.new_segment_callback_user_data = &user_data;
414
428
  }
@@ -468,7 +482,6 @@ struct rnwhisper_segments_callback_data {
468
482
  params.print_progress = false;
469
483
  params.print_timestamps = false;
470
484
  params.print_special = false;
471
- params.speed_up = options[@"speedUp"] != nil ? [options[@"speedUp"] boolValue] : false;
472
485
  params.translate = options[@"translate"] != nil ? [options[@"translate"] boolValue] : false;
473
486
  params.language = options[@"language"] != nil ? strdup([options[@"language"] UTF8String]) : "auto";
474
487
  params.n_threads = n_threads > 0 ? n_threads : default_n_threads;
@@ -480,6 +493,7 @@ struct rnwhisper_segments_callback_data {
480
493
  params.max_len = [options[@"maxLen"] intValue];
481
494
  }
482
495
  params.token_timestamps = options[@"tokenTimestamps"] != nil ? [options[@"tokenTimestamps"] boolValue] : false;
496
+ params.tdrz_enable = options[@"tdrzEnable"] != nil ? [options[@"tdrzEnable"] boolValue] : false;
483
497
 
484
498
  if (options[@"bestOf"] != nil) {
485
499
  params.greedy.best_of = [options[@"bestOf"] intValue];
@@ -529,12 +543,21 @@ struct rnwhisper_segments_callback_data {
529
543
  NSMutableArray *segments = [[NSMutableArray alloc] init];
530
544
  for (int i = 0; i < n_segments; i++) {
531
545
  const char * text_cur = whisper_full_get_segment_text(self->ctx, i);
532
- text = [text stringByAppendingString:[NSString stringWithUTF8String:text_cur]];
546
+ NSMutableString *mutable_ns_text = [NSMutableString stringWithUTF8String:text_cur];
547
+
548
+ // Simplified condition
549
+ if (self->recordState.options[@"tdrzEnable"] &&
550
+ [self->recordState.options[@"tdrzEnable"] boolValue] &&
551
+ whisper_full_get_segment_speaker_turn_next(self->ctx, i)) {
552
+ [mutable_ns_text appendString:@" [SPEAKER_TURN]"];
553
+ }
554
+
555
+ text = [text stringByAppendingString:mutable_ns_text];
533
556
 
534
557
  const int64_t t0 = whisper_full_get_segment_t0(self->ctx, i);
535
558
  const int64_t t1 = whisper_full_get_segment_t1(self->ctx, i);
536
559
  NSDictionary *segment = @{
537
- @"text": [NSString stringWithUTF8String:text_cur],
560
+ @"text": [NSString stringWithString:mutable_ns_text],
538
561
  @"t0": [NSNumber numberWithLongLong:t0],
539
562
  @"t1": [NSNumber numberWithLongLong:t1]
540
563
  };
@@ -546,6 +569,16 @@ struct rnwhisper_segments_callback_data {
546
569
  return result;
547
570
  }
548
571
 
572
+ - (NSString *)bench:(int)maxThreads {
573
+ const int n_threads = maxThreads > 0 ? maxThreads : 0;
574
+
575
+ const int max_threads = (int) [[NSProcessInfo processInfo] processorCount];
576
+ // Use 2 threads by default on 4-core devices, 4 threads on more cores
577
+ const int default_n_threads = max_threads == 4 ? 2 : MIN(4, max_threads);
578
+ NSString *result = [NSString stringWithUTF8String:rnwhisper::bench(self->ctx, n_threads).c_str()];
579
+ return result;
580
+ }
581
+
549
582
  - (void)invalidate {
550
583
  [self stopCurrentTranscribe];
551
584
  whisper_free(self->ctx);
package/jest/mock.js CHANGED
@@ -45,11 +45,19 @@ if (!NativeModules.RNWhisper) {
45
45
  })
46
46
  })
47
47
  }),
48
+ bench: jest.fn(() => Promise.resolve({
49
+ config: 'NEON',
50
+ nThreads: 1,
51
+ encodeMs: 1,
52
+ decodeMs: 1,
53
+ batchMs: 1,
54
+ promptMs: 1,
55
+ })),
48
56
  releaseContext: jest.fn(() => Promise.resolve()),
49
57
  releaseAllContexts: jest.fn(() => Promise.resolve()),
50
58
 
51
59
  // iOS AudioSession utils
52
- getAudioSessionCurrentCategory: jest.fn(() => Promise.resolve({
60
+ getAudioSessionCurrentCategory: jest.fn(() => Promise.resolve({
53
61
  category: 'AVAudioSessionCategoryPlayAndRecord',
54
62
  options: [],
55
63
  })),
@@ -1 +1 @@
1
- {"version":3,"names":["_reactNative","require","_default","TurboModuleRegistry","get","exports","default"],"sourceRoot":"../../src","sources":["NativeRNWhisper.ts"],"mappings":";;;;;;AACA,IAAAA,YAAA,GAAAC,OAAA;AAAkD,IAAAC,QAAA,GAiGnCC,gCAAmB,CAACC,GAAG,CAAO,WAAW,CAAC;AAAAC,OAAA,CAAAC,OAAA,GAAAJ,QAAA"}
1
+ {"version":3,"names":["_reactNative","require","_default","TurboModuleRegistry","get","exports","default"],"sourceRoot":"../../src","sources":["NativeRNWhisper.ts"],"mappings":";;;;;;AACA,IAAAA,YAAA,GAAAC,OAAA;AAAkD,IAAAC,QAAA,GA0GnCC,gCAAmB,CAACC,GAAG,CAAO,WAAW,CAAC;AAAAC,OAAA,CAAAC,OAAA,GAAAJ,QAAA"}