react-native-executorch 0.5.3 → 0.5.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (158) hide show
  1. package/android/src/main/cpp/CMakeLists.txt +2 -1
  2. package/common/rnexecutorch/data_processing/Numerical.cpp +27 -19
  3. package/common/rnexecutorch/data_processing/Numerical.h +53 -4
  4. package/common/rnexecutorch/data_processing/dsp.cpp +1 -1
  5. package/common/rnexecutorch/data_processing/dsp.h +1 -1
  6. package/common/rnexecutorch/data_processing/gzip.cpp +47 -0
  7. package/common/rnexecutorch/data_processing/gzip.h +7 -0
  8. package/common/rnexecutorch/host_objects/ModelHostObject.h +24 -0
  9. package/common/rnexecutorch/metaprogramming/TypeConcepts.h +21 -1
  10. package/common/rnexecutorch/models/BaseModel.cpp +3 -2
  11. package/common/rnexecutorch/models/BaseModel.h +3 -2
  12. package/common/rnexecutorch/models/speech_to_text/SpeechToText.cpp +103 -39
  13. package/common/rnexecutorch/models/speech_to_text/SpeechToText.h +39 -21
  14. package/common/rnexecutorch/models/speech_to_text/asr/ASR.cpp +310 -0
  15. package/common/rnexecutorch/models/speech_to_text/asr/ASR.h +62 -0
  16. package/common/rnexecutorch/models/speech_to_text/stream/HypothesisBuffer.cpp +82 -0
  17. package/common/rnexecutorch/models/speech_to_text/stream/HypothesisBuffer.h +25 -0
  18. package/common/rnexecutorch/models/speech_to_text/stream/OnlineASRProcessor.cpp +99 -0
  19. package/common/rnexecutorch/models/speech_to_text/stream/OnlineASRProcessor.h +33 -0
  20. package/common/rnexecutorch/models/speech_to_text/types/DecodingOptions.h +15 -0
  21. package/common/rnexecutorch/models/speech_to_text/types/GenerationResult.h +12 -0
  22. package/common/rnexecutorch/models/speech_to_text/types/ProcessResult.h +12 -0
  23. package/common/rnexecutorch/models/speech_to_text/types/Segment.h +14 -0
  24. package/common/rnexecutorch/models/speech_to_text/types/Word.h +13 -0
  25. package/lib/module/modules/natural_language_processing/SpeechToTextModule.js +75 -53
  26. package/lib/module/modules/natural_language_processing/SpeechToTextModule.js.map +1 -1
  27. package/lib/typescript/hooks/natural_language_processing/useSpeechToText.d.ts +5 -5
  28. package/lib/typescript/modules/natural_language_processing/SpeechToTextModule.d.ts +7 -12
  29. package/lib/typescript/modules/natural_language_processing/SpeechToTextModule.d.ts.map +1 -1
  30. package/lib/typescript/types/stt.d.ts +0 -9
  31. package/lib/typescript/types/stt.d.ts.map +1 -1
  32. package/package.json +1 -1
  33. package/react-native-executorch.podspec +2 -0
  34. package/src/modules/natural_language_processing/SpeechToTextModule.ts +118 -54
  35. package/src/types/stt.ts +0 -12
  36. package/common/rnexecutorch/models/EncoderDecoderBase.cpp +0 -21
  37. package/common/rnexecutorch/models/EncoderDecoderBase.h +0 -31
  38. package/common/rnexecutorch/models/speech_to_text/SpeechToTextStrategy.h +0 -27
  39. package/common/rnexecutorch/models/speech_to_text/WhisperStrategy.cpp +0 -50
  40. package/common/rnexecutorch/models/speech_to_text/WhisperStrategy.h +0 -25
  41. package/lib/Error.js +0 -53
  42. package/lib/ThreadPool.d.ts +0 -10
  43. package/lib/ThreadPool.js +0 -28
  44. package/lib/common/Logger.d.ts +0 -8
  45. package/lib/common/Logger.js +0 -19
  46. package/lib/constants/directories.js +0 -2
  47. package/lib/constants/llmDefaults.d.ts +0 -6
  48. package/lib/constants/llmDefaults.js +0 -16
  49. package/lib/constants/modelUrls.d.ts +0 -223
  50. package/lib/constants/modelUrls.js +0 -322
  51. package/lib/constants/ocr/models.d.ts +0 -882
  52. package/lib/constants/ocr/models.js +0 -182
  53. package/lib/constants/ocr/symbols.js +0 -139
  54. package/lib/constants/sttDefaults.d.ts +0 -28
  55. package/lib/constants/sttDefaults.js +0 -68
  56. package/lib/controllers/LLMController.d.ts +0 -47
  57. package/lib/controllers/LLMController.js +0 -213
  58. package/lib/controllers/OCRController.js +0 -67
  59. package/lib/controllers/SpeechToTextController.d.ts +0 -56
  60. package/lib/controllers/SpeechToTextController.js +0 -349
  61. package/lib/controllers/VerticalOCRController.js +0 -70
  62. package/lib/hooks/computer_vision/useClassification.d.ts +0 -15
  63. package/lib/hooks/computer_vision/useClassification.js +0 -7
  64. package/lib/hooks/computer_vision/useImageEmbeddings.d.ts +0 -15
  65. package/lib/hooks/computer_vision/useImageEmbeddings.js +0 -7
  66. package/lib/hooks/computer_vision/useImageSegmentation.d.ts +0 -38
  67. package/lib/hooks/computer_vision/useImageSegmentation.js +0 -7
  68. package/lib/hooks/computer_vision/useOCR.d.ts +0 -20
  69. package/lib/hooks/computer_vision/useOCR.js +0 -41
  70. package/lib/hooks/computer_vision/useObjectDetection.d.ts +0 -15
  71. package/lib/hooks/computer_vision/useObjectDetection.js +0 -7
  72. package/lib/hooks/computer_vision/useStyleTransfer.d.ts +0 -15
  73. package/lib/hooks/computer_vision/useStyleTransfer.js +0 -7
  74. package/lib/hooks/computer_vision/useVerticalOCR.d.ts +0 -21
  75. package/lib/hooks/computer_vision/useVerticalOCR.js +0 -43
  76. package/lib/hooks/general/useExecutorchModule.d.ts +0 -13
  77. package/lib/hooks/general/useExecutorchModule.js +0 -7
  78. package/lib/hooks/natural_language_processing/useLLM.d.ts +0 -10
  79. package/lib/hooks/natural_language_processing/useLLM.js +0 -78
  80. package/lib/hooks/natural_language_processing/useSpeechToText.d.ts +0 -27
  81. package/lib/hooks/natural_language_processing/useSpeechToText.js +0 -49
  82. package/lib/hooks/natural_language_processing/useTextEmbeddings.d.ts +0 -16
  83. package/lib/hooks/natural_language_processing/useTextEmbeddings.js +0 -7
  84. package/lib/hooks/natural_language_processing/useTokenizer.d.ts +0 -17
  85. package/lib/hooks/natural_language_processing/useTokenizer.js +0 -52
  86. package/lib/hooks/useModule.js +0 -45
  87. package/lib/hooks/useNonStaticModule.d.ts +0 -20
  88. package/lib/hooks/useNonStaticModule.js +0 -49
  89. package/lib/index.d.ts +0 -48
  90. package/lib/index.js +0 -58
  91. package/lib/module/utils/SpeechToTextModule/ASR.js +0 -191
  92. package/lib/module/utils/SpeechToTextModule/ASR.js.map +0 -1
  93. package/lib/module/utils/SpeechToTextModule/OnlineProcessor.js +0 -73
  94. package/lib/module/utils/SpeechToTextModule/OnlineProcessor.js.map +0 -1
  95. package/lib/module/utils/SpeechToTextModule/hypothesisBuffer.js +0 -56
  96. package/lib/module/utils/SpeechToTextModule/hypothesisBuffer.js.map +0 -1
  97. package/lib/module/utils/stt.js +0 -22
  98. package/lib/module/utils/stt.js.map +0 -1
  99. package/lib/modules/BaseModule.js +0 -25
  100. package/lib/modules/BaseNonStaticModule.js +0 -14
  101. package/lib/modules/computer_vision/ClassificationModule.d.ts +0 -8
  102. package/lib/modules/computer_vision/ClassificationModule.js +0 -17
  103. package/lib/modules/computer_vision/ImageEmbeddingsModule.d.ts +0 -8
  104. package/lib/modules/computer_vision/ImageEmbeddingsModule.js +0 -17
  105. package/lib/modules/computer_vision/ImageSegmentationModule.d.ts +0 -11
  106. package/lib/modules/computer_vision/ImageSegmentationModule.js +0 -27
  107. package/lib/modules/computer_vision/OCRModule.d.ts +0 -14
  108. package/lib/modules/computer_vision/OCRModule.js +0 -17
  109. package/lib/modules/computer_vision/ObjectDetectionModule.d.ts +0 -9
  110. package/lib/modules/computer_vision/ObjectDetectionModule.js +0 -17
  111. package/lib/modules/computer_vision/StyleTransferModule.d.ts +0 -8
  112. package/lib/modules/computer_vision/StyleTransferModule.js +0 -17
  113. package/lib/modules/computer_vision/VerticalOCRModule.d.ts +0 -14
  114. package/lib/modules/computer_vision/VerticalOCRModule.js +0 -19
  115. package/lib/modules/general/ExecutorchModule.d.ts +0 -7
  116. package/lib/modules/general/ExecutorchModule.js +0 -14
  117. package/lib/modules/natural_language_processing/LLMModule.d.ts +0 -28
  118. package/lib/modules/natural_language_processing/LLMModule.js +0 -45
  119. package/lib/modules/natural_language_processing/SpeechToTextModule.d.ts +0 -24
  120. package/lib/modules/natural_language_processing/SpeechToTextModule.js +0 -36
  121. package/lib/modules/natural_language_processing/TextEmbeddingsModule.d.ts +0 -9
  122. package/lib/modules/natural_language_processing/TextEmbeddingsModule.js +0 -21
  123. package/lib/modules/natural_language_processing/TokenizerModule.d.ts +0 -12
  124. package/lib/modules/natural_language_processing/TokenizerModule.js +0 -30
  125. package/lib/native/NativeETInstaller.js +0 -2
  126. package/lib/native/NativeOCR.js +0 -2
  127. package/lib/native/NativeVerticalOCR.js +0 -2
  128. package/lib/native/RnExecutorchModules.d.ts +0 -7
  129. package/lib/native/RnExecutorchModules.js +0 -18
  130. package/lib/tsconfig.tsbuildinfo +0 -1
  131. package/lib/types/common.d.ts +0 -32
  132. package/lib/types/common.js +0 -25
  133. package/lib/types/imageSegmentation.js +0 -26
  134. package/lib/types/llm.d.ts +0 -46
  135. package/lib/types/llm.js +0 -9
  136. package/lib/types/objectDetection.js +0 -94
  137. package/lib/types/ocr.js +0 -1
  138. package/lib/types/stt.d.ts +0 -94
  139. package/lib/types/stt.js +0 -85
  140. package/lib/typescript/utils/SpeechToTextModule/ASR.d.ts +0 -27
  141. package/lib/typescript/utils/SpeechToTextModule/ASR.d.ts.map +0 -1
  142. package/lib/typescript/utils/SpeechToTextModule/OnlineProcessor.d.ts +0 -23
  143. package/lib/typescript/utils/SpeechToTextModule/OnlineProcessor.d.ts.map +0 -1
  144. package/lib/typescript/utils/SpeechToTextModule/hypothesisBuffer.d.ts +0 -13
  145. package/lib/typescript/utils/SpeechToTextModule/hypothesisBuffer.d.ts.map +0 -1
  146. package/lib/typescript/utils/stt.d.ts +0 -2
  147. package/lib/typescript/utils/stt.d.ts.map +0 -1
  148. package/lib/utils/ResourceFetcher.d.ts +0 -24
  149. package/lib/utils/ResourceFetcher.js +0 -305
  150. package/lib/utils/ResourceFetcherUtils.d.ts +0 -54
  151. package/lib/utils/ResourceFetcherUtils.js +0 -127
  152. package/lib/utils/llm.d.ts +0 -6
  153. package/lib/utils/llm.js +0 -72
  154. package/lib/utils/stt.js +0 -21
  155. package/src/utils/SpeechToTextModule/ASR.ts +0 -303
  156. package/src/utils/SpeechToTextModule/OnlineProcessor.ts +0 -87
  157. package/src/utils/SpeechToTextModule/hypothesisBuffer.ts +0 -79
  158. package/src/utils/stt.ts +0 -28
