whisper.rn 0.4.3 → 0.5.0-rc.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +1 -3
- package/android/build.gradle +70 -11
- package/android/src/main/CMakeLists.txt +28 -1
- package/android/src/main/java/com/rnwhisper/JSCallInvokerResolver.java +40 -0
- package/android/src/main/java/com/rnwhisper/RNWhisper.java +59 -11
- package/android/src/main/java/com/rnwhisper/WhisperContext.java +21 -9
- package/android/src/main/java/com/rnwhisper/WhisperVadContext.java +1 -1
- package/android/src/main/jni.cpp +79 -2
- package/android/src/newarch/java/com/rnwhisper/RNWhisperModule.java +5 -0
- package/android/src/oldarch/java/com/rnwhisper/RNWhisperModule.java +5 -0
- package/cpp/jsi/RNWhisperJSI.cpp +681 -0
- package/cpp/jsi/RNWhisperJSI.h +44 -0
- package/cpp/jsi/ThreadPool.h +100 -0
- package/ios/RNWhisper.h +3 -0
- package/ios/RNWhisper.mm +38 -0
- package/jest/mock.js +1 -0
- package/lib/commonjs/NativeRNWhisper.js.map +1 -1
- package/lib/commonjs/index.js +83 -2
- package/lib/commonjs/index.js.map +1 -1
- package/lib/module/NativeRNWhisper.js.map +1 -1
- package/lib/module/index.js +83 -2
- package/lib/module/index.js.map +1 -1
- package/lib/typescript/NativeRNWhisper.d.ts +4 -0
- package/lib/typescript/NativeRNWhisper.d.ts.map +1 -1
- package/lib/typescript/index.d.ts +18 -6
- package/lib/typescript/index.d.ts.map +1 -1
- package/package.json +2 -3
- package/src/NativeRNWhisper.ts +2 -0
- package/src/index.ts +162 -33
- package/whisper-rn.podspec +6 -3
|
@@ -0,0 +1,681 @@
|
|
|
1
|
+
#include "RNWhisperJSI.h"
|
|
2
|
+
#include "ThreadPool.h"
|
|
3
|
+
#include <jsi/jsi.h>
|
|
4
|
+
#include <memory>
|
|
5
|
+
#include <mutex>
|
|
6
|
+
#include <thread>
|
|
7
|
+
#include <unordered_map>
|
|
8
|
+
#include <algorithm>
|
|
9
|
+
#include <vector>
|
|
10
|
+
#include <atomic>
|
|
11
|
+
|
|
12
|
+
#if defined(__ANDROID__)
|
|
13
|
+
#include <android/log.h>
|
|
14
|
+
#endif
|
|
15
|
+
|
|
16
|
+
using namespace facebook::jsi;
|
|
17
|
+
|
|
18
|
+
namespace rnwhisper_jsi {
|
|
19
|
+
|
|
20
|
+
// Consolidated logging function
|
|
21
|
+
enum class LogLevel { LOG_DEBUG, LOG_INFO, LOG_ERROR };
|
|
22
|
+
|
|
23
|
+
static void log(LogLevel level, const char* format, ...) {
|
|
24
|
+
va_list args;
|
|
25
|
+
va_start(args, format);
|
|
26
|
+
|
|
27
|
+
#if defined(__ANDROID__)
|
|
28
|
+
int androidLevel = (level == LogLevel::LOG_DEBUG) ? ANDROID_LOG_DEBUG :
|
|
29
|
+
(level == LogLevel::LOG_INFO) ? ANDROID_LOG_INFO : ANDROID_LOG_ERROR;
|
|
30
|
+
__android_log_vprint(androidLevel, "RNWhisperJSI", format, args);
|
|
31
|
+
#else
|
|
32
|
+
char buffer[1024];
|
|
33
|
+
vsnprintf(buffer, sizeof(buffer), format, args);
|
|
34
|
+
const char* levelStr = (level == LogLevel::LOG_DEBUG) ? "DEBUG" :
|
|
35
|
+
(level == LogLevel::LOG_INFO) ? "INFO" : "ERROR";
|
|
36
|
+
printf("RNWhisperJSI %s: %s\n", levelStr, buffer);
|
|
37
|
+
#endif
|
|
38
|
+
|
|
39
|
+
va_end(args);
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
#define logInfo(format, ...) log(LogLevel::LOG_INFO, format, ##__VA_ARGS__)
|
|
43
|
+
#define logError(format, ...) log(LogLevel::LOG_ERROR, format, ##__VA_ARGS__)
|
|
44
|
+
#define logDebug(format, ...) log(LogLevel::LOG_DEBUG, format, ##__VA_ARGS__)
|
|
45
|
+
|
|
46
|
+
static std::unique_ptr<ThreadPool> whisperThreadPool = nullptr;
|
|
47
|
+
static std::mutex threadPoolMutex;
|
|
48
|
+
|
|
49
|
+
// Initialize ThreadPool with optimal thread count
|
|
50
|
+
ThreadPool& getWhisperThreadPool() {
|
|
51
|
+
std::lock_guard<std::mutex> lock(threadPoolMutex);
|
|
52
|
+
if (!whisperThreadPool) {
|
|
53
|
+
int max_threads = std::thread::hardware_concurrency();
|
|
54
|
+
int thread_count = std::max(2, std::min(4, max_threads)); // Use 2-4 threads
|
|
55
|
+
whisperThreadPool = std::make_unique<ThreadPool>(thread_count);
|
|
56
|
+
logInfo("Initialized ThreadPool with %d threads", thread_count);
|
|
57
|
+
}
|
|
58
|
+
return *whisperThreadPool;
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
// Template-based context management
|
|
62
|
+
template<typename T>
|
|
63
|
+
class ContextManager {
|
|
64
|
+
private:
|
|
65
|
+
std::unordered_map<int, long> contextMap;
|
|
66
|
+
std::mutex contextMutex;
|
|
67
|
+
const char* contextType;
|
|
68
|
+
|
|
69
|
+
public:
|
|
70
|
+
ContextManager(const char* type) : contextType(type) {}
|
|
71
|
+
|
|
72
|
+
void add(int contextId, long contextPtr) {
|
|
73
|
+
std::lock_guard<std::mutex> lock(contextMutex);
|
|
74
|
+
contextMap[contextId] = contextPtr;
|
|
75
|
+
logDebug("%s Context added: id=%d, ptr=%ld", contextType, contextId, contextPtr);
|
|
76
|
+
}
|
|
77
|
+
|
|
78
|
+
void remove(int contextId) {
|
|
79
|
+
std::lock_guard<std::mutex> lock(contextMutex);
|
|
80
|
+
auto it = contextMap.find(contextId);
|
|
81
|
+
if (it != contextMap.end()) {
|
|
82
|
+
logDebug("%s Context removed: id=%d", contextType, contextId);
|
|
83
|
+
contextMap.erase(it);
|
|
84
|
+
}
|
|
85
|
+
}
|
|
86
|
+
|
|
87
|
+
long get(int contextId) {
|
|
88
|
+
std::lock_guard<std::mutex> lock(contextMutex);
|
|
89
|
+
auto it = contextMap.find(contextId);
|
|
90
|
+
return (it != contextMap.end()) ? it->second : 0;
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
T* getTyped(int contextId) {
|
|
94
|
+
long ptr = get(contextId);
|
|
95
|
+
return ptr ? reinterpret_cast<T*>(ptr) : nullptr;
|
|
96
|
+
}
|
|
97
|
+
};
|
|
98
|
+
|
|
99
|
+
static ContextManager<whisper_context> contextManager("Whisper");
|
|
100
|
+
static ContextManager<whisper_vad_context> vadContextManager("VAD");
|
|
101
|
+
|
|
102
|
+
// Context management functions
|
|
103
|
+
void addContext(int contextId, long contextPtr) {
|
|
104
|
+
contextManager.add(contextId, contextPtr);
|
|
105
|
+
}
|
|
106
|
+
|
|
107
|
+
void removeContext(int contextId) {
|
|
108
|
+
contextManager.remove(contextId);
|
|
109
|
+
}
|
|
110
|
+
|
|
111
|
+
void addVadContext(int contextId, long vadContextPtr) {
|
|
112
|
+
vadContextManager.add(contextId, vadContextPtr);
|
|
113
|
+
}
|
|
114
|
+
|
|
115
|
+
void removeVadContext(int contextId) {
|
|
116
|
+
vadContextManager.remove(contextId);
|
|
117
|
+
}
|
|
118
|
+
|
|
119
|
+
long getContextPtr(int contextId) {
|
|
120
|
+
return contextManager.get(contextId);
|
|
121
|
+
}
|
|
122
|
+
|
|
123
|
+
long getVadContextPtr(int contextId) {
|
|
124
|
+
return vadContextManager.get(contextId);
|
|
125
|
+
}
|
|
126
|
+
|
|
127
|
+
// Helper function to validate JSI function arguments
|
|
128
|
+
struct JSIValidationResult {
|
|
129
|
+
bool isValid;
|
|
130
|
+
std::string errorMessage;
|
|
131
|
+
};
|
|
132
|
+
|
|
133
|
+
JSIValidationResult validateJSIArguments(Runtime& runtime, const Value* arguments, size_t count, size_t expectedCount) {
|
|
134
|
+
if (count != expectedCount) {
|
|
135
|
+
return {false, "Expected " + std::to_string(expectedCount) + " arguments, got " + std::to_string(count)};
|
|
136
|
+
}
|
|
137
|
+
|
|
138
|
+
if (!arguments[0].isNumber()) {
|
|
139
|
+
return {false, "First argument (contextId) must be a number"};
|
|
140
|
+
}
|
|
141
|
+
|
|
142
|
+
if (!arguments[1].isObject()) {
|
|
143
|
+
return {false, "Second argument (options) must be an object"};
|
|
144
|
+
}
|
|
145
|
+
|
|
146
|
+
if (!arguments[2].isObject() || !arguments[2].getObject(runtime).isArrayBuffer(runtime)) {
|
|
147
|
+
return {false, "Third argument must be an ArrayBuffer"};
|
|
148
|
+
}
|
|
149
|
+
|
|
150
|
+
return {true, ""};
|
|
151
|
+
}
|
|
152
|
+
|
|
153
|
+
// Helper function to create error objects
|
|
154
|
+
Object createErrorObject(Runtime& runtime, const std::string& message, int code = -1) {
|
|
155
|
+
auto errorObj = Object(runtime);
|
|
156
|
+
errorObj.setProperty(runtime, "message", String::createFromUtf8(runtime, message));
|
|
157
|
+
if (code != -1) {
|
|
158
|
+
errorObj.setProperty(runtime, "code", Value(code));
|
|
159
|
+
}
|
|
160
|
+
return errorObj;
|
|
161
|
+
}
|
|
162
|
+
|
|
163
|
+
// Helper function to convert JSI object to whisper_full_params
|
|
164
|
+
whisper_full_params createFullParamsFromJSI(Runtime& runtime, const Object& optionsObj) {
|
|
165
|
+
whisper_full_params params = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
|
|
166
|
+
|
|
167
|
+
params.print_realtime = false;
|
|
168
|
+
params.print_progress = false;
|
|
169
|
+
params.print_timestamps = false;
|
|
170
|
+
params.print_special = false;
|
|
171
|
+
|
|
172
|
+
int max_threads = std::thread::hardware_concurrency();
|
|
173
|
+
int default_n_threads = max_threads == 4 ? 2 : std::min(4, max_threads);
|
|
174
|
+
|
|
175
|
+
try {
|
|
176
|
+
auto propNames = optionsObj.getPropertyNames(runtime);
|
|
177
|
+
for (size_t i = 0; i < propNames.size(runtime); i++) {
|
|
178
|
+
auto propNameValue = propNames.getValueAtIndex(runtime, i);
|
|
179
|
+
std::string propName = propNameValue.getString(runtime).utf8(runtime);
|
|
180
|
+
Value propValue = optionsObj.getProperty(runtime, propNameValue.getString(runtime));
|
|
181
|
+
|
|
182
|
+
if (propName == "maxThreads" && propValue.isNumber()) {
|
|
183
|
+
int n_threads = (int)propValue.getNumber();
|
|
184
|
+
params.n_threads = n_threads > 0 ? n_threads : default_n_threads;
|
|
185
|
+
} else if (propName == "translate" && propValue.isBool()) {
|
|
186
|
+
params.translate = propValue.getBool();
|
|
187
|
+
} else if (propName == "tokenTimestamps" && propValue.isBool()) {
|
|
188
|
+
params.token_timestamps = propValue.getBool();
|
|
189
|
+
} else if (propName == "tdrzEnable" && propValue.isBool()) {
|
|
190
|
+
params.tdrz_enable = propValue.getBool();
|
|
191
|
+
} else if (propName == "beamSize" && propValue.isNumber()) {
|
|
192
|
+
int beam_size = (int)propValue.getNumber();
|
|
193
|
+
if (beam_size > 0) {
|
|
194
|
+
params.strategy = WHISPER_SAMPLING_BEAM_SEARCH;
|
|
195
|
+
params.beam_search.beam_size = beam_size;
|
|
196
|
+
}
|
|
197
|
+
} else if (propName == "bestOf" && propValue.isNumber()) {
|
|
198
|
+
params.greedy.best_of = (int)propValue.getNumber();
|
|
199
|
+
} else if (propName == "maxLen" && propValue.isNumber()) {
|
|
200
|
+
params.max_len = (int)propValue.getNumber();
|
|
201
|
+
} else if (propName == "maxContext" && propValue.isNumber()) {
|
|
202
|
+
params.n_max_text_ctx = (int)propValue.getNumber();
|
|
203
|
+
} else if (propName == "offset" && propValue.isNumber()) {
|
|
204
|
+
params.offset_ms = (int)propValue.getNumber();
|
|
205
|
+
} else if (propName == "duration" && propValue.isNumber()) {
|
|
206
|
+
params.duration_ms = (int)propValue.getNumber();
|
|
207
|
+
} else if (propName == "wordThold" && propValue.isNumber()) {
|
|
208
|
+
params.thold_pt = (int)propValue.getNumber();
|
|
209
|
+
} else if (propName == "temperature" && propValue.isNumber()) {
|
|
210
|
+
params.temperature = (float)propValue.getNumber();
|
|
211
|
+
} else if (propName == "temperatureInc" && propValue.isNumber()) {
|
|
212
|
+
params.temperature_inc = (float)propValue.getNumber();
|
|
213
|
+
} else if (propName == "prompt" && propValue.isString()) {
|
|
214
|
+
std::string prompt = propValue.getString(runtime).utf8(runtime);
|
|
215
|
+
params.initial_prompt = strdup(prompt.c_str());
|
|
216
|
+
} else if (propName == "language" && propValue.isString()) {
|
|
217
|
+
std::string language = propValue.getString(runtime).utf8(runtime);
|
|
218
|
+
params.language = strdup(language.c_str());
|
|
219
|
+
}
|
|
220
|
+
}
|
|
221
|
+
} catch (...) {
|
|
222
|
+
// Use default values if parsing fails
|
|
223
|
+
}
|
|
224
|
+
|
|
225
|
+
params.offset_ms = 0;
|
|
226
|
+
params.no_context = true;
|
|
227
|
+
params.single_segment = false;
|
|
228
|
+
|
|
229
|
+
return params;
|
|
230
|
+
}
|
|
231
|
+
|
|
232
|
+
// Helper function to convert ArrayBuffer to float32 audio data
|
|
233
|
+
struct AudioData {
|
|
234
|
+
std::vector<float> data;
|
|
235
|
+
int count;
|
|
236
|
+
};
|
|
237
|
+
|
|
238
|
+
AudioData convertArrayBufferToAudioData(Runtime& runtime, size_t arrayBufferSize, uint8_t* arrayBufferData) {
|
|
239
|
+
// Convert ArrayBuffer to float32 array (assuming 16-bit PCM input)
|
|
240
|
+
if (arrayBufferSize % 2 != 0) {
|
|
241
|
+
throw JSError(runtime, "ArrayBuffer size must be even for 16-bit PCM data");
|
|
242
|
+
}
|
|
243
|
+
|
|
244
|
+
int audioDataCount = (int)(arrayBufferSize / 2); // 16-bit samples
|
|
245
|
+
std::vector<float> audioData(audioDataCount);
|
|
246
|
+
|
|
247
|
+
// Convert 16-bit PCM to float32
|
|
248
|
+
int16_t* pcmData = (int16_t*)arrayBufferData;
|
|
249
|
+
for (int i = 0; i < audioDataCount; i++) {
|
|
250
|
+
audioData[i] = (float)pcmData[i] / 32768.0f;
|
|
251
|
+
}
|
|
252
|
+
|
|
253
|
+
return {std::move(audioData), audioDataCount};
|
|
254
|
+
}
|
|
255
|
+
|
|
256
|
+
// Common callback data structure
|
|
257
|
+
template<typename CallbackType>
|
|
258
|
+
struct CallbackData {
|
|
259
|
+
std::shared_ptr<facebook::react::CallInvoker> callInvoker;
|
|
260
|
+
std::shared_ptr<Function> callback;
|
|
261
|
+
std::shared_ptr<Runtime> safeRuntime;
|
|
262
|
+
std::atomic<int> counter{0};
|
|
263
|
+
};
|
|
264
|
+
|
|
265
|
+
// Helper function to extract callbacks from options
|
|
266
|
+
struct CallbackInfo {
|
|
267
|
+
std::shared_ptr<Function> onProgressCallback;
|
|
268
|
+
std::shared_ptr<Function> onNewSegmentsCallback;
|
|
269
|
+
int jobId;
|
|
270
|
+
};
|
|
271
|
+
|
|
272
|
+
CallbackInfo extractCallbacks(Runtime& runtime, const Object& optionsObj) {
|
|
273
|
+
CallbackInfo info;
|
|
274
|
+
info.jobId = rand(); // Default fallback jobId
|
|
275
|
+
|
|
276
|
+
try {
|
|
277
|
+
auto propNames = optionsObj.getPropertyNames(runtime);
|
|
278
|
+
for (size_t i = 0; i < propNames.size(runtime); i++) {
|
|
279
|
+
auto propNameValue = propNames.getValueAtIndex(runtime, i);
|
|
280
|
+
std::string propName = propNameValue.getString(runtime).utf8(runtime);
|
|
281
|
+
Value propValue = optionsObj.getProperty(runtime, propNameValue.getString(runtime));
|
|
282
|
+
|
|
283
|
+
if (propName == "onProgress" && propValue.isObject() && propValue.getObject(runtime).isFunction(runtime)) {
|
|
284
|
+
info.onProgressCallback = std::make_shared<Function>(propValue.getObject(runtime).getFunction(runtime));
|
|
285
|
+
} else if (propName == "onNewSegments" && propValue.isObject() && propValue.getObject(runtime).isFunction(runtime)) {
|
|
286
|
+
info.onNewSegmentsCallback = std::make_shared<Function>(propValue.getObject(runtime).getFunction(runtime));
|
|
287
|
+
} else if (propName == "jobId" && propValue.isNumber()) {
|
|
288
|
+
info.jobId = (int)propValue.getNumber();
|
|
289
|
+
}
|
|
290
|
+
}
|
|
291
|
+
} catch (...) {
|
|
292
|
+
// Ignore callback detection errors
|
|
293
|
+
}
|
|
294
|
+
|
|
295
|
+
return info;
|
|
296
|
+
}
|
|
297
|
+
|
|
298
|
+
// Helper function to create segments array
|
|
299
|
+
Array createSegmentsArray(Runtime& runtime, struct whisper_context* ctx, int offset) {
|
|
300
|
+
int n_segments = whisper_full_n_segments(ctx);
|
|
301
|
+
auto segmentsArray = Array(runtime, n_segments);
|
|
302
|
+
|
|
303
|
+
for (int i = offset; i < n_segments; i++) {
|
|
304
|
+
const char* text = whisper_full_get_segment_text(ctx, i);
|
|
305
|
+
auto segmentObj = Object(runtime);
|
|
306
|
+
segmentObj.setProperty(runtime, "text", String::createFromUtf8(runtime, text));
|
|
307
|
+
segmentObj.setProperty(runtime, "t0", Value((double)whisper_full_get_segment_t0(ctx, i)));
|
|
308
|
+
segmentObj.setProperty(runtime, "t1", Value((double)whisper_full_get_segment_t1(ctx, i)));
|
|
309
|
+
segmentsArray.setValueAtIndex(runtime, i, segmentObj);
|
|
310
|
+
}
|
|
311
|
+
|
|
312
|
+
return segmentsArray;
|
|
313
|
+
}
|
|
314
|
+
|
|
315
|
+
// Helper function to create full text from segments
|
|
316
|
+
std::string createFullTextFromSegments(struct whisper_context* ctx, int offset) {
|
|
317
|
+
int n_segments = whisper_full_n_segments(ctx);
|
|
318
|
+
std::string fullText = "";
|
|
319
|
+
|
|
320
|
+
for (int i = offset; i < n_segments; i++) {
|
|
321
|
+
const char* text = whisper_full_get_segment_text(ctx, i);
|
|
322
|
+
fullText += text;
|
|
323
|
+
}
|
|
324
|
+
|
|
325
|
+
return fullText;
|
|
326
|
+
}
|
|
327
|
+
|
|
328
|
+
// Helper function to create and execute promise-based operations
|
|
329
|
+
template<typename ContextType, typename TaskFunc>
|
|
330
|
+
Value createPromiseTask(
|
|
331
|
+
Runtime& runtime,
|
|
332
|
+
const std::string& functionName,
|
|
333
|
+
std::shared_ptr<facebook::react::CallInvoker> callInvoker,
|
|
334
|
+
const Value* arguments,
|
|
335
|
+
size_t count,
|
|
336
|
+
TaskFunc task
|
|
337
|
+
) {
|
|
338
|
+
// Validate arguments
|
|
339
|
+
auto validation = validateJSIArguments(runtime, arguments, count, 3);
|
|
340
|
+
if (!validation.isValid) {
|
|
341
|
+
throw JSError(runtime, functionName + " " + validation.errorMessage);
|
|
342
|
+
}
|
|
343
|
+
|
|
344
|
+
int contextId = (int)arguments[0].getNumber();
|
|
345
|
+
auto optionsObj = arguments[1].getObject(runtime);
|
|
346
|
+
auto arrayBuffer = arguments[2].getObject(runtime).getArrayBuffer(runtime);
|
|
347
|
+
|
|
348
|
+
size_t arrayBufferSize = arrayBuffer.size(runtime);
|
|
349
|
+
uint8_t* arrayBufferData = arrayBuffer.data(runtime);
|
|
350
|
+
|
|
351
|
+
logInfo("%s called with contextId=%d, arrayBuffer size=%zu", functionName.c_str(), contextId, arrayBufferSize);
|
|
352
|
+
|
|
353
|
+
// Convert ArrayBuffer to audio data
|
|
354
|
+
AudioData audioResult = convertArrayBufferToAudioData(runtime, arrayBufferSize, arrayBufferData);
|
|
355
|
+
|
|
356
|
+
whisper_full_params params = {};
|
|
357
|
+
CallbackInfo callbackInfo = {};
|
|
358
|
+
if (functionName == "whisperTranscribeData") {
|
|
359
|
+
params = createFullParamsFromJSI(runtime, optionsObj);
|
|
360
|
+
// Extract data from optionsObj before lambda capture
|
|
361
|
+
callbackInfo = extractCallbacks(runtime, optionsObj);
|
|
362
|
+
}
|
|
363
|
+
|
|
364
|
+
// Create promise
|
|
365
|
+
auto promiseConstructor = runtime.global().getPropertyAsFunction(runtime, "Promise");
|
|
366
|
+
|
|
367
|
+
auto promiseExecutor = Function::createFromHostFunction(
|
|
368
|
+
runtime,
|
|
369
|
+
PropNameID::forAscii(runtime, ""),
|
|
370
|
+
2, // resolve, reject
|
|
371
|
+
[contextId, audioResult, params, callbackInfo, task, callInvoker, functionName](Runtime& runtime, const Value& thisValue, const Value* arguments, size_t count) -> Value {
|
|
372
|
+
if (count != 2) {
|
|
373
|
+
throw JSError(runtime, "Promise executor expects 2 arguments (resolve, reject)");
|
|
374
|
+
}
|
|
375
|
+
|
|
376
|
+
auto resolvePtr = std::make_shared<Function>(arguments[0].getObject(runtime).getFunction(runtime));
|
|
377
|
+
auto rejectPtr = std::make_shared<Function>(arguments[1].getObject(runtime).getFunction(runtime));
|
|
378
|
+
auto safeRuntime = std::shared_ptr<Runtime>(&runtime, [](Runtime*){});
|
|
379
|
+
|
|
380
|
+
// Execute task in ThreadPool
|
|
381
|
+
auto future = getWhisperThreadPool().enqueue([
|
|
382
|
+
contextId, audioResult, params, callbackInfo, task, resolvePtr, rejectPtr, callInvoker, safeRuntime, functionName]() {
|
|
383
|
+
|
|
384
|
+
try {
|
|
385
|
+
task(contextId, audioResult, params, callbackInfo, resolvePtr, rejectPtr, callInvoker, safeRuntime);
|
|
386
|
+
} catch (...) {
|
|
387
|
+
callInvoker->invokeAsync([rejectPtr, safeRuntime, functionName]() {
|
|
388
|
+
auto& runtime = *safeRuntime;
|
|
389
|
+
auto errorObj = createErrorObject(runtime, functionName + " processing error");
|
|
390
|
+
rejectPtr->call(runtime, errorObj);
|
|
391
|
+
});
|
|
392
|
+
}
|
|
393
|
+
});
|
|
394
|
+
|
|
395
|
+
return Value::undefined();
|
|
396
|
+
}
|
|
397
|
+
);
|
|
398
|
+
|
|
399
|
+
return promiseConstructor.callAsConstructor(runtime, promiseExecutor);
|
|
400
|
+
}
|
|
401
|
+
|
|
402
|
+
void installJSIBindings(
|
|
403
|
+
facebook::jsi::Runtime& runtime,
|
|
404
|
+
std::shared_ptr<facebook::react::CallInvoker> callInvoker
|
|
405
|
+
) {
|
|
406
|
+
try {
|
|
407
|
+
// whisperTranscribeData function
|
|
408
|
+
auto whisperTranscribeData = Function::createFromHostFunction(
|
|
409
|
+
runtime,
|
|
410
|
+
PropNameID::forAscii(runtime, "whisperTranscribeData"),
|
|
411
|
+
3, // number of arguments
|
|
412
|
+
[callInvoker](Runtime& runtime, const Value& thisValue, const Value* arguments, size_t count) -> Value {
|
|
413
|
+
try {
|
|
414
|
+
return createPromiseTask<whisper_context>(
|
|
415
|
+
runtime, "whisperTranscribeData", callInvoker, arguments, count,
|
|
416
|
+
[](int contextId, const AudioData& audioResult, const whisper_full_params& params, const CallbackInfo& callbackInfo,
|
|
417
|
+
std::shared_ptr<Function> resolvePtr, std::shared_ptr<Function> rejectPtr,
|
|
418
|
+
std::shared_ptr<facebook::react::CallInvoker> callInvoker,
|
|
419
|
+
std::shared_ptr<Runtime> safeRuntime) {
|
|
420
|
+
|
|
421
|
+
// Get context
|
|
422
|
+
auto context = contextManager.getTyped(contextId);
|
|
423
|
+
if (!context) {
|
|
424
|
+
callInvoker->invokeAsync([rejectPtr, safeRuntime, contextId]() {
|
|
425
|
+
auto& runtime = *safeRuntime;
|
|
426
|
+
auto errorObj = createErrorObject(runtime, "Context not found for id: " + std::to_string(contextId));
|
|
427
|
+
rejectPtr->call(runtime, errorObj);
|
|
428
|
+
});
|
|
429
|
+
return;
|
|
430
|
+
}
|
|
431
|
+
|
|
432
|
+
// Validate audio data
|
|
433
|
+
if (audioResult.data.empty() || audioResult.count <= 0) {
|
|
434
|
+
logError("Invalid audio data: size=%zu, count=%d", audioResult.data.size(), audioResult.count);
|
|
435
|
+
callInvoker->invokeAsync([rejectPtr, safeRuntime]() {
|
|
436
|
+
auto& runtime = *safeRuntime;
|
|
437
|
+
auto errorObj = createErrorObject(runtime, "Invalid audio data");
|
|
438
|
+
rejectPtr->call(runtime, errorObj);
|
|
439
|
+
});
|
|
440
|
+
return;
|
|
441
|
+
}
|
|
442
|
+
|
|
443
|
+
logInfo("Starting whisper_full: context=%p, audioDataCount=%d, jobId=%d",
|
|
444
|
+
context, audioResult.count, callbackInfo.jobId);
|
|
445
|
+
whisper_reset_timings(context);
|
|
446
|
+
|
|
447
|
+
// Setup callbacks
|
|
448
|
+
whisper_full_params mutable_params = params;
|
|
449
|
+
auto progress_data = std::make_shared<CallbackData<Function>>();
|
|
450
|
+
progress_data->callInvoker = callInvoker;
|
|
451
|
+
progress_data->callback = callbackInfo.onProgressCallback;
|
|
452
|
+
progress_data->safeRuntime = safeRuntime;
|
|
453
|
+
|
|
454
|
+
if (callbackInfo.onProgressCallback) {
|
|
455
|
+
mutable_params.progress_callback = [](struct whisper_context* /*ctx*/, struct whisper_state* /*state*/, int progress, void* user_data) {
|
|
456
|
+
auto* data_ptr = static_cast<std::shared_ptr<CallbackData<Function>>*>(user_data);
|
|
457
|
+
if (data_ptr && *data_ptr) {
|
|
458
|
+
auto data = *data_ptr;
|
|
459
|
+
if (data->callInvoker && data->callback && data->safeRuntime) {
|
|
460
|
+
data->callInvoker->invokeAsync([progress, callback = data->callback, safeRuntime = data->safeRuntime]() {
|
|
461
|
+
try {
|
|
462
|
+
logInfo("Progress: %d%%", progress);
|
|
463
|
+
auto& runtime = *safeRuntime;
|
|
464
|
+
callback->call(runtime, Value(progress));
|
|
465
|
+
} catch (...) {
|
|
466
|
+
logError("Error in progress callback");
|
|
467
|
+
}
|
|
468
|
+
});
|
|
469
|
+
}
|
|
470
|
+
}
|
|
471
|
+
};
|
|
472
|
+
mutable_params.progress_callback_user_data = &progress_data;
|
|
473
|
+
}
|
|
474
|
+
|
|
475
|
+
auto segments_data = std::make_shared<CallbackData<Function>>();
|
|
476
|
+
segments_data->callInvoker = callInvoker;
|
|
477
|
+
segments_data->callback = callbackInfo.onNewSegmentsCallback;
|
|
478
|
+
segments_data->safeRuntime = safeRuntime;
|
|
479
|
+
|
|
480
|
+
if (callbackInfo.onNewSegmentsCallback) {
|
|
481
|
+
mutable_params.new_segment_callback = [](struct whisper_context* ctx, struct whisper_state* /*state*/, int n_new, void* user_data) {
|
|
482
|
+
auto* data_ptr = static_cast<std::shared_ptr<CallbackData<Function>>*>(user_data);
|
|
483
|
+
if (data_ptr && *data_ptr) {
|
|
484
|
+
auto data = *data_ptr;
|
|
485
|
+
if (data->callInvoker && data->callback && data->safeRuntime && ctx) {
|
|
486
|
+
int current_total = data->counter.fetch_add(n_new) + n_new;
|
|
487
|
+
data->callInvoker->invokeAsync([ctx, n_new, current_total, callback = data->callback, safeRuntime = data->safeRuntime]() {
|
|
488
|
+
try {
|
|
489
|
+
logInfo("New segments: %d (total: %d)", n_new, current_total);
|
|
490
|
+
auto& runtime = *safeRuntime;
|
|
491
|
+
auto resultObj = Object(runtime);
|
|
492
|
+
resultObj.setProperty(runtime, "nNew", Value(n_new));
|
|
493
|
+
resultObj.setProperty(runtime, "totalNNew", Value(current_total));
|
|
494
|
+
auto offset = current_total - n_new;
|
|
495
|
+
resultObj.setProperty(runtime, "segments", createSegmentsArray(runtime, ctx, offset));
|
|
496
|
+
resultObj.setProperty(runtime, "result", String::createFromUtf8(runtime, createFullTextFromSegments(ctx, offset)));
|
|
497
|
+
callback->call(runtime, resultObj);
|
|
498
|
+
} catch (...) {
|
|
499
|
+
logError("Error in new segments callback");
|
|
500
|
+
}
|
|
501
|
+
});
|
|
502
|
+
}
|
|
503
|
+
}
|
|
504
|
+
};
|
|
505
|
+
mutable_params.new_segment_callback_user_data = &segments_data;
|
|
506
|
+
}
|
|
507
|
+
|
|
508
|
+
// Execute transcription
|
|
509
|
+
rnwhisper::job* job = rnwhisper::job_new(callbackInfo.jobId, mutable_params);
|
|
510
|
+
int code = -1;
|
|
511
|
+
|
|
512
|
+
if (job == nullptr) {
|
|
513
|
+
logError("Failed to create job for transcription");
|
|
514
|
+
code = -2;
|
|
515
|
+
} else {
|
|
516
|
+
code = whisper_full(context, job->params, audioResult.data.data(), audioResult.count);
|
|
517
|
+
if (job->is_aborted()) {
|
|
518
|
+
code = -999;
|
|
519
|
+
}
|
|
520
|
+
rnwhisper::job_remove(callbackInfo.jobId);
|
|
521
|
+
}
|
|
522
|
+
|
|
523
|
+
// Resolve with results
|
|
524
|
+
callInvoker->invokeAsync([resolvePtr, rejectPtr, code, context, safeRuntime]() {
|
|
525
|
+
try {
|
|
526
|
+
auto& runtime = *safeRuntime;
|
|
527
|
+
if (code == 0) {
|
|
528
|
+
auto resultObj = Object(runtime);
|
|
529
|
+
resultObj.setProperty(runtime, "code", Value(code));
|
|
530
|
+
resultObj.setProperty(runtime, "result", String::createFromUtf8(runtime, createFullTextFromSegments(context, 0)));
|
|
531
|
+
resultObj.setProperty(runtime, "segments", createSegmentsArray(runtime, context, 0));
|
|
532
|
+
resolvePtr->call(runtime, resultObj);
|
|
533
|
+
} else {
|
|
534
|
+
std::string errorMsg = (code == -2) ? "Failed to create transcription job" :
|
|
535
|
+
(code == -999) ? "Transcription was aborted" :
|
|
536
|
+
"Transcription failed";
|
|
537
|
+
auto errorObj = createErrorObject(runtime, errorMsg, code);
|
|
538
|
+
rejectPtr->call(runtime, errorObj);
|
|
539
|
+
}
|
|
540
|
+
} catch (...) {
|
|
541
|
+
auto& runtime = *safeRuntime;
|
|
542
|
+
auto errorObj = createErrorObject(runtime, "Unknown error");
|
|
543
|
+
rejectPtr->call(runtime, errorObj);
|
|
544
|
+
}
|
|
545
|
+
});
|
|
546
|
+
}
|
|
547
|
+
);
|
|
548
|
+
} catch (const JSError& e) {
|
|
549
|
+
throw;
|
|
550
|
+
} catch (const std::exception& e) {
|
|
551
|
+
logError("Exception in whisperTranscribeData: %s", e.what());
|
|
552
|
+
throw JSError(runtime, std::string("whisperTranscribeData error: ") + e.what());
|
|
553
|
+
} catch (...) {
|
|
554
|
+
logError("Unknown exception in whisperTranscribeData");
|
|
555
|
+
throw JSError(runtime, "whisperTranscribeData encountered unknown error");
|
|
556
|
+
}
|
|
557
|
+
}
|
|
558
|
+
);
|
|
559
|
+
|
|
560
|
+
// whisperVadDetectSpeech function
|
|
561
|
+
auto whisperVadDetectSpeech = Function::createFromHostFunction(
|
|
562
|
+
runtime,
|
|
563
|
+
PropNameID::forAscii(runtime, "whisperVadDetectSpeech"),
|
|
564
|
+
3, // number of arguments
|
|
565
|
+
[callInvoker](Runtime& runtime, const Value& thisValue, const Value* arguments, size_t count) -> Value {
|
|
566
|
+
try {
|
|
567
|
+
return createPromiseTask<whisper_vad_context>(
|
|
568
|
+
runtime, "whisperVadDetectSpeech", callInvoker, arguments, count,
|
|
569
|
+
[](int contextId, const AudioData& audioResult, const whisper_full_params& params, const CallbackInfo& callbackInfo,
|
|
570
|
+
std::shared_ptr<Function> resolvePtr, std::shared_ptr<Function> rejectPtr,
|
|
571
|
+
std::shared_ptr<facebook::react::CallInvoker> callInvoker,
|
|
572
|
+
std::shared_ptr<Runtime> safeRuntime) {
|
|
573
|
+
|
|
574
|
+
// Get VAD context
|
|
575
|
+
auto vadContext = vadContextManager.getTyped(contextId);
|
|
576
|
+
if (!vadContext) {
|
|
577
|
+
callInvoker->invokeAsync([rejectPtr, safeRuntime, contextId]() {
|
|
578
|
+
auto& runtime = *safeRuntime;
|
|
579
|
+
auto errorObj = createErrorObject(runtime, "VAD Context not found for id: " + std::to_string(contextId));
|
|
580
|
+
rejectPtr->call(runtime, errorObj);
|
|
581
|
+
});
|
|
582
|
+
return;
|
|
583
|
+
}
|
|
584
|
+
|
|
585
|
+
// Validate audio data
|
|
586
|
+
if (audioResult.data.empty() || audioResult.count <= 0) {
|
|
587
|
+
logError("Invalid audio data: size=%zu, count=%d", audioResult.data.size(), audioResult.count);
|
|
588
|
+
callInvoker->invokeAsync([rejectPtr, safeRuntime]() {
|
|
589
|
+
auto& runtime = *safeRuntime;
|
|
590
|
+
auto errorObj = createErrorObject(runtime, "Invalid audio data");
|
|
591
|
+
rejectPtr->call(runtime, errorObj);
|
|
592
|
+
});
|
|
593
|
+
return;
|
|
594
|
+
}
|
|
595
|
+
|
|
596
|
+
logInfo("Starting whisper_vad_detect_speech: vadContext=%p, audioDataCount=%d",
|
|
597
|
+
vadContext, audioResult.count);
|
|
598
|
+
|
|
599
|
+
// Perform VAD detection
|
|
600
|
+
bool isSpeech = whisper_vad_detect_speech(vadContext, audioResult.data.data(), audioResult.count);
|
|
601
|
+
logInfo("VAD detection result: %s", isSpeech ? "speech" : "no speech");
|
|
602
|
+
|
|
603
|
+
struct whisper_vad_params vad_params = whisper_vad_default_params();
|
|
604
|
+
struct whisper_vad_segments* segments = nullptr;
|
|
605
|
+
if (isSpeech) {
|
|
606
|
+
segments = whisper_vad_segments_from_probs(vadContext, vad_params);
|
|
607
|
+
}
|
|
608
|
+
|
|
609
|
+
// Process results on JS thread
|
|
610
|
+
callInvoker->invokeAsync([resolvePtr, rejectPtr, segments, safeRuntime]() {
|
|
611
|
+
try {
|
|
612
|
+
auto& runtime = *safeRuntime;
|
|
613
|
+
auto resultObj = Object(runtime);
|
|
614
|
+
|
|
615
|
+
if (segments) {
|
|
616
|
+
int n_segments = whisper_vad_segments_n_segments(segments);
|
|
617
|
+
resultObj.setProperty(runtime, "hasSpeech", Value(n_segments > 0));
|
|
618
|
+
auto segmentsArray = Array(runtime, n_segments);
|
|
619
|
+
|
|
620
|
+
for (int i = 0; i < n_segments; i++) {
|
|
621
|
+
auto segmentObj = Object(runtime);
|
|
622
|
+
segmentObj.setProperty(runtime, "t0", Value((double)whisper_vad_segments_get_segment_t0(segments, i)));
|
|
623
|
+
segmentObj.setProperty(runtime, "t1", Value((double)whisper_vad_segments_get_segment_t1(segments, i)));
|
|
624
|
+
segmentsArray.setValueAtIndex(runtime, i, segmentObj);
|
|
625
|
+
}
|
|
626
|
+
|
|
627
|
+
resultObj.setProperty(runtime, "segments", segmentsArray);
|
|
628
|
+
whisper_vad_free_segments(segments);
|
|
629
|
+
} else {
|
|
630
|
+
resultObj.setProperty(runtime, "hasSpeech", Value(false));
|
|
631
|
+
resultObj.setProperty(runtime, "segments", Array(runtime, 0));
|
|
632
|
+
}
|
|
633
|
+
|
|
634
|
+
resolvePtr->call(runtime, resultObj);
|
|
635
|
+
} catch (...) {
|
|
636
|
+
auto& runtime = *safeRuntime;
|
|
637
|
+
auto errorObj = createErrorObject(runtime, "VAD result processing error");
|
|
638
|
+
rejectPtr->call(runtime, errorObj);
|
|
639
|
+
}
|
|
640
|
+
});
|
|
641
|
+
}
|
|
642
|
+
);
|
|
643
|
+
} catch (const JSError& e) {
|
|
644
|
+
throw;
|
|
645
|
+
} catch (const std::exception& e) {
|
|
646
|
+
logError("Exception in whisperVadDetectSpeech: %s", e.what());
|
|
647
|
+
throw JSError(runtime, std::string("whisperVadDetectSpeech error: ") + e.what());
|
|
648
|
+
} catch (...) {
|
|
649
|
+
logError("Unknown exception in whisperVadDetectSpeech");
|
|
650
|
+
throw JSError(runtime, "whisperVadDetectSpeech encountered unknown error");
|
|
651
|
+
}
|
|
652
|
+
}
|
|
653
|
+
);
|
|
654
|
+
|
|
655
|
+
// Install the JSI functions on the global object
|
|
656
|
+
runtime.global().setProperty(runtime, "whisperTranscribeData", std::move(whisperTranscribeData));
|
|
657
|
+
runtime.global().setProperty(runtime, "whisperVadDetectSpeech", std::move(whisperVadDetectSpeech));
|
|
658
|
+
|
|
659
|
+
logInfo("JSI bindings installed successfully");
|
|
660
|
+
} catch (const JSError& e) {
|
|
661
|
+
logError("JSError installing JSI bindings: %s", e.getMessage().c_str());
|
|
662
|
+
throw;
|
|
663
|
+
} catch (const std::exception& e) {
|
|
664
|
+
logError("Exception installing JSI bindings: %s", e.what());
|
|
665
|
+
throw JSError(runtime, std::string("Failed to install JSI bindings: ") + e.what());
|
|
666
|
+
} catch (...) {
|
|
667
|
+
logError("Unknown exception installing JSI bindings");
|
|
668
|
+
throw JSError(runtime, "Failed to install JSI bindings: unknown error");
|
|
669
|
+
}
|
|
670
|
+
}
|
|
671
|
+
|
|
672
|
+
// Cleanup function to dispose of ThreadPool
|
|
673
|
+
void cleanupJSIBindings() {
|
|
674
|
+
std::lock_guard<std::mutex> lock(threadPoolMutex);
|
|
675
|
+
if (whisperThreadPool) {
|
|
676
|
+
logInfo("Cleaning up ThreadPool");
|
|
677
|
+
whisperThreadPool.reset();
|
|
678
|
+
}
|
|
679
|
+
}
|
|
680
|
+
|
|
681
|
+
} // namespace rnwhisper_jsi
|