react-native-sherpa-onnx 0.4.0 → 0.4.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +3 -0
- package/android/src/main/assets/model_licenses/alignment-models-license-status.csv +5 -0
- package/android/src/main/cpp/CMakeLists.txt +3 -0
- package/android/src/main/cpp/jni/model_detect/sherpa-onnx-alignment-wrapper.cpp +66 -0
- package/android/src/main/cpp/jni/model_detect/sherpa-onnx-alignment-wrapper.h +17 -0
- package/android/src/main/cpp/jni/model_detect/sherpa-onnx-model-detect-alignment.cpp +108 -0
- package/android/src/main/cpp/jni/model_detect/sherpa-onnx-model-detect.h +30 -0
- package/android/src/main/cpp/jni/model_detect/sherpa-onnx-validate-alignment.cpp +66 -0
- package/android/src/main/cpp/jni/model_detect/sherpa-onnx-validate-alignment.h +30 -0
- package/android/src/main/cpp/jni/module/sherpa-onnx-module-jni.cpp +21 -0
- package/android/src/main/java/com/sherpaonnx/SherpaOnnxAlignmentHelper.kt +555 -0
- package/android/src/main/java/com/sherpaonnx/SherpaOnnxModule.kt +76 -0
- package/android/src/main/java/com/sherpaonnx/SherpaOnnxTextSegmenter.kt +330 -0
- package/android/src/main/java/com/sherpaonnx/SherpaOnnxTtsHelper.kt +180 -23
- package/ios/Resources/model_licenses/alignment-models-license-status.csv +5 -0
- package/ios/SherpaOnnx+Alignment.mm +704 -0
- package/ios/SherpaOnnx+STT.mm +6 -0
- package/ios/SherpaOnnx+TTS.mm +624 -50
- package/ios/model_detect/sherpa-onnx-model-detect-alignment.mm +108 -0
- package/ios/model_detect/sherpa-onnx-model-detect.h +31 -0
- package/ios/model_detect/sherpa-onnx-validate-alignment.h +30 -0
- package/ios/model_detect/sherpa-onnx-validate-alignment.mm +66 -0
- package/ios/stt/sherpa-onnx-stt-wrapper.h +3 -1
- package/ios/stt/sherpa-onnx-stt-wrapper.mm +6 -0
- package/lib/module/NativeSherpaOnnx.js.map +1 -1
- package/lib/module/alignment/index.js +27 -0
- package/lib/module/alignment/index.js.map +1 -0
- package/lib/module/alignment/types.js +2 -0
- package/lib/module/alignment/types.js.map +1 -0
- package/lib/module/alignment/vocab.js +40 -0
- package/lib/module/alignment/vocab.js.map +1 -0
- package/lib/module/download/paths.js +9 -1
- package/lib/module/download/paths.js.map +1 -1
- package/lib/module/download/registry.js +17 -1
- package/lib/module/download/registry.js.map +1 -1
- package/lib/module/download/types.js +1 -0
- package/lib/module/download/types.js.map +1 -1
- package/lib/module/index.js +6 -4
- package/lib/module/index.js.map +1 -1
- package/lib/module/licenses.js +8 -2
- package/lib/module/licenses.js.map +1 -1
- package/lib/module/stt/types.js.map +1 -1
- package/lib/module/tts/index.js +68 -2
- package/lib/module/tts/index.js.map +1 -1
- package/lib/module/tts/subtitles.js +400 -0
- package/lib/module/tts/subtitles.js.map +1 -0
- package/lib/module/tts/tempAudio.js +17 -0
- package/lib/module/tts/tempAudio.js.map +1 -0
- package/lib/module/tts/types.js.map +1 -1
- package/lib/typescript/src/NativeSherpaOnnx.d.ts +34 -3
- package/lib/typescript/src/NativeSherpaOnnx.d.ts.map +1 -1
- package/lib/typescript/src/alignment/index.d.ts +8 -0
- package/lib/typescript/src/alignment/index.d.ts.map +1 -0
- package/lib/typescript/src/alignment/types.d.ts +23 -0
- package/lib/typescript/src/alignment/types.d.ts.map +1 -0
- package/lib/typescript/src/alignment/vocab.d.ts +5 -0
- package/lib/typescript/src/alignment/vocab.d.ts.map +1 -0
- package/lib/typescript/src/download/paths.d.ts +5 -2
- package/lib/typescript/src/download/paths.d.ts.map +1 -1
- package/lib/typescript/src/download/registry.d.ts.map +1 -1
- package/lib/typescript/src/download/types.d.ts +2 -1
- package/lib/typescript/src/download/types.d.ts.map +1 -1
- package/lib/typescript/src/index.d.ts +1 -0
- package/lib/typescript/src/index.d.ts.map +1 -1
- package/lib/typescript/src/licenses.d.ts.map +1 -1
- package/lib/typescript/src/stt/types.d.ts +5 -2
- package/lib/typescript/src/stt/types.d.ts.map +1 -1
- package/lib/typescript/src/tts/index.d.ts +2 -1
- package/lib/typescript/src/tts/index.d.ts.map +1 -1
- package/lib/typescript/src/tts/subtitles.d.ts +24 -0
- package/lib/typescript/src/tts/subtitles.d.ts.map +1 -0
- package/lib/typescript/src/tts/tempAudio.d.ts +3 -0
- package/lib/typescript/src/tts/tempAudio.d.ts.map +1 -0
- package/lib/typescript/src/tts/types.d.ts +68 -2
- package/lib/typescript/src/tts/types.d.ts.map +1 -1
- package/package.json +6 -1
- package/scripts/alignment-models/README.md +90 -0
- package/scripts/alignment-models/build_and_upload.js +724 -0
- package/scripts/alignment-models/sources.csv +5 -0
- package/scripts/alignment-models/sync_alignment_license_status.js +123 -0
- package/src/NativeSherpaOnnx.ts +35 -3
- package/src/alignment/index.ts +41 -0
- package/src/alignment/types.ts +22 -0
- package/src/alignment/vocab.ts +38 -0
- package/src/download/paths.ts +18 -5
- package/src/download/registry.ts +23 -3
- package/src/download/types.ts +1 -0
- package/src/index.tsx +6 -4
- package/src/licenses.ts +12 -1
- package/src/stt/types.ts +5 -2
- package/src/tts/index.ts +110 -3
- package/src/tts/subtitles.ts +611 -0
- package/src/tts/tempAudio.ts +31 -0
- package/src/tts/types.ts +79 -2
- package/third_party/sherpa-onnx-prebuilt/IOS_RELEASE_TAG +1 -1
|
@@ -0,0 +1,704 @@
|
|
|
1
|
+
#import "SherpaOnnx.h"
|
|
2
|
+
#import <React/RCTLog.h>
|
|
3
|
+
|
|
4
|
+
#include "sherpa-onnx/c-api/cxx-api.h"
|
|
5
|
+
#include "sherpa-onnx-model-detect.h"
|
|
6
|
+
|
|
7
|
+
#if __has_include("../third_party/onnxruntime/include/onnxruntime/core/session/onnxruntime_c_api.h")
|
|
8
|
+
#include "../third_party/onnxruntime/include/onnxruntime/core/session/onnxruntime_c_api.h"
|
|
9
|
+
#define SHERPA_ONNX_HAS_ORT_C_API 1
|
|
10
|
+
#elif __has_include("onnxruntime_c_api.h")
|
|
11
|
+
#include "onnxruntime_c_api.h"
|
|
12
|
+
#define SHERPA_ONNX_HAS_ORT_C_API 1
|
|
13
|
+
#else
|
|
14
|
+
#define SHERPA_ONNX_HAS_ORT_C_API 0
|
|
15
|
+
#endif
|
|
16
|
+
|
|
17
|
+
#include <algorithm>
|
|
18
|
+
#include <cmath>
|
|
19
|
+
#include <cstdint>
|
|
20
|
+
#include <sstream>
|
|
21
|
+
#include <stdexcept>
|
|
22
|
+
#include <string>
|
|
23
|
+
#include <unordered_map>
|
|
24
|
+
#include <utility>
|
|
25
|
+
#include <vector>
|
|
26
|
+
|
|
27
|
+
namespace {
|
|
28
|
+
|
|
29
|
+
struct AlignmentItem {
|
|
30
|
+
std::string text;
|
|
31
|
+
double start = 0.0;
|
|
32
|
+
double end = 0.0;
|
|
33
|
+
};
|
|
34
|
+
|
|
35
|
+
struct ExpandedTarget {
|
|
36
|
+
std::vector<int32_t> ids;
|
|
37
|
+
std::vector<int32_t> tokenIndices;
|
|
38
|
+
};
|
|
39
|
+
|
|
40
|
+
static NSString *alignmentKindToNSString(sherpaonnx::AlignmentModelKind kind) {
|
|
41
|
+
using K = sherpaonnx::AlignmentModelKind;
|
|
42
|
+
switch (kind) {
|
|
43
|
+
case K::kWav2Vec2:
|
|
44
|
+
return @"wav2vec2";
|
|
45
|
+
default:
|
|
46
|
+
return @"unknown";
|
|
47
|
+
}
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
static NSDictionary *alignmentDetectResultToDict(
|
|
51
|
+
const sherpaonnx::AlignmentDetectResult &result) {
|
|
52
|
+
NSMutableArray *detectedModelsArray = [NSMutableArray array];
|
|
53
|
+
for (const auto &model : result.detectedModels) {
|
|
54
|
+
[detectedModelsArray addObject:@{
|
|
55
|
+
@"type": [NSString stringWithUTF8String:model.type.c_str()] ?: @"",
|
|
56
|
+
@"modelDir": [NSString stringWithUTF8String:model.modelDir.c_str()] ?: @""
|
|
57
|
+
}];
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
NSMutableDictionary *dict = [@{
|
|
61
|
+
@"success": @(result.ok),
|
|
62
|
+
@"detectedModels": detectedModelsArray,
|
|
63
|
+
@"modelType": alignmentKindToNSString(result.selectedKind),
|
|
64
|
+
} mutableCopy];
|
|
65
|
+
if (!result.paths.model.empty()) {
|
|
66
|
+
dict[@"paths"] = @{
|
|
67
|
+
@"model": [NSString stringWithUTF8String:result.paths.model.c_str()] ?: @""
|
|
68
|
+
};
|
|
69
|
+
}
|
|
70
|
+
if (!result.ok && !result.error.empty()) {
|
|
71
|
+
dict[@"error"] = [NSString stringWithUTF8String:result.error.c_str()] ?: @"Alignment model detection failed";
|
|
72
|
+
}
|
|
73
|
+
return dict;
|
|
74
|
+
}
|
|
75
|
+
|
|
76
|
+
static std::unordered_map<std::string, int32_t> ParseVocabJson(NSString *vocabJson) {
|
|
77
|
+
if (vocabJson == nil || vocabJson.length == 0) {
|
|
78
|
+
throw std::runtime_error("Vocabulary JSON is empty");
|
|
79
|
+
}
|
|
80
|
+
|
|
81
|
+
NSError *error = nil;
|
|
82
|
+
NSData *data = [vocabJson dataUsingEncoding:NSUTF8StringEncoding];
|
|
83
|
+
id parsed = [NSJSONSerialization JSONObjectWithData:data options:0 error:&error];
|
|
84
|
+
if (error != nil || ![parsed isKindOfClass:[NSDictionary class]]) {
|
|
85
|
+
throw std::runtime_error("Failed to parse vocabulary JSON");
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
NSDictionary *dict = (NSDictionary *)parsed;
|
|
89
|
+
std::unordered_map<std::string, int32_t> vocab;
|
|
90
|
+
for (id key in dict) {
|
|
91
|
+
if (![key isKindOfClass:[NSString class]]) {
|
|
92
|
+
continue;
|
|
93
|
+
}
|
|
94
|
+
id value = dict[key];
|
|
95
|
+
if (![value isKindOfClass:[NSNumber class]]) {
|
|
96
|
+
continue;
|
|
97
|
+
}
|
|
98
|
+
|
|
99
|
+
NSString *token = (NSString *)key;
|
|
100
|
+
vocab[std::string([token UTF8String])] = (int32_t)[(NSNumber *)value intValue];
|
|
101
|
+
}
|
|
102
|
+
|
|
103
|
+
if (vocab.empty()) {
|
|
104
|
+
throw std::runtime_error("Vocabulary JSON has no valid entries");
|
|
105
|
+
}
|
|
106
|
+
|
|
107
|
+
return vocab;
|
|
108
|
+
}
|
|
109
|
+
|
|
110
|
+
static std::vector<std::string> BuildTokenTexts(
|
|
111
|
+
const std::string &text,
|
|
112
|
+
const std::unordered_map<std::string, int32_t> &vocab,
|
|
113
|
+
int32_t wordBoundaryId) {
|
|
114
|
+
NSString *source = [NSString stringWithUTF8String:text.c_str()];
|
|
115
|
+
if (source == nil || source.length == 0) {
|
|
116
|
+
return {};
|
|
117
|
+
}
|
|
118
|
+
|
|
119
|
+
NSString *uppercase = [source uppercaseStringWithLocale:[NSLocale localeWithLocaleIdentifier:@"en_US_POSIX"]];
|
|
120
|
+
std::vector<std::string> tokens;
|
|
121
|
+
|
|
122
|
+
NSCharacterSet *ws = [NSCharacterSet whitespaceAndNewlineCharacterSet];
|
|
123
|
+
for (NSUInteger i = 0; i < uppercase.length; ++i) {
|
|
124
|
+
unichar c = [uppercase characterAtIndex:i];
|
|
125
|
+
|
|
126
|
+
if ([ws characterIsMember:c]) {
|
|
127
|
+
if (!tokens.empty() && tokens.back() != "|") {
|
|
128
|
+
tokens.push_back("|");
|
|
129
|
+
}
|
|
130
|
+
continue;
|
|
131
|
+
}
|
|
132
|
+
|
|
133
|
+
unichar normalized = c;
|
|
134
|
+
if (c == 0x2019 || c == 0x0060 || c == 0x00B4) {
|
|
135
|
+
normalized = '\'';
|
|
136
|
+
}
|
|
137
|
+
|
|
138
|
+
NSString *token = [NSString stringWithCharacters:&normalized length:1];
|
|
139
|
+
std::string tokenUtf8([token UTF8String]);
|
|
140
|
+
if (vocab.find(tokenUtf8) != vocab.end()) {
|
|
141
|
+
tokens.push_back(tokenUtf8);
|
|
142
|
+
}
|
|
143
|
+
}
|
|
144
|
+
|
|
145
|
+
while (!tokens.empty() && tokens.front() == "|") {
|
|
146
|
+
tokens.erase(tokens.begin());
|
|
147
|
+
}
|
|
148
|
+
while (!tokens.empty() && tokens.back() == "|") {
|
|
149
|
+
tokens.pop_back();
|
|
150
|
+
}
|
|
151
|
+
|
|
152
|
+
auto boundaryIt = vocab.find("|");
|
|
153
|
+
if (boundaryIt == vocab.end() || boundaryIt->second != wordBoundaryId) {
|
|
154
|
+
tokens.erase(std::remove(tokens.begin(), tokens.end(), "|"), tokens.end());
|
|
155
|
+
}
|
|
156
|
+
|
|
157
|
+
return tokens;
|
|
158
|
+
}
|
|
159
|
+
|
|
160
|
+
static std::vector<float> ResampleLinear(
|
|
161
|
+
const std::vector<float> &input,
|
|
162
|
+
int32_t sourceSampleRate,
|
|
163
|
+
int32_t targetSampleRate) {
|
|
164
|
+
if (input.empty() || sourceSampleRate <= 0 || targetSampleRate <= 0) {
|
|
165
|
+
return {};
|
|
166
|
+
}
|
|
167
|
+
if (sourceSampleRate == targetSampleRate) {
|
|
168
|
+
return input;
|
|
169
|
+
}
|
|
170
|
+
|
|
171
|
+
size_t outputLength = std::max<size_t>(
|
|
172
|
+
1,
|
|
173
|
+
static_cast<size_t>(std::floor(static_cast<double>(input.size()) * targetSampleRate / sourceSampleRate)));
|
|
174
|
+
std::vector<float> output(outputLength, 0.0f);
|
|
175
|
+
const double ratio = static_cast<double>(sourceSampleRate) / targetSampleRate;
|
|
176
|
+
|
|
177
|
+
for (size_t i = 0; i < outputLength; ++i) {
|
|
178
|
+
double srcPos = static_cast<double>(i) * ratio;
|
|
179
|
+
size_t left = static_cast<size_t>(std::floor(srcPos));
|
|
180
|
+
size_t right = std::min(left + 1, input.size() - 1);
|
|
181
|
+
double frac = srcPos - left;
|
|
182
|
+
|
|
183
|
+
float leftVal = input[std::min(left, input.size() - 1)];
|
|
184
|
+
float rightVal = input[right];
|
|
185
|
+
output[i] = static_cast<float>(leftVal + (rightVal - leftVal) * frac);
|
|
186
|
+
}
|
|
187
|
+
|
|
188
|
+
return output;
|
|
189
|
+
}
|
|
190
|
+
|
|
191
|
+
static std::vector<float> NormalizeAudio(const std::vector<float> &input) {
|
|
192
|
+
if (input.empty()) {
|
|
193
|
+
return input;
|
|
194
|
+
}
|
|
195
|
+
|
|
196
|
+
double sum = 0.0;
|
|
197
|
+
for (float v : input) {
|
|
198
|
+
sum += v;
|
|
199
|
+
}
|
|
200
|
+
const double mean = sum / input.size();
|
|
201
|
+
|
|
202
|
+
double variance = 0.0;
|
|
203
|
+
for (float v : input) {
|
|
204
|
+
const double centered = v - mean;
|
|
205
|
+
variance += centered * centered;
|
|
206
|
+
}
|
|
207
|
+
variance /= input.size();
|
|
208
|
+
|
|
209
|
+
const double std = std::sqrt(std::max(variance, 1e-12));
|
|
210
|
+
std::vector<float> out(input.size(), 0.0f);
|
|
211
|
+
for (size_t i = 0; i < input.size(); ++i) {
|
|
212
|
+
out[i] = static_cast<float>((input[i] - mean) / std);
|
|
213
|
+
}
|
|
214
|
+
|
|
215
|
+
return out;
|
|
216
|
+
}
|
|
217
|
+
|
|
218
|
+
static std::vector<std::vector<float>> LogSoftmax(
|
|
219
|
+
const std::vector<float> &logitsFlat,
|
|
220
|
+
int32_t frames,
|
|
221
|
+
int32_t vocabSize) {
|
|
222
|
+
if (frames <= 0 || vocabSize <= 0) {
|
|
223
|
+
return {};
|
|
224
|
+
}
|
|
225
|
+
|
|
226
|
+
std::vector<std::vector<float>> out(frames, std::vector<float>(vocabSize, 0.0f));
|
|
227
|
+
|
|
228
|
+
for (int32_t t = 0; t < frames; ++t) {
|
|
229
|
+
int32_t rowOffset = t * vocabSize;
|
|
230
|
+
float rowMax = -INFINITY;
|
|
231
|
+
for (int32_t v = 0; v < vocabSize; ++v) {
|
|
232
|
+
rowMax = std::max(rowMax, logitsFlat[rowOffset + v]);
|
|
233
|
+
}
|
|
234
|
+
|
|
235
|
+
double sumExp = 0.0;
|
|
236
|
+
for (int32_t v = 0; v < vocabSize; ++v) {
|
|
237
|
+
sumExp += std::exp(static_cast<double>(logitsFlat[rowOffset + v] - rowMax));
|
|
238
|
+
}
|
|
239
|
+
double logDenom = rowMax + std::log(std::max(sumExp, 1e-12));
|
|
240
|
+
|
|
241
|
+
for (int32_t v = 0; v < vocabSize; ++v) {
|
|
242
|
+
out[t][v] = static_cast<float>(logitsFlat[rowOffset + v] - logDenom);
|
|
243
|
+
}
|
|
244
|
+
}
|
|
245
|
+
|
|
246
|
+
return out;
|
|
247
|
+
}
|
|
248
|
+
|
|
249
|
+
#if SHERPA_ONNX_HAS_ORT_C_API
|
|
250
|
+
|
|
251
|
+
static void CheckOrtStatus(const OrtApi *api, OrtStatus *status, const char *prefix) {
|
|
252
|
+
if (status == nullptr) {
|
|
253
|
+
return;
|
|
254
|
+
}
|
|
255
|
+
|
|
256
|
+
std::string message(prefix);
|
|
257
|
+
message += ": ";
|
|
258
|
+
message += api->GetErrorMessage(status);
|
|
259
|
+
api->ReleaseStatus(status);
|
|
260
|
+
throw std::runtime_error(message);
|
|
261
|
+
}
|
|
262
|
+
|
|
263
|
+
/** ORT GetApi(ORT_API_VERSION) can be null if the embedded runtime is older than compile-time headers. */
|
|
264
|
+
static const OrtApi *ResolveOrtApiForAlignment() {
|
|
265
|
+
const OrtApiBase *base = OrtGetApiBase();
|
|
266
|
+
if (base == nullptr) {
|
|
267
|
+
throw std::runtime_error(
|
|
268
|
+
"ONNX Runtime is not available (OrtGetApiBase returned null). "
|
|
269
|
+
"Subtitle-accurate alignment requires SherpaOnnx built with ONNX Runtime.");
|
|
270
|
+
}
|
|
271
|
+
|
|
272
|
+
const char *rtVersion = "unknown";
|
|
273
|
+
if (base->GetVersionString != nullptr) {
|
|
274
|
+
rtVersion = base->GetVersionString();
|
|
275
|
+
}
|
|
276
|
+
|
|
277
|
+
// Try compile-time version first, then older API versions. ORT appends new entries to OrtApi;
|
|
278
|
+
// early function pointers stay at stable offsets for the calls we use.
|
|
279
|
+
constexpr uint32_t kMinOrtApiVersion = 17;
|
|
280
|
+
for (uint32_t ver = ORT_API_VERSION; ver >= kMinOrtApiVersion; --ver) {
|
|
281
|
+
const OrtApi *api = base->GetApi(ver);
|
|
282
|
+
if (api != nullptr && api->CreateEnv != nullptr && api->CreateSession != nullptr) {
|
|
283
|
+
return api;
|
|
284
|
+
}
|
|
285
|
+
}
|
|
286
|
+
|
|
287
|
+
std::ostringstream oss;
|
|
288
|
+
oss << "ONNX Runtime API mismatch: GetApi() returned null for API " << ORT_API_VERSION
|
|
289
|
+
<< " down to " << kMinOrtApiVersion << ". Runtime version string: " << rtVersion
|
|
290
|
+
<< ". Rebuild ios/Frameworks/sherpa_onnx (or align onnxruntime headers with embedded ORT).";
|
|
291
|
+
throw std::runtime_error(oss.str());
|
|
292
|
+
}
|
|
293
|
+
|
|
294
|
+
static std::vector<std::vector<float>> RunOrtInference(
|
|
295
|
+
const std::string &modelPath,
|
|
296
|
+
const std::vector<float> &samples) {
|
|
297
|
+
const OrtApi *api = ResolveOrtApiForAlignment();
|
|
298
|
+
|
|
299
|
+
OrtEnv *env = nullptr;
|
|
300
|
+
OrtSessionOptions *sessionOptions = nullptr;
|
|
301
|
+
OrtSession *session = nullptr;
|
|
302
|
+
OrtAllocator *allocator = nullptr;
|
|
303
|
+
OrtMemoryInfo *memoryInfo = nullptr;
|
|
304
|
+
OrtValue *inputTensor = nullptr;
|
|
305
|
+
OrtValue *outputTensor = nullptr;
|
|
306
|
+
OrtTensorTypeAndShapeInfo *shapeInfo = nullptr;
|
|
307
|
+
char *inputName = nullptr;
|
|
308
|
+
char *outputName = nullptr;
|
|
309
|
+
|
|
310
|
+
auto cleanup = [&]() {
|
|
311
|
+
if (shapeInfo != nullptr) api->ReleaseTensorTypeAndShapeInfo(shapeInfo);
|
|
312
|
+
if (outputTensor != nullptr) api->ReleaseValue(outputTensor);
|
|
313
|
+
if (inputTensor != nullptr) api->ReleaseValue(inputTensor);
|
|
314
|
+
if (memoryInfo != nullptr) api->ReleaseMemoryInfo(memoryInfo);
|
|
315
|
+
if (inputName != nullptr && allocator != nullptr) api->AllocatorFree(allocator, inputName);
|
|
316
|
+
if (outputName != nullptr && allocator != nullptr) api->AllocatorFree(allocator, outputName);
|
|
317
|
+
if (session != nullptr) api->ReleaseSession(session);
|
|
318
|
+
if (sessionOptions != nullptr) api->ReleaseSessionOptions(sessionOptions);
|
|
319
|
+
if (env != nullptr) api->ReleaseEnv(env);
|
|
320
|
+
};
|
|
321
|
+
|
|
322
|
+
try {
|
|
323
|
+
CheckOrtStatus(api, api->CreateEnv(ORT_LOGGING_LEVEL_WARNING, "sherpa-onnx-rn", &env), "CreateEnv failed");
|
|
324
|
+
CheckOrtStatus(api, api->CreateSessionOptions(&sessionOptions), "CreateSessionOptions failed");
|
|
325
|
+
CheckOrtStatus(api, api->CreateSession(env, modelPath.c_str(), sessionOptions, &session), "CreateSession failed");
|
|
326
|
+
|
|
327
|
+
CheckOrtStatus(api, api->GetAllocatorWithDefaultOptions(&allocator), "GetAllocatorWithDefaultOptions failed");
|
|
328
|
+
|
|
329
|
+
size_t inputCount = 0;
|
|
330
|
+
size_t outputCount = 0;
|
|
331
|
+
CheckOrtStatus(api, api->SessionGetInputCount(session, &inputCount), "SessionGetInputCount failed");
|
|
332
|
+
CheckOrtStatus(api, api->SessionGetOutputCount(session, &outputCount), "SessionGetOutputCount failed");
|
|
333
|
+
if (inputCount == 0 || outputCount == 0) {
|
|
334
|
+
throw std::runtime_error("Alignment model has no inputs/outputs");
|
|
335
|
+
}
|
|
336
|
+
|
|
337
|
+
CheckOrtStatus(api, api->SessionGetInputName(session, 0, allocator, &inputName), "SessionGetInputName failed");
|
|
338
|
+
CheckOrtStatus(api, api->SessionGetOutputName(session, 0, allocator, &outputName), "SessionGetOutputName failed");
|
|
339
|
+
|
|
340
|
+
CheckOrtStatus(api, api->CreateCpuMemoryInfo(OrtArenaAllocator, OrtMemTypeDefault, &memoryInfo), "CreateCpuMemoryInfo failed");
|
|
341
|
+
|
|
342
|
+
int64_t inputShape[2] = {1, static_cast<int64_t>(samples.size())};
|
|
343
|
+
CheckOrtStatus(
|
|
344
|
+
api,
|
|
345
|
+
api->CreateTensorWithDataAsOrtValue(
|
|
346
|
+
memoryInfo,
|
|
347
|
+
const_cast<float *>(samples.data()),
|
|
348
|
+
samples.size() * sizeof(float),
|
|
349
|
+
inputShape,
|
|
350
|
+
2,
|
|
351
|
+
ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT,
|
|
352
|
+
&inputTensor),
|
|
353
|
+
"CreateTensorWithDataAsOrtValue failed");
|
|
354
|
+
|
|
355
|
+
const char *inputNames[] = {inputName};
|
|
356
|
+
const char *outputNames[] = {outputName};
|
|
357
|
+
const OrtValue *inputValues[] = {inputTensor};
|
|
358
|
+
|
|
359
|
+
CheckOrtStatus(
|
|
360
|
+
api,
|
|
361
|
+
api->Run(session, nullptr, inputNames, inputValues, 1, outputNames, 1, &outputTensor),
|
|
362
|
+
"Run failed");
|
|
363
|
+
|
|
364
|
+
CheckOrtStatus(api, api->GetTensorTypeAndShape(outputTensor, &shapeInfo), "GetTensorTypeAndShape failed");
|
|
365
|
+
|
|
366
|
+
size_t dimCount = 0;
|
|
367
|
+
CheckOrtStatus(api, api->GetDimensionsCount(shapeInfo, &dimCount), "GetDimensionsCount failed");
|
|
368
|
+
|
|
369
|
+
std::vector<int64_t> dims(dimCount, 0);
|
|
370
|
+
if (dimCount > 0) {
|
|
371
|
+
CheckOrtStatus(api, api->GetDimensions(shapeInfo, dims.data(), dimCount), "GetDimensions failed");
|
|
372
|
+
}
|
|
373
|
+
|
|
374
|
+
size_t elementCount = 0;
|
|
375
|
+
CheckOrtStatus(api, api->GetTensorShapeElementCount(shapeInfo, &elementCount), "GetTensorShapeElementCount failed");
|
|
376
|
+
if (elementCount == 0) {
|
|
377
|
+
throw std::runtime_error("Model output tensor is empty");
|
|
378
|
+
}
|
|
379
|
+
|
|
380
|
+
float *logitsData = nullptr;
|
|
381
|
+
CheckOrtStatus(api, api->GetTensorMutableData(outputTensor, reinterpret_cast<void **>(&logitsData)), "GetTensorMutableData failed");
|
|
382
|
+
|
|
383
|
+
std::vector<float> logitsFlat(logitsData, logitsData + elementCount);
|
|
384
|
+
|
|
385
|
+
int32_t frames = 1;
|
|
386
|
+
int32_t vocabSize = static_cast<int32_t>(elementCount);
|
|
387
|
+
if (dims.size() >= 3) {
|
|
388
|
+
frames = std::max<int32_t>(1, static_cast<int32_t>(dims[1]));
|
|
389
|
+
vocabSize = std::max<int32_t>(1, static_cast<int32_t>(dims[2]));
|
|
390
|
+
} else if (dims.size() == 2) {
|
|
391
|
+
frames = std::max<int32_t>(1, static_cast<int32_t>(dims[0]));
|
|
392
|
+
vocabSize = std::max<int32_t>(1, static_cast<int32_t>(dims[1]));
|
|
393
|
+
}
|
|
394
|
+
|
|
395
|
+
frames = std::max<int32_t>(1, std::min<int32_t>(frames, static_cast<int32_t>(elementCount)));
|
|
396
|
+
vocabSize = std::max<int32_t>(1, std::min<int32_t>(vocabSize, static_cast<int32_t>(elementCount / frames)));
|
|
397
|
+
|
|
398
|
+
std::vector<std::vector<float>> logProbs = LogSoftmax(logitsFlat, frames, vocabSize);
|
|
399
|
+
cleanup();
|
|
400
|
+
return logProbs;
|
|
401
|
+
} catch (...) {
|
|
402
|
+
cleanup();
|
|
403
|
+
throw;
|
|
404
|
+
}
|
|
405
|
+
}
|
|
406
|
+
|
|
407
|
+
#endif
|
|
408
|
+
|
|
409
|
+
static ExpandedTarget BuildExpandedTarget(const std::vector<int32_t> &tokenIds, int32_t blankId) {
|
|
410
|
+
ExpandedTarget target;
|
|
411
|
+
target.ids.reserve(tokenIds.size() * 2 + 1);
|
|
412
|
+
target.tokenIndices.reserve(tokenIds.size() * 2 + 1);
|
|
413
|
+
|
|
414
|
+
target.ids.push_back(blankId);
|
|
415
|
+
target.tokenIndices.push_back(-1);
|
|
416
|
+
|
|
417
|
+
for (size_t i = 0; i < tokenIds.size(); ++i) {
|
|
418
|
+
target.ids.push_back(tokenIds[i]);
|
|
419
|
+
target.tokenIndices.push_back(static_cast<int32_t>(i));
|
|
420
|
+
|
|
421
|
+
target.ids.push_back(blankId);
|
|
422
|
+
target.tokenIndices.push_back(-1);
|
|
423
|
+
}
|
|
424
|
+
|
|
425
|
+
return target;
|
|
426
|
+
}
|
|
427
|
+
|
|
428
|
+
static float SafeLogProb(const std::vector<float> &row, int32_t tokenId) {
|
|
429
|
+
if (tokenId < 0 || tokenId >= static_cast<int32_t>(row.size())) {
|
|
430
|
+
return -1.0e30f;
|
|
431
|
+
}
|
|
432
|
+
return row[tokenId];
|
|
433
|
+
}
|
|
434
|
+
|
|
435
|
+
static std::vector<int32_t> CtcBacktrack(
|
|
436
|
+
const std::vector<std::vector<float>> &logProbs,
|
|
437
|
+
const std::vector<int32_t> &expandedTarget,
|
|
438
|
+
int32_t blankId) {
|
|
439
|
+
const int32_t T = static_cast<int32_t>(logProbs.size());
|
|
440
|
+
const int32_t S = static_cast<int32_t>(expandedTarget.size());
|
|
441
|
+
if (T <= 0 || S <= 0) {
|
|
442
|
+
return {};
|
|
443
|
+
}
|
|
444
|
+
|
|
445
|
+
const float kNegInf = -1.0e30f;
|
|
446
|
+
std::vector<std::vector<float>> trellis(T, std::vector<float>(S, kNegInf));
|
|
447
|
+
|
|
448
|
+
trellis[0][0] = SafeLogProb(logProbs[0], expandedTarget[0]);
|
|
449
|
+
if (S > 1) {
|
|
450
|
+
trellis[0][1] = SafeLogProb(logProbs[0], expandedTarget[1]);
|
|
451
|
+
}
|
|
452
|
+
|
|
453
|
+
for (int32_t t = 1; t < T; ++t) {
|
|
454
|
+
for (int32_t s = 0; s < S; ++s) {
|
|
455
|
+
float best = trellis[t - 1][s];
|
|
456
|
+
if (s > 0) {
|
|
457
|
+
best = std::max(best, trellis[t - 1][s - 1]);
|
|
458
|
+
}
|
|
459
|
+
if (s > 1 && expandedTarget[s] != blankId && expandedTarget[s] != expandedTarget[s - 2]) {
|
|
460
|
+
best = std::max(best, trellis[t - 1][s - 2]);
|
|
461
|
+
}
|
|
462
|
+
|
|
463
|
+
if (best <= kNegInf / 2) {
|
|
464
|
+
trellis[t][s] = kNegInf;
|
|
465
|
+
} else {
|
|
466
|
+
trellis[t][s] = best + SafeLogProb(logProbs[t], expandedTarget[s]);
|
|
467
|
+
}
|
|
468
|
+
}
|
|
469
|
+
}
|
|
470
|
+
|
|
471
|
+
int32_t state =
|
|
472
|
+
(S > 1 && trellis[T - 1][S - 2] > trellis[T - 1][S - 1]) ? (S - 2) : (S - 1);
|
|
473
|
+
|
|
474
|
+
std::vector<int32_t> path(T, 0);
|
|
475
|
+
path[T - 1] = state;
|
|
476
|
+
|
|
477
|
+
for (int32_t t = T - 1; t > 0; --t) {
|
|
478
|
+
int32_t bestState = state;
|
|
479
|
+
float bestScore = trellis[t - 1][state];
|
|
480
|
+
|
|
481
|
+
if (state > 0 && trellis[t - 1][state - 1] > bestScore) {
|
|
482
|
+
bestScore = trellis[t - 1][state - 1];
|
|
483
|
+
bestState = state - 1;
|
|
484
|
+
}
|
|
485
|
+
|
|
486
|
+
if (state > 1 && expandedTarget[state] != blankId && expandedTarget[state] != expandedTarget[state - 2]) {
|
|
487
|
+
if (trellis[t - 1][state - 2] > bestScore) {
|
|
488
|
+
bestState = state - 2;
|
|
489
|
+
}
|
|
490
|
+
}
|
|
491
|
+
|
|
492
|
+
state = bestState;
|
|
493
|
+
path[t - 1] = state;
|
|
494
|
+
}
|
|
495
|
+
|
|
496
|
+
return path;
|
|
497
|
+
}
|
|
498
|
+
|
|
499
|
+
static NSArray *AlignmentItemsToNSArray(const std::vector<AlignmentItem> &items) {
|
|
500
|
+
NSMutableArray *array = [NSMutableArray arrayWithCapacity:items.size()];
|
|
501
|
+
for (const auto &item : items) {
|
|
502
|
+
[array addObject:@{
|
|
503
|
+
@"text": [NSString stringWithUTF8String:item.text.c_str()] ?: @"",
|
|
504
|
+
@"start": @(item.start),
|
|
505
|
+
@"end": @(item.end),
|
|
506
|
+
}];
|
|
507
|
+
}
|
|
508
|
+
return array;
|
|
509
|
+
}
|
|
510
|
+
|
|
511
|
+
} // namespace
|
|
512
|
+
|
|
513
|
+
@implementation SherpaOnnx (Alignment)
|
|
514
|
+
|
|
515
|
+
- (void)detectAlignmentModel:(NSString *)modelDir
|
|
516
|
+
modelType:(NSString *)modelType
|
|
517
|
+
resolve:(RCTPromiseResolveBlock)resolve
|
|
518
|
+
reject:(RCTPromiseRejectBlock)reject
|
|
519
|
+
{
|
|
520
|
+
@try {
|
|
521
|
+
std::string modelDirStr = (modelDir != nil) ? [modelDir UTF8String] : "";
|
|
522
|
+
std::string modelTypeStr =
|
|
523
|
+
(modelType != nil && [modelType length] > 0) ? [modelType UTF8String]
|
|
524
|
+
: "auto";
|
|
525
|
+
auto result = sherpaonnx::DetectAlignmentModel(modelDirStr, modelTypeStr);
|
|
526
|
+
resolve(alignmentDetectResultToDict(result));
|
|
527
|
+
} @catch (NSException *exception) {
|
|
528
|
+
reject(@"DETECT_ERROR",
|
|
529
|
+
[NSString stringWithFormat:@"Alignment detect failed: %@",
|
|
530
|
+
exception.reason],
|
|
531
|
+
nil);
|
|
532
|
+
}
|
|
533
|
+
}
|
|
534
|
+
|
|
535
|
+
- (void)runCTCForcedAlignment:(NSString *)modelPath
|
|
536
|
+
audioPath:(NSString *)audioPath
|
|
537
|
+
text:(NSString *)text
|
|
538
|
+
vocabJson:(NSString *)vocabJson
|
|
539
|
+
resolve:(RCTPromiseResolveBlock)resolve
|
|
540
|
+
reject:(RCTPromiseRejectBlock)reject
|
|
541
|
+
{
|
|
542
|
+
if (modelPath == nil || [modelPath length] == 0) {
|
|
543
|
+
reject(@"ALIGNMENT_ERROR", @"modelPath is required", nil);
|
|
544
|
+
return;
|
|
545
|
+
}
|
|
546
|
+
if (audioPath == nil || [audioPath length] == 0) {
|
|
547
|
+
reject(@"ALIGNMENT_ERROR", @"audioPath is required", nil);
|
|
548
|
+
return;
|
|
549
|
+
}
|
|
550
|
+
if (text == nil || [text length] == 0) {
|
|
551
|
+
reject(@"ALIGNMENT_ERROR", @"text is required", nil);
|
|
552
|
+
return;
|
|
553
|
+
}
|
|
554
|
+
|
|
555
|
+
dispatch_async(dispatch_get_global_queue(DISPATCH_QUEUE_PRIORITY_DEFAULT, 0), ^{
|
|
556
|
+
try {
|
|
557
|
+
std::string modelPathStr([modelPath UTF8String]);
|
|
558
|
+
std::string audioPathStr([audioPath UTF8String]);
|
|
559
|
+
std::string textStr([text UTF8String]);
|
|
560
|
+
|
|
561
|
+
auto vocab = ParseVocabJson(vocabJson);
|
|
562
|
+
int32_t blankId = 0;
|
|
563
|
+
auto blankIt = vocab.find("<pad>");
|
|
564
|
+
if (blankIt != vocab.end()) {
|
|
565
|
+
blankId = blankIt->second;
|
|
566
|
+
}
|
|
567
|
+
int32_t wordBoundaryId = 4;
|
|
568
|
+
auto boundaryIt = vocab.find("|");
|
|
569
|
+
if (boundaryIt != vocab.end()) {
|
|
570
|
+
wordBoundaryId = boundaryIt->second;
|
|
571
|
+
}
|
|
572
|
+
|
|
573
|
+
std::vector<std::string> tokenTexts = BuildTokenTexts(textStr, vocab, wordBoundaryId);
|
|
574
|
+
if (tokenTexts.empty()) {
|
|
575
|
+
reject(@"ALIGNMENT_ERROR", @"Transcript has no alignable tokens for provided vocabulary", nil);
|
|
576
|
+
return;
|
|
577
|
+
}
|
|
578
|
+
|
|
579
|
+
std::vector<int32_t> tokenIds;
|
|
580
|
+
tokenIds.reserve(tokenTexts.size());
|
|
581
|
+
for (const auto &token : tokenTexts) {
|
|
582
|
+
auto it = vocab.find(token);
|
|
583
|
+
if (it != vocab.end()) {
|
|
584
|
+
tokenIds.push_back(it->second);
|
|
585
|
+
} else {
|
|
586
|
+
tokenIds.push_back(blankId);
|
|
587
|
+
}
|
|
588
|
+
}
|
|
589
|
+
|
|
590
|
+
sherpa_onnx::cxx::Wave wave = sherpa_onnx::cxx::ReadWave(audioPathStr);
|
|
591
|
+
if (wave.samples.empty() || wave.sample_rate <= 0) {
|
|
592
|
+
reject(@"ALIGNMENT_ERROR", @"Failed to read WAV audio for alignment", nil);
|
|
593
|
+
return;
|
|
594
|
+
}
|
|
595
|
+
|
|
596
|
+
std::vector<float> mono16k =
|
|
597
|
+
wave.sample_rate == 16000 ? wave.samples : ResampleLinear(wave.samples, wave.sample_rate, 16000);
|
|
598
|
+
std::vector<float> normalized = NormalizeAudio(mono16k);
|
|
599
|
+
if (normalized.empty()) {
|
|
600
|
+
reject(@"ALIGNMENT_ERROR", @"Audio is empty after preprocessing", nil);
|
|
601
|
+
return;
|
|
602
|
+
}
|
|
603
|
+
|
|
604
|
+
std::vector<std::vector<float>> logProbs;
|
|
605
|
+
#if SHERPA_ONNX_HAS_ORT_C_API
|
|
606
|
+
logProbs = RunOrtInference(modelPathStr, normalized);
|
|
607
|
+
#else
|
|
608
|
+
reject(@"ALIGNMENT_ERROR",
|
|
609
|
+
@"Accurate alignment requires ONNX Runtime, which is not available in this build.",
|
|
610
|
+
nil);
|
|
611
|
+
return;
|
|
612
|
+
#endif
|
|
613
|
+
|
|
614
|
+
if (logProbs.empty()) {
|
|
615
|
+
reject(@"ALIGNMENT_ERROR", @"Alignment model produced empty probabilities", nil);
|
|
616
|
+
return;
|
|
617
|
+
}
|
|
618
|
+
|
|
619
|
+
ExpandedTarget expanded = BuildExpandedTarget(tokenIds, blankId);
|
|
620
|
+
std::vector<int32_t> path = CtcBacktrack(logProbs, expanded.ids, blankId);
|
|
621
|
+
|
|
622
|
+
std::vector<std::vector<int32_t>> frameIndicesByToken(tokenIds.size());
|
|
623
|
+
for (int32_t t = 0; t < static_cast<int32_t>(path.size()); ++t) {
|
|
624
|
+
int32_t state = path[t];
|
|
625
|
+
if (state < 0 || state >= static_cast<int32_t>(expanded.tokenIndices.size())) {
|
|
626
|
+
continue;
|
|
627
|
+
}
|
|
628
|
+
int32_t tokenIndex = expanded.tokenIndices[state];
|
|
629
|
+
int32_t tokenId = expanded.ids[state];
|
|
630
|
+
if (tokenIndex >= 0 && tokenIndex < static_cast<int32_t>(frameIndicesByToken.size()) && tokenId != blankId) {
|
|
631
|
+
frameIndicesByToken[tokenIndex].push_back(t);
|
|
632
|
+
}
|
|
633
|
+
}
|
|
634
|
+
|
|
635
|
+
std::vector<AlignmentItem> charItems;
|
|
636
|
+
charItems.reserve(tokenTexts.size());
|
|
637
|
+
|
|
638
|
+
int32_t fallbackEndFrame = 0;
|
|
639
|
+
for (size_t i = 0; i < tokenTexts.size(); ++i) {
|
|
640
|
+
if (tokenTexts[i] == "|") {
|
|
641
|
+
continue;
|
|
642
|
+
}
|
|
643
|
+
|
|
644
|
+
const auto &frames = frameIndicesByToken[i];
|
|
645
|
+
int32_t startFrame = fallbackEndFrame;
|
|
646
|
+
int32_t endFrameExclusive = fallbackEndFrame;
|
|
647
|
+
if (!frames.empty()) {
|
|
648
|
+
startFrame = frames.front();
|
|
649
|
+
endFrameExclusive = frames.back() + 1;
|
|
650
|
+
fallbackEndFrame = std::max(fallbackEndFrame, endFrameExclusive);
|
|
651
|
+
}
|
|
652
|
+
|
|
653
|
+
double start = startFrame * 0.02;
|
|
654
|
+
double end = std::max(start, endFrameExclusive * 0.02);
|
|
655
|
+
charItems.push_back(AlignmentItem{tokenTexts[i], start, end});
|
|
656
|
+
}
|
|
657
|
+
|
|
658
|
+
std::vector<AlignmentItem> wordItems;
|
|
659
|
+
std::string currentWord;
|
|
660
|
+
double wordStart = 0.0;
|
|
661
|
+
double wordEnd = 0.0;
|
|
662
|
+
size_t charCursor = 0;
|
|
663
|
+
|
|
664
|
+
for (const auto &token : tokenTexts) {
|
|
665
|
+
if (token == "|") {
|
|
666
|
+
if (!currentWord.empty()) {
|
|
667
|
+
wordItems.push_back(AlignmentItem{currentWord, wordStart, wordEnd});
|
|
668
|
+
currentWord.clear();
|
|
669
|
+
}
|
|
670
|
+
continue;
|
|
671
|
+
}
|
|
672
|
+
|
|
673
|
+
if (charCursor >= charItems.size()) {
|
|
674
|
+
continue;
|
|
675
|
+
}
|
|
676
|
+
|
|
677
|
+
const AlignmentItem &charItem = charItems[charCursor++];
|
|
678
|
+
if (currentWord.empty()) {
|
|
679
|
+
wordStart = charItem.start;
|
|
680
|
+
wordEnd = charItem.end;
|
|
681
|
+
} else {
|
|
682
|
+
wordEnd = std::max(wordEnd, charItem.end);
|
|
683
|
+
}
|
|
684
|
+
currentWord += charItem.text;
|
|
685
|
+
}
|
|
686
|
+
|
|
687
|
+
if (!currentWord.empty()) {
|
|
688
|
+
wordItems.push_back(AlignmentItem{currentWord, wordStart, wordEnd});
|
|
689
|
+
}
|
|
690
|
+
|
|
691
|
+
resolve(@{
|
|
692
|
+
@"words": AlignmentItemsToNSArray(wordItems),
|
|
693
|
+
@"chars": AlignmentItemsToNSArray(charItems),
|
|
694
|
+
});
|
|
695
|
+
} catch (const std::exception &e) {
|
|
696
|
+
NSString *errorMsg = [NSString stringWithUTF8String:e.what()] ?: @"CTC alignment failed";
|
|
697
|
+
reject(@"ALIGNMENT_ERROR", errorMsg, nil);
|
|
698
|
+
} catch (...) {
|
|
699
|
+
reject(@"ALIGNMENT_ERROR", @"CTC alignment failed", nil);
|
|
700
|
+
}
|
|
701
|
+
});
|
|
702
|
+
}
|
|
703
|
+
|
|
704
|
+
@end
|