react-native-executorch 0.3.3 → 0.4.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (343) hide show
  1. package/README.md +30 -13
  2. package/android/build.gradle +1 -1
  3. package/android/src/main/java/com/swmansion/rnexecutorch/ETModule.kt +1 -2
  4. package/android/src/main/java/com/swmansion/rnexecutorch/ImageSegmentation.kt +58 -0
  5. package/android/src/main/java/com/swmansion/rnexecutorch/LLM.kt +13 -49
  6. package/android/src/main/java/com/swmansion/rnexecutorch/RnExecutorchPackage.kt +37 -0
  7. package/android/src/main/java/com/swmansion/rnexecutorch/StyleTransfer.kt +1 -1
  8. package/android/src/main/java/com/swmansion/rnexecutorch/TextEmbeddings.kt +51 -0
  9. package/android/src/main/java/com/swmansion/rnexecutorch/Tokenizer.kt +86 -0
  10. package/android/src/main/java/com/swmansion/rnexecutorch/models/BaseModel.kt +3 -4
  11. package/android/src/main/java/com/swmansion/rnexecutorch/models/TextEmbeddings/TextEmbeddingsModel.kt +48 -0
  12. package/android/src/main/java/com/swmansion/rnexecutorch/models/TextEmbeddings/TextEmbeddingsUtils.kt +37 -0
  13. package/android/src/main/java/com/swmansion/rnexecutorch/models/classification/ClassificationModel.kt +1 -0
  14. package/android/src/main/java/com/swmansion/rnexecutorch/models/imageSegmentation/Constants.kt +26 -0
  15. package/android/src/main/java/com/swmansion/rnexecutorch/models/imageSegmentation/ImageSegmentationModel.kt +142 -0
  16. package/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/RecognizerUtils.kt +3 -0
  17. package/android/src/main/java/com/swmansion/rnexecutorch/models/{StyleTransferModel.kt → styleTransfer/StyleTransferModel.kt} +2 -1
  18. package/android/src/main/java/com/swmansion/rnexecutorch/utils/ArrayUtils.kt +0 -8
  19. package/android/src/main/java/com/swmansion/rnexecutorch/{models/classification/Utils.kt → utils/Numerical.kt} +1 -1
  20. package/ios/ExecutorchLib.xcframework/Info.plist +4 -4
  21. package/ios/ExecutorchLib.xcframework/ios-arm64/ExecutorchLib.framework/ExecutorchLib +0 -0
  22. package/ios/ExecutorchLib.xcframework/ios-arm64/ExecutorchLib.framework/Headers/HuggingFaceTokenizer.h +14 -0
  23. package/ios/ExecutorchLib.xcframework/ios-arm64/ExecutorchLib.framework/Headers/LLaMARunner.h +1 -23
  24. package/ios/ExecutorchLib.xcframework/ios-arm64/ExecutorchLib.framework/Info.plist +0 -0
  25. package/ios/ExecutorchLib.xcframework/ios-arm64-simulator/ExecutorchLib.framework/ExecutorchLib +0 -0
  26. package/ios/ExecutorchLib.xcframework/ios-arm64-simulator/ExecutorchLib.framework/Headers/HuggingFaceTokenizer.h +14 -0
  27. package/ios/ExecutorchLib.xcframework/ios-arm64-simulator/ExecutorchLib.framework/Headers/LLaMARunner.h +1 -23
  28. package/ios/ExecutorchLib.xcframework/ios-arm64-simulator/ExecutorchLib.framework/Info.plist +0 -0
  29. package/ios/RnExecutorch/Classification.mm +15 -18
  30. package/ios/RnExecutorch/ETModule.mm +6 -5
  31. package/ios/RnExecutorch/ImageSegmentation.h +5 -0
  32. package/ios/RnExecutorch/ImageSegmentation.mm +60 -0
  33. package/ios/RnExecutorch/LLM.mm +12 -53
  34. package/ios/RnExecutorch/OCR.mm +39 -43
  35. package/ios/RnExecutorch/ObjectDetection.mm +20 -20
  36. package/ios/RnExecutorch/SpeechToText.mm +6 -7
  37. package/ios/RnExecutorch/StyleTransfer.mm +16 -19
  38. package/ios/RnExecutorch/TextEmbeddings.h +5 -0
  39. package/ios/RnExecutorch/TextEmbeddings.mm +62 -0
  40. package/ios/RnExecutorch/Tokenizer.h +5 -0
  41. package/ios/RnExecutorch/Tokenizer.mm +83 -0
  42. package/ios/RnExecutorch/VerticalOCR.mm +36 -36
  43. package/ios/RnExecutorch/models/BaseModel.h +2 -5
  44. package/ios/RnExecutorch/models/BaseModel.mm +5 -15
  45. package/ios/RnExecutorch/models/classification/ClassificationModel.mm +2 -3
  46. package/ios/RnExecutorch/models/classification/Constants.mm +0 -1
  47. package/ios/RnExecutorch/models/image_segmentation/Constants.h +4 -0
  48. package/ios/RnExecutorch/models/image_segmentation/Constants.mm +8 -0
  49. package/ios/RnExecutorch/models/image_segmentation/ImageSegmentationModel.h +10 -0
  50. package/ios/RnExecutorch/models/image_segmentation/ImageSegmentationModel.mm +146 -0
  51. package/ios/RnExecutorch/models/object_detection/SSDLiteLargeModel.mm +1 -2
  52. package/ios/RnExecutorch/models/ocr/Detector.h +0 -2
  53. package/ios/RnExecutorch/models/ocr/Detector.mm +2 -1
  54. package/ios/RnExecutorch/models/ocr/RecognitionHandler.h +5 -4
  55. package/ios/RnExecutorch/models/ocr/RecognitionHandler.mm +9 -26
  56. package/ios/RnExecutorch/models/ocr/Recognizer.mm +1 -2
  57. package/ios/RnExecutorch/models/ocr/VerticalDetector.h +0 -2
  58. package/ios/RnExecutorch/models/ocr/VerticalDetector.mm +2 -1
  59. package/ios/RnExecutorch/models/ocr/utils/OCRUtils.mm +0 -1
  60. package/ios/RnExecutorch/models/stt/Moonshine.mm +1 -6
  61. package/ios/RnExecutorch/models/stt/SpeechToTextBaseModel.mm +7 -11
  62. package/ios/RnExecutorch/models/stt/Whisper.mm +0 -5
  63. package/ios/RnExecutorch/models/{StyleTransferModel.h → style_transfer/StyleTransferModel.h} +1 -1
  64. package/ios/RnExecutorch/models/{StyleTransferModel.mm → style_transfer/StyleTransferModel.mm} +2 -3
  65. package/ios/RnExecutorch/models/text_embeddings/TextEmbeddingsModel.h +15 -0
  66. package/ios/RnExecutorch/models/text_embeddings/TextEmbeddingsModel.mm +45 -0
  67. package/ios/RnExecutorch/models/text_embeddings/TextEmbeddingsUtils.h +8 -0
  68. package/ios/RnExecutorch/models/text_embeddings/TextEmbeddingsUtils.mm +49 -0
  69. package/ios/RnExecutorch/utils/Conversions.h +15 -0
  70. package/ios/RnExecutorch/utils/ImageProcessor.h +0 -1
  71. package/ios/RnExecutorch/{models/classification/Utils.h → utils/Numerical.h} +0 -2
  72. package/ios/RnExecutorch/{models/classification/Utils.mm → utils/Numerical.mm} +0 -2
  73. package/ios/RnExecutorch/utils/ObjectDetectionUtils.mm +0 -2
  74. package/ios/RnExecutorch/utils/SFFT.mm +1 -1
  75. package/ios/RnExecutorch/utils/ScalarType.h +0 -2
  76. package/lib/module/Error.js +16 -2
  77. package/lib/module/Error.js.map +1 -1
  78. package/lib/module/constants/{llamaDefaults.js → llmDefaults.js} +7 -3
  79. package/lib/module/constants/llmDefaults.js.map +1 -0
  80. package/lib/module/constants/modelUrls.js +88 -27
  81. package/lib/module/constants/modelUrls.js.map +1 -1
  82. package/lib/module/constants/ocr/models.js +290 -0
  83. package/lib/module/constants/ocr/models.js.map +1 -0
  84. package/lib/module/constants/ocr/symbols.js +137 -2
  85. package/lib/module/constants/ocr/symbols.js.map +1 -1
  86. package/lib/module/constants/sttDefaults.js +50 -25
  87. package/lib/module/constants/sttDefaults.js.map +1 -1
  88. package/lib/module/controllers/LLMController.js +205 -0
  89. package/lib/module/controllers/LLMController.js.map +1 -0
  90. package/lib/module/controllers/OCRController.js +5 -10
  91. package/lib/module/controllers/OCRController.js.map +1 -1
  92. package/lib/module/controllers/SpeechToTextController.js +225 -122
  93. package/lib/module/controllers/SpeechToTextController.js.map +1 -1
  94. package/lib/module/controllers/VerticalOCRController.js +6 -10
  95. package/lib/module/controllers/VerticalOCRController.js.map +1 -1
  96. package/lib/module/hooks/computer_vision/useClassification.js +8 -23
  97. package/lib/module/hooks/computer_vision/useClassification.js.map +1 -1
  98. package/lib/module/hooks/computer_vision/useImageSegmentation.js +13 -0
  99. package/lib/module/hooks/computer_vision/useImageSegmentation.js.map +1 -0
  100. package/lib/module/hooks/computer_vision/useOCR.js +11 -6
  101. package/lib/module/hooks/computer_vision/useOCR.js.map +1 -1
  102. package/lib/module/hooks/computer_vision/useObjectDetection.js +8 -23
  103. package/lib/module/hooks/computer_vision/useObjectDetection.js.map +1 -1
  104. package/lib/module/hooks/computer_vision/useStyleTransfer.js +8 -23
  105. package/lib/module/hooks/computer_vision/useStyleTransfer.js.map +1 -1
  106. package/lib/module/hooks/computer_vision/useVerticalOCR.js +10 -7
  107. package/lib/module/hooks/computer_vision/useVerticalOCR.js.map +1 -1
  108. package/lib/module/hooks/general/useExecutorchModule.js +8 -36
  109. package/lib/module/hooks/general/useExecutorchModule.js.map +1 -1
  110. package/lib/module/hooks/natural_language_processing/useLLM.js +54 -63
  111. package/lib/module/hooks/natural_language_processing/useLLM.js.map +1 -1
  112. package/lib/module/hooks/natural_language_processing/useSpeechToText.js +15 -11
  113. package/lib/module/hooks/natural_language_processing/useSpeechToText.js.map +1 -1
  114. package/lib/module/hooks/natural_language_processing/useTextEmbeddings.js +14 -0
  115. package/lib/module/hooks/natural_language_processing/useTextEmbeddings.js.map +1 -0
  116. package/lib/module/hooks/natural_language_processing/useTokenizer.js +54 -0
  117. package/lib/module/hooks/natural_language_processing/useTokenizer.js.map +1 -0
  118. package/lib/module/hooks/useModule.js +18 -62
  119. package/lib/module/hooks/useModule.js.map +1 -1
  120. package/lib/module/index.js +16 -2
  121. package/lib/module/index.js.map +1 -1
  122. package/lib/module/modules/BaseModule.js +9 -10
  123. package/lib/module/modules/BaseModule.js.map +1 -1
  124. package/lib/module/modules/computer_vision/ClassificationModule.js +8 -5
  125. package/lib/module/modules/computer_vision/ClassificationModule.js.map +1 -1
  126. package/lib/module/modules/computer_vision/ImageSegmentationModule.js +28 -0
  127. package/lib/module/modules/computer_vision/ImageSegmentationModule.js.map +1 -0
  128. package/lib/module/modules/computer_vision/ObjectDetectionModule.js +8 -5
  129. package/lib/module/modules/computer_vision/ObjectDetectionModule.js.map +1 -1
  130. package/lib/module/modules/computer_vision/StyleTransferModule.js +8 -5
  131. package/lib/module/modules/computer_vision/StyleTransferModule.js.map +1 -1
  132. package/lib/module/modules/general/ExecutorchModule.js +8 -5
  133. package/lib/module/modules/general/ExecutorchModule.js.map +1 -1
  134. package/lib/module/modules/natural_language_processing/LLMModule.js +46 -27
  135. package/lib/module/modules/natural_language_processing/LLMModule.js.map +1 -1
  136. package/lib/module/modules/natural_language_processing/SpeechToTextModule.js +8 -5
  137. package/lib/module/modules/natural_language_processing/SpeechToTextModule.js.map +1 -1
  138. package/lib/module/modules/natural_language_processing/TextEmbeddingsModule.js +14 -0
  139. package/lib/module/modules/natural_language_processing/TextEmbeddingsModule.js.map +1 -0
  140. package/lib/module/modules/natural_language_processing/TokenizerModule.js +26 -0
  141. package/lib/module/modules/natural_language_processing/TokenizerModule.js.map +1 -0
  142. package/lib/module/native/NativeClassification.js.map +1 -1
  143. package/lib/module/native/NativeImageSegmentation.js +5 -0
  144. package/lib/module/native/NativeImageSegmentation.js.map +1 -0
  145. package/lib/module/native/NativeLLM.js.map +1 -1
  146. package/lib/module/native/NativeTextEmbeddings.js +5 -0
  147. package/lib/module/native/NativeTextEmbeddings.js.map +1 -0
  148. package/lib/module/native/NativeTokenizer.js +5 -0
  149. package/lib/module/native/NativeTokenizer.js.map +1 -0
  150. package/lib/module/native/RnExecutorchModules.js +18 -113
  151. package/lib/module/native/RnExecutorchModules.js.map +1 -1
  152. package/lib/module/types/common.js.map +1 -1
  153. package/lib/module/types/imageSegmentation.js +29 -0
  154. package/lib/module/types/imageSegmentation.js.map +1 -0
  155. package/lib/module/types/llm.js +7 -0
  156. package/lib/module/types/llm.js.map +1 -0
  157. package/lib/module/types/{object_detection.js → objectDetection.js} +1 -1
  158. package/lib/module/types/objectDetection.js.map +1 -0
  159. package/lib/module/types/ocr.js +2 -0
  160. package/lib/module/types/stt.js +82 -0
  161. package/lib/module/types/stt.js.map +1 -0
  162. package/lib/module/utils/ResourceFetcher.js +156 -0
  163. package/lib/module/utils/ResourceFetcher.js.map +1 -0
  164. package/lib/module/utils/llm.js +25 -0
  165. package/lib/module/utils/llm.js.map +1 -0
  166. package/lib/module/utils/stt.js +22 -0
  167. package/lib/module/utils/stt.js.map +1 -0
  168. package/lib/typescript/Error.d.ts +4 -1
  169. package/lib/typescript/Error.d.ts.map +1 -1
  170. package/lib/typescript/constants/{llamaDefaults.d.ts → llmDefaults.d.ts} +5 -5
  171. package/lib/typescript/constants/llmDefaults.d.ts.map +1 -0
  172. package/lib/typescript/constants/modelUrls.d.ts +74 -28
  173. package/lib/typescript/constants/modelUrls.d.ts.map +1 -1
  174. package/lib/typescript/constants/ocr/models.d.ts +285 -0
  175. package/lib/typescript/constants/ocr/models.d.ts.map +1 -0
  176. package/lib/typescript/constants/ocr/symbols.d.ts +73 -1
  177. package/lib/typescript/constants/ocr/symbols.d.ts.map +1 -1
  178. package/lib/typescript/constants/sttDefaults.d.ts +8 -13
  179. package/lib/typescript/constants/sttDefaults.d.ts.map +1 -1
  180. package/lib/typescript/controllers/LLMController.d.ts +46 -0
  181. package/lib/typescript/controllers/LLMController.d.ts.map +1 -0
  182. package/lib/typescript/controllers/OCRController.d.ts.map +1 -1
  183. package/lib/typescript/controllers/SpeechToTextController.d.ts +30 -16
  184. package/lib/typescript/controllers/SpeechToTextController.d.ts.map +1 -1
  185. package/lib/typescript/controllers/VerticalOCRController.d.ts +1 -1
  186. package/lib/typescript/controllers/VerticalOCRController.d.ts.map +1 -1
  187. package/lib/typescript/hooks/computer_vision/useClassification.d.ts +5 -5
  188. package/lib/typescript/hooks/computer_vision/useClassification.d.ts.map +1 -1
  189. package/lib/typescript/hooks/computer_vision/useImageSegmentation.d.ts +37 -0
  190. package/lib/typescript/hooks/computer_vision/useImageSegmentation.d.ts.map +1 -0
  191. package/lib/typescript/hooks/computer_vision/useOCR.d.ts +2 -1
  192. package/lib/typescript/hooks/computer_vision/useOCR.d.ts.map +1 -1
  193. package/lib/typescript/hooks/computer_vision/useObjectDetection.d.ts +5 -4
  194. package/lib/typescript/hooks/computer_vision/useObjectDetection.d.ts.map +1 -1
  195. package/lib/typescript/hooks/computer_vision/useStyleTransfer.d.ts +4 -2
  196. package/lib/typescript/hooks/computer_vision/useStyleTransfer.d.ts.map +1 -1
  197. package/lib/typescript/hooks/computer_vision/useVerticalOCR.d.ts +2 -1
  198. package/lib/typescript/hooks/computer_vision/useVerticalOCR.d.ts.map +1 -1
  199. package/lib/typescript/hooks/general/useExecutorchModule.d.ts +5 -6
  200. package/lib/typescript/hooks/general/useExecutorchModule.d.ts.map +1 -1
  201. package/lib/typescript/hooks/natural_language_processing/useLLM.d.ts +6 -6
  202. package/lib/typescript/hooks/natural_language_processing/useLLM.d.ts.map +1 -1
  203. package/lib/typescript/hooks/natural_language_processing/useSpeechToText.d.ts +7 -3
  204. package/lib/typescript/hooks/natural_language_processing/useSpeechToText.d.ts.map +1 -1
  205. package/lib/typescript/hooks/natural_language_processing/useTextEmbeddings.d.ts +13 -0
  206. package/lib/typescript/hooks/natural_language_processing/useTextEmbeddings.d.ts.map +1 -0
  207. package/lib/typescript/hooks/natural_language_processing/useTokenizer.d.ts +16 -0
  208. package/lib/typescript/hooks/natural_language_processing/useTokenizer.d.ts.map +1 -0
  209. package/lib/typescript/hooks/useModule.d.ts +11 -10
  210. package/lib/typescript/hooks/useModule.d.ts.map +1 -1
  211. package/lib/typescript/index.d.ts +15 -2
  212. package/lib/typescript/index.d.ts.map +1 -1
  213. package/lib/typescript/modules/BaseModule.d.ts +4 -5
  214. package/lib/typescript/modules/BaseModule.d.ts.map +1 -1
  215. package/lib/typescript/modules/computer_vision/ClassificationModule.d.ts +7 -7
  216. package/lib/typescript/modules/computer_vision/ClassificationModule.d.ts.map +1 -1
  217. package/lib/typescript/modules/computer_vision/ImageSegmentationModule.d.ts +32 -0
  218. package/lib/typescript/modules/computer_vision/ImageSegmentationModule.d.ts.map +1 -0
  219. package/lib/typescript/modules/computer_vision/ObjectDetectionModule.d.ts +6 -5
  220. package/lib/typescript/modules/computer_vision/ObjectDetectionModule.d.ts.map +1 -1
  221. package/lib/typescript/modules/computer_vision/StyleTransferModule.d.ts +6 -5
  222. package/lib/typescript/modules/computer_vision/StyleTransferModule.d.ts.map +1 -1
  223. package/lib/typescript/modules/general/ExecutorchModule.d.ts +4 -3
  224. package/lib/typescript/modules/general/ExecutorchModule.d.ts.map +1 -1
  225. package/lib/typescript/modules/natural_language_processing/LLMModule.d.ts +19 -5
  226. package/lib/typescript/modules/natural_language_processing/LLMModule.d.ts.map +1 -1
  227. package/lib/typescript/modules/natural_language_processing/SpeechToTextModule.d.ts +7 -4
  228. package/lib/typescript/modules/natural_language_processing/SpeechToTextModule.d.ts.map +1 -1
  229. package/lib/typescript/modules/natural_language_processing/TextEmbeddingsModule.d.ts +8 -0
  230. package/lib/typescript/modules/natural_language_processing/TextEmbeddingsModule.d.ts.map +1 -0
  231. package/lib/typescript/modules/natural_language_processing/TokenizerModule.d.ts +12 -0
  232. package/lib/typescript/modules/natural_language_processing/TokenizerModule.d.ts.map +1 -0
  233. package/lib/typescript/native/NativeClassification.d.ts.map +1 -1
  234. package/lib/typescript/native/NativeImageSegmentation.d.ts +10 -0
  235. package/lib/typescript/native/NativeImageSegmentation.d.ts.map +1 -0
  236. package/lib/typescript/native/NativeLLM.d.ts +3 -4
  237. package/lib/typescript/native/NativeLLM.d.ts.map +1 -1
  238. package/lib/typescript/native/NativeObjectDetection.d.ts +1 -1
  239. package/lib/typescript/native/NativeObjectDetection.d.ts.map +1 -1
  240. package/lib/typescript/native/NativeSpeechToText.d.ts +2 -2
  241. package/lib/typescript/native/NativeSpeechToText.d.ts.map +1 -1
  242. package/lib/typescript/native/NativeTextEmbeddings.d.ts +8 -0
  243. package/lib/typescript/native/NativeTextEmbeddings.d.ts.map +1 -0
  244. package/lib/typescript/native/NativeTokenizer.d.ts +12 -0
  245. package/lib/typescript/native/NativeTokenizer.d.ts.map +1 -0
  246. package/lib/typescript/native/RnExecutorchModules.d.ts +18 -41
  247. package/lib/typescript/native/RnExecutorchModules.d.ts.map +1 -1
  248. package/lib/typescript/types/common.d.ts +1 -26
  249. package/lib/typescript/types/common.d.ts.map +1 -1
  250. package/lib/typescript/types/imageSegmentation.d.ts +25 -0
  251. package/lib/typescript/types/imageSegmentation.d.ts.map +1 -0
  252. package/lib/typescript/types/llm.d.ts +38 -0
  253. package/lib/typescript/types/llm.d.ts.map +1 -0
  254. package/lib/typescript/types/{object_detection.d.ts → objectDetection.d.ts} +1 -1
  255. package/lib/typescript/types/objectDetection.d.ts.map +1 -0
  256. package/lib/typescript/types/ocr.d.ts +2 -1
  257. package/lib/typescript/types/ocr.d.ts.map +1 -1
  258. package/lib/typescript/types/stt.d.ts +91 -0
  259. package/lib/typescript/types/stt.d.ts.map +1 -0
  260. package/lib/typescript/utils/ResourceFetcher.d.ts +17 -0
  261. package/lib/typescript/utils/ResourceFetcher.d.ts.map +1 -0
  262. package/lib/typescript/utils/llm.d.ts +3 -0
  263. package/lib/typescript/utils/llm.d.ts.map +1 -0
  264. package/lib/typescript/utils/stt.d.ts +2 -0
  265. package/lib/typescript/utils/stt.d.ts.map +1 -0
  266. package/package.json +13 -49
  267. package/react-native-executorch.podspec +1 -1
  268. package/src/Error.ts +16 -3
  269. package/src/constants/llmDefaults.ts +14 -0
  270. package/src/constants/modelUrls.ts +146 -39
  271. package/src/constants/ocr/models.ts +453 -0
  272. package/src/constants/ocr/symbols.ts +147 -3
  273. package/src/constants/sttDefaults.ts +55 -37
  274. package/src/controllers/LLMController.ts +286 -0
  275. package/src/controllers/OCRController.ts +14 -28
  276. package/src/controllers/SpeechToTextController.ts +318 -180
  277. package/src/controllers/VerticalOCRController.ts +17 -32
  278. package/src/hooks/computer_vision/useClassification.ts +11 -26
  279. package/src/hooks/computer_vision/useImageSegmentation.ts +18 -0
  280. package/src/hooks/computer_vision/useOCR.ts +17 -5
  281. package/src/hooks/computer_vision/useObjectDetection.ts +10 -24
  282. package/src/hooks/computer_vision/useStyleTransfer.ts +9 -25
  283. package/src/hooks/computer_vision/useVerticalOCR.ts +11 -4
  284. package/src/hooks/general/useExecutorchModule.ts +10 -50
  285. package/src/hooks/natural_language_processing/useLLM.ts +80 -97
  286. package/src/hooks/natural_language_processing/useSpeechToText.ts +39 -12
  287. package/src/hooks/natural_language_processing/useTextEmbeddings.ts +18 -0
  288. package/src/hooks/natural_language_processing/useTokenizer.ts +61 -0
  289. package/src/hooks/useModule.ts +32 -92
  290. package/src/index.tsx +16 -2
  291. package/src/modules/BaseModule.ts +16 -26
  292. package/src/modules/computer_vision/ClassificationModule.ts +13 -8
  293. package/src/modules/computer_vision/ImageSegmentationModule.ts +39 -0
  294. package/src/modules/computer_vision/ObjectDetectionModule.ts +13 -8
  295. package/src/modules/computer_vision/StyleTransferModule.ts +13 -8
  296. package/src/modules/general/ExecutorchModule.ts +11 -6
  297. package/src/modules/natural_language_processing/LLMModule.ts +64 -51
  298. package/src/modules/natural_language_processing/SpeechToTextModule.ts +25 -10
  299. package/src/modules/natural_language_processing/TextEmbeddingsModule.ts +18 -0
  300. package/src/modules/natural_language_processing/TokenizerModule.ts +34 -0
  301. package/src/native/NativeClassification.ts +0 -1
  302. package/src/native/NativeImageSegmentation.ts +14 -0
  303. package/src/native/NativeLLM.ts +3 -10
  304. package/src/native/NativeObjectDetection.ts +1 -1
  305. package/src/native/NativeSpeechToText.ts +2 -2
  306. package/src/native/NativeTextEmbeddings.ts +9 -0
  307. package/src/native/NativeTokenizer.ts +13 -0
  308. package/src/native/RnExecutorchModules.ts +54 -234
  309. package/src/types/common.ts +1 -44
  310. package/src/types/imageSegmentation.ts +25 -0
  311. package/src/types/llm.ts +57 -0
  312. package/src/types/ocr.ts +3 -1
  313. package/src/types/stt.ts +93 -0
  314. package/src/utils/ResourceFetcher.ts +196 -0
  315. package/src/utils/llm.ts +34 -0
  316. package/src/utils/stt.ts +28 -0
  317. package/android/src/main/java/com/swmansion/rnexecutorch/utils/llms/ConversationManager.kt +0 -68
  318. package/ios/ExecutorchLib.xcframework/ios-arm64-simulator/ExecutorchLib.framework/_CodeSignature/CodeResources +0 -124
  319. package/ios/RnExecutorch/utils/llms/Constants.h +0 -6
  320. package/ios/RnExecutorch/utils/llms/Constants.mm +0 -23
  321. package/ios/RnExecutorch/utils/llms/ConversationManager.h +0 -26
  322. package/ios/RnExecutorch/utils/llms/ConversationManager.mm +0 -71
  323. package/lib/module/constants/llamaDefaults.js.map +0 -1
  324. package/lib/module/modules/computer_vision/BaseCVModule.js +0 -14
  325. package/lib/module/modules/computer_vision/BaseCVModule.js.map +0 -1
  326. package/lib/module/types/object_detection.js.map +0 -1
  327. package/lib/module/utils/fetchResource.js +0 -93
  328. package/lib/module/utils/fetchResource.js.map +0 -1
  329. package/lib/module/utils/listDownloadedResources.js +0 -13
  330. package/lib/module/utils/listDownloadedResources.js.map +0 -1
  331. package/lib/typescript/constants/llamaDefaults.d.ts.map +0 -1
  332. package/lib/typescript/modules/computer_vision/BaseCVModule.d.ts +0 -9
  333. package/lib/typescript/modules/computer_vision/BaseCVModule.d.ts.map +0 -1
  334. package/lib/typescript/types/object_detection.d.ts.map +0 -1
  335. package/lib/typescript/utils/fetchResource.d.ts +0 -3
  336. package/lib/typescript/utils/fetchResource.d.ts.map +0 -1
  337. package/lib/typescript/utils/listDownloadedResources.d.ts +0 -3
  338. package/lib/typescript/utils/listDownloadedResources.d.ts.map +0 -1
  339. package/src/constants/llamaDefaults.ts +0 -9
  340. package/src/modules/computer_vision/BaseCVModule.ts +0 -22
  341. package/src/utils/fetchResource.ts +0 -106
  342. package/src/utils/listDownloadedResources.ts +0 -12
  343. /package/src/types/{object_detection.ts → objectDetection.ts} +0 -0
