react-native-executorch 0.5.1 → 0.5.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/lib/Error.d.ts +30 -0
- package/lib/Error.js +50 -0
- package/lib/common/Logger.d.ts +8 -0
- package/lib/common/Logger.js +19 -0
- package/lib/constants/directories.d.ts +1 -0
- package/lib/constants/directories.js +2 -0
- package/lib/constants/llmDefaults.d.ts +6 -0
- package/lib/constants/llmDefaults.js +16 -0
- package/lib/constants/modelUrls.d.ts +223 -0
- package/lib/constants/modelUrls.js +322 -0
- package/lib/constants/ocr/models.d.ts +882 -0
- package/lib/constants/ocr/models.js +182 -0
- package/lib/constants/ocr/symbols.d.ts +75 -0
- package/lib/constants/ocr/symbols.js +139 -0
- package/lib/constants/sttDefaults.d.ts +28 -0
- package/lib/constants/sttDefaults.js +68 -0
- package/lib/controllers/LLMController.d.ts +47 -0
- package/lib/controllers/LLMController.js +213 -0
- package/lib/controllers/OCRController.d.ts +23 -0
- package/lib/controllers/OCRController.js +72 -0
- package/lib/controllers/SpeechToTextController.d.ts +56 -0
- package/lib/controllers/SpeechToTextController.js +349 -0
- package/lib/controllers/VerticalOCRController.d.ts +25 -0
- package/lib/controllers/VerticalOCRController.js +75 -0
- package/lib/hooks/computer_vision/useClassification.d.ts +15 -0
- package/lib/hooks/computer_vision/useClassification.js +7 -0
- package/lib/hooks/computer_vision/useImageEmbeddings.d.ts +15 -0
- package/lib/hooks/computer_vision/useImageEmbeddings.js +7 -0
- package/lib/hooks/computer_vision/useImageSegmentation.d.ts +38 -0
- package/lib/hooks/computer_vision/useImageSegmentation.js +7 -0
- package/lib/hooks/computer_vision/useOCR.d.ts +20 -0
- package/lib/hooks/computer_vision/useOCR.js +42 -0
- package/lib/hooks/computer_vision/useObjectDetection.d.ts +15 -0
- package/lib/hooks/computer_vision/useObjectDetection.js +7 -0
- package/lib/hooks/computer_vision/useStyleTransfer.d.ts +15 -0
- package/lib/hooks/computer_vision/useStyleTransfer.js +7 -0
- package/lib/hooks/computer_vision/useVerticalOCR.d.ts +21 -0
- package/lib/hooks/computer_vision/useVerticalOCR.js +45 -0
- package/lib/hooks/general/useExecutorchModule.d.ts +13 -0
- package/lib/hooks/general/useExecutorchModule.js +7 -0
- package/lib/hooks/natural_language_processing/useLLM.d.ts +10 -0
- package/lib/hooks/natural_language_processing/useLLM.js +78 -0
- package/lib/hooks/natural_language_processing/useSpeechToText.d.ts +27 -0
- package/lib/hooks/natural_language_processing/useSpeechToText.js +49 -0
- package/lib/hooks/natural_language_processing/useTextEmbeddings.d.ts +16 -0
- package/lib/hooks/natural_language_processing/useTextEmbeddings.js +7 -0
- package/lib/hooks/natural_language_processing/useTokenizer.d.ts +17 -0
- package/lib/hooks/natural_language_processing/useTokenizer.js +52 -0
- package/lib/hooks/useModule.d.ts +17 -0
- package/lib/hooks/useModule.js +45 -0
- package/lib/hooks/useNonStaticModule.d.ts +20 -0
- package/lib/hooks/useNonStaticModule.js +49 -0
- package/lib/index.d.ts +50 -0
- package/lib/index.js +60 -0
- package/lib/module/utils/ResourceFetcher.js +6 -8
- package/lib/module/utils/ResourceFetcher.js.map +1 -1
- package/lib/module/utils/ResourceFetcherUtils.js +20 -20
- package/lib/module/utils/ResourceFetcherUtils.js.map +1 -1
- package/lib/modules/BaseModule.d.ts +8 -0
- package/lib/modules/BaseModule.js +25 -0
- package/lib/modules/BaseNonStaticModule.d.ts +9 -0
- package/lib/modules/BaseNonStaticModule.js +14 -0
- package/lib/modules/computer_vision/ClassificationModule.d.ts +8 -0
- package/lib/modules/computer_vision/ClassificationModule.js +17 -0
- package/lib/modules/computer_vision/ImageEmbeddingsModule.d.ts +8 -0
- package/lib/modules/computer_vision/ImageEmbeddingsModule.js +17 -0
- package/lib/modules/computer_vision/ImageSegmentationModule.d.ts +11 -0
- package/lib/modules/computer_vision/ImageSegmentationModule.js +27 -0
- package/lib/modules/computer_vision/OCRModule.d.ts +15 -0
- package/lib/modules/computer_vision/OCRModule.js +20 -0
- package/lib/modules/computer_vision/ObjectDetectionModule.d.ts +9 -0
- package/lib/modules/computer_vision/ObjectDetectionModule.js +17 -0
- package/lib/modules/computer_vision/StyleTransferModule.d.ts +8 -0
- package/lib/modules/computer_vision/StyleTransferModule.js +17 -0
- package/lib/modules/computer_vision/VerticalOCRModule.d.ts +15 -0
- package/lib/modules/computer_vision/VerticalOCRModule.js +22 -0
- package/lib/modules/general/ExecutorchModule.d.ts +7 -0
- package/lib/modules/general/ExecutorchModule.js +14 -0
- package/lib/modules/natural_language_processing/LLMModule.d.ts +28 -0
- package/lib/modules/natural_language_processing/LLMModule.js +45 -0
- package/lib/modules/natural_language_processing/SpeechToTextModule.d.ts +24 -0
- package/lib/modules/natural_language_processing/SpeechToTextModule.js +36 -0
- package/lib/modules/natural_language_processing/TextEmbeddingsModule.d.ts +9 -0
- package/lib/modules/natural_language_processing/TextEmbeddingsModule.js +21 -0
- package/lib/modules/natural_language_processing/TokenizerModule.d.ts +12 -0
- package/lib/modules/natural_language_processing/TokenizerModule.js +30 -0
- package/lib/native/NativeETInstaller.d.ts +6 -0
- package/lib/native/NativeETInstaller.js +2 -0
- package/lib/native/NativeOCR.d.ts +8 -0
- package/lib/native/NativeOCR.js +2 -0
- package/lib/native/NativeVerticalOCR.d.ts +8 -0
- package/lib/native/NativeVerticalOCR.js +2 -0
- package/lib/native/RnExecutorchModules.d.ts +3 -0
- package/lib/native/RnExecutorchModules.js +16 -0
- package/lib/types/common.d.ts +31 -0
- package/lib/types/common.js +25 -0
- package/lib/types/imageSegmentation.d.ts +24 -0
- package/lib/types/imageSegmentation.js +26 -0
- package/lib/types/llm.d.ts +46 -0
- package/lib/types/llm.js +9 -0
- package/lib/types/objectDetection.d.ts +104 -0
- package/lib/types/objectDetection.js +94 -0
- package/lib/types/ocr.d.ts +11 -0
- package/lib/types/ocr.js +1 -0
- package/lib/types/stt.d.ts +94 -0
- package/lib/types/stt.js +85 -0
- package/lib/typescript/utils/ResourceFetcher.d.ts.map +1 -1
- package/lib/typescript/utils/ResourceFetcherUtils.d.ts.map +1 -1
- package/lib/utils/ResourceFetcher.d.ts +24 -0
- package/lib/utils/ResourceFetcher.js +305 -0
- package/lib/utils/ResourceFetcherUtils.d.ts +54 -0
- package/lib/utils/ResourceFetcherUtils.js +128 -0
- package/lib/utils/llm.d.ts +6 -0
- package/lib/utils/llm.js +73 -0
- package/lib/utils/stt.d.ts +1 -0
- package/lib/utils/stt.js +21 -0
- package/package.json +1 -1
- package/src/utils/ResourceFetcher.ts +9 -7
- package/src/utils/ResourceFetcherUtils.ts +15 -17
- package/ios/RnExecutorch.xcodeproj/project.xcworkspace/contents.xcworkspacedata +0 -7
- package/ios/RnExecutorch.xcodeproj/project.xcworkspace/xcuserdata/jakubchmura.xcuserdatad/UserInterfaceState.xcuserstate +0 -0
- package/ios/RnExecutorch.xcodeproj/xcuserdata/jakubchmura.xcuserdatad/xcschemes/xcschememanagement.plist +0 -14
- package/lib/tsconfig.tsbuildinfo +0 -1
- package/third-party/ios/ExecutorchLib/ExecutorchLib.xcodeproj/project.xcworkspace/contents.xcworkspacedata +0 -7
- package/third-party/ios/ExecutorchLib/ExecutorchLib.xcodeproj/project.xcworkspace/xcuserdata/jakubchmura.xcuserdatad/UserInterfaceState.xcuserstate +0 -0
- package/third-party/ios/ExecutorchLib/ExecutorchLib.xcodeproj/xcuserdata/jakubchmura.xcuserdatad/xcschemes/xcschememanagement.plist +0 -14
|
@@ -0,0 +1,349 @@
|
|
|
1
|
+
import { HAMMING_DIST_THRESHOLD, MODEL_CONFIGS, SECOND, MODES, NUM_TOKENS_TO_TRIM, STREAMING_ACTION, } from '../constants/sttDefaults';
|
|
2
|
+
import { AvailableModels } from '../types/stt';
|
|
3
|
+
import { TokenizerModule } from '../modules/natural_language_processing/TokenizerModule';
|
|
4
|
+
import { ResourceFetcher } from '../utils/ResourceFetcher';
|
|
5
|
+
import { longCommonInfPref } from '../utils/stt';
|
|
6
|
+
import { ETError, getError } from '../Error';
|
|
7
|
+
import { Logger } from '../common/Logger';
|
|
8
|
+
export class SpeechToTextController {
|
|
9
|
+
speechToTextNativeModule;
|
|
10
|
+
sequence = [];
|
|
11
|
+
isReady = false;
|
|
12
|
+
isGenerating = false;
|
|
13
|
+
tokenizerModule;
|
|
14
|
+
overlapSeconds;
|
|
15
|
+
windowSize;
|
|
16
|
+
chunks = [];
|
|
17
|
+
seqs = [];
|
|
18
|
+
prevSeq = [];
|
|
19
|
+
waveform = [];
|
|
20
|
+
numOfChunks = 0;
|
|
21
|
+
streaming = false;
|
|
22
|
+
// User callbacks
|
|
23
|
+
decodedTranscribeCallback;
|
|
24
|
+
isReadyCallback;
|
|
25
|
+
isGeneratingCallback;
|
|
26
|
+
onErrorCallback;
|
|
27
|
+
config;
|
|
28
|
+
constructor({ transcribeCallback, isReadyCallback, isGeneratingCallback, onErrorCallback, overlapSeconds, windowSize, streamingConfig, }) {
|
|
29
|
+
this.tokenizerModule = new TokenizerModule();
|
|
30
|
+
this.decodedTranscribeCallback = async (seq) => transcribeCallback(await this.tokenIdsToText(seq));
|
|
31
|
+
this.isReadyCallback = (isReady) => {
|
|
32
|
+
this.isReady = isReady;
|
|
33
|
+
isReadyCallback?.(isReady);
|
|
34
|
+
};
|
|
35
|
+
this.isGeneratingCallback = (isGenerating) => {
|
|
36
|
+
this.isGenerating = isGenerating;
|
|
37
|
+
isGeneratingCallback?.(isGenerating);
|
|
38
|
+
};
|
|
39
|
+
this.onErrorCallback = (error) => {
|
|
40
|
+
if (onErrorCallback) {
|
|
41
|
+
onErrorCallback(error ? new Error(getError(error)) : undefined);
|
|
42
|
+
return;
|
|
43
|
+
}
|
|
44
|
+
else {
|
|
45
|
+
throw new Error(getError(error));
|
|
46
|
+
}
|
|
47
|
+
};
|
|
48
|
+
this.configureStreaming(overlapSeconds, windowSize, streamingConfig || 'balanced');
|
|
49
|
+
}
|
|
50
|
+
async load({ modelName, encoderSource, decoderSource, tokenizerSource, onDownloadProgressCallback, }) {
|
|
51
|
+
this.onErrorCallback(undefined);
|
|
52
|
+
this.isReadyCallback(false);
|
|
53
|
+
this.config = MODEL_CONFIGS[modelName];
|
|
54
|
+
try {
|
|
55
|
+
const tokenizerLoadPromise = this.tokenizerModule.load({
|
|
56
|
+
tokenizerSource: tokenizerSource || this.config.tokenizer.source,
|
|
57
|
+
});
|
|
58
|
+
const pathsPromise = ResourceFetcher.fetch(onDownloadProgressCallback, encoderSource || this.config.sources.encoder, decoderSource || this.config.sources.decoder);
|
|
59
|
+
const [_, encoderDecoderResults] = await Promise.all([
|
|
60
|
+
tokenizerLoadPromise,
|
|
61
|
+
pathsPromise,
|
|
62
|
+
]);
|
|
63
|
+
encoderSource = encoderDecoderResults?.[0];
|
|
64
|
+
decoderSource = encoderDecoderResults?.[1];
|
|
65
|
+
if (!encoderSource || !decoderSource) {
|
|
66
|
+
throw new Error('Download interrupted.');
|
|
67
|
+
}
|
|
68
|
+
}
|
|
69
|
+
catch (e) {
|
|
70
|
+
this.onErrorCallback(e);
|
|
71
|
+
return;
|
|
72
|
+
}
|
|
73
|
+
if (modelName === 'whisperMultilingual') {
|
|
74
|
+
// The underlying native class is instantiated based on the name of the model. There is no need to
|
|
75
|
+
// create a separate class for multilingual version of Whisper, since it is the same. We just need
|
|
76
|
+
// the distinction here, in TS, for start tokens and such. If we introduce
|
|
77
|
+
// more versions of Whisper, such as the small one, this should be refactored.
|
|
78
|
+
modelName = AvailableModels.WHISPER;
|
|
79
|
+
}
|
|
80
|
+
try {
|
|
81
|
+
const nativeSpeechToText = await global.loadSpeechToText(encoderSource, decoderSource, modelName);
|
|
82
|
+
this.speechToTextNativeModule = nativeSpeechToText;
|
|
83
|
+
this.isReadyCallback(true);
|
|
84
|
+
}
|
|
85
|
+
catch (e) {
|
|
86
|
+
this.onErrorCallback(e);
|
|
87
|
+
}
|
|
88
|
+
}
|
|
89
|
+
configureStreaming(overlapSeconds, windowSize, streamingConfig) {
|
|
90
|
+
if (streamingConfig) {
|
|
91
|
+
this.windowSize = MODES[streamingConfig].windowSize * SECOND;
|
|
92
|
+
this.overlapSeconds = MODES[streamingConfig].overlapSeconds * SECOND;
|
|
93
|
+
}
|
|
94
|
+
if (streamingConfig && (windowSize || overlapSeconds)) {
|
|
95
|
+
Logger.warn(`windowSize and overlapSeconds overrides values from streamingConfig ${streamingConfig}.`);
|
|
96
|
+
}
|
|
97
|
+
this.windowSize = (windowSize || 0) * SECOND || this.windowSize;
|
|
98
|
+
this.overlapSeconds = (overlapSeconds || 0) * SECOND || this.overlapSeconds;
|
|
99
|
+
if (2 * this.overlapSeconds + this.windowSize >= 30 * SECOND) {
|
|
100
|
+
Logger.warn(`Invalid values for overlapSeconds and/or windowSize provided. Expected windowSize + 2 * overlapSeconds (== ${this.windowSize + 2 * this.overlapSeconds}) <= 30. Setting windowSize to ${30 * SECOND - 2 * this.overlapSeconds}.`);
|
|
101
|
+
this.windowSize = 30 * SECOND - 2 * this.overlapSeconds;
|
|
102
|
+
}
|
|
103
|
+
}
|
|
104
|
+
chunkWaveform() {
|
|
105
|
+
this.numOfChunks = Math.ceil(this.waveform.length / this.windowSize);
|
|
106
|
+
for (let i = 0; i < this.numOfChunks; i++) {
|
|
107
|
+
let chunk = [];
|
|
108
|
+
const left = Math.max(this.windowSize * i - this.overlapSeconds, 0);
|
|
109
|
+
const right = Math.min(this.windowSize * (i + 1) + this.overlapSeconds, this.waveform.length);
|
|
110
|
+
chunk = this.waveform.slice(left, right);
|
|
111
|
+
this.chunks.push(chunk);
|
|
112
|
+
}
|
|
113
|
+
}
|
|
114
|
+
resetState() {
|
|
115
|
+
this.sequence = [];
|
|
116
|
+
this.seqs = [];
|
|
117
|
+
this.waveform = [];
|
|
118
|
+
this.prevSeq = [];
|
|
119
|
+
this.chunks = [];
|
|
120
|
+
this.decodedTranscribeCallback([]);
|
|
121
|
+
this.onErrorCallback(undefined);
|
|
122
|
+
}
|
|
123
|
+
expectedChunkLength() {
|
|
124
|
+
//only first chunk can be of shorter length, for first chunk there are no seqs decoded
|
|
125
|
+
return this.seqs.length
|
|
126
|
+
? this.windowSize + 2 * this.overlapSeconds
|
|
127
|
+
: this.windowSize + this.overlapSeconds;
|
|
128
|
+
}
|
|
129
|
+
async getStartingTokenIds(audioLanguage) {
|
|
130
|
+
// We need different starting token ids based on the multilingualism of the model.
|
|
131
|
+
// The eng version only needs BOS token, while the multilingual one needs:
|
|
132
|
+
// [BOS, LANG, TRANSCRIBE]. Optionally we should also set notimestamps token, as timestamps
|
|
133
|
+
// is not yet supported.
|
|
134
|
+
if (!audioLanguage) {
|
|
135
|
+
return [this.config.tokenizer.bos];
|
|
136
|
+
}
|
|
137
|
+
// FIXME: I should use .getTokenId for the BOS as well, should remove it from config
|
|
138
|
+
const langTokenId = await this.tokenizerModule.tokenToId(`<|${audioLanguage}|>`);
|
|
139
|
+
const transcribeTokenId = await this.tokenizerModule.tokenToId('<|transcribe|>');
|
|
140
|
+
const noTimestampsTokenId = await this.tokenizerModule.tokenToId('<|notimestamps|>');
|
|
141
|
+
const startingTokenIds = [
|
|
142
|
+
this.config.tokenizer.bos,
|
|
143
|
+
langTokenId,
|
|
144
|
+
transcribeTokenId,
|
|
145
|
+
noTimestampsTokenId,
|
|
146
|
+
];
|
|
147
|
+
return startingTokenIds;
|
|
148
|
+
}
|
|
149
|
+
async decodeChunk(chunk, audioLanguage) {
|
|
150
|
+
const seq = await this.getStartingTokenIds(audioLanguage);
|
|
151
|
+
let prevSeqTokenIdx = 0;
|
|
152
|
+
this.prevSeq = this.sequence.slice();
|
|
153
|
+
try {
|
|
154
|
+
await this.encode(new Float32Array(chunk));
|
|
155
|
+
}
|
|
156
|
+
catch (error) {
|
|
157
|
+
this.onErrorCallback(new Error(getError(error) + ' encoding error'));
|
|
158
|
+
return [];
|
|
159
|
+
}
|
|
160
|
+
let lastToken = seq.at(-1);
|
|
161
|
+
while (lastToken !== this.config.tokenizer.eos) {
|
|
162
|
+
try {
|
|
163
|
+
lastToken = await this.decode(seq);
|
|
164
|
+
}
|
|
165
|
+
catch (error) {
|
|
166
|
+
this.onErrorCallback(new Error(getError(error) + ' decoding error'));
|
|
167
|
+
return [...seq, this.config.tokenizer.eos];
|
|
168
|
+
}
|
|
169
|
+
seq.push(lastToken);
|
|
170
|
+
if (this.seqs.length > 0 &&
|
|
171
|
+
seq.length < this.seqs.at(-1).length &&
|
|
172
|
+
seq.length % 3 !== 0) {
|
|
173
|
+
this.prevSeq.push(this.seqs.at(-1)[prevSeqTokenIdx++]);
|
|
174
|
+
this.decodedTranscribeCallback(this.prevSeq);
|
|
175
|
+
}
|
|
176
|
+
}
|
|
177
|
+
return seq;
|
|
178
|
+
}
|
|
179
|
+
async handleOverlaps(seqs) {
|
|
180
|
+
const maxInd = longCommonInfPref(seqs.at(-2), seqs.at(-1), HAMMING_DIST_THRESHOLD);
|
|
181
|
+
this.sequence = [...this.sequence, ...seqs.at(-2).slice(0, maxInd)];
|
|
182
|
+
this.decodedTranscribeCallback(this.sequence);
|
|
183
|
+
return this.sequence.slice();
|
|
184
|
+
}
|
|
185
|
+
trimLeft(numOfTokensToTrim) {
|
|
186
|
+
const idx = this.seqs.length - 1;
|
|
187
|
+
if (this.seqs[idx][0] === this.config.tokenizer.bos) {
|
|
188
|
+
this.seqs[idx] = this.seqs[idx].slice(numOfTokensToTrim);
|
|
189
|
+
}
|
|
190
|
+
}
|
|
191
|
+
trimRight(numOfTokensToTrim) {
|
|
192
|
+
const idx = this.seqs.length - 2;
|
|
193
|
+
if (this.seqs[idx].at(-1) === this.config.tokenizer.eos) {
|
|
194
|
+
this.seqs[idx] = this.seqs[idx].slice(0, -numOfTokensToTrim);
|
|
195
|
+
}
|
|
196
|
+
}
|
|
197
|
+
// since we are calling this every time (except first) after a new seq is pushed to this.seqs
|
|
198
|
+
// we can only trim left the last seq and trim right the second to last seq
|
|
199
|
+
async trimSequences(audioLanguage) {
|
|
200
|
+
const numSpecialTokens = (await this.getStartingTokenIds(audioLanguage))
|
|
201
|
+
.length;
|
|
202
|
+
this.trimLeft(numSpecialTokens + NUM_TOKENS_TO_TRIM);
|
|
203
|
+
this.trimRight(numSpecialTokens + NUM_TOKENS_TO_TRIM);
|
|
204
|
+
}
|
|
205
|
+
// if last chunk is too short combine it with second to last to improve quality
|
|
206
|
+
validateAndFixLastChunk() {
|
|
207
|
+
if (this.chunks.length < 2)
|
|
208
|
+
return;
|
|
209
|
+
const lastChunkLength = this.chunks.at(-1).length / SECOND;
|
|
210
|
+
const secondToLastChunkLength = this.chunks.at(-2).length / SECOND;
|
|
211
|
+
if (lastChunkLength < 5 && secondToLastChunkLength + lastChunkLength < 30) {
|
|
212
|
+
this.chunks[this.chunks.length - 2] = [
|
|
213
|
+
...this.chunks.at(-2).slice(0, -this.overlapSeconds * 2),
|
|
214
|
+
...this.chunks.at(-1),
|
|
215
|
+
];
|
|
216
|
+
this.chunks = this.chunks.slice(0, -1);
|
|
217
|
+
}
|
|
218
|
+
}
|
|
219
|
+
async tokenIdsToText(tokenIds) {
|
|
220
|
+
try {
|
|
221
|
+
return await this.tokenizerModule.decode(tokenIds, true);
|
|
222
|
+
}
|
|
223
|
+
catch (e) {
|
|
224
|
+
this.onErrorCallback(new Error(`An error has occurred when decoding the token ids: ${e}`));
|
|
225
|
+
return '';
|
|
226
|
+
}
|
|
227
|
+
}
|
|
228
|
+
async transcribe(waveform, audioLanguage) {
|
|
229
|
+
try {
|
|
230
|
+
if (!this.isReady)
|
|
231
|
+
throw Error(getError(ETError.ModuleNotLoaded));
|
|
232
|
+
if (this.isGenerating || this.streaming)
|
|
233
|
+
throw Error(getError(ETError.ModelGenerating));
|
|
234
|
+
if (!!audioLanguage !== this.config.isMultilingual)
|
|
235
|
+
throw new Error(getError(ETError.MultilingualConfiguration));
|
|
236
|
+
}
|
|
237
|
+
catch (e) {
|
|
238
|
+
this.onErrorCallback(e);
|
|
239
|
+
return '';
|
|
240
|
+
}
|
|
241
|
+
// Making sure that the error is not set when we get there
|
|
242
|
+
this.isGeneratingCallback(true);
|
|
243
|
+
this.resetState();
|
|
244
|
+
this.waveform = waveform;
|
|
245
|
+
this.chunkWaveform();
|
|
246
|
+
this.validateAndFixLastChunk();
|
|
247
|
+
for (let chunkId = 0; chunkId < this.chunks.length; chunkId++) {
|
|
248
|
+
const seq = await this.decodeChunk(this.chunks.at(chunkId), audioLanguage);
|
|
249
|
+
// whole audio is inside one chunk, no processing required
|
|
250
|
+
if (this.chunks.length === 1) {
|
|
251
|
+
this.sequence = seq;
|
|
252
|
+
this.decodedTranscribeCallback(seq);
|
|
253
|
+
break;
|
|
254
|
+
}
|
|
255
|
+
this.seqs.push(seq);
|
|
256
|
+
if (this.seqs.length < 2)
|
|
257
|
+
continue;
|
|
258
|
+
// Remove starting tokenIds and some additional ones
|
|
259
|
+
await this.trimSequences(audioLanguage);
|
|
260
|
+
this.prevSeq = await this.handleOverlaps(this.seqs);
|
|
261
|
+
// last sequence processed
|
|
262
|
+
// overlaps are already handled, so just append the last seq
|
|
263
|
+
if (this.seqs.length === this.chunks.length) {
|
|
264
|
+
this.sequence = [...this.sequence, ...this.seqs.at(-1)];
|
|
265
|
+
this.decodedTranscribeCallback(this.sequence);
|
|
266
|
+
this.prevSeq = this.sequence;
|
|
267
|
+
}
|
|
268
|
+
}
|
|
269
|
+
const decodedText = await this.tokenIdsToText(this.sequence);
|
|
270
|
+
this.isGeneratingCallback(false);
|
|
271
|
+
return decodedText;
|
|
272
|
+
}
|
|
273
|
+
async streamingTranscribe(streamAction, waveform, audioLanguage) {
|
|
274
|
+
try {
|
|
275
|
+
if (!this.isReady)
|
|
276
|
+
throw Error(getError(ETError.ModuleNotLoaded));
|
|
277
|
+
if (!!audioLanguage !== this.config.isMultilingual)
|
|
278
|
+
throw new Error(getError(ETError.MultilingualConfiguration));
|
|
279
|
+
if (streamAction === STREAMING_ACTION.START &&
|
|
280
|
+
!this.streaming &&
|
|
281
|
+
this.isGenerating)
|
|
282
|
+
throw Error(getError(ETError.ModelGenerating));
|
|
283
|
+
if (streamAction === STREAMING_ACTION.START && this.streaming)
|
|
284
|
+
throw Error(getError(ETError.ModelGenerating));
|
|
285
|
+
if (streamAction === STREAMING_ACTION.DATA && !this.streaming)
|
|
286
|
+
throw Error(getError(ETError.StreamingNotStarted));
|
|
287
|
+
if (streamAction === STREAMING_ACTION.STOP && !this.streaming)
|
|
288
|
+
throw Error(getError(ETError.StreamingNotStarted));
|
|
289
|
+
if (streamAction === STREAMING_ACTION.DATA && !waveform)
|
|
290
|
+
throw new Error(getError(ETError.MissingDataChunk));
|
|
291
|
+
}
|
|
292
|
+
catch (e) {
|
|
293
|
+
this.onErrorCallback(e);
|
|
294
|
+
return '';
|
|
295
|
+
}
|
|
296
|
+
if (streamAction === STREAMING_ACTION.START) {
|
|
297
|
+
this.resetState();
|
|
298
|
+
this.streaming = true;
|
|
299
|
+
this.isGeneratingCallback(true);
|
|
300
|
+
}
|
|
301
|
+
this.waveform = [...this.waveform, ...(waveform || [])];
|
|
302
|
+
// while buffer has at least required size get chunk and decode
|
|
303
|
+
while (this.waveform.length >= this.expectedChunkLength()) {
|
|
304
|
+
const chunk = this.waveform.slice(0, this.windowSize +
|
|
305
|
+
this.overlapSeconds * (1 + Number(this.seqs.length > 0)));
|
|
306
|
+
this.chunks = [chunk]; //save last chunk for STREAMING_ACTION.STOP
|
|
307
|
+
this.waveform = this.waveform.slice(this.windowSize - this.overlapSeconds * Number(this.seqs.length === 0));
|
|
308
|
+
const seq = await this.decodeChunk(chunk, audioLanguage);
|
|
309
|
+
this.seqs.push(seq);
|
|
310
|
+
if (this.seqs.length < 2)
|
|
311
|
+
continue;
|
|
312
|
+
await this.trimSequences(audioLanguage);
|
|
313
|
+
await this.handleOverlaps(this.seqs);
|
|
314
|
+
}
|
|
315
|
+
// got final package, process all remaining waveform data
|
|
316
|
+
// since we run the loop above the waveform has at most one chunk in it
|
|
317
|
+
if (streamAction === STREAMING_ACTION.STOP) {
|
|
318
|
+
// pad remaining waveform data with previous chunk to this.windowSize + 2 * this.overlapSeconds
|
|
319
|
+
const chunk = this.chunks.length
|
|
320
|
+
? [
|
|
321
|
+
...this.chunks[0].slice(0, this.windowSize),
|
|
322
|
+
...this.waveform,
|
|
323
|
+
].slice(-this.windowSize - 2 * this.overlapSeconds)
|
|
324
|
+
: this.waveform;
|
|
325
|
+
this.waveform = [];
|
|
326
|
+
const seq = await this.decodeChunk(chunk, audioLanguage);
|
|
327
|
+
this.seqs.push(seq);
|
|
328
|
+
if (this.seqs.length === 1) {
|
|
329
|
+
this.sequence = this.seqs[0];
|
|
330
|
+
}
|
|
331
|
+
else {
|
|
332
|
+
await this.trimSequences(audioLanguage);
|
|
333
|
+
await this.handleOverlaps(this.seqs);
|
|
334
|
+
this.sequence = [...this.sequence, ...this.seqs.at(-1)];
|
|
335
|
+
}
|
|
336
|
+
this.decodedTranscribeCallback(this.sequence);
|
|
337
|
+
this.isGeneratingCallback(false);
|
|
338
|
+
this.streaming = false;
|
|
339
|
+
}
|
|
340
|
+
const decodedText = await this.tokenIdsToText(this.sequence);
|
|
341
|
+
return decodedText;
|
|
342
|
+
}
|
|
343
|
+
async encode(waveform) {
|
|
344
|
+
return await this.speechToTextNativeModule.encode(waveform);
|
|
345
|
+
}
|
|
346
|
+
async decode(seq) {
|
|
347
|
+
return await this.speechToTextNativeModule.decode(seq);
|
|
348
|
+
}
|
|
349
|
+
}
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
import { ResourceSource } from '../types/common';
|
|
2
|
+
import { OCRLanguage } from '../types/ocr';
|
|
3
|
+
export declare class VerticalOCRController {
|
|
4
|
+
private ocrNativeModule;
|
|
5
|
+
isReady: boolean;
|
|
6
|
+
isGenerating: boolean;
|
|
7
|
+
error: string | null;
|
|
8
|
+
private isReadyCallback;
|
|
9
|
+
private isGeneratingCallback;
|
|
10
|
+
private errorCallback;
|
|
11
|
+
constructor({ isReadyCallback, isGeneratingCallback, errorCallback, }?: {
|
|
12
|
+
isReadyCallback?: ((_isReady: boolean) => void) | undefined;
|
|
13
|
+
isGeneratingCallback?: ((_isGenerating: boolean) => void) | undefined;
|
|
14
|
+
errorCallback?: ((_error: string) => void) | undefined;
|
|
15
|
+
});
|
|
16
|
+
load: (detectorSources: {
|
|
17
|
+
detectorLarge: ResourceSource;
|
|
18
|
+
detectorNarrow: ResourceSource;
|
|
19
|
+
}, recognizerSources: {
|
|
20
|
+
recognizerLarge: ResourceSource;
|
|
21
|
+
recognizerSmall: ResourceSource;
|
|
22
|
+
}, language: OCRLanguage, independentCharacters: boolean, onDownloadProgressCallback: (downloadProgress: number) => void) => Promise<void>;
|
|
23
|
+
forward: (input: string) => Promise<any>;
|
|
24
|
+
delete(): void;
|
|
25
|
+
}
|
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
import { symbols } from '../constants/ocr/symbols';
|
|
2
|
+
import { ETError, getError } from '../Error';
|
|
3
|
+
import { ResourceFetcher } from '../utils/ResourceFetcher';
|
|
4
|
+
export class VerticalOCRController {
|
|
5
|
+
ocrNativeModule;
|
|
6
|
+
isReady = false;
|
|
7
|
+
isGenerating = false;
|
|
8
|
+
error = null;
|
|
9
|
+
isReadyCallback;
|
|
10
|
+
isGeneratingCallback;
|
|
11
|
+
errorCallback;
|
|
12
|
+
constructor({ isReadyCallback = (_isReady) => { }, isGeneratingCallback = (_isGenerating) => { }, errorCallback = (_error) => { }, } = {}) {
|
|
13
|
+
this.isReadyCallback = isReadyCallback;
|
|
14
|
+
this.isGeneratingCallback = isGeneratingCallback;
|
|
15
|
+
this.errorCallback = errorCallback;
|
|
16
|
+
}
|
|
17
|
+
load = async (detectorSources, recognizerSources, language, independentCharacters, onDownloadProgressCallback) => {
|
|
18
|
+
try {
|
|
19
|
+
if (Object.keys(detectorSources).length !== 2 ||
|
|
20
|
+
Object.keys(recognizerSources).length !== 2)
|
|
21
|
+
return;
|
|
22
|
+
if (!symbols[language]) {
|
|
23
|
+
throw new Error(getError(ETError.LanguageNotSupported));
|
|
24
|
+
}
|
|
25
|
+
this.isReady = false;
|
|
26
|
+
this.isReadyCallback(this.isReady);
|
|
27
|
+
const paths = await ResourceFetcher.fetch(onDownloadProgressCallback, detectorSources.detectorLarge, detectorSources.detectorNarrow, independentCharacters
|
|
28
|
+
? recognizerSources.recognizerSmall
|
|
29
|
+
: recognizerSources.recognizerLarge);
|
|
30
|
+
if (paths === null || paths.length < 3) {
|
|
31
|
+
throw new Error('Download interrupted');
|
|
32
|
+
}
|
|
33
|
+
this.ocrNativeModule = global.loadVerticalOCR(paths[0], paths[1], paths[2], symbols[language], independentCharacters);
|
|
34
|
+
this.isReady = true;
|
|
35
|
+
this.isReadyCallback(this.isReady);
|
|
36
|
+
}
|
|
37
|
+
catch (e) {
|
|
38
|
+
if (this.errorCallback) {
|
|
39
|
+
this.errorCallback(getError(e));
|
|
40
|
+
}
|
|
41
|
+
else {
|
|
42
|
+
throw new Error(getError(e));
|
|
43
|
+
}
|
|
44
|
+
}
|
|
45
|
+
};
|
|
46
|
+
forward = async (input) => {
|
|
47
|
+
if (!this.isReady) {
|
|
48
|
+
throw new Error(getError(ETError.ModuleNotLoaded));
|
|
49
|
+
}
|
|
50
|
+
if (this.isGenerating) {
|
|
51
|
+
throw new Error(getError(ETError.ModelGenerating));
|
|
52
|
+
}
|
|
53
|
+
try {
|
|
54
|
+
this.isGenerating = true;
|
|
55
|
+
this.isGeneratingCallback(this.isGenerating);
|
|
56
|
+
return await this.ocrNativeModule.generate(input);
|
|
57
|
+
}
|
|
58
|
+
catch (e) {
|
|
59
|
+
throw new Error(getError(e));
|
|
60
|
+
}
|
|
61
|
+
finally {
|
|
62
|
+
this.isGenerating = false;
|
|
63
|
+
this.isGeneratingCallback(this.isGenerating);
|
|
64
|
+
}
|
|
65
|
+
};
|
|
66
|
+
delete() {
|
|
67
|
+
if (this.isGenerating) {
|
|
68
|
+
throw new Error(getError(ETError.ModelGenerating) +
|
|
69
|
+
'You cannot delete the model. You must wait until the generating is finished.');
|
|
70
|
+
}
|
|
71
|
+
this.ocrNativeModule.unload();
|
|
72
|
+
this.isReadyCallback(false);
|
|
73
|
+
this.isGeneratingCallback(false);
|
|
74
|
+
}
|
|
75
|
+
}
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
import { ResourceSource } from '../../types/common';
|
|
2
|
+
interface Props {
|
|
3
|
+
model: {
|
|
4
|
+
modelSource: ResourceSource;
|
|
5
|
+
};
|
|
6
|
+
preventLoad?: boolean;
|
|
7
|
+
}
|
|
8
|
+
export declare const useClassification: ({ model, preventLoad }: Props) => {
|
|
9
|
+
error: string | null;
|
|
10
|
+
isReady: boolean;
|
|
11
|
+
isGenerating: boolean;
|
|
12
|
+
downloadProgress: number;
|
|
13
|
+
forward: (imageSource: string) => Promise<any>;
|
|
14
|
+
};
|
|
15
|
+
export {};
|
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
import { useNonStaticModule } from '../useNonStaticModule';
|
|
2
|
+
import { ClassificationModule } from '../../modules/computer_vision/ClassificationModule';
|
|
3
|
+
export const useClassification = ({ model, preventLoad = false }) => useNonStaticModule({
|
|
4
|
+
module: ClassificationModule,
|
|
5
|
+
model,
|
|
6
|
+
preventLoad: preventLoad,
|
|
7
|
+
});
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
import { ResourceSource } from '../../types/common';
|
|
2
|
+
interface Props {
|
|
3
|
+
model: {
|
|
4
|
+
modelSource: ResourceSource;
|
|
5
|
+
};
|
|
6
|
+
preventLoad?: boolean;
|
|
7
|
+
}
|
|
8
|
+
export declare const useImageEmbeddings: ({ model, preventLoad }: Props) => {
|
|
9
|
+
error: string | null;
|
|
10
|
+
isReady: boolean;
|
|
11
|
+
isGenerating: boolean;
|
|
12
|
+
downloadProgress: number;
|
|
13
|
+
forward: (imageSource: string) => Promise<Float32Array<ArrayBufferLike>>;
|
|
14
|
+
};
|
|
15
|
+
export {};
|
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
import { ImageEmbeddingsModule } from '../../modules/computer_vision/ImageEmbeddingsModule';
|
|
2
|
+
import { useNonStaticModule } from '../useNonStaticModule';
|
|
3
|
+
export const useImageEmbeddings = ({ model, preventLoad = false }) => useNonStaticModule({
|
|
4
|
+
module: ImageEmbeddingsModule,
|
|
5
|
+
model,
|
|
6
|
+
preventLoad,
|
|
7
|
+
});
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
import { ResourceSource } from '../../types/common';
|
|
2
|
+
interface Props {
|
|
3
|
+
model: {
|
|
4
|
+
modelSource: ResourceSource;
|
|
5
|
+
};
|
|
6
|
+
preventLoad?: boolean;
|
|
7
|
+
}
|
|
8
|
+
export declare const useImageSegmentation: ({ model, preventLoad }: Props) => {
|
|
9
|
+
error: string | null;
|
|
10
|
+
isReady: boolean;
|
|
11
|
+
isGenerating: boolean;
|
|
12
|
+
downloadProgress: number;
|
|
13
|
+
forward: (imageSource: string, classesOfInterest?: import("../..").DeeplabLabel[] | undefined, resize?: boolean | undefined) => Promise<{
|
|
14
|
+
0?: number[] | undefined;
|
|
15
|
+
1?: number[] | undefined;
|
|
16
|
+
2?: number[] | undefined;
|
|
17
|
+
3?: number[] | undefined;
|
|
18
|
+
4?: number[] | undefined;
|
|
19
|
+
5?: number[] | undefined;
|
|
20
|
+
6?: number[] | undefined;
|
|
21
|
+
7?: number[] | undefined;
|
|
22
|
+
8?: number[] | undefined;
|
|
23
|
+
9?: number[] | undefined;
|
|
24
|
+
10?: number[] | undefined;
|
|
25
|
+
11?: number[] | undefined;
|
|
26
|
+
12?: number[] | undefined;
|
|
27
|
+
13?: number[] | undefined;
|
|
28
|
+
14?: number[] | undefined;
|
|
29
|
+
15?: number[] | undefined;
|
|
30
|
+
16?: number[] | undefined;
|
|
31
|
+
17?: number[] | undefined;
|
|
32
|
+
18?: number[] | undefined;
|
|
33
|
+
19?: number[] | undefined;
|
|
34
|
+
20?: number[] | undefined;
|
|
35
|
+
21?: number[] | undefined;
|
|
36
|
+
}>;
|
|
37
|
+
};
|
|
38
|
+
export {};
|
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
import { useNonStaticModule } from '../useNonStaticModule';
|
|
2
|
+
import { ImageSegmentationModule } from '../../modules/computer_vision/ImageSegmentationModule';
|
|
3
|
+
export const useImageSegmentation = ({ model, preventLoad = false }) => useNonStaticModule({
|
|
4
|
+
module: ImageSegmentationModule,
|
|
5
|
+
model,
|
|
6
|
+
preventLoad,
|
|
7
|
+
});
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
import { ResourceSource } from '../../types/common';
|
|
2
|
+
import { OCRDetection, OCRLanguage } from '../../types/ocr';
|
|
3
|
+
interface OCRModule {
|
|
4
|
+
error: string | null;
|
|
5
|
+
isReady: boolean;
|
|
6
|
+
isGenerating: boolean;
|
|
7
|
+
forward: (input: string) => Promise<OCRDetection[]>;
|
|
8
|
+
downloadProgress: number;
|
|
9
|
+
}
|
|
10
|
+
export declare const useOCR: ({ model, preventLoad, }: {
|
|
11
|
+
model: {
|
|
12
|
+
detectorSource: ResourceSource;
|
|
13
|
+
recognizerLarge: ResourceSource;
|
|
14
|
+
recognizerMedium: ResourceSource;
|
|
15
|
+
recognizerSmall: ResourceSource;
|
|
16
|
+
language: OCRLanguage;
|
|
17
|
+
};
|
|
18
|
+
preventLoad?: boolean;
|
|
19
|
+
}) => OCRModule;
|
|
20
|
+
export {};
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
import { useEffect, useMemo, useState } from 'react';
|
|
2
|
+
import { OCRController } from '../../controllers/OCRController';
|
|
3
|
+
export const useOCR = ({ model, preventLoad = false, }) => {
|
|
4
|
+
const [error, setError] = useState(null);
|
|
5
|
+
const [isReady, setIsReady] = useState(false);
|
|
6
|
+
const [isGenerating, setIsGenerating] = useState(false);
|
|
7
|
+
const [downloadProgress, setDownloadProgress] = useState(0);
|
|
8
|
+
const controllerInstance = useMemo(() => new OCRController({
|
|
9
|
+
isReadyCallback: setIsReady,
|
|
10
|
+
isGeneratingCallback: setIsGenerating,
|
|
11
|
+
errorCallback: setError,
|
|
12
|
+
}), []);
|
|
13
|
+
useEffect(() => {
|
|
14
|
+
if (preventLoad)
|
|
15
|
+
return;
|
|
16
|
+
(async () => {
|
|
17
|
+
await controllerInstance.load(model.detectorSource, {
|
|
18
|
+
recognizerLarge: model.recognizerLarge,
|
|
19
|
+
recognizerMedium: model.recognizerMedium,
|
|
20
|
+
recognizerSmall: model.recognizerSmall,
|
|
21
|
+
}, model.language, setDownloadProgress);
|
|
22
|
+
})();
|
|
23
|
+
return () => {
|
|
24
|
+
controllerInstance.delete();
|
|
25
|
+
};
|
|
26
|
+
}, [
|
|
27
|
+
controllerInstance,
|
|
28
|
+
model.detectorSource,
|
|
29
|
+
model.recognizerLarge,
|
|
30
|
+
model.recognizerMedium,
|
|
31
|
+
model.recognizerSmall,
|
|
32
|
+
model.language,
|
|
33
|
+
preventLoad,
|
|
34
|
+
]);
|
|
35
|
+
return {
|
|
36
|
+
error,
|
|
37
|
+
isReady,
|
|
38
|
+
isGenerating,
|
|
39
|
+
forward: controllerInstance.forward,
|
|
40
|
+
downloadProgress,
|
|
41
|
+
};
|
|
42
|
+
};
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
import { ResourceSource } from '../../types/common';
|
|
2
|
+
interface Props {
|
|
3
|
+
model: {
|
|
4
|
+
modelSource: ResourceSource;
|
|
5
|
+
};
|
|
6
|
+
preventLoad?: boolean;
|
|
7
|
+
}
|
|
8
|
+
export declare const useObjectDetection: ({ model, preventLoad }: Props) => {
|
|
9
|
+
error: string | null;
|
|
10
|
+
isReady: boolean;
|
|
11
|
+
isGenerating: boolean;
|
|
12
|
+
downloadProgress: number;
|
|
13
|
+
forward: (imageSource: string, detectionThreshold?: number | undefined) => Promise<import("../..").Detection[]>;
|
|
14
|
+
};
|
|
15
|
+
export {};
|
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
import { useNonStaticModule } from '../useNonStaticModule';
|
|
2
|
+
import { ObjectDetectionModule } from '../../modules/computer_vision/ObjectDetectionModule';
|
|
3
|
+
export const useObjectDetection = ({ model, preventLoad = false }) => useNonStaticModule({
|
|
4
|
+
module: ObjectDetectionModule,
|
|
5
|
+
model,
|
|
6
|
+
preventLoad: preventLoad,
|
|
7
|
+
});
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
import { ResourceSource } from '../../types/common';
|
|
2
|
+
interface Props {
|
|
3
|
+
model: {
|
|
4
|
+
modelSource: ResourceSource;
|
|
5
|
+
};
|
|
6
|
+
preventLoad?: boolean;
|
|
7
|
+
}
|
|
8
|
+
export declare const useStyleTransfer: ({ model, preventLoad }: Props) => {
|
|
9
|
+
error: string | null;
|
|
10
|
+
isReady: boolean;
|
|
11
|
+
isGenerating: boolean;
|
|
12
|
+
downloadProgress: number;
|
|
13
|
+
forward: (imageSource: string) => Promise<string>;
|
|
14
|
+
};
|
|
15
|
+
export {};
|
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
import { useNonStaticModule } from '../useNonStaticModule';
|
|
2
|
+
import { StyleTransferModule } from '../../modules/computer_vision/StyleTransferModule';
|
|
3
|
+
export const useStyleTransfer = ({ model, preventLoad = false }) => useNonStaticModule({
|
|
4
|
+
module: StyleTransferModule,
|
|
5
|
+
model,
|
|
6
|
+
preventLoad: preventLoad,
|
|
7
|
+
});
|