@@ -1,127 +0,0 @@
1
- /**
2
- * @internal
3
- */
4
- import { getInfoAsync, makeDirectoryAsync, } from 'expo-file-system';
5
- import { RNEDirectory } from '../constants/directories';
6
- import { Asset } from 'expo-asset';
7
- import { Logger } from '../common/Logger';
8
- export var ResourceFetcherUtils;
9
- (function (ResourceFetcherUtils) {
10
- function getType(source) {
11
- if (typeof source === 'object') {
12
- return 0 /* SourceType.OBJECT */;
13
- }
14
- else if (typeof source === 'number') {
15
- const uri = Asset.fromModule(source).uri;
16
- if (!uri.includes('://')) {
17
- return 2 /* SourceType.RELEASE_MODE_FILE */;
18
- }
19
- return 3 /* SourceType.DEV_MODE_FILE */;
20
- }
21
- else {
22
- // typeof source == 'string'
23
- if (source.startsWith('file://')) {
24
- return 1 /* SourceType.LOCAL_FILE */;
25
- }
26
- return 4 /* SourceType.REMOTE_FILE */;
27
- }
28
- }
29
- ResourceFetcherUtils.getType = getType;
30
- async function getFilesSizes(sources) {
31
- const results = [];
32
- let totalLength = 0;
33
- let previousFilesTotalLength = 0;
34
- for (const source of sources) {
35
- const type = await ResourceFetcherUtils.getType(source);
36
- let length = 0;
37
- if (type === 4 /* SourceType.REMOTE_FILE */ && typeof source === 'string') {
38
- try {
39
- const response = await fetch(source, { method: 'HEAD' });
40
- if (!response.ok) {
41
- Logger.warn(`Failed to fetch HEAD for ${source}: ${response.status}`);
42
- continue;
43
- }
44
- const contentLength = response.headers.get('content-length');
45
- if (!contentLength) {
46
- Logger.warn(`No content-length header for ${source}`);
47
- }
48
- length = contentLength ? parseInt(contentLength, 10) : 0;
49
- previousFilesTotalLength = totalLength;
50
- totalLength += length;
51
- }
52
- catch (error) {
53
- Logger.warn(`Error fetching HEAD for ${source}:`, error);
54
- continue;
55
- }
56
- }
57
- results.push({ source, type, length, previousFilesTotalLength });
58
- }
59
- return { results, totalLength };
60
- }
61
- ResourceFetcherUtils.getFilesSizes = getFilesSizes;
62
- function removeFilePrefix(uri) {
63
- return uri.startsWith('file://') ? uri.slice(7) : uri;
64
- }
65
- ResourceFetcherUtils.removeFilePrefix = removeFilePrefix;
66
- function hashObject(jsonString) {
67
- let hash = 0;
68
- for (let i = 0; i < jsonString.length; i++) {
69
- // eslint-disable-next-line no-bitwise
70
- hash = (hash << 5) - hash + jsonString.charCodeAt(i);
71
- // eslint-disable-next-line no-bitwise
72
- hash |= 0;
73
- }
74
- // eslint-disable-next-line no-bitwise
75
- return (hash >>> 0).toString();
76
- }
77
- ResourceFetcherUtils.hashObject = hashObject;
78
- function calculateDownloadProgress(totalLength, previousFilesTotalLength, currentFileLength, setProgress) {
79
- return (progress) => {
80
- if (progress === 1 &&
81
- previousFilesTotalLength === totalLength - currentFileLength) {
82
- setProgress(1);
83
- return;
84
- }
85
- // Avoid division by zero
86
- if (totalLength === 0) {
87
- setProgress(0);
88
- return;
89
- }
90
- const baseProgress = previousFilesTotalLength / totalLength;
91
- const scaledProgress = progress * (currentFileLength / totalLength);
92
- const updatedProgress = baseProgress + scaledProgress;
93
- setProgress(updatedProgress);
94
- };
95
- }
96
- ResourceFetcherUtils.calculateDownloadProgress = calculateDownloadProgress;
97
- /*
98
- * Increments the Hugging Face download counter if the URI points to a Software Mansion Hugging Face repo.
99
- * More information: https://huggingface.co/docs/hub/models-download-stats
100
- */
101
- async function triggerHuggingFaceDownloadCounter(uri) {
102
- const url = new URL(uri);
103
- if (url.host === 'huggingface.co' &&
104
- url.pathname.startsWith('/software-mansion/')) {
105
- const baseUrl = `${url.protocol}//${url.host}${url.pathname.split('resolve')[0]}`;
106
- fetch(`${baseUrl}resolve/main/config.json`, { method: 'HEAD' });
107
- }
108
- }
109
- ResourceFetcherUtils.triggerHuggingFaceDownloadCounter = triggerHuggingFaceDownloadCounter;
110
- async function createDirectoryIfNoExists() {
111
- if (!(await checkFileExists(RNEDirectory))) {
112
- await makeDirectoryAsync(RNEDirectory, { intermediates: true });
113
- }
114
- }
115
- ResourceFetcherUtils.createDirectoryIfNoExists = createDirectoryIfNoExists;
116
- async function checkFileExists(fileUri) {
117
- const fileInfo = await getInfoAsync(fileUri);
118
- return fileInfo.exists;
119
- }
120
- ResourceFetcherUtils.checkFileExists = checkFileExists;
121
- function getFilenameFromUri(uri) {
122
- let cleanUri = uri.replace(/^https?:\/\//, '');
123
- cleanUri = cleanUri.split('#')?.[0] ?? cleanUri;
124
- return cleanUri.replace(/[^a-zA-Z0-9._-]/g, '_');
125
- }
126
- ResourceFetcherUtils.getFilenameFromUri = getFilenameFromUri;
127
- })(ResourceFetcherUtils || (ResourceFetcherUtils = {}));
@@ -1,6 +0,0 @@
1
- import { ToolCall } from '../types/llm';
2
- import { Schema } from 'jsonschema';
3
- import * as zCore from 'zod/v4/core';
4
- export declare const parseToolCall: (message: string) => ToolCall[];
5
- export declare const getStructuredOutputPrompt: <T extends zCore.$ZodType>(responseSchema: T | Schema) => string;
6
- export declare const fixAndValidateStructuredOutput: <T extends zCore.$ZodType>(output: string, responseSchema: T | Schema) => zCore.output<T>;
package/lib/utils/llm.js DELETED
@@ -1,72 +0,0 @@
1
- import * as z from 'zod/v4';
2
- import { Validator } from 'jsonschema';
3
- import { jsonrepair } from 'jsonrepair';
4
- import { DEFAULT_STRUCTURED_OUTPUT_PROMPT } from '../constants/llmDefaults';
5
- import * as zCore from 'zod/v4/core';
6
- import { Logger } from '../common/Logger';
7
- export const parseToolCall = (message) => {
8
- try {
9
- const unparsedToolCalls = message.match('\\[(.|\\s)*\\]');
10
- if (!unparsedToolCalls) {
11
- throw Error('Regex did not match array.');
12
- }
13
- const parsedMessage = JSON.parse(unparsedToolCalls[0]);
14
- const results = [];
15
- for (const tool of parsedMessage) {
16
- if ('name' in tool &&
17
- typeof tool.name === 'string' &&
18
- 'arguments' in tool &&
19
- tool.arguments !== null &&
20
- typeof tool.arguments === 'object') {
21
- results.push({
22
- toolName: tool.name,
23
- arguments: tool.arguments,
24
- });
25
- }
26
- }
27
- return results;
28
- }
29
- catch (e) {
30
- Logger.error(e);
31
- return [];
32
- }
33
- };
34
- const filterObjectKeys = (obj, keysToRemove) => {
35
- const entries = Object.entries(obj);
36
- const filteredEntries = entries.filter(([key, _]) => !keysToRemove.includes(key));
37
- return Object.fromEntries(filteredEntries);
38
- };
39
- export const getStructuredOutputPrompt = (responseSchema) => {
40
- const schemaObject = responseSchema instanceof zCore.$ZodType
41
- ? filterObjectKeys(z.toJSONSchema(responseSchema), [
42
- '$schema',
43
- 'additionalProperties',
44
- ])
45
- : responseSchema;
46
- const schemaString = JSON.stringify(schemaObject);
47
- return DEFAULT_STRUCTURED_OUTPUT_PROMPT(schemaString);
48
- };
49
- const extractBetweenBrackets = (text) => {
50
- const startIndex = text.search(/[\\{\\[]/); // First occurrence of either { or [
51
- const openingBracket = text[startIndex];
52
- const closingBracket = openingBracket === '{' ? '}' : ']';
53
- if (!openingBracket)
54
- throw Error("Couldn't find JSON in text");
55
- return text.slice(text.indexOf(openingBracket), text.lastIndexOf(closingBracket) + 1);
56
- };
57
- // this is a bit hacky typing
58
- export const fixAndValidateStructuredOutput = (output, responseSchema) => {
59
- const extractedOutput = extractBetweenBrackets(output);
60
- const repairedOutput = jsonrepair(extractedOutput);
61
- const outputJSON = JSON.parse(repairedOutput);
62
- if (responseSchema instanceof zCore.$ZodType) {
63
- return z.parse(responseSchema, outputJSON);
64
- }
65
- else {
66
- const validator = new Validator();
67
- validator.validate(outputJSON, responseSchema, {
68
- throwAll: true,
69
- });
70
- return outputJSON;
71
- }
72
- };
package/lib/utils/stt.js DELETED
@@ -1,21 +0,0 @@
1
- export const longCommonInfPref = (seq1, seq2, hammingDistThreshold) => {
2
- let maxInd = 0;
3
- let maxLength = 0;
4
- for (let i = 0; i < seq1.length; i++) {
5
- let j = 0;
6
- let hammingDist = 0;
7
- while (j < seq2.length &&
8
- i + j < seq1.length &&
9
- (seq1[i + j] === seq2[j] || hammingDist < hammingDistThreshold)) {
10
- if (seq1[i + j] !== seq2[j]) {
11
- hammingDist++;
12
- }
13
- j++;
14
- }
15
- if (j >= maxLength) {
16
- maxLength = j;
17
- maxInd = i;
18
- }
19
- }
20
- return maxInd;
21
- };
@@ -1,303 +0,0 @@
1
- // NOTE: This will be implemented in C++
2
-
3
- import { TokenizerModule } from '../../modules/natural_language_processing/TokenizerModule';
4
- import {
5
- DecodingOptions,
6
- Segment,
7
- SpeechToTextModelConfig,
8
- WordObject,
9
- WordTuple,
10
- } from '../../types/stt';
11
- import { ResourceFetcher } from '../ResourceFetcher';
12
-
13
- export class ASR {
14
- private nativeModule: any;
15
- private tokenizerModule: TokenizerModule = new TokenizerModule();
16
-
17
- private timePrecision: number = 0.02; // Whisper timestamp precision
18
- private maxDecodeLength: number = 128;
19
- private chunkSize: number = 30; // 30 seconds
20
- private minChunkSamples: number = 1 * 16000; // 1 second
21
- private samplingRate: number = 16000;
22
-
23
- private startOfTranscriptToken!: number;
24
- private endOfTextToken!: number;
25
- private timestampBeginToken!: number;
26
-
27
- public async load(
28
- model: SpeechToTextModelConfig,
29
- onDownloadProgressCallback: (progress: number) => void
30
- ) {
31
- const tokenizerLoadPromise = this.tokenizerModule.load(model);
32
- const encoderDecoderPromise = ResourceFetcher.fetch(
33
- onDownloadProgressCallback,
34
- model.encoderSource,
35
- model.decoderSource
36
- );
37
- const [_, encoderDecoderResults] = await Promise.all([
38
- tokenizerLoadPromise,
39
- encoderDecoderPromise,
40
- ]);
41
- const encoderSource = encoderDecoderResults?.[0];
42
- const decoderSource = encoderDecoderResults?.[1];
43
- if (!encoderSource || !decoderSource) {
44
- throw new Error('Download interrupted.');
45
- }
46
- this.nativeModule = await global.loadSpeechToText(
47
- encoderSource,
48
- decoderSource,
49
- 'whisper'
50
- );
51
-
52
- this.startOfTranscriptToken = await this.tokenizerModule.tokenToId(
53
- '<|startoftranscript|>'
54
- );
55
- this.endOfTextToken = await this.tokenizerModule.tokenToId('<|endoftext|>');
56
- this.timestampBeginToken = await this.tokenizerModule.tokenToId('<|0.00|>');
57
- }
58
-
59
- private async getInitialSequence(
60
- options: DecodingOptions
61
- ): Promise<number[]> {
62
- const initialSequence: number[] = [this.startOfTranscriptToken];
63
- if (options.language) {
64
- const languageToken = await this.tokenizerModule.tokenToId(
65
- `<|${options.language}|>`
66
- );
67
- const taskToken = await this.tokenizerModule.tokenToId('<|transcribe|>');
68
- initialSequence.push(languageToken);
69
- initialSequence.push(taskToken);
70
- }
71
- initialSequence.push(this.timestampBeginToken);
72
- return initialSequence;
73
- }
74
-
75
- private async generate(
76
- audio: number[],
77
- temperature: number,
78
- options: DecodingOptions
79
- ): Promise<{
80
- sequencesIds: number[];
81
- scores: number[];
82
- }> {
83
- await this.encode(new Float32Array(audio));
84
- const initialSequence = await this.getInitialSequence(options);
85
- const sequencesIds = [...initialSequence];
86
- const scores: number[] = [];
87
-
88
- while (sequencesIds.length <= this.maxDecodeLength) {
89
- const logits = this.softmaxWithTemperature(
90
- Array.from(await this.decode(sequencesIds)),
91
- temperature === 0 ? 1 : temperature
92
- );
93
- const nextTokenId =
94
- temperature === 0
95
- ? logits.indexOf(Math.max(...logits))
96
- : this.sampleFromDistribution(logits);
97
- const nextTokenProb = logits[nextTokenId]!;
98
- sequencesIds.push(nextTokenId);
99
- scores.push(nextTokenProb);
100
- if (nextTokenId === this.endOfTextToken) {
101
- break;
102
- }
103
- }
104
-
105
- return {
106
- sequencesIds: sequencesIds.slice(initialSequence.length),
107
- scores: scores.slice(initialSequence.length),
108
- };
109
- }
110
-
111
- private softmaxWithTemperature(logits: number[], temperature = 1.0) {
112
- const max = Math.max(...logits);
113
- const exps = logits.map((logit) => Math.exp((logit - max) / temperature));
114
- const sum = exps.reduce((a, b) => a + b, 0);
115
- return exps.map((exp) => exp / sum);
116
- }
117
-
118
- private sampleFromDistribution(probs: number[]): number {
119
- const r = Math.random();
120
- let cumulative = 0;
121
- for (let i = 0; i < probs.length; i++) {
122
- cumulative += probs[i]!;
123
- if (r < cumulative) {
124
- return i;
125
- }
126
- }
127
- return probs.length - 1;
128
- }
129
-
130
- private async generateWithFallback(
131
- audio: number[],
132
- options: DecodingOptions
133
- ) {
134
- const temperatures = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0];
135
- let generatedTokens: number[] = [];
136
-
137
- for (const temperature of temperatures) {
138
- const result = await this.generate(audio, temperature, options);
139
- const tokens = result.sequencesIds;
140
- const scores = result.scores;
141
-
142
- const seqLen = tokens.length;
143
- const cumLogProb = scores.reduce(
144
- (acc, score) => acc + Math.log(score),
145
- 0
146
- );
147
- const avgLogProb = cumLogProb / seqLen;
148
-
149
- if (avgLogProb >= -1.0) {
150
- generatedTokens = tokens;
151
- break;
152
- }
153
- }
154
-
155
- return this.calculateWordLevelTimestamps(generatedTokens, audio);
156
- }
157
-
158
- private async calculateWordLevelTimestamps(
159
- generatedTokens: number[],
160
- audio: number[]
161
- ): Promise<Segment[]> {
162
- const segments: Segment[] = [];
163
-
164
- let tokens: number[] = [];
165
- let prevTimestamp = this.timestampBeginToken;
166
- for (let i = 0; i < generatedTokens.length; i++) {
167
- if (generatedTokens[i]! < this.timestampBeginToken) {
168
- tokens.push(generatedTokens[i]!);
169
- }
170
-
171
- if (
172
- i > 0 &&
173
- generatedTokens[i - 1]! >= this.timestampBeginToken &&
174
- generatedTokens[i]! >= this.timestampBeginToken
175
- ) {
176
- const start = prevTimestamp;
177
- const end = generatedTokens[i - 1]!;
178
- const wordObjects = await this.estimateWordTimestampsLinear(
179
- tokens,
180
- start,
181
- end
182
- );
183
- segments.push({
184
- words: wordObjects,
185
- });
186
- tokens = [];
187
- prevTimestamp = generatedTokens[i]!;
188
- }
189
- }
190
-
191
- const start = prevTimestamp;
192
- const end = generatedTokens.at(-2)!;
193
- const wordObjects = await this.estimateWordTimestampsLinear(
194
- tokens,
195
- start,
196
- end
197
- );
198
- segments.push({
199
- words: wordObjects,
200
- });
201
-
202
- const scalingFactor =
203
- audio.length /
204
- this.samplingRate /
205
- ((end - this.timestampBeginToken) * this.timePrecision);
206
- if (scalingFactor < 1) {
207
- for (const segment of segments) {
208
- for (const word of segment.words) {
209
- word.start *= scalingFactor;
210
- word.end *= scalingFactor;
211
- }
212
- }
213
- }
214
-
215
- return segments;
216
- }
217
-
218
- private async estimateWordTimestampsLinear(
219
- tokens: number[],
220
- timestampStart: number,
221
- timestampEnd: number
222
- ): Promise<WordObject[]> {
223
- const duration = (timestampEnd - timestampStart) * this.timePrecision;
224
- const segmentText = (
225
- (await this.tokenizerModule.decode(tokens)) as string
226
- ).trim();
227
-
228
- const words = segmentText.split(' ').map((w) => ` ${w}`);
229
- const numOfCharacters = words.reduce(
230
- (acc: number, word: string) => acc + word.length,
231
- 0
232
- );
233
-
234
- const timePerCharacter = duration / numOfCharacters;
235
-
236
- const wordObjects: WordObject[] = [];
237
- const startTimeOffset =
238
- (timestampStart - this.timestampBeginToken) * this.timePrecision;
239
-
240
- let prevCharNum = 0;
241
- for (let j = 0; j < words.length; j++) {
242
- const word = words[j]!;
243
- const start = startTimeOffset + prevCharNum * timePerCharacter;
244
- const end = start + timePerCharacter * word.length;
245
- wordObjects.push({ word, start, end });
246
- prevCharNum += word.length;
247
- }
248
-
249
- return wordObjects;
250
- }
251
-
252
- public async transcribe(
253
- audio: number[],
254
- options: DecodingOptions
255
- ): Promise<Segment[]> {
256
- let seek = 0;
257
- const allSegments: Segment[] = [];
258
-
259
- while (seek * this.samplingRate < audio.length) {
260
- const chunk = audio.slice(
261
- seek * this.samplingRate,
262
- (seek + this.chunkSize) * this.samplingRate
263
- );
264
- if (chunk.length < this.minChunkSamples) {
265
- return allSegments;
266
- }
267
- const segments = await this.generateWithFallback(chunk, options);
268
- for (const segment of segments) {
269
- for (const word of segment.words) {
270
- word.start += seek;
271
- word.end += seek;
272
- }
273
- }
274
- allSegments.push(...segments);
275
- const lastTimeStamp = segments.at(-1)!.words.at(-1)!.end;
276
- seek = lastTimeStamp;
277
- }
278
-
279
- return allSegments;
280
- }
281
-
282
- public tsWords(segments: Segment[]): WordTuple[] {
283
- const o: WordTuple[] = [];
284
- for (const segment of segments) {
285
- for (const word of segment.words) {
286
- o.push([word.start, word.end, word.word]);
287
- }
288
- }
289
- return o;
290
- }
291
-
292
- public segmentsEndTs(res: Segment[]) {
293
- return res.map((segment) => segment.words.at(-1)!.end);
294
- }
295
-
296
- public async encode(waveform: Float32Array): Promise<void> {
297
- await this.nativeModule.encode(waveform);
298
- }
299
-
300
- public async decode(tokens: number[]): Promise<Float32Array> {
301
- return new Float32Array(await this.nativeModule.decode(tokens));
302
- }
303
- }
@@ -1,87 +0,0 @@
1
- // NOTE: This will be implemented in C++
2
-
3
- import { WordTuple, DecodingOptions, Segment } from '../../types/stt';
4
- import { ASR } from './ASR';
5
- import { HypothesisBuffer } from './hypothesisBuffer';
6
-
7
- export class OnlineASRProcessor {
8
- private asr: ASR;
9
-
10
- private samplingRate: number = 16000;
11
- public audioBuffer: number[] = [];
12
- private transcriptBuffer: HypothesisBuffer = new HypothesisBuffer();
13
- private bufferTimeOffset: number = 0;
14
- private committed: WordTuple[] = [];
15
-
16
- constructor(asr: ASR) {
17
- this.asr = asr;
18
- }
19
-
20
- public insertAudioChunk(audio: number[]) {
21
- this.audioBuffer.push(...audio);
22
- }
23
-
24
- public async processIter(options: DecodingOptions) {
25
- const res = await this.asr.transcribe(this.audioBuffer, options);
26
- const tsw = this.asr.tsWords(res);
27
- this.transcriptBuffer.insert(tsw, this.bufferTimeOffset);
28
- const o = this.transcriptBuffer.flush();
29
- this.committed.push(...o);
30
-
31
- const s = 15;
32
- if (this.audioBuffer.length / this.samplingRate > s) {
33
- this.chunkCompletedSegment(res);
34
- }
35
-
36
- const committed = this.toFlush(o)[2];
37
- const nonCommitted = this.transcriptBuffer
38
- .complete()
39
- .map((x) => x[2])
40
- .join('');
41
- return { committed, nonCommitted };
42
- }
43
-
44
- private chunkCompletedSegment(res: Segment[]) {
45
- if (this.committed.length === 0) {
46
- return;
47
- }
48
-
49
- const ends = this.asr.segmentsEndTs(res);
50
- const t = this.committed.at(-1)![1];
51
-
52
- if (ends.length > 1) {
53
- let e = ends.at(-2)! + this.bufferTimeOffset;
54
- while (ends.length > 2 && e > t) {
55
- ends.pop();
56
- e = ends.at(-2)! + this.bufferTimeOffset;
57
- }
58
-
59
- if (e <= t) {
60
- this.chunkAt(e);
61
- }
62
- }
63
- }
64
-
65
- private chunkAt(time: number) {
66
- this.transcriptBuffer.popCommitted(time);
67
- const cutSeconds = time - this.bufferTimeOffset;
68
- this.audioBuffer = this.audioBuffer.slice(
69
- Math.floor(cutSeconds * this.samplingRate)
70
- );
71
- this.bufferTimeOffset = time;
72
- }
73
-
74
- public async finish() {
75
- const o = this.transcriptBuffer.complete();
76
- const f = this.toFlush(o);
77
- this.bufferTimeOffset += this.audioBuffer.length / this.samplingRate;
78
- return { committed: f[2] };
79
- }
80
-
81
- private toFlush(words: WordTuple[]): [number | null, number | null, string] {
82
- const t = words.map((s) => s[2]).join(' ');
83
- const b = words.length === 0 ? null : words[0]![0];
84
- const e = words.length === 0 ? null : words.at(-1)![1];
85
- return [b, e, t];
86
- }
87
- }
@@ -1,79 +0,0 @@
1
- // NOTE: This will be implemented in C++
2
-
3
- import { WordTuple } from '../../types/stt';
4
-
5
- export class HypothesisBuffer {
6
- private committedInBuffer: WordTuple[] = [];
7
- private buffer: WordTuple[] = [];
8
- private new: WordTuple[] = [];
9
-
10
- private lastCommittedTime: number = 0;
11
- public lastCommittedWord: string | null = null;
12
-
13
- public insert(newWords: WordTuple[], offset: number) {
14
- const newWordsOffset: WordTuple[] = newWords.map(([a, b, t]) => [
15
- a + offset,
16
- b + offset,
17
- t,
18
- ]);
19
- this.new = newWordsOffset.filter(
20
- ([a, _b, _t]) => a > this.lastCommittedTime - 0.5
21
- );
22
-
23
- if (this.new.length > 0) {
24
- const [a, _b, _t] = this.new[0]!;
25
- if (
26
- Math.abs(a - this.lastCommittedTime) < 1 &&
27
- this.committedInBuffer.length > 0
28
- ) {
29
- const cn = this.committedInBuffer.length;
30
- const nn = this.new.length;
31
-
32
- for (let i = 1; i <= Math.min(cn, nn, 5); i++) {
33
- const c = this.committedInBuffer
34
- .slice(-i)
35
- .map((w) => w[2])
36
- .join(' ');
37
- const tail = this.new
38
- .slice(0, i)
39
- .map((w) => w[2])
40
- .join(' ');
41
- if (c === tail) {
42
- for (let j = 0; j < i; j++) {
43
- this.new.shift();
44
- }
45
- break;
46
- }
47
- }
48
- }
49
- }
50
- }
51
-
52
- public flush(): WordTuple[] {
53
- const commit: WordTuple[] = [];
54
- while (this.new.length > 0 && this.buffer.length > 0) {
55
- if (this.new[0]![2] !== this.buffer[0]![2]) {
56
- break;
57
- }
58
- commit.push(this.new[0]!);
59
- this.lastCommittedWord = this.new[0]![2];
60
- this.lastCommittedTime = this.new[0]![1];
61
- this.buffer.shift();
62
- this.new.shift();
63
- }
64
- this.buffer = this.new;
65
- this.new = [];
66
- this.committedInBuffer.push(...commit);
67
- return commit;
68
- }
69
-
70
- public popCommitted(time: number) {
71
- this.committedInBuffer = this.committedInBuffer.filter(
72
- ([_a, b, _t]) => b > time
73
- );
74
- }
75
-
76
- public complete(): WordTuple[] {
77
- return this.buffer;
78
- }
79
- }