react-native-executorch 0.5.8 → 0.5.10

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (103) hide show
  1. package/README.md +24 -8
  2. package/lib/module/hooks/general/useExecutorchModule.js +1 -3
  3. package/lib/module/hooks/general/useExecutorchModule.js.map +1 -1
  4. package/package.json +1 -1
  5. package/src/hooks/general/useExecutorchModule.ts +1 -1
  6. package/lib/Error.js +0 -53
  7. package/lib/ThreadPool.d.ts +0 -10
  8. package/lib/ThreadPool.js +0 -28
  9. package/lib/common/Logger.d.ts +0 -8
  10. package/lib/common/Logger.js +0 -19
  11. package/lib/constants/directories.js +0 -2
  12. package/lib/constants/llmDefaults.d.ts +0 -6
  13. package/lib/constants/llmDefaults.js +0 -16
  14. package/lib/constants/modelUrls.d.ts +0 -223
  15. package/lib/constants/modelUrls.js +0 -322
  16. package/lib/constants/ocr/models.d.ts +0 -882
  17. package/lib/constants/ocr/models.js +0 -182
  18. package/lib/constants/ocr/symbols.js +0 -139
  19. package/lib/constants/sttDefaults.d.ts +0 -28
  20. package/lib/constants/sttDefaults.js +0 -68
  21. package/lib/controllers/LLMController.d.ts +0 -47
  22. package/lib/controllers/LLMController.js +0 -213
  23. package/lib/controllers/OCRController.js +0 -67
  24. package/lib/controllers/SpeechToTextController.d.ts +0 -56
  25. package/lib/controllers/SpeechToTextController.js +0 -349
  26. package/lib/controllers/VerticalOCRController.js +0 -70
  27. package/lib/hooks/computer_vision/useClassification.d.ts +0 -15
  28. package/lib/hooks/computer_vision/useClassification.js +0 -7
  29. package/lib/hooks/computer_vision/useImageEmbeddings.d.ts +0 -15
  30. package/lib/hooks/computer_vision/useImageEmbeddings.js +0 -7
  31. package/lib/hooks/computer_vision/useImageSegmentation.d.ts +0 -38
  32. package/lib/hooks/computer_vision/useImageSegmentation.js +0 -7
  33. package/lib/hooks/computer_vision/useOCR.d.ts +0 -20
  34. package/lib/hooks/computer_vision/useOCR.js +0 -41
  35. package/lib/hooks/computer_vision/useObjectDetection.d.ts +0 -15
  36. package/lib/hooks/computer_vision/useObjectDetection.js +0 -7
  37. package/lib/hooks/computer_vision/useStyleTransfer.d.ts +0 -15
  38. package/lib/hooks/computer_vision/useStyleTransfer.js +0 -7
  39. package/lib/hooks/computer_vision/useVerticalOCR.d.ts +0 -21
  40. package/lib/hooks/computer_vision/useVerticalOCR.js +0 -43
  41. package/lib/hooks/general/useExecutorchModule.d.ts +0 -13
  42. package/lib/hooks/general/useExecutorchModule.js +0 -7
  43. package/lib/hooks/natural_language_processing/useLLM.d.ts +0 -10
  44. package/lib/hooks/natural_language_processing/useLLM.js +0 -78
  45. package/lib/hooks/natural_language_processing/useSpeechToText.d.ts +0 -27
  46. package/lib/hooks/natural_language_processing/useSpeechToText.js +0 -49
  47. package/lib/hooks/natural_language_processing/useTextEmbeddings.d.ts +0 -16
  48. package/lib/hooks/natural_language_processing/useTextEmbeddings.js +0 -7
  49. package/lib/hooks/natural_language_processing/useTokenizer.d.ts +0 -17
  50. package/lib/hooks/natural_language_processing/useTokenizer.js +0 -52
  51. package/lib/hooks/useModule.js +0 -45
  52. package/lib/hooks/useNonStaticModule.d.ts +0 -20
  53. package/lib/hooks/useNonStaticModule.js +0 -49
  54. package/lib/index.d.ts +0 -48
  55. package/lib/index.js +0 -58
  56. package/lib/modules/BaseModule.js +0 -25
  57. package/lib/modules/BaseNonStaticModule.js +0 -14
  58. package/lib/modules/computer_vision/ClassificationModule.d.ts +0 -8
  59. package/lib/modules/computer_vision/ClassificationModule.js +0 -17
  60. package/lib/modules/computer_vision/ImageEmbeddingsModule.d.ts +0 -8
  61. package/lib/modules/computer_vision/ImageEmbeddingsModule.js +0 -17
  62. package/lib/modules/computer_vision/ImageSegmentationModule.d.ts +0 -11
  63. package/lib/modules/computer_vision/ImageSegmentationModule.js +0 -27
  64. package/lib/modules/computer_vision/OCRModule.d.ts +0 -14
  65. package/lib/modules/computer_vision/OCRModule.js +0 -17
  66. package/lib/modules/computer_vision/ObjectDetectionModule.d.ts +0 -9
  67. package/lib/modules/computer_vision/ObjectDetectionModule.js +0 -17
  68. package/lib/modules/computer_vision/StyleTransferModule.d.ts +0 -8
  69. package/lib/modules/computer_vision/StyleTransferModule.js +0 -17
  70. package/lib/modules/computer_vision/VerticalOCRModule.d.ts +0 -14
  71. package/lib/modules/computer_vision/VerticalOCRModule.js +0 -19
  72. package/lib/modules/general/ExecutorchModule.d.ts +0 -7
  73. package/lib/modules/general/ExecutorchModule.js +0 -14
  74. package/lib/modules/natural_language_processing/LLMModule.d.ts +0 -28
  75. package/lib/modules/natural_language_processing/LLMModule.js +0 -45
  76. package/lib/modules/natural_language_processing/SpeechToTextModule.d.ts +0 -24
  77. package/lib/modules/natural_language_processing/SpeechToTextModule.js +0 -36
  78. package/lib/modules/natural_language_processing/TextEmbeddingsModule.d.ts +0 -9
  79. package/lib/modules/natural_language_processing/TextEmbeddingsModule.js +0 -21
  80. package/lib/modules/natural_language_processing/TokenizerModule.d.ts +0 -12
  81. package/lib/modules/natural_language_processing/TokenizerModule.js +0 -30
  82. package/lib/native/NativeETInstaller.js +0 -2
  83. package/lib/native/NativeOCR.js +0 -2
  84. package/lib/native/NativeVerticalOCR.js +0 -2
  85. package/lib/native/RnExecutorchModules.d.ts +0 -7
  86. package/lib/native/RnExecutorchModules.js +0 -18
  87. package/lib/tsconfig.tsbuildinfo +0 -1
  88. package/lib/types/common.d.ts +0 -32
  89. package/lib/types/common.js +0 -25
  90. package/lib/types/imageSegmentation.js +0 -26
  91. package/lib/types/llm.d.ts +0 -46
  92. package/lib/types/llm.js +0 -9
  93. package/lib/types/objectDetection.js +0 -94
  94. package/lib/types/ocr.js +0 -1
  95. package/lib/types/stt.d.ts +0 -94
  96. package/lib/types/stt.js +0 -85
  97. package/lib/utils/ResourceFetcher.d.ts +0 -24
  98. package/lib/utils/ResourceFetcher.js +0 -305
  99. package/lib/utils/ResourceFetcherUtils.d.ts +0 -54
  100. package/lib/utils/ResourceFetcherUtils.js +0 -127
  101. package/lib/utils/llm.d.ts +0 -6
  102. package/lib/utils/llm.js +0 -72
  103. package/lib/utils/stt.js +0 -21
