whisper.rn 0.4.3 → 0.5.0-rc.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,681 @@
1
+ #include "RNWhisperJSI.h"
2
+ #include "ThreadPool.h"
3
+ #include <jsi/jsi.h>
4
+ #include <memory>
5
+ #include <mutex>
6
+ #include <thread>
7
+ #include <unordered_map>
8
+ #include <algorithm>
9
+ #include <vector>
10
+ #include <atomic>
11
+
12
+ #if defined(__ANDROID__)
13
+ #include <android/log.h>
14
+ #endif
15
+
16
+ using namespace facebook::jsi;
17
+
18
+ namespace rnwhisper_jsi {
19
+
20
+ // Consolidated logging function
21
+ enum class LogLevel { LOG_DEBUG, LOG_INFO, LOG_ERROR };
22
+
23
+ static void log(LogLevel level, const char* format, ...) {
24
+ va_list args;
25
+ va_start(args, format);
26
+
27
+ #if defined(__ANDROID__)
28
+ int androidLevel = (level == LogLevel::LOG_DEBUG) ? ANDROID_LOG_DEBUG :
29
+ (level == LogLevel::LOG_INFO) ? ANDROID_LOG_INFO : ANDROID_LOG_ERROR;
30
+ __android_log_vprint(androidLevel, "RNWhisperJSI", format, args);
31
+ #else
32
+ char buffer[1024];
33
+ vsnprintf(buffer, sizeof(buffer), format, args);
34
+ const char* levelStr = (level == LogLevel::LOG_DEBUG) ? "DEBUG" :
35
+ (level == LogLevel::LOG_INFO) ? "INFO" : "ERROR";
36
+ printf("RNWhisperJSI %s: %s\n", levelStr, buffer);
37
+ #endif
38
+
39
+ va_end(args);
40
+ }
41
+
42
+ #define logInfo(format, ...) log(LogLevel::LOG_INFO, format, ##__VA_ARGS__)
43
+ #define logError(format, ...) log(LogLevel::LOG_ERROR, format, ##__VA_ARGS__)
44
+ #define logDebug(format, ...) log(LogLevel::LOG_DEBUG, format, ##__VA_ARGS__)
45
+
46
+ static std::unique_ptr<ThreadPool> whisperThreadPool = nullptr;
47
+ static std::mutex threadPoolMutex;
48
+
49
+ // Initialize ThreadPool with optimal thread count
50
+ ThreadPool& getWhisperThreadPool() {
51
+ std::lock_guard<std::mutex> lock(threadPoolMutex);
52
+ if (!whisperThreadPool) {
53
+ int max_threads = std::thread::hardware_concurrency();
54
+ int thread_count = std::max(2, std::min(4, max_threads)); // Use 2-4 threads
55
+ whisperThreadPool = std::make_unique<ThreadPool>(thread_count);
56
+ logInfo("Initialized ThreadPool with %d threads", thread_count);
57
+ }
58
+ return *whisperThreadPool;
59
+ }
60
+
61
+ // Template-based context management
62
+ template<typename T>
63
+ class ContextManager {
64
+ private:
65
+ std::unordered_map<int, long> contextMap;
66
+ std::mutex contextMutex;
67
+ const char* contextType;
68
+
69
+ public:
70
+ ContextManager(const char* type) : contextType(type) {}
71
+
72
+ void add(int contextId, long contextPtr) {
73
+ std::lock_guard<std::mutex> lock(contextMutex);
74
+ contextMap[contextId] = contextPtr;
75
+ logDebug("%s Context added: id=%d, ptr=%ld", contextType, contextId, contextPtr);
76
+ }
77
+
78
+ void remove(int contextId) {
79
+ std::lock_guard<std::mutex> lock(contextMutex);
80
+ auto it = contextMap.find(contextId);
81
+ if (it != contextMap.end()) {
82
+ logDebug("%s Context removed: id=%d", contextType, contextId);
83
+ contextMap.erase(it);
84
+ }
85
+ }
86
+
87
+ long get(int contextId) {
88
+ std::lock_guard<std::mutex> lock(contextMutex);
89
+ auto it = contextMap.find(contextId);
90
+ return (it != contextMap.end()) ? it->second : 0;
91
+ }
92
+
93
+ T* getTyped(int contextId) {
94
+ long ptr = get(contextId);
95
+ return ptr ? reinterpret_cast<T*>(ptr) : nullptr;
96
+ }
97
+ };
98
+
99
+ static ContextManager<whisper_context> contextManager("Whisper");
100
+ static ContextManager<whisper_vad_context> vadContextManager("VAD");
101
+
102
+ // Context management functions
103
+ void addContext(int contextId, long contextPtr) {
104
+ contextManager.add(contextId, contextPtr);
105
+ }
106
+
107
+ void removeContext(int contextId) {
108
+ contextManager.remove(contextId);
109
+ }
110
+
111
+ void addVadContext(int contextId, long vadContextPtr) {
112
+ vadContextManager.add(contextId, vadContextPtr);
113
+ }
114
+
115
+ void removeVadContext(int contextId) {
116
+ vadContextManager.remove(contextId);
117
+ }
118
+
119
+ long getContextPtr(int contextId) {
120
+ return contextManager.get(contextId);
121
+ }
122
+
123
+ long getVadContextPtr(int contextId) {
124
+ return vadContextManager.get(contextId);
125
+ }
126
+
127
+ // Helper function to validate JSI function arguments
128
+ struct JSIValidationResult {
129
+ bool isValid;
130
+ std::string errorMessage;
131
+ };
132
+
133
+ JSIValidationResult validateJSIArguments(Runtime& runtime, const Value* arguments, size_t count, size_t expectedCount) {
134
+ if (count != expectedCount) {
135
+ return {false, "Expected " + std::to_string(expectedCount) + " arguments, got " + std::to_string(count)};
136
+ }
137
+
138
+ if (!arguments[0].isNumber()) {
139
+ return {false, "First argument (contextId) must be a number"};
140
+ }
141
+
142
+ if (!arguments[1].isObject()) {
143
+ return {false, "Second argument (options) must be an object"};
144
+ }
145
+
146
+ if (!arguments[2].isObject() || !arguments[2].getObject(runtime).isArrayBuffer(runtime)) {
147
+ return {false, "Third argument must be an ArrayBuffer"};
148
+ }
149
+
150
+ return {true, ""};
151
+ }
152
+
153
+ // Helper function to create error objects
154
+ Object createErrorObject(Runtime& runtime, const std::string& message, int code = -1) {
155
+ auto errorObj = Object(runtime);
156
+ errorObj.setProperty(runtime, "message", String::createFromUtf8(runtime, message));
157
+ if (code != -1) {
158
+ errorObj.setProperty(runtime, "code", Value(code));
159
+ }
160
+ return errorObj;
161
+ }
162
+
163
+ // Helper function to convert JSI object to whisper_full_params
164
+ whisper_full_params createFullParamsFromJSI(Runtime& runtime, const Object& optionsObj) {
165
+ whisper_full_params params = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
166
+
167
+ params.print_realtime = false;
168
+ params.print_progress = false;
169
+ params.print_timestamps = false;
170
+ params.print_special = false;
171
+
172
+ int max_threads = std::thread::hardware_concurrency();
173
+ int default_n_threads = max_threads == 4 ? 2 : std::min(4, max_threads);
174
+
175
+ try {
176
+ auto propNames = optionsObj.getPropertyNames(runtime);
177
+ for (size_t i = 0; i < propNames.size(runtime); i++) {
178
+ auto propNameValue = propNames.getValueAtIndex(runtime, i);
179
+ std::string propName = propNameValue.getString(runtime).utf8(runtime);
180
+ Value propValue = optionsObj.getProperty(runtime, propNameValue.getString(runtime));
181
+
182
+ if (propName == "maxThreads" && propValue.isNumber()) {
183
+ int n_threads = (int)propValue.getNumber();
184
+ params.n_threads = n_threads > 0 ? n_threads : default_n_threads;
185
+ } else if (propName == "translate" && propValue.isBool()) {
186
+ params.translate = propValue.getBool();
187
+ } else if (propName == "tokenTimestamps" && propValue.isBool()) {
188
+ params.token_timestamps = propValue.getBool();
189
+ } else if (propName == "tdrzEnable" && propValue.isBool()) {
190
+ params.tdrz_enable = propValue.getBool();
191
+ } else if (propName == "beamSize" && propValue.isNumber()) {
192
+ int beam_size = (int)propValue.getNumber();
193
+ if (beam_size > 0) {
194
+ params.strategy = WHISPER_SAMPLING_BEAM_SEARCH;
195
+ params.beam_search.beam_size = beam_size;
196
+ }
197
+ } else if (propName == "bestOf" && propValue.isNumber()) {
198
+ params.greedy.best_of = (int)propValue.getNumber();
199
+ } else if (propName == "maxLen" && propValue.isNumber()) {
200
+ params.max_len = (int)propValue.getNumber();
201
+ } else if (propName == "maxContext" && propValue.isNumber()) {
202
+ params.n_max_text_ctx = (int)propValue.getNumber();
203
+ } else if (propName == "offset" && propValue.isNumber()) {
204
+ params.offset_ms = (int)propValue.getNumber();
205
+ } else if (propName == "duration" && propValue.isNumber()) {
206
+ params.duration_ms = (int)propValue.getNumber();
207
+ } else if (propName == "wordThold" && propValue.isNumber()) {
208
+ params.thold_pt = (int)propValue.getNumber();
209
+ } else if (propName == "temperature" && propValue.isNumber()) {
210
+ params.temperature = (float)propValue.getNumber();
211
+ } else if (propName == "temperatureInc" && propValue.isNumber()) {
212
+ params.temperature_inc = (float)propValue.getNumber();
213
+ } else if (propName == "prompt" && propValue.isString()) {
214
+ std::string prompt = propValue.getString(runtime).utf8(runtime);
215
+ params.initial_prompt = strdup(prompt.c_str());
216
+ } else if (propName == "language" && propValue.isString()) {
217
+ std::string language = propValue.getString(runtime).utf8(runtime);
218
+ params.language = strdup(language.c_str());
219
+ }
220
+ }
221
+ } catch (...) {
222
+ // Use default values if parsing fails
223
+ }
224
+
225
+ params.offset_ms = 0;
226
+ params.no_context = true;
227
+ params.single_segment = false;
228
+
229
+ return params;
230
+ }
231
+
232
+ // Helper function to convert ArrayBuffer to float32 audio data
233
+ struct AudioData {
234
+ std::vector<float> data;
235
+ int count;
236
+ };
237
+
238
+ AudioData convertArrayBufferToAudioData(Runtime& runtime, size_t arrayBufferSize, uint8_t* arrayBufferData) {
239
+ // Convert ArrayBuffer to float32 array (assuming 16-bit PCM input)
240
+ if (arrayBufferSize % 2 != 0) {
241
+ throw JSError(runtime, "ArrayBuffer size must be even for 16-bit PCM data");
242
+ }
243
+
244
+ int audioDataCount = (int)(arrayBufferSize / 2); // 16-bit samples
245
+ std::vector<float> audioData(audioDataCount);
246
+
247
+ // Convert 16-bit PCM to float32
248
+ int16_t* pcmData = (int16_t*)arrayBufferData;
249
+ for (int i = 0; i < audioDataCount; i++) {
250
+ audioData[i] = (float)pcmData[i] / 32768.0f;
251
+ }
252
+
253
+ return {std::move(audioData), audioDataCount};
254
+ }
255
+
256
+ // Common callback data structure
257
+ template<typename CallbackType>
258
+ struct CallbackData {
259
+ std::shared_ptr<facebook::react::CallInvoker> callInvoker;
260
+ std::shared_ptr<Function> callback;
261
+ std::shared_ptr<Runtime> safeRuntime;
262
+ std::atomic<int> counter{0};
263
+ };
264
+
265
+ // Helper function to extract callbacks from options
266
+ struct CallbackInfo {
267
+ std::shared_ptr<Function> onProgressCallback;
268
+ std::shared_ptr<Function> onNewSegmentsCallback;
269
+ int jobId;
270
+ };
271
+
272
+ CallbackInfo extractCallbacks(Runtime& runtime, const Object& optionsObj) {
273
+ CallbackInfo info;
274
+ info.jobId = rand(); // Default fallback jobId
275
+
276
+ try {
277
+ auto propNames = optionsObj.getPropertyNames(runtime);
278
+ for (size_t i = 0; i < propNames.size(runtime); i++) {
279
+ auto propNameValue = propNames.getValueAtIndex(runtime, i);
280
+ std::string propName = propNameValue.getString(runtime).utf8(runtime);
281
+ Value propValue = optionsObj.getProperty(runtime, propNameValue.getString(runtime));
282
+
283
+ if (propName == "onProgress" && propValue.isObject() && propValue.getObject(runtime).isFunction(runtime)) {
284
+ info.onProgressCallback = std::make_shared<Function>(propValue.getObject(runtime).getFunction(runtime));
285
+ } else if (propName == "onNewSegments" && propValue.isObject() && propValue.getObject(runtime).isFunction(runtime)) {
286
+ info.onNewSegmentsCallback = std::make_shared<Function>(propValue.getObject(runtime).getFunction(runtime));
287
+ } else if (propName == "jobId" && propValue.isNumber()) {
288
+ info.jobId = (int)propValue.getNumber();
289
+ }
290
+ }
291
+ } catch (...) {
292
+ // Ignore callback detection errors
293
+ }
294
+
295
+ return info;
296
+ }
297
+
298
+ // Helper function to create segments array
299
+ Array createSegmentsArray(Runtime& runtime, struct whisper_context* ctx, int offset) {
300
+ int n_segments = whisper_full_n_segments(ctx);
301
+ auto segmentsArray = Array(runtime, n_segments);
302
+
303
+ for (int i = offset; i < n_segments; i++) {
304
+ const char* text = whisper_full_get_segment_text(ctx, i);
305
+ auto segmentObj = Object(runtime);
306
+ segmentObj.setProperty(runtime, "text", String::createFromUtf8(runtime, text));
307
+ segmentObj.setProperty(runtime, "t0", Value((double)whisper_full_get_segment_t0(ctx, i)));
308
+ segmentObj.setProperty(runtime, "t1", Value((double)whisper_full_get_segment_t1(ctx, i)));
309
+ segmentsArray.setValueAtIndex(runtime, i, segmentObj);
310
+ }
311
+
312
+ return segmentsArray;
313
+ }
314
+
315
+ // Helper function to create full text from segments
316
+ std::string createFullTextFromSegments(struct whisper_context* ctx, int offset) {
317
+ int n_segments = whisper_full_n_segments(ctx);
318
+ std::string fullText = "";
319
+
320
+ for (int i = offset; i < n_segments; i++) {
321
+ const char* text = whisper_full_get_segment_text(ctx, i);
322
+ fullText += text;
323
+ }
324
+
325
+ return fullText;
326
+ }
327
+
328
+ // Helper function to create and execute promise-based operations
329
+ template<typename ContextType, typename TaskFunc>
330
+ Value createPromiseTask(
331
+ Runtime& runtime,
332
+ const std::string& functionName,
333
+ std::shared_ptr<facebook::react::CallInvoker> callInvoker,
334
+ const Value* arguments,
335
+ size_t count,
336
+ TaskFunc task
337
+ ) {
338
+ // Validate arguments
339
+ auto validation = validateJSIArguments(runtime, arguments, count, 3);
340
+ if (!validation.isValid) {
341
+ throw JSError(runtime, functionName + " " + validation.errorMessage);
342
+ }
343
+
344
+ int contextId = (int)arguments[0].getNumber();
345
+ auto optionsObj = arguments[1].getObject(runtime);
346
+ auto arrayBuffer = arguments[2].getObject(runtime).getArrayBuffer(runtime);
347
+
348
+ size_t arrayBufferSize = arrayBuffer.size(runtime);
349
+ uint8_t* arrayBufferData = arrayBuffer.data(runtime);
350
+
351
+ logInfo("%s called with contextId=%d, arrayBuffer size=%zu", functionName.c_str(), contextId, arrayBufferSize);
352
+
353
+ // Convert ArrayBuffer to audio data
354
+ AudioData audioResult = convertArrayBufferToAudioData(runtime, arrayBufferSize, arrayBufferData);
355
+
356
+ whisper_full_params params = {};
357
+ CallbackInfo callbackInfo = {};
358
+ if (functionName == "whisperTranscribeData") {
359
+ params = createFullParamsFromJSI(runtime, optionsObj);
360
+ // Extract data from optionsObj before lambda capture
361
+ callbackInfo = extractCallbacks(runtime, optionsObj);
362
+ }
363
+
364
+ // Create promise
365
+ auto promiseConstructor = runtime.global().getPropertyAsFunction(runtime, "Promise");
366
+
367
+ auto promiseExecutor = Function::createFromHostFunction(
368
+ runtime,
369
+ PropNameID::forAscii(runtime, ""),
370
+ 2, // resolve, reject
371
+ [contextId, audioResult, params, callbackInfo, task, callInvoker, functionName](Runtime& runtime, const Value& thisValue, const Value* arguments, size_t count) -> Value {
372
+ if (count != 2) {
373
+ throw JSError(runtime, "Promise executor expects 2 arguments (resolve, reject)");
374
+ }
375
+
376
+ auto resolvePtr = std::make_shared<Function>(arguments[0].getObject(runtime).getFunction(runtime));
377
+ auto rejectPtr = std::make_shared<Function>(arguments[1].getObject(runtime).getFunction(runtime));
378
+ auto safeRuntime = std::shared_ptr<Runtime>(&runtime, [](Runtime*){});
379
+
380
+ // Execute task in ThreadPool
381
+ auto future = getWhisperThreadPool().enqueue([
382
+ contextId, audioResult, params, callbackInfo, task, resolvePtr, rejectPtr, callInvoker, safeRuntime, functionName]() {
383
+
384
+ try {
385
+ task(contextId, audioResult, params, callbackInfo, resolvePtr, rejectPtr, callInvoker, safeRuntime);
386
+ } catch (...) {
387
+ callInvoker->invokeAsync([rejectPtr, safeRuntime, functionName]() {
388
+ auto& runtime = *safeRuntime;
389
+ auto errorObj = createErrorObject(runtime, functionName + " processing error");
390
+ rejectPtr->call(runtime, errorObj);
391
+ });
392
+ }
393
+ });
394
+
395
+ return Value::undefined();
396
+ }
397
+ );
398
+
399
+ return promiseConstructor.callAsConstructor(runtime, promiseExecutor);
400
+ }
401
+
402
+ void installJSIBindings(
403
+ facebook::jsi::Runtime& runtime,
404
+ std::shared_ptr<facebook::react::CallInvoker> callInvoker
405
+ ) {
406
+ try {
407
+ // whisperTranscribeData function
408
+ auto whisperTranscribeData = Function::createFromHostFunction(
409
+ runtime,
410
+ PropNameID::forAscii(runtime, "whisperTranscribeData"),
411
+ 3, // number of arguments
412
+ [callInvoker](Runtime& runtime, const Value& thisValue, const Value* arguments, size_t count) -> Value {
413
+ try {
414
+ return createPromiseTask<whisper_context>(
415
+ runtime, "whisperTranscribeData", callInvoker, arguments, count,
416
+ [](int contextId, const AudioData& audioResult, const whisper_full_params& params, const CallbackInfo& callbackInfo,
417
+ std::shared_ptr<Function> resolvePtr, std::shared_ptr<Function> rejectPtr,
418
+ std::shared_ptr<facebook::react::CallInvoker> callInvoker,
419
+ std::shared_ptr<Runtime> safeRuntime) {
420
+
421
+ // Get context
422
+ auto context = contextManager.getTyped(contextId);
423
+ if (!context) {
424
+ callInvoker->invokeAsync([rejectPtr, safeRuntime, contextId]() {
425
+ auto& runtime = *safeRuntime;
426
+ auto errorObj = createErrorObject(runtime, "Context not found for id: " + std::to_string(contextId));
427
+ rejectPtr->call(runtime, errorObj);
428
+ });
429
+ return;
430
+ }
431
+
432
+ // Validate audio data
433
+ if (audioResult.data.empty() || audioResult.count <= 0) {
434
+ logError("Invalid audio data: size=%zu, count=%d", audioResult.data.size(), audioResult.count);
435
+ callInvoker->invokeAsync([rejectPtr, safeRuntime]() {
436
+ auto& runtime = *safeRuntime;
437
+ auto errorObj = createErrorObject(runtime, "Invalid audio data");
438
+ rejectPtr->call(runtime, errorObj);
439
+ });
440
+ return;
441
+ }
442
+
443
+ logInfo("Starting whisper_full: context=%p, audioDataCount=%d, jobId=%d",
444
+ context, audioResult.count, callbackInfo.jobId);
445
+ whisper_reset_timings(context);
446
+
447
+ // Setup callbacks
448
+ whisper_full_params mutable_params = params;
449
+ auto progress_data = std::make_shared<CallbackData<Function>>();
450
+ progress_data->callInvoker = callInvoker;
451
+ progress_data->callback = callbackInfo.onProgressCallback;
452
+ progress_data->safeRuntime = safeRuntime;
453
+
454
+ if (callbackInfo.onProgressCallback) {
455
+ mutable_params.progress_callback = [](struct whisper_context* /*ctx*/, struct whisper_state* /*state*/, int progress, void* user_data) {
456
+ auto* data_ptr = static_cast<std::shared_ptr<CallbackData<Function>>*>(user_data);
457
+ if (data_ptr && *data_ptr) {
458
+ auto data = *data_ptr;
459
+ if (data->callInvoker && data->callback && data->safeRuntime) {
460
+ data->callInvoker->invokeAsync([progress, callback = data->callback, safeRuntime = data->safeRuntime]() {
461
+ try {
462
+ logInfo("Progress: %d%%", progress);
463
+ auto& runtime = *safeRuntime;
464
+ callback->call(runtime, Value(progress));
465
+ } catch (...) {
466
+ logError("Error in progress callback");
467
+ }
468
+ });
469
+ }
470
+ }
471
+ };
472
+ mutable_params.progress_callback_user_data = &progress_data;
473
+ }
474
+
475
+ auto segments_data = std::make_shared<CallbackData<Function>>();
476
+ segments_data->callInvoker = callInvoker;
477
+ segments_data->callback = callbackInfo.onNewSegmentsCallback;
478
+ segments_data->safeRuntime = safeRuntime;
479
+
480
+ if (callbackInfo.onNewSegmentsCallback) {
481
+ mutable_params.new_segment_callback = [](struct whisper_context* ctx, struct whisper_state* /*state*/, int n_new, void* user_data) {
482
+ auto* data_ptr = static_cast<std::shared_ptr<CallbackData<Function>>*>(user_data);
483
+ if (data_ptr && *data_ptr) {
484
+ auto data = *data_ptr;
485
+ if (data->callInvoker && data->callback && data->safeRuntime && ctx) {
486
+ int current_total = data->counter.fetch_add(n_new) + n_new;
487
+ data->callInvoker->invokeAsync([ctx, n_new, current_total, callback = data->callback, safeRuntime = data->safeRuntime]() {
488
+ try {
489
+ logInfo("New segments: %d (total: %d)", n_new, current_total);
490
+ auto& runtime = *safeRuntime;
491
+ auto resultObj = Object(runtime);
492
+ resultObj.setProperty(runtime, "nNew", Value(n_new));
493
+ resultObj.setProperty(runtime, "totalNNew", Value(current_total));
494
+ auto offset = current_total - n_new;
495
+ resultObj.setProperty(runtime, "segments", createSegmentsArray(runtime, ctx, offset));
496
+ resultObj.setProperty(runtime, "result", String::createFromUtf8(runtime, createFullTextFromSegments(ctx, offset)));
497
+ callback->call(runtime, resultObj);
498
+ } catch (...) {
499
+ logError("Error in new segments callback");
500
+ }
501
+ });
502
+ }
503
+ }
504
+ };
505
+ mutable_params.new_segment_callback_user_data = &segments_data;
506
+ }
507
+
508
+ // Execute transcription
509
+ rnwhisper::job* job = rnwhisper::job_new(callbackInfo.jobId, mutable_params);
510
+ int code = -1;
511
+
512
+ if (job == nullptr) {
513
+ logError("Failed to create job for transcription");
514
+ code = -2;
515
+ } else {
516
+ code = whisper_full(context, job->params, audioResult.data.data(), audioResult.count);
517
+ if (job->is_aborted()) {
518
+ code = -999;
519
+ }
520
+ rnwhisper::job_remove(callbackInfo.jobId);
521
+ }
522
+
523
+ // Resolve with results
524
+ callInvoker->invokeAsync([resolvePtr, rejectPtr, code, context, safeRuntime]() {
525
+ try {
526
+ auto& runtime = *safeRuntime;
527
+ if (code == 0) {
528
+ auto resultObj = Object(runtime);
529
+ resultObj.setProperty(runtime, "code", Value(code));
530
+ resultObj.setProperty(runtime, "result", String::createFromUtf8(runtime, createFullTextFromSegments(context, 0)));
531
+ resultObj.setProperty(runtime, "segments", createSegmentsArray(runtime, context, 0));
532
+ resolvePtr->call(runtime, resultObj);
533
+ } else {
534
+ std::string errorMsg = (code == -2) ? "Failed to create transcription job" :
535
+ (code == -999) ? "Transcription was aborted" :
536
+ "Transcription failed";
537
+ auto errorObj = createErrorObject(runtime, errorMsg, code);
538
+ rejectPtr->call(runtime, errorObj);
539
+ }
540
+ } catch (...) {
541
+ auto& runtime = *safeRuntime;
542
+ auto errorObj = createErrorObject(runtime, "Unknown error");
543
+ rejectPtr->call(runtime, errorObj);
544
+ }
545
+ });
546
+ }
547
+ );
548
+ } catch (const JSError& e) {
549
+ throw;
550
+ } catch (const std::exception& e) {
551
+ logError("Exception in whisperTranscribeData: %s", e.what());
552
+ throw JSError(runtime, std::string("whisperTranscribeData error: ") + e.what());
553
+ } catch (...) {
554
+ logError("Unknown exception in whisperTranscribeData");
555
+ throw JSError(runtime, "whisperTranscribeData encountered unknown error");
556
+ }
557
+ }
558
+ );
559
+
560
+ // whisperVadDetectSpeech function
561
+ auto whisperVadDetectSpeech = Function::createFromHostFunction(
562
+ runtime,
563
+ PropNameID::forAscii(runtime, "whisperVadDetectSpeech"),
564
+ 3, // number of arguments
565
+ [callInvoker](Runtime& runtime, const Value& thisValue, const Value* arguments, size_t count) -> Value {
566
+ try {
567
+ return createPromiseTask<whisper_vad_context>(
568
+ runtime, "whisperVadDetectSpeech", callInvoker, arguments, count,
569
+ [](int contextId, const AudioData& audioResult, const whisper_full_params& params, const CallbackInfo& callbackInfo,
570
+ std::shared_ptr<Function> resolvePtr, std::shared_ptr<Function> rejectPtr,
571
+ std::shared_ptr<facebook::react::CallInvoker> callInvoker,
572
+ std::shared_ptr<Runtime> safeRuntime) {
573
+
574
+ // Get VAD context
575
+ auto vadContext = vadContextManager.getTyped(contextId);
576
+ if (!vadContext) {
577
+ callInvoker->invokeAsync([rejectPtr, safeRuntime, contextId]() {
578
+ auto& runtime = *safeRuntime;
579
+ auto errorObj = createErrorObject(runtime, "VAD Context not found for id: " + std::to_string(contextId));
580
+ rejectPtr->call(runtime, errorObj);
581
+ });
582
+ return;
583
+ }
584
+
585
+ // Validate audio data
586
+ if (audioResult.data.empty() || audioResult.count <= 0) {
587
+ logError("Invalid audio data: size=%zu, count=%d", audioResult.data.size(), audioResult.count);
588
+ callInvoker->invokeAsync([rejectPtr, safeRuntime]() {
589
+ auto& runtime = *safeRuntime;
590
+ auto errorObj = createErrorObject(runtime, "Invalid audio data");
591
+ rejectPtr->call(runtime, errorObj);
592
+ });
593
+ return;
594
+ }
595
+
596
+ logInfo("Starting whisper_vad_detect_speech: vadContext=%p, audioDataCount=%d",
597
+ vadContext, audioResult.count);
598
+
599
+ // Perform VAD detection
600
+ bool isSpeech = whisper_vad_detect_speech(vadContext, audioResult.data.data(), audioResult.count);
601
+ logInfo("VAD detection result: %s", isSpeech ? "speech" : "no speech");
602
+
603
+ struct whisper_vad_params vad_params = whisper_vad_default_params();
604
+ struct whisper_vad_segments* segments = nullptr;
605
+ if (isSpeech) {
606
+ segments = whisper_vad_segments_from_probs(vadContext, vad_params);
607
+ }
608
+
609
+ // Process results on JS thread
610
+ callInvoker->invokeAsync([resolvePtr, rejectPtr, segments, safeRuntime]() {
611
+ try {
612
+ auto& runtime = *safeRuntime;
613
+ auto resultObj = Object(runtime);
614
+
615
+ if (segments) {
616
+ int n_segments = whisper_vad_segments_n_segments(segments);
617
+ resultObj.setProperty(runtime, "hasSpeech", Value(n_segments > 0));
618
+ auto segmentsArray = Array(runtime, n_segments);
619
+
620
+ for (int i = 0; i < n_segments; i++) {
621
+ auto segmentObj = Object(runtime);
622
+ segmentObj.setProperty(runtime, "t0", Value((double)whisper_vad_segments_get_segment_t0(segments, i)));
623
+ segmentObj.setProperty(runtime, "t1", Value((double)whisper_vad_segments_get_segment_t1(segments, i)));
624
+ segmentsArray.setValueAtIndex(runtime, i, segmentObj);
625
+ }
626
+
627
+ resultObj.setProperty(runtime, "segments", segmentsArray);
628
+ whisper_vad_free_segments(segments);
629
+ } else {
630
+ resultObj.setProperty(runtime, "hasSpeech", Value(false));
631
+ resultObj.setProperty(runtime, "segments", Array(runtime, 0));
632
+ }
633
+
634
+ resolvePtr->call(runtime, resultObj);
635
+ } catch (...) {
636
+ auto& runtime = *safeRuntime;
637
+ auto errorObj = createErrorObject(runtime, "VAD result processing error");
638
+ rejectPtr->call(runtime, errorObj);
639
+ }
640
+ });
641
+ }
642
+ );
643
+ } catch (const JSError& e) {
644
+ throw;
645
+ } catch (const std::exception& e) {
646
+ logError("Exception in whisperVadDetectSpeech: %s", e.what());
647
+ throw JSError(runtime, std::string("whisperVadDetectSpeech error: ") + e.what());
648
+ } catch (...) {
649
+ logError("Unknown exception in whisperVadDetectSpeech");
650
+ throw JSError(runtime, "whisperVadDetectSpeech encountered unknown error");
651
+ }
652
+ }
653
+ );
654
+
655
+ // Install the JSI functions on the global object
656
+ runtime.global().setProperty(runtime, "whisperTranscribeData", std::move(whisperTranscribeData));
657
+ runtime.global().setProperty(runtime, "whisperVadDetectSpeech", std::move(whisperVadDetectSpeech));
658
+
659
+ logInfo("JSI bindings installed successfully");
660
+ } catch (const JSError& e) {
661
+ logError("JSError installing JSI bindings: %s", e.getMessage().c_str());
662
+ throw;
663
+ } catch (const std::exception& e) {
664
+ logError("Exception installing JSI bindings: %s", e.what());
665
+ throw JSError(runtime, std::string("Failed to install JSI bindings: ") + e.what());
666
+ } catch (...) {
667
+ logError("Unknown exception installing JSI bindings");
668
+ throw JSError(runtime, "Failed to install JSI bindings: unknown error");
669
+ }
670
+ }
671
+
672
+ // Cleanup function to dispose of ThreadPool
673
+ void cleanupJSIBindings() {
674
+ std::lock_guard<std::mutex> lock(threadPoolMutex);
675
+ if (whisperThreadPool) {
676
+ logInfo("Cleaning up ThreadPool");
677
+ whisperThreadPool.reset();
678
+ }
679
+ }
680
+
681
+ } // namespace rnwhisper_jsi