@@ -1,68 +1,49 @@
1
- import { _SpeechToTextModule } from '../native/RnExecutorchModules';
2
- import * as FileSystem from 'expo-file-system';
3
- import { fetchResource } from '../utils/fetchResource';
4
- import { ResourceSource } from '../types/common';
5
1
  import {
6
2
  HAMMING_DIST_THRESHOLD,
7
- SECOND,
8
3
  MODEL_CONFIGS,
9
- ModelConfig,
4
+ SECOND,
10
5
  MODES,
6
+ NUM_TOKENS_TO_TRIM,
7
+ STREAMING_ACTION,
11
8
  } from '../constants/sttDefaults';
12
-
13
- const longCommonInfPref = (seq1: number[], seq2: number[]) => {
14
- let maxInd = 0;
15
- let maxLength = 0;
16
-
17
- for (let i = 0; i < seq1.length; i++) {
18
- let j = 0;
19
- let hammingDist = 0;
20
- while (
21
- j < seq2.length &&
22
- i + j < seq1.length &&
23
- (seq1[i + j] === seq2[j] || hammingDist < HAMMING_DIST_THRESHOLD)
24
- ) {
25
- if (seq1[i + j] !== seq2[j]) {
26
- hammingDist++;
27
- }
28
- j++;
29
- }
30
- if (j >= maxLength) {
31
- maxLength = j;
32
- maxInd = i;
33
- }
34
- }
35
- return maxInd;
36
- };
9
+ import { AvailableModels, ModelConfig } from '../types/stt';
10
+ import { SpeechToTextNativeModule } from '../native/RnExecutorchModules';
11
+ import { TokenizerModule } from '../modules/natural_language_processing/TokenizerModule';
12
+ import { ResourceSource } from '../types/common';
13
+ import { ResourceFetcher } from '../utils/ResourceFetcher';
14
+ import { longCommonInfPref } from '../utils/stt';
15
+ import { SpeechToTextLanguage } from '../types/stt';
16
+ import { ETError, getError } from '../Error';
37
17
 