@@ -1,213 +0,0 @@
1
- import { ResourceFetcher } from '../utils/ResourceFetcher';
2
- import { ETError, getError } from '../Error';
3
- import { Template } from '@huggingface/jinja';
4
- import { DEFAULT_CHAT_CONFIG } from '../constants/llmDefaults';
5
- import { readAsStringAsync } from 'expo-file-system';
6
- import { SPECIAL_TOKENS, } from '../types/llm';
7
- import { parseToolCall } from '../utils/llm';
8
- import { Logger } from '../common/Logger';
9
- export class LLMController {
10
- nativeModule;
11
- chatConfig = DEFAULT_CHAT_CONFIG;
12
- toolsConfig;
13
- tokenizerConfig;
14
- onToken;
15
- _response = '';
16
- _isReady = false;
17
- _isGenerating = false;
18
- _messageHistory = [];
19
- // User callbacks
20
- tokenCallback;
21
- responseCallback;
22
- messageHistoryCallback;
23
- isReadyCallback;
24
- isGeneratingCallback;
25
- constructor({ tokenCallback, responseCallback, messageHistoryCallback, isReadyCallback, isGeneratingCallback, }) {
26
- if (responseCallback !== undefined) {
27
- Logger.warn('Passing response callback is deprecated and will be removed in 0.6.0');
28
- }
29
- this.tokenCallback = (token) => {
30
- tokenCallback?.(token);
31
- };
32
- this.responseCallback = (response) => {
33
- this._response = response;
34
- responseCallback?.(response);
35
- };
36
- this.messageHistoryCallback = (messageHistory) => {
37
- this._messageHistory = messageHistory;
38
- messageHistoryCallback?.(messageHistory);
39
- };
40
- this.isReadyCallback = (isReady) => {
41
- this._isReady = isReady;
42
- isReadyCallback?.(isReady);
43
- };
44
- this.isGeneratingCallback = (isGenerating) => {
45
- this._isGenerating = isGenerating;
46
- isGeneratingCallback?.(isGenerating);
47
- };
48
- }
49
- get response() {
50
- return this._response;
51
- }
52
- get isReady() {
53
- return this._isReady;
54
- }
55
- get isGenerating() {
56
- return this._isGenerating;
57
- }
58
- get messageHistory() {
59
- return this._messageHistory;
60
- }
61
- async load({ modelSource, tokenizerSource, tokenizerConfigSource, onDownloadProgressCallback, }) {
62
- // reset inner state when loading new model
63
- this.responseCallback('');
64
- this.messageHistoryCallback(this.chatConfig.initialMessageHistory);
65
- this.isGeneratingCallback(false);
66
- this.isReadyCallback(false);
67
- try {
68
- const tokenizersPromise = ResourceFetcher.fetch(undefined, tokenizerSource, tokenizerConfigSource);
69
- const modelPromise = ResourceFetcher.fetch(onDownloadProgressCallback, modelSource);
70
- const [tokenizersResults, modelResult] = await Promise.all([
71
- tokenizersPromise,
72
- modelPromise,
73
- ]);
74
- const tokenizerPath = tokenizersResults?.[0];
75
- const tokenizerConfigPath = tokenizersResults?.[1];
76
- const modelPath = modelResult?.[0];
77
- if (!tokenizerPath || !tokenizerConfigPath || !modelPath) {
78
- throw new Error('Download interrupted!');
79
- }
80
- this.tokenizerConfig = JSON.parse(await readAsStringAsync('file://' + tokenizerConfigPath));
81
- this.nativeModule = global.loadLLM(modelPath, tokenizerPath);
82
- this.isReadyCallback(true);
83
- this.onToken = (data) => {
84
- if (!data ||
85
- (SPECIAL_TOKENS.EOS_TOKEN in this.tokenizerConfig &&
86
- data === this.tokenizerConfig.eos_token) ||
87
- (SPECIAL_TOKENS.PAD_TOKEN in this.tokenizerConfig &&
88
- data === this.tokenizerConfig.pad_token)) {
89
- return;
90
- }
91
- this.tokenCallback(data);
92
- this.responseCallback(this._response + data);
93
- };
94
- }
95
- catch (e) {
96
- this.isReadyCallback(false);
97
- throw new Error(getError(e));
98
- }
99
- }
100
- setTokenCallback(tokenCallback) {
101
- this.tokenCallback = tokenCallback;
102
- }
103
- configure({ chatConfig, toolsConfig, }) {
104
- this.chatConfig = { ...DEFAULT_CHAT_CONFIG, ...chatConfig };
105
- this.toolsConfig = toolsConfig;
106
- // reset inner state when loading new configuration
107
- this.responseCallback('');
108
- this.messageHistoryCallback(this.chatConfig.initialMessageHistory);
109
- this.isGeneratingCallback(false);
110
- }
111
- delete() {
112
- if (this._isGenerating) {
113
- throw new Error(getError(ETError.ModelGenerating) +
114
- 'You cannot delete the model now. You need to interrupt first.');
115
- }
116
- this.onToken = () => { };
117
- this.nativeModule.unload();
118
- this.isReadyCallback(false);
119
- this.isGeneratingCallback(false);
120
- }
121
- async forward(input) {
122
- if (!this._isReady) {
123
- throw new Error(getError(ETError.ModuleNotLoaded));
124
- }
125
- if (this._isGenerating) {
126
- throw new Error(getError(ETError.ModelGenerating));
127
- }
128
- try {
129
- this.responseCallback('');
130
- this.isGeneratingCallback(true);
131
- await this.nativeModule.generate(input, this.onToken);
132
- }
133
- catch (e) {
134
- throw new Error(getError(e));
135
- }
136
- finally {
137
- this.isGeneratingCallback(false);
138
- }
139
- }
140
- interrupt() {
141
- this.nativeModule.interrupt();
142
- }
143
- async generate(messages, tools) {
144
- if (!this._isReady) {
145
- throw new Error(getError(ETError.ModuleNotLoaded));
146
- }
147
- if (messages.length === 0) {
148
- throw new Error(`Empty 'messages' array!`);
149
- }
150
- if (messages[0] && messages[0].role !== 'system') {
151
- Logger.warn(`You are not providing system prompt. You can pass it in the first message using { role: 'system', content: YOUR_PROMPT }. Otherwise prompt from your model's chat template will be used.`);
152
- }
153
- const renderedChat = this.applyChatTemplate(messages, this.tokenizerConfig, tools,
154
- // eslint-disable-next-line camelcase
155
- { tools_in_user_message: false, add_generation_prompt: true });
156
- await this.forward(renderedChat);
157
- }
158
- async sendMessage(message) {
159
- this.messageHistoryCallback([
160
- ...this._messageHistory,
161
- { content: message, role: 'user' },
162
- ]);
163
- const messageHistoryWithPrompt = [
164
- { content: this.chatConfig.systemPrompt, role: 'system' },
165
- ...this._messageHistory.slice(-this.chatConfig.contextWindowLength),
166
- ];
167
- await this.generate(messageHistoryWithPrompt, this.toolsConfig?.tools);
168
- if (!this.toolsConfig || this.toolsConfig.displayToolCalls) {
169
- this.messageHistoryCallback([
170
- ...this._messageHistory,
171
- { content: this._response, role: 'assistant' },
172
- ]);
173
- }
174
- if (!this.toolsConfig) {
175
- return;
176
- }
177
- const toolCalls = parseToolCall(this._response);
178
- for (const toolCall of toolCalls) {
179
- this.toolsConfig
180
- .executeToolCallback(toolCall)
181
- .then((toolResponse) => {
182
- if (toolResponse) {
183
- this.messageHistoryCallback([
184
- ...this._messageHistory,
185
- { content: toolResponse, role: 'assistant' },
186
- ]);
187
- }
188
- });
189
- }
190
- }
191
- deleteMessage(index) {
192
- // we delete referenced message and all messages after it
193
- // so the model responses that used them are deleted as well
194
- const newMessageHistory = this._messageHistory.slice(0, index);
195
- this.messageHistoryCallback(newMessageHistory);
196
- }
197
- applyChatTemplate(messages, tokenizerConfig, tools, templateFlags) {
198
- if (!tokenizerConfig.chat_template) {
199
- throw Error("Tokenizer config doesn't include chat_template");
200
- }
201
- const template = new Template(tokenizerConfig.chat_template);
202
- const specialTokens = Object.fromEntries(Object.keys(SPECIAL_TOKENS)
203
- .filter((key) => key in tokenizerConfig)
204
- .map((key) => [key, tokenizerConfig[key]]));
205
- const result = template.render({
206
- messages,
207
- tools,
208
- ...templateFlags,
209
- ...specialTokens,
210
- });
211
- return result;
212
- }
213
- }
@@ -1,67 +0,0 @@
1
- import { symbols } from '../constants/ocr/symbols';
2
- import { ETError, getError } from '../Error';
3
- import { OCRNativeModule } from '../native/RnExecutorchModules';
4
- import { ResourceFetcher } from '../utils/ResourceFetcher';
5
- export class OCRController {
6
- nativeModule;
7
- isReady = false;
8
- isGenerating = false;
9
- error = null;
10
- modelDownloadProgressCallback;
11
- isReadyCallback;
12
- isGeneratingCallback;
13
- errorCallback;
14
- constructor({ modelDownloadProgressCallback = (_downloadProgress) => { }, isReadyCallback = (_isReady) => { }, isGeneratingCallback = (_isGenerating) => { }, errorCallback = (_error) => { }, }) {
15
- this.nativeModule = OCRNativeModule;
16
- this.modelDownloadProgressCallback = modelDownloadProgressCallback;
17
- this.isReadyCallback = isReadyCallback;
18
- this.isGeneratingCallback = isGeneratingCallback;
19
- this.errorCallback = errorCallback;
20
- }
21
- loadModel = async (detectorSource, recognizerSources, language) => {
22
- try {
23
- if (!detectorSource || Object.keys(recognizerSources).length !== 3)
24
- return;
25
- if (!symbols[language]) {
26
- throw new Error(getError(ETError.LanguageNotSupported));
27
- }
28
- this.isReady = false;
29
- this.isReadyCallback(false);
30
- const paths = await ResourceFetcher.fetch(this.modelDownloadProgressCallback, detectorSource, recognizerSources.recognizerLarge, recognizerSources.recognizerMedium, recognizerSources.recognizerSmall);
31
- if (paths === null || paths?.length < 4) {
32
- throw new Error('Download interrupted!');
33
- }
34
- await this.nativeModule.loadModule(paths[0], paths[1], paths[2], paths[3], symbols[language]);
35
- this.isReady = true;
36
- this.isReadyCallback(this.isReady);
37
- }
38
- catch (e) {
39
- if (this.errorCallback) {
40
- this.errorCallback(getError(e));
41
- }
42
- else {
43
- throw new Error(getError(e));
44
- }
45
- }
46
- };
47
- forward = async (input) => {
48
- if (!this.isReady) {
49
- throw new Error(getError(ETError.ModuleNotLoaded));
50
- }
51
- if (this.isGenerating) {
52
- throw new Error(getError(ETError.ModelGenerating));
53
- }
54
- try {
55
- this.isGenerating = true;
56
- this.isGeneratingCallback(this.isGenerating);
57
- return await this.nativeModule.forward(input);
58
- }
59
- catch (e) {
60
- throw new Error(getError(e));
61
- }
62
- finally {
63
- this.isGenerating = false;
64
- this.isGeneratingCallback(this.isGenerating);
65
- }
66
- };
67
- }
@@ -1,56 +0,0 @@
1
- import { MODES, STREAMING_ACTION } from '../constants/sttDefaults';
2
- import { AvailableModels } from '../types/stt';
3
- import { ResourceSource } from '../types/common';
4
- import { SpeechToTextLanguage } from '../types/stt';
5
- export declare class SpeechToTextController {
6
- private speechToTextNativeModule;
7
- sequence: number[];
8
- isReady: boolean;
9
- isGenerating: boolean;
10
- private tokenizerModule;
11
- private overlapSeconds;
12
- private windowSize;
13
- private chunks;
14
- private seqs;
15
- private prevSeq;
16
- private waveform;
17
- private numOfChunks;
18
- private streaming;
19
- private decodedTranscribeCallback;
20
- private isReadyCallback;
21
- private isGeneratingCallback;
22
- private onErrorCallback;
23
- private config;
24
- constructor({ transcribeCallback, isReadyCallback, isGeneratingCallback, onErrorCallback, overlapSeconds, windowSize, streamingConfig, }: {
25
- transcribeCallback: (sequence: string) => void;
26
- isReadyCallback?: (isReady: boolean) => void;
27
- isGeneratingCallback?: (isGenerating: boolean) => void;
28
- onErrorCallback?: (error: Error | undefined) => void;
29
- overlapSeconds?: number;
30
- windowSize?: number;
31
- streamingConfig?: keyof typeof MODES;
32
- });
33
- load({ modelName, encoderSource, decoderSource, tokenizerSource, onDownloadProgressCallback, }: {
34
- modelName: AvailableModels;
35
- encoderSource?: ResourceSource;
36
- decoderSource?: ResourceSource;
37
- tokenizerSource?: ResourceSource;
38
- onDownloadProgressCallback?: (downloadProgress: number) => void;
39
- }): Promise<void>;
40
- configureStreaming(overlapSeconds?: number, windowSize?: number, streamingConfig?: keyof typeof MODES): void;
41
- private chunkWaveform;
42
- private resetState;
43
- private expectedChunkLength;
44
- private getStartingTokenIds;
45
- private decodeChunk;
46
- private handleOverlaps;
47
- private trimLeft;
48
- private trimRight;
49
- private trimSequences;
50
- private validateAndFixLastChunk;
51
- private tokenIdsToText;
52
- transcribe(waveform: number[], audioLanguage?: SpeechToTextLanguage): Promise<string>;
53
- streamingTranscribe(streamAction: STREAMING_ACTION, waveform?: number[], audioLanguage?: SpeechToTextLanguage): Promise<string>;
54
- encode(waveform: Float32Array): Promise<null>;
55
- decode(seq: number[]): Promise<number>;
56
- }
@@ -1,349 +0,0 @@
1
- import { HAMMING_DIST_THRESHOLD, MODEL_CONFIGS, SECOND, MODES, NUM_TOKENS_TO_TRIM, STREAMING_ACTION, } from '../constants/sttDefaults';
2
- import { AvailableModels } from '../types/stt';
3
- import { TokenizerModule } from '../modules/natural_language_processing/TokenizerModule';
4
- import { ResourceFetcher } from '../utils/ResourceFetcher';
5
- import { longCommonInfPref } from '../utils/stt';
6
- import { ETError, getError } from '../Error';
7
- import { Logger } from '../common/Logger';
8
- export class SpeechToTextController {
9
- speechToTextNativeModule;
10
- sequence = [];
11
- isReady = false;
12
- isGenerating = false;
13
- tokenizerModule;
14
- overlapSeconds;
15
- windowSize;
16
- chunks = [];
17
- seqs = [];
18
- prevSeq = [];
19
- waveform = [];
20
- numOfChunks = 0;
21
- streaming = false;
22
- // User callbacks
23
- decodedTranscribeCallback;
24
- isReadyCallback;
25
- isGeneratingCallback;
26
- onErrorCallback;
27
- config;
28
- constructor({ transcribeCallback, isReadyCallback, isGeneratingCallback, onErrorCallback, overlapSeconds, windowSize, streamingConfig, }) {
29
- this.tokenizerModule = new TokenizerModule();
30
- this.decodedTranscribeCallback = async (seq) => transcribeCallback(await this.tokenIdsToText(seq));
31
- this.isReadyCallback = (isReady) => {
32
- this.isReady = isReady;
33
- isReadyCallback?.(isReady);
34
- };
35
- this.isGeneratingCallback = (isGenerating) => {
36
- this.isGenerating = isGenerating;
37
- isGeneratingCallback?.(isGenerating);
38
- };
39
- this.onErrorCallback = (error) => {
40
- if (onErrorCallback) {
41
- onErrorCallback(error ? new Error(getError(error)) : undefined);
42
- return;
43
- }
44
- else {
45
- throw new Error(getError(error));
46
- }
47
- };
48
- this.configureStreaming(overlapSeconds, windowSize, streamingConfig || 'balanced');
49
- }
50
- async load({ modelName, encoderSource, decoderSource, tokenizerSource, onDownloadProgressCallback, }) {
51
- this.onErrorCallback(undefined);
52
- this.isReadyCallback(false);
53
- this.config = MODEL_CONFIGS[modelName];
54
- try {
55
- const tokenizerLoadPromise = this.tokenizerModule.load({
56
- tokenizerSource: tokenizerSource || this.config.tokenizer.source,
57
- });
58
- const pathsPromise = ResourceFetcher.fetch(onDownloadProgressCallback, encoderSource || this.config.sources.encoder, decoderSource || this.config.sources.decoder);
59
- const [_, encoderDecoderResults] = await Promise.all([
60
- tokenizerLoadPromise,
61
- pathsPromise,
62
- ]);
63
- encoderSource = encoderDecoderResults?.[0];
64
- decoderSource = encoderDecoderResults?.[1];
65
- if (!encoderSource || !decoderSource) {
66
- throw new Error('Download interrupted.');
67
- }
68
- }
69
- catch (e) {
70
- this.onErrorCallback(e);
71
- return;
72
- }
73
- if (modelName === 'whisperMultilingual') {
74
- // The underlying native class is instantiated based on the name of the model. There is no need to
75
- // create a separate class for multilingual version of Whisper, since it is the same. We just need
76
- // the distinction here, in TS, for start tokens and such. If we introduce
77
- // more versions of Whisper, such as the small one, this should be refactored.
78
- modelName = AvailableModels.WHISPER;
79
- }
80
- try {
81
- const nativeSpeechToText = await global.loadSpeechToText(encoderSource, decoderSource, modelName);
82
- this.speechToTextNativeModule = nativeSpeechToText;
83
- this.isReadyCallback(true);
84
- }
85
- catch (e) {
86
- this.onErrorCallback(e);
87
- }
88
- }
89
- configureStreaming(overlapSeconds, windowSize, streamingConfig) {
90
- if (streamingConfig) {
91
- this.windowSize = MODES[streamingConfig].windowSize * SECOND;
92
- this.overlapSeconds = MODES[streamingConfig].overlapSeconds * SECOND;
93
- }
94
- if (streamingConfig && (windowSize || overlapSeconds)) {
95
- Logger.warn(`windowSize and overlapSeconds overrides values from streamingConfig ${streamingConfig}.`);
96
- }
97
- this.windowSize = (windowSize || 0) * SECOND || this.windowSize;
98
- this.overlapSeconds = (overlapSeconds || 0) * SECOND || this.overlapSeconds;
99
- if (2 * this.overlapSeconds + this.windowSize >= 30 * SECOND) {
100
- Logger.warn(`Invalid values for overlapSeconds and/or windowSize provided. Expected windowSize + 2 * overlapSeconds (== ${this.windowSize + 2 * this.overlapSeconds}) <= 30. Setting windowSize to ${30 * SECOND - 2 * this.overlapSeconds}.`);
101
- this.windowSize = 30 * SECOND - 2 * this.overlapSeconds;
102
- }
103
- }
104
- chunkWaveform() {
105
- this.numOfChunks = Math.ceil(this.waveform.length / this.windowSize);
106
- for (let i = 0; i < this.numOfChunks; i++) {
107
- let chunk = [];
108
- const left = Math.max(this.windowSize * i - this.overlapSeconds, 0);
109
- const right = Math.min(this.windowSize * (i + 1) + this.overlapSeconds, this.waveform.length);
110
- chunk = this.waveform.slice(left, right);
111
- this.chunks.push(chunk);
112
- }
113
- }
114
- resetState() {
115
- this.sequence = [];
116
- this.seqs = [];
117
- this.waveform = [];
118
- this.prevSeq = [];
119
- this.chunks = [];
120
- this.decodedTranscribeCallback([]);
121
- this.onErrorCallback(undefined);
122
- }
123
- expectedChunkLength() {
124
- //only first chunk can be of shorter length, for first chunk there are no seqs decoded
125
- return this.seqs.length
126
- ? this.windowSize + 2 * this.overlapSeconds
127
- : this.windowSize + this.overlapSeconds;
128
- }
129
- async getStartingTokenIds(audioLanguage) {
130
- // We need different starting token ids based on the multilingualism of the model.
131
- // The eng version only needs BOS token, while the multilingual one needs:
132
- // [BOS, LANG, TRANSCRIBE]. Optionally we should also set notimestamps token, as timestamps
133
- // is not yet supported.
134
- if (!audioLanguage) {
135
- return [this.config.tokenizer.bos];
136
- }
137
- // FIXME: I should use .getTokenId for the BOS as well, should remove it from config
138
- const langTokenId = await this.tokenizerModule.tokenToId(`<|${audioLanguage}|>`);
139
- const transcribeTokenId = await this.tokenizerModule.tokenToId('<|transcribe|>');
140
- const noTimestampsTokenId = await this.tokenizerModule.tokenToId('<|notimestamps|>');
141
- const startingTokenIds = [
142
- this.config.tokenizer.bos,
143
- langTokenId,
144
- transcribeTokenId,
145
- noTimestampsTokenId,
146
- ];
147
- return startingTokenIds;
148
- }
149
- async decodeChunk(chunk, audioLanguage) {
150
- const seq = await this.getStartingTokenIds(audioLanguage);
151
- let prevSeqTokenIdx = 0;
152
- this.prevSeq = this.sequence.slice();
153
- try {
154
- await this.encode(new Float32Array(chunk));
155
- }
156
- catch (error) {
157
- this.onErrorCallback(new Error(getError(error) + ' encoding error'));
158
- return [];
159
- }
160
- let lastToken = seq.at(-1);
161
- while (lastToken !== this.config.tokenizer.eos) {
162
- try {
163
- lastToken = await this.decode(seq);
164
- }
165
- catch (error) {
166
- this.onErrorCallback(new Error(getError(error) + ' decoding error'));
167
- return [...seq, this.config.tokenizer.eos];
168
- }
169
- seq.push(lastToken);
170
- if (this.seqs.length > 0 &&
171
- seq.length < this.seqs.at(-1).length &&
172
- seq.length % 3 !== 0) {
173
- this.prevSeq.push(this.seqs.at(-1)[prevSeqTokenIdx++]);
174
- this.decodedTranscribeCallback(this.prevSeq);
175
- }
176
- }
177
- return seq;
178
- }
179
- async handleOverlaps(seqs) {
180
- const maxInd = longCommonInfPref(seqs.at(-2), seqs.at(-1), HAMMING_DIST_THRESHOLD);
181
- this.sequence = [...this.sequence, ...seqs.at(-2).slice(0, maxInd)];
182
- this.decodedTranscribeCallback(this.sequence);
183
- return this.sequence.slice();
184
- }
185
- trimLeft(numOfTokensToTrim) {
186
- const idx = this.seqs.length - 1;
187
- if (this.seqs[idx][0] === this.config.tokenizer.bos) {
188
- this.seqs[idx] = this.seqs[idx].slice(numOfTokensToTrim);
189
- }
190
- }
191
- trimRight(numOfTokensToTrim) {
192
- const idx = this.seqs.length - 2;
193
- if (this.seqs[idx].at(-1) === this.config.tokenizer.eos) {
194
- this.seqs[idx] = this.seqs[idx].slice(0, -numOfTokensToTrim);
195
- }
196
- }
197
- // since we are calling this every time (except first) after a new seq is pushed to this.seqs
198
- // we can only trim left the last seq and trim right the second to last seq
199
- async trimSequences(audioLanguage) {
200
- const numSpecialTokens = (await this.getStartingTokenIds(audioLanguage))
201
- .length;
202
- this.trimLeft(numSpecialTokens + NUM_TOKENS_TO_TRIM);
203
- this.trimRight(numSpecialTokens + NUM_TOKENS_TO_TRIM);
204
- }
205
- // if last chunk is too short combine it with second to last to improve quality
206
- validateAndFixLastChunk() {
207
- if (this.chunks.length < 2)
208
- return;
209
- const lastChunkLength = this.chunks.at(-1).length / SECOND;
210
- const secondToLastChunkLength = this.chunks.at(-2).length / SECOND;
211
- if (lastChunkLength < 5 && secondToLastChunkLength + lastChunkLength < 30) {
212
- this.chunks[this.chunks.length - 2] = [
213
- ...this.chunks.at(-2).slice(0, -this.overlapSeconds * 2),
214
- ...this.chunks.at(-1),
215
- ];
216
- this.chunks = this.chunks.slice(0, -1);
217
- }
218
- }
219
- async tokenIdsToText(tokenIds) {
220
- try {
221
- return await this.tokenizerModule.decode(tokenIds, true);
222
- }
223
- catch (e) {
224
- this.onErrorCallback(new Error(`An error has occurred when decoding the token ids: ${e}`));
225
- return '';
226
- }
227
- }
228
- async transcribe(waveform, audioLanguage) {
229
- try {
230
- if (!this.isReady)
231
- throw Error(getError(ETError.ModuleNotLoaded));
232
- if (this.isGenerating || this.streaming)
233
- throw Error(getError(ETError.ModelGenerating));
234
- if (!!audioLanguage !== this.config.isMultilingual)
235
- throw new Error(getError(ETError.MultilingualConfiguration));
236
- }
237
- catch (e) {
238
- this.onErrorCallback(e);
239
- return '';
240
- }
241
- // Making sure that the error is not set when we get there
242
- this.isGeneratingCallback(true);
243
- this.resetState();
244
- this.waveform = waveform;
245
- this.chunkWaveform();
246
- this.validateAndFixLastChunk();
247
- for (let chunkId = 0; chunkId < this.chunks.length; chunkId++) {
248
- const seq = await this.decodeChunk(this.chunks.at(chunkId), audioLanguage);
249
- // whole audio is inside one chunk, no processing required
250
- if (this.chunks.length === 1) {
251
- this.sequence = seq;
252
- this.decodedTranscribeCallback(seq);
253
- break;
254
- }
255
- this.seqs.push(seq);
256
- if (this.seqs.length < 2)
257
- continue;
258
- // Remove starting tokenIds and some additional ones
259
- await this.trimSequences(audioLanguage);
260
- this.prevSeq = await this.handleOverlaps(this.seqs);
261
- // last sequence processed
262
- // overlaps are already handled, so just append the last seq
263
- if (this.seqs.length === this.chunks.length) {
264
- this.sequence = [...this.sequence, ...this.seqs.at(-1)];
265
- this.decodedTranscribeCallback(this.sequence);
266
- this.prevSeq = this.sequence;
267
- }
268
- }
269
- const decodedText = await this.tokenIdsToText(this.sequence);
270
- this.isGeneratingCallback(false);
271
- return decodedText;
272
- }
273
- async streamingTranscribe(streamAction, waveform, audioLanguage) {
274
- try {
275
- if (!this.isReady)
276
- throw Error(getError(ETError.ModuleNotLoaded));
277
- if (!!audioLanguage !== this.config.isMultilingual)
278
- throw new Error(getError(ETError.MultilingualConfiguration));
279
- if (streamAction === STREAMING_ACTION.START &&
280
- !this.streaming &&
281
- this.isGenerating)
282
- throw Error(getError(ETError.ModelGenerating));
283
- if (streamAction === STREAMING_ACTION.START && this.streaming)
284
- throw Error(getError(ETError.ModelGenerating));
285
- if (streamAction === STREAMING_ACTION.DATA && !this.streaming)
286
- throw Error(getError(ETError.StreamingNotStarted));
287
- if (streamAction === STREAMING_ACTION.STOP && !this.streaming)
288
- throw Error(getError(ETError.StreamingNotStarted));
289
- if (streamAction === STREAMING_ACTION.DATA && !waveform)
290
- throw new Error(getError(ETError.MissingDataChunk));
291
- }
292
- catch (e) {
293
- this.onErrorCallback(e);
294
- return '';
295
- }
296
- if (streamAction === STREAMING_ACTION.START) {
297
- this.resetState();
298
- this.streaming = true;
299
- this.isGeneratingCallback(true);
300
- }
301
- this.waveform = [...this.waveform, ...(waveform || [])];
302
- // while buffer has at least required size get chunk and decode
303
- while (this.waveform.length >= this.expectedChunkLength()) {
304
- const chunk = this.waveform.slice(0, this.windowSize +
305
- this.overlapSeconds * (1 + Number(this.seqs.length > 0)));
306
- this.chunks = [chunk]; //save last chunk for STREAMING_ACTION.STOP
307
- this.waveform = this.waveform.slice(this.windowSize - this.overlapSeconds * Number(this.seqs.length === 0));
308
- const seq = await this.decodeChunk(chunk, audioLanguage);
309
- this.seqs.push(seq);
310
- if (this.seqs.length < 2)
311
- continue;
312
- await this.trimSequences(audioLanguage);
313
- await this.handleOverlaps(this.seqs);
314
- }
315
- // got final package, process all remaining waveform data
316
- // since we run the loop above the waveform has at most one chunk in it
317
- if (streamAction === STREAMING_ACTION.STOP) {
318
- // pad remaining waveform data with previous chunk to this.windowSize + 2 * this.overlapSeconds
319
- const chunk = this.chunks.length
320
- ? [
321
- ...this.chunks[0].slice(0, this.windowSize),
322
- ...this.waveform,
323
- ].slice(-this.windowSize - 2 * this.overlapSeconds)
324
- : this.waveform;
325
- this.waveform = [];
326
- const seq = await this.decodeChunk(chunk, audioLanguage);
327
- this.seqs.push(seq);
328
- if (this.seqs.length === 1) {
329
- this.sequence = this.seqs[0];
330
- }
331
- else {
332
- await this.trimSequences(audioLanguage);
333
- await this.handleOverlaps(this.seqs);
334
- this.sequence = [...this.sequence, ...this.seqs.at(-1)];
335
- }
336
- this.decodedTranscribeCallback(this.sequence);
337
- this.isGeneratingCallback(false);
338
- this.streaming = false;
339
- }
340
- const decodedText = await this.tokenIdsToText(this.sequence);
341
- return decodedText;
342
- }
343
- async encode(waveform) {
344
- return await this.speechToTextNativeModule.encode(waveform);
345
- }
346
- async decode(seq) {
347
- return await this.speechToTextNativeModule.decode(seq);
348
- }
349
- }