38
18
  export class SpeechToTextController {
39
- private nativeModule: _SpeechToTextModule;
19
+ private speechToTextNativeModule = SpeechToTextNativeModule;
40
20
 
41
- private overlapSeconds!: number;
42
- private windowSize!: number;
43
-
44
- private chunks: number[][] = [];
45
21
  public sequence: number[] = [];
46
22
  public isReady = false;
47
23
  public isGenerating = false;
48
- private modelName!: 'moonshine' | 'whisper';
49
24
 
50
- // tokenizer tokens to string mapping used for decoding sequence
51
- private tokenMapping!: { [key: number]: string };
25
+ private overlapSeconds!: number;
26
+ private windowSize!: number;
27
+ private chunks: number[][] = [];
28
+ private seqs: number[][] = [];
29
+ private prevSeq: number[] = [];
30
+ private waveform: number[] = [];
31
+ private numOfChunks = 0;
32
+ private streaming = false;
52
33
 
53
34
  // User callbacks
54
35
  private decodedTranscribeCallback: (sequence: number[]) => void;
55
- private modelDownloadProgessCallback:
36
+ private modelDownloadProgressCallback:
56
37
  | ((downloadProgress: number) => void)
57
38
  | undefined;
58
39
  private isReadyCallback: (isReady: boolean) => void;
59
40
  private isGeneratingCallback: (isGenerating: boolean) => void;
60
- private onErrorCallback: ((error: any) => void) | undefined;
41
+ private onErrorCallback: (error: any) => void;
61
42
  private config!: ModelConfig;
62
43
 
63
44
  constructor({
64
45
  transcribeCallback,
65
- modelDownloadProgessCallback,
46
+ modelDownloadProgressCallback,
66
47
  isReadyCallback,
67
48
  isGeneratingCallback,
68
49
  onErrorCallback,
@@ -71,7 +52,7 @@ export class SpeechToTextController {
71
52
  streamingConfig,
72
53
  }: {
73
54
  transcribeCallback: (sequence: string) => void;
74
- modelDownloadProgessCallback?: (downloadProgress: number) => void;
55
+ modelDownloadProgressCallback?: (downloadProgress: number) => void;
75
56
  isReadyCallback?: (isReady: boolean) => void;
76
57
  isGeneratingCallback?: (isGenerating: boolean) => void;
77
58
  onErrorCallback?: (error: Error | undefined) => void;
@@ -79,9 +60,9 @@ export class SpeechToTextController {
79
60
  windowSize?: number;
80
61
  streamingConfig?: keyof typeof MODES;
81
62
  }) {
82
- this.decodedTranscribeCallback = (seq) =>
83
- transcribeCallback(this.decodeSeq(seq));
84
- this.modelDownloadProgessCallback = modelDownloadProgessCallback;
63
+ this.decodedTranscribeCallback = async (seq) =>
64
+ transcribeCallback(await this.tokenIdsToText(seq));
65
+ this.modelDownloadProgressCallback = modelDownloadProgressCallback;
85
66
  this.isReadyCallback = (isReady) => {
86
67
  this.isReady = isReady;
87
68
  isReadyCallback?.(isReady);
@@ -90,8 +71,14 @@ export class SpeechToTextController {
90
71
  this.isGenerating = isGenerating;
91
72
  isGeneratingCallback?.(isGenerating);
92
73
  };
93
- this.onErrorCallback = onErrorCallback;
94
- this.nativeModule = new _SpeechToTextModule();
74
+ this.onErrorCallback = (error) => {
75
+ if (onErrorCallback) {
76
+ onErrorCallback(error ? new Error(getError(error)) : undefined);
77
+ return;
78
+ } else {
79
+ throw new Error(getError(error));
80
+ }
81
+ };
95
82
  this.configureStreaming(
96
83
  overlapSeconds,
97
84
  windowSize,
@@ -99,54 +86,49 @@ export class SpeechToTextController {
99
86
  );
100
87
  }
101
88
 
102
- private async fetchTokenizer(
103
- localUri?: ResourceSource
104
- ): Promise<{ [key: number]: string }> {
105
- let tokenzerUri = await fetchResource(
106
- localUri || this.config.tokenizer.source
107
- );
108
- return JSON.parse(await FileSystem.readAsStringAsync(tokenzerUri));
109
- }
110
-
111
89
  public async loadModel(
112
- modelName: 'moonshine' | 'whisper',
90
+ modelName: AvailableModels,
113
91
  encoderSource?: ResourceSource,
114
92
  decoderSource?: ResourceSource,
115
93
  tokenizerSource?: ResourceSource
116
94
  ) {
117
- this.onErrorCallback?.(undefined);
95
+ this.onErrorCallback(undefined);
118
96
  this.isReadyCallback(false);
119
97
  this.config = MODEL_CONFIGS[modelName];
120
- this.modelName = modelName;
121
98
 
122
99
  try {
123
- this.tokenMapping = await this.fetchTokenizer(tokenizerSource);
124
- encoderSource = await fetchResource(
125
- encoderSource || this.config.sources.encoder,
126
- (progress) => this.modelDownloadProgessCallback?.(progress / 2)
100
+ await TokenizerModule.load(
101
+ tokenizerSource || this.config.tokenizer.source
127
102
  );
128
103
 
129
- decoderSource = await fetchResource(
130
- decoderSource || this.config.sources.decoder,
131
- (progress) => this.modelDownloadProgessCallback?.(0.5 + progress / 2)
132
- );
104
+ [encoderSource, decoderSource] =
105
+ await ResourceFetcher.fetchMultipleResources(
106
+ this.modelDownloadProgressCallback,
107
+ encoderSource || this.config.sources.encoder,
108
+ decoderSource || this.config.sources.decoder
109
+ );
133
110
  } catch (e) {
134
- this.onErrorCallback?.(e);
111
+ this.onErrorCallback(e);
135
112
  return;
136
113
  }
137
114
 
115
+ if (modelName === 'whisperMultilingual') {
116
+ // The underlying native class is instantiated based on the name of the model. There is no need to
117
+ // create a separate class for multilingual version of Whisper, since it is the same. We just need
118
+ // the distinction here, in TS, for start tokens and such. If we introduce
119
+ // more versions of Whisper, such as the small one, this should be refactored.
120
+ modelName = 'whisper';
121
+ }
122
+
138
123
  try {
139
- await this.nativeModule.loadModule(modelName, [
124
+ await this.speechToTextNativeModule.loadModule(modelName, [
140
125
  encoderSource!,
141
126
  decoderSource!,
142
127
  ]);
143
- this.modelDownloadProgessCallback?.(1);
128
+ this.modelDownloadProgressCallback?.(1);
144
129
  this.isReadyCallback(true);
145
130
  } catch (e) {
146
- this.onErrorCallback?.(
147
- new Error(`Error when loading the SpeechToTextController! ${e}`)
148
- );
149
- console.error('Error when loading the SpeechToTextController!', e);
131
+ this.onErrorCallback(e);
150
132
  }
151
133
  }
152
134
 
@@ -174,142 +156,298 @@ export class SpeechToTextController {
174
156
  }
175
157
  }
176
158
 
177
- private chunkWaveform(waveform: number[]) {
178
- this.chunks = [];
179
- const numOfChunks = Math.ceil(waveform.length / this.windowSize);
180
- for (let i = 0; i < numOfChunks; i++) {
181
- let chunk = waveform.slice(
182
- Math.max(this.windowSize * i - this.overlapSeconds, 0),
183
- Math.min(
184
- this.windowSize * (i + 1) + this.overlapSeconds,
185
- waveform.length
186
- )
159
+ private chunkWaveform() {
160
+ this.numOfChunks = Math.ceil(this.waveform.length / this.windowSize);
161
+ for (let i = 0; i < this.numOfChunks; i++) {
162
+ let chunk: number[] = [];
163
+ const left = Math.max(this.windowSize * i - this.overlapSeconds, 0);
164
+ const right = Math.min(
165
+ this.windowSize * (i + 1) + this.overlapSeconds,
166
+ this.waveform.length
187
167
  );
188
-
168
+ chunk = this.waveform.slice(left, right);
189
169
  this.chunks.push(chunk);
190
170
  }
191
171
  }
192
172
 
193
- public async transcribe(waveform: number[]): Promise<string> {
194
- if (!this.isReady) {
195
- this.onErrorCallback?.(new Error('Model is not yet ready'));
196
- return '';
197
- }
198
- if (this.isGenerating) {
199
- this.onErrorCallback?.(new Error('Model is already transcribing'));
200
- return '';
201
- }
202
- this.onErrorCallback?.(undefined);
203
- this.isGeneratingCallback(true);
204
-
173
+ private resetState() {
205
174
  this.sequence = [];
175
+ this.seqs = [];
176
+ this.waveform = [];
177
+ this.prevSeq = [];
178
+ this.chunks = [];
179
+ this.decodedTranscribeCallback([]);
180
+ this.onErrorCallback(undefined);
181
+ }
206
182
 
207
- if (!waveform) {
208
- this.isGeneratingCallback(false);
183
+ private expectedChunkLength() {
184
+ //only first chunk can be of shorter length, for first chunk there are no seqs decoded
185
+ return this.seqs.length
186
+ ? this.windowSize + 2 * this.overlapSeconds
187
+ : this.windowSize + this.overlapSeconds;
188
+ }
209
189
 
210
- this.onErrorCallback?.(
211
- new Error(
212
- `Nothing to transcribe, perhaps you forgot to call this.loadAudio().`
213
- )
214
- );
190
+ private async getStartingTokenIds(audioLanguage?: string): Promise<number[]> {
191
+ // We need different starting token ids based on the multilingualism of the model.
192
+ // The eng version only needs BOS token, while the multilingual one needs:
193
+ // [BOS, LANG, TRANSCRIBE]. Optionally we should also set notimestamps token, as timestamps
194
+ // is not yet supported.
195
+ if (!audioLanguage) {
196
+ return [this.config.tokenizer.bos];
215
197
  }
198
+ // FIXME: I should use .getTokenId for the BOS as well, should remove it from config
199
+ const langTokenId = await TokenizerModule.tokenToId(`<|${audioLanguage}|>`);
200
+ const transcribeTokenId = await TokenizerModule.tokenToId('<|transcribe|>');
201
+ const noTimestampsTokenId =
202
+ await TokenizerModule.tokenToId('<|notimestamps|>');
203
+ const startingTokenIds = [
204
+ this.config.tokenizer.bos,
205
+ langTokenId,
206
+ transcribeTokenId,
207
+ noTimestampsTokenId,
208
+ ];
209
+ return startingTokenIds;
210
+ }
216
211
 
217
- this.chunkWaveform(waveform);
218
-
219
- let seqs: number[][] = [];
220
- let prevseq: number[] = [];
221
- for (let chunkId = 0; chunkId < this.chunks.length; chunkId++) {
222
- let lastToken = this.config.tokenizer.sos;
223
- let prevSeqTokenIdx = 0;
224
- let finalSeq: number[] = [];
225
- let seq = [lastToken];
212
+ private async decodeChunk(
213
+ chunk: number[],
214
+ audioLanguage?: SpeechToTextLanguage
215
+ ): Promise<number[]> {
216
+ const seq = await this.getStartingTokenIds(audioLanguage);
217
+ let prevSeqTokenIdx = 0;
218
+ this.prevSeq = this.sequence.slice();
219
+ try {
220
+ await this.encode(chunk);
221
+ } catch (error) {
222
+ this.onErrorCallback(new Error(getError(error) + ' encoding error'));
223
+ return [];
224
+ }
225
+ let lastToken = seq.at(-1) as number;
226
+ while (lastToken !== this.config.tokenizer.eos) {
226
227
  try {
227
- await this.nativeModule.encode(this.chunks!.at(chunkId)!);
228
+ lastToken = await this.decode(seq);
228
229
  } catch (error) {
229
- this.onErrorCallback?.(`Encode ${error}`);
230
- return '';
230
+ this.onErrorCallback(new Error(getError(error) + ' decoding error'));
231
+ return [...seq, this.config.tokenizer.eos];
231
232
  }
232
- while (lastToken !== this.config.tokenizer.eos) {
233
- try {
234
- lastToken = await this.nativeModule.decode(seq);
235
- } catch (error) {
236
- this.onErrorCallback?.(`Decode ${error}`);
237
- return '';
238
- }
239
- seq = [...seq, lastToken];
240
- if (
241
- seqs.length > 0 &&
242
- seq.length < seqs.at(-1)!.length &&
243
- seq.length % 3 !== 0
244
- ) {
245
- prevseq = [...prevseq, seqs.at(-1)![prevSeqTokenIdx++]!];
246
- this.decodedTranscribeCallback(prevseq);
247
- }
233
+ seq.push(lastToken);
234
+ if (
235
+ this.seqs.length > 0 &&
236
+ seq.length < this.seqs.at(-1)!.length &&
237
+ seq.length % 3 !== 0
238
+ ) {
239
+ this.prevSeq.push(this.seqs.at(-1)![prevSeqTokenIdx++]!);
240
+ this.decodedTranscribeCallback(this.prevSeq);
248
241
  }
242
+ }
243
+ return seq;
244
+ }
245
+
246
+ private async handleOverlaps(seqs: number[][]): Promise<number[]> {
247
+ const maxInd = longCommonInfPref(
248
+ seqs.at(-2)!,
249
+ seqs.at(-1)!,
250
+ HAMMING_DIST_THRESHOLD
251
+ );
252
+ this.sequence = [...this.sequence, ...seqs.at(-2)!.slice(0, maxInd)];
253
+ this.decodedTranscribeCallback(this.sequence);
254
+ return this.sequence.slice();
255
+ }
256
+
257
+ private trimLeft(numOfTokensToTrim: number) {
258
+ const idx = this.seqs.length - 1;
259
+ if (this.seqs[idx]![0] === this.config.tokenizer.bos) {
260
+ this.seqs[idx] = this.seqs[idx]!.slice(numOfTokensToTrim);
261
+ }
262
+ }
263
+
264
+ private trimRight(numOfTokensToTrim: number) {
265
+ const idx = this.seqs.length - 2;
266
+ if (this.seqs[idx]!.at(-1) === this.config.tokenizer.eos) {
267
+ this.seqs[idx] = this.seqs[idx]!.slice(0, -numOfTokensToTrim);
268
+ }
269
+ }
270
+
271
+ // since we are calling this every time (except first) after a new seq is pushed to this.seqs
272
+ // we can only trim left the last seq and trim right the second to last seq
273
+ private async trimSequences(audioLanguage?: string) {
274
+ const numSpecialTokens = (await this.getStartingTokenIds(audioLanguage))
275
+ .length;
276
+ this.trimLeft(numSpecialTokens + NUM_TOKENS_TO_TRIM);
277
+ this.trimRight(numSpecialTokens + NUM_TOKENS_TO_TRIM);
278
+ }
279
+
280
+ // if last chunk is too short combine it with second to last to improve quality
281
+ private validateAndFixLastChunk() {
282
+ const lastChunkLength = this.chunks.at(-1)!.length / SECOND;
283
+ const secondToLastChunkLength = this.chunks.at(-2)!.length / SECOND;
284
+ if (lastChunkLength < 5 && secondToLastChunkLength + lastChunkLength < 30) {
285
+ this.chunks[this.chunks.length - 2] = [
286
+ ...this.chunks.at(-2)!.slice(0, -this.overlapSeconds * 2),
287
+ ...this.chunks.at(-1)!,
288
+ ];
289
+ this.chunks = this.chunks.slice(0, -1);
290
+ }
291
+ }
292
+
293
+ private async tokenIdsToText(tokenIds: number[]): Promise<string> {
294
+ try {
295
+ return TokenizerModule.decode(tokenIds, true);
296
+ } catch (e) {
297
+ this.onErrorCallback(
298
+ new Error(`An error has occurred when decoding the token ids: ${e}`)
299
+ );
300
+ return '';
301
+ }
302
+ }
249
303
 
304
+ public async transcribe(
305
+ waveform: number[],
306
+ audioLanguage?: SpeechToTextLanguage
307
+ ): Promise<string> {
308
+ try {
309
+ if (!this.isReady) throw Error(getError(ETError.ModuleNotLoaded));
310
+ if (this.isGenerating || this.streaming)
311
+ throw Error(getError(ETError.ModelGenerating));
312
+ if (!!audioLanguage !== this.config.isMultilingual)
313
+ throw new Error(getError(ETError.MultilingualConfiguration));
314
+ } catch (e) {
315
+ this.onErrorCallback(e);
316
+ return '';
317
+ }
318
+
319
+ // Making sure that the error is not set when we get there
320
+ this.isGeneratingCallback(true);
321
+
322
+ this.resetState();
323
+ this.waveform = waveform;
324
+ this.chunkWaveform();
325
+ this.validateAndFixLastChunk();
326
+
327
+ for (let chunkId = 0; chunkId < this.chunks.length; chunkId++) {
328
+ const seq = await this.decodeChunk(this.chunks!.at(chunkId)!);
329
+ // whole audio is inside one chunk, no processing required
250
330
  if (this.chunks.length === 1) {
251
- finalSeq = seq;
252
- this.sequence = finalSeq;
253
- this.decodedTranscribeCallback(finalSeq);
331
+ this.sequence = seq;
332
+ this.decodedTranscribeCallback(seq);
254
333
  break;
255
334
  }
256
- // remove sos/eos token and 3 additional ones
257
- if (seqs.length === 0) {
258
- seqs = [seq.slice(0, -4)];
259
- } else if (seqs.length === this.chunks.length - 1) {
260
- seqs = [...seqs, seq.slice(4)];
261
- } else {
262
- seqs = [...seqs, seq.slice(4, -4)];
263
- }
264
- if (seqs.length < 2) {
265
- continue;
266
- }
335
+ this.seqs.push(seq);
336
+
337
+ if (this.seqs.length < 2) continue;
267
338
 
268
- const maxInd = longCommonInfPref(seqs.at(-2)!, seqs.at(-1)!);
269
- finalSeq = [...this.sequence, ...seqs.at(-2)!.slice(0, maxInd)];
270
- this.sequence = finalSeq;
271
- this.decodedTranscribeCallback(finalSeq);
272
- prevseq = finalSeq;
273
-
274
- //last sequence processed
275
- if (seqs.length === this.chunks.length) {
276
- finalSeq = [...this.sequence, ...seqs.at(-1)!];
277
- this.sequence = finalSeq;
278
- this.decodedTranscribeCallback(finalSeq);
279
- prevseq = finalSeq;
339
+ // Remove starting tokenIds and some additional ones
340
+ await this.trimSequences(audioLanguage);
341
+
342
+ this.prevSeq = await this.handleOverlaps(this.seqs);
343
+
344
+ // last sequence processed
345
+ // overlaps are already handled, so just append the last seq
346
+ if (this.seqs.length === this.chunks.length) {
347
+ this.sequence = [...this.sequence, ...this.seqs.at(-1)!];
348
+ this.decodedTranscribeCallback(this.sequence);
349
+ this.prevSeq = this.sequence;
280
350
  }
281
351
  }
282
- const decodedSeq = this.decodeSeq(this.sequence);
352
+ const decodedText = await this.tokenIdsToText(this.sequence);
283
353
  this.isGeneratingCallback(false);
284
- return decodedSeq;
354
+ return decodedText;
285
355
  }
286
356
 
287
- public decodeSeq(seq?: number[]): string {
288
- if (!this.modelName) {
289
- this.onErrorCallback?.(
290
- new Error('Model is not loaded, call `loadModel` first')
291
- );
357
+ public async streamingTranscribe(
358
+ streamAction: STREAMING_ACTION,
359
+ waveform?: number[],
360
+ audioLanguage?: SpeechToTextLanguage
361
+ ): Promise<string> {
362
+ try {
363
+ if (!this.isReady) throw Error(getError(ETError.ModuleNotLoaded));
364
+ if (!!audioLanguage !== this.config.isMultilingual)
365
+ throw new Error(getError(ETError.MultilingualConfiguration));
366
+
367
+ if (
368
+ streamAction === STREAMING_ACTION.START &&
369
+ !this.streaming &&
370
+ this.isGenerating
371
+ )
372
+ throw Error(getError(ETError.ModelGenerating));
373
+ if (streamAction === STREAMING_ACTION.START && this.streaming)
374
+ throw Error(getError(ETError.ModelGenerating));
375
+ if (streamAction === STREAMING_ACTION.DATA && !this.streaming)
376
+ throw Error(getError(ETError.StreamingNotStarted));
377
+ if (streamAction === STREAMING_ACTION.STOP && !this.streaming)
378
+ throw Error(getError(ETError.StreamingNotStarted));
379
+ if (streamAction === STREAMING_ACTION.DATA && !waveform)
380
+ throw new Error(getError(ETError.MissingDataChunk));
381
+ } catch (e) {
382
+ this.onErrorCallback(e);
292
383
  return '';
293
384
  }
294
- this.onErrorCallback?.(undefined);
295
- if (!seq) seq = this.sequence;
296
-
297
- return seq
298
- .filter(
299
- (token) =>
300
- token !== this.config.tokenizer.eos &&
301
- token !== this.config.tokenizer.sos
302
- )
303
- .map((token) => this.tokenMapping[token])
304
- .join('')
305
- .replaceAll(this.config.tokenizer.specialChar, ' ');
385
+
386
+ if (streamAction === STREAMING_ACTION.START) {
387
+ this.resetState();
388
+ this.streaming = true;
389
+ this.isGeneratingCallback(true);
390
+ }
391
+
392
+ this.waveform = [...this.waveform, ...(waveform || [])];
393
+
394
+ // while buffer has at least required size get chunk and decode
395
+ while (this.waveform.length >= this.expectedChunkLength()) {
396
+ const chunk = this.waveform.slice(
397
+ 0,
398
+ this.windowSize +
399
+ this.overlapSeconds * (1 + Number(this.seqs.length > 0))
400
+ );
401
+ this.chunks = [chunk]; //save last chunk for STREAMING_ACTION.STOP
402
+ this.waveform = this.waveform.slice(
403
+ this.windowSize - this.overlapSeconds * Number(this.seqs.length === 0)
404
+ );
405
+ const seq = await this.decodeChunk(chunk, audioLanguage);
406
+ this.seqs.push(seq);
407
+
408
+ if (this.seqs.length < 2) continue;
409
+
410
+ await this.trimSequences(audioLanguage);
411
+ await this.handleOverlaps(this.seqs);
412
+ }
413
+
414
+ // got final package, process all remaining waveform data
415
+ // since we run the loop above the waveform has at most one chunk in it
416
+ if (streamAction === STREAMING_ACTION.STOP) {
417
+ // pad remaining waveform data with previous chunk to this.windowSize + 2 * this.overlapSeconds
418
+ const chunk = this.chunks.length
419
+ ? [
420
+ ...this.chunks[0]!.slice(0, this.windowSize),
421
+ ...this.waveform,
422
+ ].slice(-this.windowSize - 2 * this.overlapSeconds)
423
+ : this.waveform;
424
+
425
+ this.waveform = [];
426
+ const seq = await this.decodeChunk(chunk, audioLanguage);
427
+ this.seqs.push(seq);
428
+
429
+ if (this.seqs.length === 1) {
430
+ this.sequence = this.seqs[0]!;
431
+ } else {
432
+ await this.trimSequences(audioLanguage);
433
+ await this.handleOverlaps(this.seqs);
434
+ this.sequence = [...this.sequence, ...this.seqs.at(-1)!];
435
+ }
436
+ this.decodedTranscribeCallback(this.sequence);
437
+ this.isGeneratingCallback(false);
438
+ this.streaming = false;
439
+ }
440
+
441
+ const decodedText = await this.tokenIdsToText(this.sequence);
442
+
443
+ return decodedText;
306
444
  }
307
445
 
308
446
  public async encode(waveform: number[]) {
309
- return await this.nativeModule.encode(waveform);
447
+ return await this.speechToTextNativeModule.encode(waveform);
310
448
  }
311
449
 
312
450
  public async decode(seq: number[], encodings?: number[]) {
313
- return await this.nativeModule.decode(seq, encodings);
451
+ return await this.speechToTextNativeModule.decode(seq, encodings || []);
314
452
  }
315
453
  }