react-native-executorch 0.3.2 → 0.4.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (346) hide show
  1. package/README.md +30 -13
  2. package/android/build.gradle +1 -1
  3. package/android/libs/executorch.aar +0 -0
  4. package/android/src/main/java/com/swmansion/rnexecutorch/ETModule.kt +1 -2
  5. package/android/src/main/java/com/swmansion/rnexecutorch/ImageSegmentation.kt +58 -0
  6. package/android/src/main/java/com/swmansion/rnexecutorch/LLM.kt +13 -49
  7. package/android/src/main/java/com/swmansion/rnexecutorch/RnExecutorchPackage.kt +37 -0
  8. package/android/src/main/java/com/swmansion/rnexecutorch/StyleTransfer.kt +1 -1
  9. package/android/src/main/java/com/swmansion/rnexecutorch/TextEmbeddings.kt +51 -0
  10. package/android/src/main/java/com/swmansion/rnexecutorch/Tokenizer.kt +86 -0
  11. package/android/src/main/java/com/swmansion/rnexecutorch/models/BaseModel.kt +3 -4
  12. package/android/src/main/java/com/swmansion/rnexecutorch/models/TextEmbeddings/TextEmbeddingsModel.kt +48 -0
  13. package/android/src/main/java/com/swmansion/rnexecutorch/models/TextEmbeddings/TextEmbeddingsUtils.kt +37 -0
  14. package/android/src/main/java/com/swmansion/rnexecutorch/models/classification/ClassificationModel.kt +1 -0
  15. package/android/src/main/java/com/swmansion/rnexecutorch/models/imageSegmentation/Constants.kt +26 -0
  16. package/android/src/main/java/com/swmansion/rnexecutorch/models/imageSegmentation/ImageSegmentationModel.kt +142 -0
  17. package/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/RecognizerUtils.kt +3 -0
  18. package/android/src/main/java/com/swmansion/rnexecutorch/models/{StyleTransferModel.kt → styleTransfer/StyleTransferModel.kt} +2 -1
  19. package/android/src/main/java/com/swmansion/rnexecutorch/utils/ArrayUtils.kt +0 -8
  20. package/android/src/main/java/com/swmansion/rnexecutorch/{models/classification/Utils.kt → utils/Numerical.kt} +1 -1
  21. package/ios/ExecutorchLib.xcframework/Info.plist +4 -4
  22. package/ios/ExecutorchLib.xcframework/ios-arm64/ExecutorchLib.framework/ExecutorchLib +0 -0
  23. package/ios/ExecutorchLib.xcframework/ios-arm64/ExecutorchLib.framework/Headers/HuggingFaceTokenizer.h +14 -0
  24. package/ios/ExecutorchLib.xcframework/ios-arm64/ExecutorchLib.framework/Headers/LLaMARunner.h +1 -23
  25. package/ios/ExecutorchLib.xcframework/ios-arm64/ExecutorchLib.framework/Info.plist +0 -0
  26. package/ios/ExecutorchLib.xcframework/ios-arm64-simulator/ExecutorchLib.framework/ExecutorchLib +0 -0
  27. package/ios/ExecutorchLib.xcframework/ios-arm64-simulator/ExecutorchLib.framework/Headers/HuggingFaceTokenizer.h +14 -0
  28. package/ios/ExecutorchLib.xcframework/ios-arm64-simulator/ExecutorchLib.framework/Headers/LLaMARunner.h +1 -23
  29. package/ios/ExecutorchLib.xcframework/ios-arm64-simulator/ExecutorchLib.framework/Info.plist +0 -0
  30. package/ios/RnExecutorch/Classification.mm +15 -18
  31. package/ios/RnExecutorch/ETModule.mm +6 -5
  32. package/ios/RnExecutorch/ImageSegmentation.h +5 -0
  33. package/ios/RnExecutorch/ImageSegmentation.mm +60 -0
  34. package/ios/RnExecutorch/LLM.mm +12 -53
  35. package/ios/RnExecutorch/OCR.mm +39 -43
  36. package/ios/RnExecutorch/ObjectDetection.mm +20 -20
  37. package/ios/RnExecutorch/SpeechToText.mm +6 -7
  38. package/ios/RnExecutorch/StyleTransfer.mm +16 -19
  39. package/ios/RnExecutorch/TextEmbeddings.h +5 -0
  40. package/ios/RnExecutorch/TextEmbeddings.mm +62 -0
  41. package/ios/RnExecutorch/Tokenizer.h +5 -0
  42. package/ios/RnExecutorch/Tokenizer.mm +83 -0
  43. package/ios/RnExecutorch/VerticalOCR.mm +36 -36
  44. package/ios/RnExecutorch/models/BaseModel.h +2 -5
  45. package/ios/RnExecutorch/models/BaseModel.mm +5 -15
  46. package/ios/RnExecutorch/models/classification/ClassificationModel.mm +2 -3
  47. package/ios/RnExecutorch/models/classification/Constants.mm +0 -1
  48. package/ios/RnExecutorch/models/image_segmentation/Constants.h +4 -0
  49. package/ios/RnExecutorch/models/image_segmentation/Constants.mm +8 -0
  50. package/ios/RnExecutorch/models/image_segmentation/ImageSegmentationModel.h +10 -0
  51. package/ios/RnExecutorch/models/image_segmentation/ImageSegmentationModel.mm +146 -0
  52. package/ios/RnExecutorch/models/object_detection/SSDLiteLargeModel.mm +1 -2
  53. package/ios/RnExecutorch/models/ocr/Detector.h +0 -2
  54. package/ios/RnExecutorch/models/ocr/Detector.mm +2 -1
  55. package/ios/RnExecutorch/models/ocr/RecognitionHandler.h +5 -4
  56. package/ios/RnExecutorch/models/ocr/RecognitionHandler.mm +9 -26
  57. package/ios/RnExecutorch/models/ocr/Recognizer.mm +1 -2
  58. package/ios/RnExecutorch/models/ocr/VerticalDetector.h +0 -2
  59. package/ios/RnExecutorch/models/ocr/VerticalDetector.mm +2 -1
  60. package/ios/RnExecutorch/models/ocr/utils/OCRUtils.mm +0 -1
  61. package/ios/RnExecutorch/models/stt/Moonshine.mm +1 -6
  62. package/ios/RnExecutorch/models/stt/SpeechToTextBaseModel.mm +7 -11
  63. package/ios/RnExecutorch/models/stt/Whisper.mm +0 -5
  64. package/ios/RnExecutorch/models/{StyleTransferModel.h → style_transfer/StyleTransferModel.h} +1 -1
  65. package/ios/RnExecutorch/models/{StyleTransferModel.mm → style_transfer/StyleTransferModel.mm} +2 -3
  66. package/ios/RnExecutorch/models/text_embeddings/TextEmbeddingsModel.h +15 -0
  67. package/ios/RnExecutorch/models/text_embeddings/TextEmbeddingsModel.mm +45 -0
  68. package/ios/RnExecutorch/models/text_embeddings/TextEmbeddingsUtils.h +8 -0
  69. package/ios/RnExecutorch/models/text_embeddings/TextEmbeddingsUtils.mm +49 -0
  70. package/ios/RnExecutorch/utils/Conversions.h +15 -0
  71. package/ios/RnExecutorch/utils/ImageProcessor.h +0 -1
  72. package/ios/RnExecutorch/{models/classification/Utils.h → utils/Numerical.h} +0 -2
  73. package/ios/RnExecutorch/{models/classification/Utils.mm → utils/Numerical.mm} +0 -2
  74. package/ios/RnExecutorch/utils/ObjectDetectionUtils.mm +0 -2
  75. package/ios/RnExecutorch/utils/SFFT.mm +1 -1
  76. package/ios/RnExecutorch/utils/ScalarType.h +0 -2
  77. package/lib/module/Error.js +16 -2
  78. package/lib/module/Error.js.map +1 -1
  79. package/lib/module/constants/{llamaDefaults.js → llmDefaults.js} +7 -3
  80. package/lib/module/constants/llmDefaults.js.map +1 -0
  81. package/lib/module/constants/modelUrls.js +88 -27
  82. package/lib/module/constants/modelUrls.js.map +1 -1
  83. package/lib/module/constants/ocr/models.js +290 -0
  84. package/lib/module/constants/ocr/models.js.map +1 -0
  85. package/lib/module/constants/ocr/symbols.js +137 -2
  86. package/lib/module/constants/ocr/symbols.js.map +1 -1
  87. package/lib/module/constants/sttDefaults.js +50 -25
  88. package/lib/module/constants/sttDefaults.js.map +1 -1
  89. package/lib/module/controllers/LLMController.js +205 -0
  90. package/lib/module/controllers/LLMController.js.map +1 -0
  91. package/lib/module/controllers/OCRController.js +5 -10
  92. package/lib/module/controllers/OCRController.js.map +1 -1
  93. package/lib/module/controllers/SpeechToTextController.js +225 -122
  94. package/lib/module/controllers/SpeechToTextController.js.map +1 -1
  95. package/lib/module/controllers/VerticalOCRController.js +6 -10
  96. package/lib/module/controllers/VerticalOCRController.js.map +1 -1
  97. package/lib/module/hooks/computer_vision/useClassification.js +8 -23
  98. package/lib/module/hooks/computer_vision/useClassification.js.map +1 -1
  99. package/lib/module/hooks/computer_vision/useImageSegmentation.js +13 -0
  100. package/lib/module/hooks/computer_vision/useImageSegmentation.js.map +1 -0
  101. package/lib/module/hooks/computer_vision/useOCR.js +11 -6
  102. package/lib/module/hooks/computer_vision/useOCR.js.map +1 -1
  103. package/lib/module/hooks/computer_vision/useObjectDetection.js +8 -23
  104. package/lib/module/hooks/computer_vision/useObjectDetection.js.map +1 -1
  105. package/lib/module/hooks/computer_vision/useStyleTransfer.js +8 -23
  106. package/lib/module/hooks/computer_vision/useStyleTransfer.js.map +1 -1
  107. package/lib/module/hooks/computer_vision/useVerticalOCR.js +10 -7
  108. package/lib/module/hooks/computer_vision/useVerticalOCR.js.map +1 -1
  109. package/lib/module/hooks/general/useExecutorchModule.js +8 -36
  110. package/lib/module/hooks/general/useExecutorchModule.js.map +1 -1
  111. package/lib/module/hooks/natural_language_processing/useLLM.js +54 -63
  112. package/lib/module/hooks/natural_language_processing/useLLM.js.map +1 -1
  113. package/lib/module/hooks/natural_language_processing/useSpeechToText.js +15 -11
  114. package/lib/module/hooks/natural_language_processing/useSpeechToText.js.map +1 -1
  115. package/lib/module/hooks/natural_language_processing/useTextEmbeddings.js +14 -0
  116. package/lib/module/hooks/natural_language_processing/useTextEmbeddings.js.map +1 -0
  117. package/lib/module/hooks/natural_language_processing/useTokenizer.js +54 -0
  118. package/lib/module/hooks/natural_language_processing/useTokenizer.js.map +1 -0
  119. package/lib/module/hooks/useModule.js +18 -62
  120. package/lib/module/hooks/useModule.js.map +1 -1
  121. package/lib/module/index.js +16 -2
  122. package/lib/module/index.js.map +1 -1
  123. package/lib/module/modules/BaseModule.js +9 -10
  124. package/lib/module/modules/BaseModule.js.map +1 -1
  125. package/lib/module/modules/computer_vision/ClassificationModule.js +8 -5
  126. package/lib/module/modules/computer_vision/ClassificationModule.js.map +1 -1
  127. package/lib/module/modules/computer_vision/ImageSegmentationModule.js +28 -0
  128. package/lib/module/modules/computer_vision/ImageSegmentationModule.js.map +1 -0
  129. package/lib/module/modules/computer_vision/ObjectDetectionModule.js +8 -5
  130. package/lib/module/modules/computer_vision/ObjectDetectionModule.js.map +1 -1
  131. package/lib/module/modules/computer_vision/StyleTransferModule.js +8 -5
  132. package/lib/module/modules/computer_vision/StyleTransferModule.js.map +1 -1
  133. package/lib/module/modules/general/ExecutorchModule.js +8 -5
  134. package/lib/module/modules/general/ExecutorchModule.js.map +1 -1
  135. package/lib/module/modules/natural_language_processing/LLMModule.js +46 -27
  136. package/lib/module/modules/natural_language_processing/LLMModule.js.map +1 -1
  137. package/lib/module/modules/natural_language_processing/SpeechToTextModule.js +8 -5
  138. package/lib/module/modules/natural_language_processing/SpeechToTextModule.js.map +1 -1
  139. package/lib/module/modules/natural_language_processing/TextEmbeddingsModule.js +14 -0
  140. package/lib/module/modules/natural_language_processing/TextEmbeddingsModule.js.map +1 -0
  141. package/lib/module/modules/natural_language_processing/TokenizerModule.js +26 -0
  142. package/lib/module/modules/natural_language_processing/TokenizerModule.js.map +1 -0
  143. package/lib/module/native/NativeClassification.js.map +1 -1
  144. package/lib/module/native/NativeImageSegmentation.js +5 -0
  145. package/lib/module/native/NativeImageSegmentation.js.map +1 -0
  146. package/lib/module/native/NativeLLM.js.map +1 -1
  147. package/lib/module/native/NativeTextEmbeddings.js +5 -0
  148. package/lib/module/native/NativeTextEmbeddings.js.map +1 -0
  149. package/lib/module/native/NativeTokenizer.js +5 -0
  150. package/lib/module/native/NativeTokenizer.js.map +1 -0
  151. package/lib/module/native/RnExecutorchModules.js +18 -113
  152. package/lib/module/native/RnExecutorchModules.js.map +1 -1
  153. package/lib/module/types/common.js.map +1 -1
  154. package/lib/module/types/imageSegmentation.js +29 -0
  155. package/lib/module/types/imageSegmentation.js.map +1 -0
  156. package/lib/module/types/llm.js +7 -0
  157. package/lib/module/types/llm.js.map +1 -0
  158. package/lib/module/types/{object_detection.js → objectDetection.js} +1 -1
  159. package/lib/module/types/objectDetection.js.map +1 -0
  160. package/lib/module/types/ocr.js +2 -0
  161. package/lib/module/types/stt.js +82 -0
  162. package/lib/module/types/stt.js.map +1 -0
  163. package/lib/module/utils/ResourceFetcher.js +156 -0
  164. package/lib/module/utils/ResourceFetcher.js.map +1 -0
  165. package/lib/module/utils/llm.js +25 -0
  166. package/lib/module/utils/llm.js.map +1 -0
  167. package/lib/module/utils/stt.js +22 -0
  168. package/lib/module/utils/stt.js.map +1 -0
  169. package/lib/typescript/Error.d.ts +4 -1
  170. package/lib/typescript/Error.d.ts.map +1 -1
  171. package/lib/typescript/constants/{llamaDefaults.d.ts → llmDefaults.d.ts} +5 -5
  172. package/lib/typescript/constants/llmDefaults.d.ts.map +1 -0
  173. package/lib/typescript/constants/modelUrls.d.ts +74 -28
  174. package/lib/typescript/constants/modelUrls.d.ts.map +1 -1
  175. package/lib/typescript/constants/ocr/models.d.ts +285 -0
  176. package/lib/typescript/constants/ocr/models.d.ts.map +1 -0
  177. package/lib/typescript/constants/ocr/symbols.d.ts +73 -1
  178. package/lib/typescript/constants/ocr/symbols.d.ts.map +1 -1
  179. package/lib/typescript/constants/sttDefaults.d.ts +8 -13
  180. package/lib/typescript/constants/sttDefaults.d.ts.map +1 -1
  181. package/lib/typescript/controllers/LLMController.d.ts +46 -0
  182. package/lib/typescript/controllers/LLMController.d.ts.map +1 -0
  183. package/lib/typescript/controllers/OCRController.d.ts.map +1 -1
  184. package/lib/typescript/controllers/SpeechToTextController.d.ts +30 -16
  185. package/lib/typescript/controllers/SpeechToTextController.d.ts.map +1 -1
  186. package/lib/typescript/controllers/VerticalOCRController.d.ts +1 -1
  187. package/lib/typescript/controllers/VerticalOCRController.d.ts.map +1 -1
  188. package/lib/typescript/hooks/computer_vision/useClassification.d.ts +5 -5
  189. package/lib/typescript/hooks/computer_vision/useClassification.d.ts.map +1 -1
  190. package/lib/typescript/hooks/computer_vision/useImageSegmentation.d.ts +37 -0
  191. package/lib/typescript/hooks/computer_vision/useImageSegmentation.d.ts.map +1 -0
  192. package/lib/typescript/hooks/computer_vision/useOCR.d.ts +2 -1
  193. package/lib/typescript/hooks/computer_vision/useOCR.d.ts.map +1 -1
  194. package/lib/typescript/hooks/computer_vision/useObjectDetection.d.ts +5 -4
  195. package/lib/typescript/hooks/computer_vision/useObjectDetection.d.ts.map +1 -1
  196. package/lib/typescript/hooks/computer_vision/useStyleTransfer.d.ts +4 -2
  197. package/lib/typescript/hooks/computer_vision/useStyleTransfer.d.ts.map +1 -1
  198. package/lib/typescript/hooks/computer_vision/useVerticalOCR.d.ts +2 -1
  199. package/lib/typescript/hooks/computer_vision/useVerticalOCR.d.ts.map +1 -1
  200. package/lib/typescript/hooks/general/useExecutorchModule.d.ts +5 -6
  201. package/lib/typescript/hooks/general/useExecutorchModule.d.ts.map +1 -1
  202. package/lib/typescript/hooks/natural_language_processing/useLLM.d.ts +6 -6
  203. package/lib/typescript/hooks/natural_language_processing/useLLM.d.ts.map +1 -1
  204. package/lib/typescript/hooks/natural_language_processing/useSpeechToText.d.ts +7 -3
  205. package/lib/typescript/hooks/natural_language_processing/useSpeechToText.d.ts.map +1 -1
  206. package/lib/typescript/hooks/natural_language_processing/useTextEmbeddings.d.ts +13 -0
  207. package/lib/typescript/hooks/natural_language_processing/useTextEmbeddings.d.ts.map +1 -0
  208. package/lib/typescript/hooks/natural_language_processing/useTokenizer.d.ts +16 -0
  209. package/lib/typescript/hooks/natural_language_processing/useTokenizer.d.ts.map +1 -0
  210. package/lib/typescript/hooks/useModule.d.ts +11 -10
  211. package/lib/typescript/hooks/useModule.d.ts.map +1 -1
  212. package/lib/typescript/index.d.ts +15 -2
  213. package/lib/typescript/index.d.ts.map +1 -1
  214. package/lib/typescript/modules/BaseModule.d.ts +4 -5
  215. package/lib/typescript/modules/BaseModule.d.ts.map +1 -1
  216. package/lib/typescript/modules/computer_vision/ClassificationModule.d.ts +7 -7
  217. package/lib/typescript/modules/computer_vision/ClassificationModule.d.ts.map +1 -1
  218. package/lib/typescript/modules/computer_vision/ImageSegmentationModule.d.ts +32 -0
  219. package/lib/typescript/modules/computer_vision/ImageSegmentationModule.d.ts.map +1 -0
  220. package/lib/typescript/modules/computer_vision/OCRModule.d.ts.map +1 -1
  221. package/lib/typescript/modules/computer_vision/ObjectDetectionModule.d.ts +6 -5
  222. package/lib/typescript/modules/computer_vision/ObjectDetectionModule.d.ts.map +1 -1
  223. package/lib/typescript/modules/computer_vision/StyleTransferModule.d.ts +6 -5
  224. package/lib/typescript/modules/computer_vision/StyleTransferModule.d.ts.map +1 -1
  225. package/lib/typescript/modules/computer_vision/VerticalOCRModule.d.ts.map +1 -1
  226. package/lib/typescript/modules/general/ExecutorchModule.d.ts +4 -3
  227. package/lib/typescript/modules/general/ExecutorchModule.d.ts.map +1 -1
  228. package/lib/typescript/modules/natural_language_processing/LLMModule.d.ts +19 -5
  229. package/lib/typescript/modules/natural_language_processing/LLMModule.d.ts.map +1 -1
  230. package/lib/typescript/modules/natural_language_processing/SpeechToTextModule.d.ts +7 -4
  231. package/lib/typescript/modules/natural_language_processing/SpeechToTextModule.d.ts.map +1 -1
  232. package/lib/typescript/modules/natural_language_processing/TextEmbeddingsModule.d.ts +8 -0
  233. package/lib/typescript/modules/natural_language_processing/TextEmbeddingsModule.d.ts.map +1 -0
  234. package/lib/typescript/modules/natural_language_processing/TokenizerModule.d.ts +12 -0
  235. package/lib/typescript/modules/natural_language_processing/TokenizerModule.d.ts.map +1 -0
  236. package/lib/typescript/native/NativeClassification.d.ts.map +1 -1
  237. package/lib/typescript/native/NativeImageSegmentation.d.ts +10 -0
  238. package/lib/typescript/native/NativeImageSegmentation.d.ts.map +1 -0
  239. package/lib/typescript/native/NativeLLM.d.ts +3 -4
  240. package/lib/typescript/native/NativeLLM.d.ts.map +1 -1
  241. package/lib/typescript/native/NativeObjectDetection.d.ts +1 -1
  242. package/lib/typescript/native/NativeObjectDetection.d.ts.map +1 -1
  243. package/lib/typescript/native/NativeSpeechToText.d.ts +2 -2
  244. package/lib/typescript/native/NativeSpeechToText.d.ts.map +1 -1
  245. package/lib/typescript/native/NativeTextEmbeddings.d.ts +8 -0
  246. package/lib/typescript/native/NativeTextEmbeddings.d.ts.map +1 -0
  247. package/lib/typescript/native/NativeTokenizer.d.ts +12 -0
  248. package/lib/typescript/native/NativeTokenizer.d.ts.map +1 -0
  249. package/lib/typescript/native/RnExecutorchModules.d.ts +18 -41
  250. package/lib/typescript/native/RnExecutorchModules.d.ts.map +1 -1
  251. package/lib/typescript/types/common.d.ts +1 -26
  252. package/lib/typescript/types/common.d.ts.map +1 -1
  253. package/lib/typescript/types/imageSegmentation.d.ts +25 -0
  254. package/lib/typescript/types/imageSegmentation.d.ts.map +1 -0
  255. package/lib/typescript/types/llm.d.ts +38 -0
  256. package/lib/typescript/types/llm.d.ts.map +1 -0
  257. package/lib/typescript/types/{object_detection.d.ts → objectDetection.d.ts} +1 -1
  258. package/lib/typescript/types/objectDetection.d.ts.map +1 -0
  259. package/lib/typescript/types/ocr.d.ts +2 -1
  260. package/lib/typescript/types/ocr.d.ts.map +1 -1
  261. package/lib/typescript/types/stt.d.ts +91 -0
  262. package/lib/typescript/types/stt.d.ts.map +1 -0
  263. package/lib/typescript/utils/ResourceFetcher.d.ts +17 -0
  264. package/lib/typescript/utils/ResourceFetcher.d.ts.map +1 -0
  265. package/lib/typescript/utils/llm.d.ts +3 -0
  266. package/lib/typescript/utils/llm.d.ts.map +1 -0
  267. package/lib/typescript/utils/stt.d.ts +2 -0
  268. package/lib/typescript/utils/stt.d.ts.map +1 -0
  269. package/package.json +12 -48
  270. package/react-native-executorch.podspec +1 -1
  271. package/src/Error.ts +16 -3
  272. package/src/constants/llmDefaults.ts +14 -0
  273. package/src/constants/modelUrls.ts +146 -39
  274. package/src/constants/ocr/models.ts +453 -0
  275. package/src/constants/ocr/symbols.ts +147 -3
  276. package/src/constants/sttDefaults.ts +55 -37
  277. package/src/controllers/LLMController.ts +286 -0
  278. package/src/controllers/OCRController.ts +14 -28
  279. package/src/controllers/SpeechToTextController.ts +318 -180
  280. package/src/controllers/VerticalOCRController.ts +17 -32
  281. package/src/hooks/computer_vision/useClassification.ts +11 -26
  282. package/src/hooks/computer_vision/useImageSegmentation.ts +18 -0
  283. package/src/hooks/computer_vision/useOCR.ts +17 -5
  284. package/src/hooks/computer_vision/useObjectDetection.ts +10 -24
  285. package/src/hooks/computer_vision/useStyleTransfer.ts +9 -25
  286. package/src/hooks/computer_vision/useVerticalOCR.ts +11 -4
  287. package/src/hooks/general/useExecutorchModule.ts +10 -50
  288. package/src/hooks/natural_language_processing/useLLM.ts +80 -97
  289. package/src/hooks/natural_language_processing/useSpeechToText.ts +39 -12
  290. package/src/hooks/natural_language_processing/useTextEmbeddings.ts +18 -0
  291. package/src/hooks/natural_language_processing/useTokenizer.ts +61 -0
  292. package/src/hooks/useModule.ts +32 -92
  293. package/src/index.tsx +16 -2
  294. package/src/modules/BaseModule.ts +16 -26
  295. package/src/modules/computer_vision/ClassificationModule.ts +13 -8
  296. package/src/modules/computer_vision/ImageSegmentationModule.ts +39 -0
  297. package/src/modules/computer_vision/ObjectDetectionModule.ts +13 -8
  298. package/src/modules/computer_vision/StyleTransferModule.ts +13 -8
  299. package/src/modules/general/ExecutorchModule.ts +11 -6
  300. package/src/modules/natural_language_processing/LLMModule.ts +64 -51
  301. package/src/modules/natural_language_processing/SpeechToTextModule.ts +25 -10
  302. package/src/modules/natural_language_processing/TextEmbeddingsModule.ts +18 -0
  303. package/src/modules/natural_language_processing/TokenizerModule.ts +34 -0
  304. package/src/native/NativeClassification.ts +0 -1
  305. package/src/native/NativeImageSegmentation.ts +14 -0
  306. package/src/native/NativeLLM.ts +3 -10
  307. package/src/native/NativeObjectDetection.ts +1 -1
  308. package/src/native/NativeSpeechToText.ts +2 -2
  309. package/src/native/NativeTextEmbeddings.ts +9 -0
  310. package/src/native/NativeTokenizer.ts +13 -0
  311. package/src/native/RnExecutorchModules.ts +54 -234
  312. package/src/types/common.ts +1 -44
  313. package/src/types/imageSegmentation.ts +25 -0
  314. package/src/types/llm.ts +57 -0
  315. package/src/types/ocr.ts +3 -1
  316. package/src/types/stt.ts +93 -0
  317. package/src/utils/ResourceFetcher.ts +196 -0
  318. package/src/utils/llm.ts +34 -0
  319. package/src/utils/stt.ts +28 -0
  320. package/android/src/main/java/com/swmansion/rnexecutorch/utils/llms/ConversationManager.kt +0 -68
  321. package/ios/ExecutorchLib.xcframework/ios-arm64-simulator/ExecutorchLib.framework/_CodeSignature/CodeResources +0 -124
  322. package/ios/RnExecutorch/utils/llms/Constants.h +0 -6
  323. package/ios/RnExecutorch/utils/llms/Constants.mm +0 -23
  324. package/ios/RnExecutorch/utils/llms/ConversationManager.h +0 -26
  325. package/ios/RnExecutorch/utils/llms/ConversationManager.mm +0 -71
  326. package/lib/module/constants/llamaDefaults.js.map +0 -1
  327. package/lib/module/modules/computer_vision/BaseCVModule.js +0 -14
  328. package/lib/module/modules/computer_vision/BaseCVModule.js.map +0 -1
  329. package/lib/module/types/object_detection.js.map +0 -1
  330. package/lib/module/utils/fetchResource.js +0 -93
  331. package/lib/module/utils/fetchResource.js.map +0 -1
  332. package/lib/module/utils/listDownloadedResources.js +0 -13
  333. package/lib/module/utils/listDownloadedResources.js.map +0 -1
  334. package/lib/typescript/constants/llamaDefaults.d.ts.map +0 -1
  335. package/lib/typescript/modules/computer_vision/BaseCVModule.d.ts +0 -9
  336. package/lib/typescript/modules/computer_vision/BaseCVModule.d.ts.map +0 -1
  337. package/lib/typescript/types/object_detection.d.ts.map +0 -1
  338. package/lib/typescript/utils/fetchResource.d.ts +0 -3
  339. package/lib/typescript/utils/fetchResource.d.ts.map +0 -1
  340. package/lib/typescript/utils/listDownloadedResources.d.ts +0 -3
  341. package/lib/typescript/utils/listDownloadedResources.d.ts.map +0 -1
  342. package/src/constants/llamaDefaults.ts +0 -9
  343. package/src/modules/computer_vision/BaseCVModule.ts +0 -22
  344. package/src/utils/fetchResource.ts +0 -106
  345. package/src/utils/listDownloadedResources.ts +0 -12
  346. /package/src/types/{object_detection.ts → objectDetection.ts} +0 -0
@@ -1,6 +1,8 @@
1
- import { useEffect, useState } from 'react';
1
+ import { useEffect, useMemo, useState } from 'react';
2
2
  import { SpeechToTextController } from '../../controllers/SpeechToTextController';
3
3
  import { ResourceSource } from '../../types/common';
4
+ import { STREAMING_ACTION } from '../../constants/sttDefaults';
5
+ import { AvailableModels, SpeechToTextLanguage } from '../../types/stt';
4
6
 
5
7
  interface SpeechToTextModule {
6
8
  isReady: boolean;
@@ -10,8 +12,14 @@ interface SpeechToTextModule {
10
12
  configureStreaming: SpeechToTextController['configureStreaming'];
11
13
  error: Error | undefined;
12
14
  transcribe: (
13
- input: number[]
15
+ input: number[],
16
+ audioLanguage?: SpeechToTextLanguage
14
17
  ) => ReturnType<SpeechToTextController['transcribe']>;
18
+ streamingTranscribe: (
19
+ streamAction: STREAMING_ACTION,
20
+ input?: number[],
21
+ audioLanguage?: SpeechToTextLanguage
22
+ ) => ReturnType<SpeechToTextController['streamingTranscribe']>;
15
23
  }
16
24
 
17
25
  export const useSpeechToText = ({
@@ -22,8 +30,9 @@ export const useSpeechToText = ({
22
30
  overlapSeconds,
23
31
  windowSize,
24
32
  streamingConfig,
33
+ preventLoad = false,
25
34
  }: {
26
- modelName: 'moonshine' | 'whisper';
35
+ modelName: AvailableModels;
27
36
  encoderSource?: ResourceSource;
28
37
  decoderSource?: ResourceSource;
29
38
  tokenizerSource?: ResourceSource;
@@ -36,6 +45,7 @@ export const useSpeechToText = ({
36
45
  streamingConfig?: ConstructorParameters<
37
46
  typeof SpeechToTextController
38
47
  >['0']['streamingConfig'];
48
+ preventLoad?: boolean;
39
49
  }): SpeechToTextModule => {
40
50
  const [sequence, setSequence] = useState<string>('');
41
51
  const [isReady, setIsReady] = useState(false);
@@ -43,20 +53,22 @@ export const useSpeechToText = ({
43
53
  const [isGenerating, setIsGenerating] = useState(false);
44
54
  const [error, setError] = useState<Error | undefined>();
45
55
 
46
- const [model, _] = useState(
56
+ const model = useMemo(
47
57
  () =>
48
58
  new SpeechToTextController({
49
59
  transcribeCallback: setSequence,
50
60
  isReadyCallback: setIsReady,
51
61
  isGeneratingCallback: setIsGenerating,
52
62
  onErrorCallback: setError,
53
- modelDownloadProgessCallback: setDownloadProgress,
54
- overlapSeconds: overlapSeconds,
55
- windowSize: windowSize,
56
- streamingConfig: streamingConfig,
57
- })
63
+ modelDownloadProgressCallback: setDownloadProgress,
64
+ }),
65
+ []
58
66
  );
59
67
 
68
+ useEffect(() => {
69
+ model.configureStreaming(overlapSeconds, windowSize, streamingConfig);
70
+ }, [model, overlapSeconds, windowSize, streamingConfig]);
71
+
60
72
  useEffect(() => {
61
73
  const loadModel = async () => {
62
74
  await model.loadModel(
@@ -66,8 +78,17 @@ export const useSpeechToText = ({
66
78
  tokenizerSource
67
79
  );
68
80
  };
69
- loadModel();
70
- }, [model, modelName, encoderSource, decoderSource, tokenizerSource]);
81
+ if (!preventLoad) {
82
+ loadModel();
83
+ }
84
+ }, [
85
+ model,
86
+ modelName,
87
+ encoderSource,
88
+ decoderSource,
89
+ tokenizerSource,
90
+ preventLoad,
91
+ ]);
71
92
 
72
93
  return {
73
94
  isReady,
@@ -76,6 +97,12 @@ export const useSpeechToText = ({
76
97
  configureStreaming: model.configureStreaming,
77
98
  sequence,
78
99
  error,
79
- transcribe: (waveform: number[]) => model.transcribe(waveform),
100
+ transcribe: (waveform: number[], audioLanguage?: SpeechToTextLanguage) =>
101
+ model.transcribe(waveform, audioLanguage),
102
+ streamingTranscribe: (
103
+ streamAction: STREAMING_ACTION,
104
+ waveform?: number[],
105
+ audioLanguage?: SpeechToTextLanguage
106
+ ) => model.streamingTranscribe(streamAction, waveform, audioLanguage),
80
107
  };
81
108
  };
@@ -0,0 +1,18 @@
1
+ import { TextEmbeddingsModule } from '../../modules/natural_language_processing/TextEmbeddingsModule';
2
+ import { ResourceSource } from '../../types/common';
3
+ import { useModule } from '../useModule';
4
+
5
+ export const useTextEmbeddings = ({
6
+ modelSource,
7
+ tokenizerSource,
8
+ preventLoad = false,
9
+ }: {
10
+ modelSource: ResourceSource;
11
+ tokenizerSource: ResourceSource;
12
+ preventLoad?: boolean;
13
+ }) =>
14
+ useModule({
15
+ module: TextEmbeddingsModule,
16
+ loadArgs: [modelSource, tokenizerSource],
17
+ preventLoad,
18
+ });
@@ -0,0 +1,61 @@
1
+ import { useEffect, useState } from 'react';
2
+ import { TokenizerModule } from '../../modules/natural_language_processing/TokenizerModule';
3
+ import { ResourceSource } from '../../types/common';
4
+ import { ETError, getError } from '../../Error';
5
+
6
+ export const useTokenizer = ({
7
+ tokenizerSource,
8
+ preventLoad = false,
9
+ }: {
10
+ tokenizerSource: ResourceSource;
11
+ preventLoad?: boolean;
12
+ }) => {
13
+ const [error, setError] = useState<null | string>(null);
14
+ const [isReady, setIsReady] = useState(false);
15
+ const [isGenerating, setIsGenerating] = useState(false);
16
+ const [downloadProgress, setDownloadProgress] = useState(0);
17
+
18
+ useEffect(() => {
19
+ const loadModule = async () => {
20
+ try {
21
+ setIsReady(false);
22
+ TokenizerModule.onDownloadProgress(setDownloadProgress);
23
+ await TokenizerModule.load(tokenizerSource);
24
+ setIsReady(true);
25
+ } catch (err) {
26
+ setError((err as Error).message);
27
+ }
28
+ };
29
+ if (!preventLoad) {
30
+ loadModule();
31
+ }
32
+ }, [tokenizerSource, preventLoad]);
33
+
34
+ const stateWrapper = <T extends (...args: any[]) => Promise<any>>(fn: T) => {
35
+ const boundFn = fn.bind(TokenizerModule);
36
+
37
+ return async (...args: Parameters<T>): Promise<ReturnType<T>> => {
38
+ if (!isReady) throw new Error(getError(ETError.ModuleNotLoaded));
39
+ if (isGenerating) throw new Error(getError(ETError.ModelGenerating));
40
+
41
+ setIsGenerating(true);
42
+ try {
43
+ return await boundFn(...args);
44
+ } finally {
45
+ setIsGenerating(false);
46
+ }
47
+ };
48
+ };
49
+
50
+ return {
51
+ error,
52
+ isReady,
53
+ isGenerating,
54
+ downloadProgress,
55
+ decode: stateWrapper(TokenizerModule.decode),
56
+ encode: stateWrapper(TokenizerModule.encode),
57
+ getVocabSize: stateWrapper(TokenizerModule.getVocabSize),
58
+ idToToken: stateWrapper(TokenizerModule.idToToken),
59
+ tokenToId: stateWrapper(TokenizerModule.tokenToId),
60
+ };
61
+ };
@@ -1,124 +1,64 @@
1
1
  import { useEffect, useState } from 'react';
2
- import { fetchResource } from '../utils/fetchResource';
3
2
  import { ETError, getError } from '../Error';
4
- import { ETInput, Module } from '../types/common';
5
- import { _ETModule } from '../native/RnExecutorchModules';
6
- import { getTypeIdentifier } from '../types/common';
7
3
 
8
- interface Props {
9
- modelSource: string | number;
10
- module: Module;
4
+ interface Module {
5
+ load: (...args: any[]) => Promise<void>;
6
+ forward: (...input: any[]) => Promise<any>;
7
+ onDownloadProgress: (cb: (progress: number) => void) => void;
11
8
  }
12
9
 
13
- interface _Module {
14
- error: string | null;
15
- isReady: boolean;
16
- isGenerating: boolean;
17
- downloadProgress: number;
18
- forwardETInput: (
19
- input: ETInput[] | ETInput,
20
- shape: number[][] | number[]
21
- ) => ReturnType<_ETModule['forward']>;
22
- forwardImage: (input: string) => Promise<any>;
23
- }
24
-
25
- export const useModule = ({ modelSource, module }: Props): _Module => {
10
+ export const useModule = <
11
+ M extends Module,
12
+ LoadArgs extends Parameters<M['load']>,
13
+ ForwardArgs extends Parameters<M['forward']>,
14
+ ForwardReturn extends Awaited<ReturnType<M['forward']>>,
15
+ >({
16
+ module,
17
+ loadArgs,
18
+ preventLoad = false,
19
+ }: {
20
+ module: M;
21
+ loadArgs: LoadArgs;
22
+ preventLoad?: boolean;
23
+ }) => {
26
24
  const [error, setError] = useState<null | string>(null);
27
25
  const [isReady, setIsReady] = useState(false);
28
26
  const [isGenerating, setIsGenerating] = useState(false);
29
27
  const [downloadProgress, setDownloadProgress] = useState(0);
30
28
 
31
29
  useEffect(() => {
32
- const loadModel = async () => {
33
- if (!modelSource) return;
34
-
30
+ const loadModule = async () => {
35
31
  try {
36
32
  setIsReady(false);
37
- const fileUri = await fetchResource(modelSource, setDownloadProgress);
38
- await module.loadModule(fileUri);
33
+ module.onDownloadProgress(setDownloadProgress);
34
+ await module.load(...loadArgs);
39
35
  setIsReady(true);
40
- } catch (e) {
41
- setError(getError(e));
36
+ } catch (err) {
37
+ setError((err as Error).message);
42
38
  }
43
39
  };
44
-
45
- loadModel();
46
- }, [modelSource, module]);
47
-
48
- const forwardImage = async (input: string) => {
49
- if (!isReady) {
50
- throw new Error(getError(ETError.ModuleNotLoaded));
51
- }
52
- if (isGenerating) {
53
- throw new Error(getError(ETError.ModelGenerating));
40
+ if (!preventLoad) {
41
+ loadModule();
54
42
  }
43
+ // eslint-disable-next-line react-hooks/exhaustive-deps
44
+ }, [...loadArgs, preventLoad]);
55
45
 
46
+ const forward = async (...input: ForwardArgs): Promise<ForwardReturn> => {
47
+ if (!isReady) throw new Error(getError(ETError.ModuleNotLoaded));
48
+ if (isGenerating) throw new Error(getError(ETError.ModelGenerating));
56
49
  try {
57
50
  setIsGenerating(true);
58
- const output = await module.forward(input);
59
- return output;
60
- } catch (e) {
61
- throw new Error(getError(e));
51
+ return await module.forward(...input);
62
52
  } finally {
63
53
  setIsGenerating(false);
64
54
  }
65
55
  };
66
56
 
67
- const forwardETInput = async (
68
- input: ETInput[] | ETInput,
69
- shape: number[][] | number[]
70
- ) => {
71
- if (!isReady) {
72
- throw new Error(getError(ETError.ModuleNotLoaded));
73
- }
74
- if (isGenerating) {
75
- throw new Error(getError(ETError.ModelGenerating));
76
- }
77
-
78
- // Since the native module expects an array of inputs and an array of shapes,
79
- // if the user provides a single ETInput, we want to "unsqueeze" the array so
80
- // the data is properly processed on the native side
81
- if (!Array.isArray(input)) {
82
- input = [input];
83
- }
84
-
85
- if (!Array.isArray(shape[0])) {
86
- shape = [shape] as number[][];
87
- }
88
-
89
- let inputTypeIdentifiers: any[] = [];
90
- let modelInputs: any[] = [];
91
-
92
- for (let idx = 0; idx < input.length; idx++) {
93
- let currentInputTypeIdentifier = getTypeIdentifier(input[idx] as ETInput);
94
- if (currentInputTypeIdentifier === -1) {
95
- throw new Error(getError(ETError.InvalidArgument));
96
- }
97
- inputTypeIdentifiers.push(currentInputTypeIdentifier);
98
- modelInputs.push([...(input[idx] as ETInput)]);
99
- }
100
-
101
- try {
102
- setIsGenerating(true);
103
- const output = await module.forward(
104
- modelInputs,
105
- shape,
106
- inputTypeIdentifiers
107
- );
108
- setIsGenerating(false);
109
- return output;
110
- } catch (e) {
111
- setIsGenerating(false);
112
- throw new Error(getError(e));
113
- }
114
- };
115
-
116
57
  return {
117
58
  error,
118
59
  isReady,
119
60
  isGenerating,
120
61
  downloadProgress,
121
- forwardETInput,
122
- forwardImage,
62
+ forward,
123
63
  };
124
64
  };
package/src/index.tsx CHANGED
@@ -1,12 +1,17 @@
1
+ import { SpeechToTextLanguage } from './types/stt';
2
+
1
3
  // hooks
2
4
  export * from './hooks/computer_vision/useClassification';
3
5
  export * from './hooks/computer_vision/useObjectDetection';
4
6
  export * from './hooks/computer_vision/useStyleTransfer';
7
+ export * from './hooks/computer_vision/useImageSegmentation';
5
8
  export * from './hooks/computer_vision/useOCR';
6
9
  export * from './hooks/computer_vision/useVerticalOCR';
7
10
 
8
11
  export * from './hooks/natural_language_processing/useLLM';
9
12
  export * from './hooks/natural_language_processing/useSpeechToText';
13
+ export * from './hooks/natural_language_processing/useTextEmbeddings';
14
+ export * from './hooks/natural_language_processing/useTokenizer';
10
15
 
11
16
  export * from './hooks/general/useExecutorchModule';
12
17
 
@@ -14,20 +19,29 @@ export * from './hooks/general/useExecutorchModule';
14
19
  export * from './modules/computer_vision/ClassificationModule';
15
20
  export * from './modules/computer_vision/ObjectDetectionModule';
16
21
  export * from './modules/computer_vision/StyleTransferModule';
22
+ export * from './modules/computer_vision/ImageSegmentationModule';
17
23
  export * from './modules/computer_vision/OCRModule';
18
24
  export * from './modules/computer_vision/VerticalOCRModule';
19
25
 
20
26
  export * from './modules/natural_language_processing/LLMModule';
21
27
  export * from './modules/natural_language_processing/SpeechToTextModule';
28
+ export * from './modules/natural_language_processing/TextEmbeddingsModule';
29
+ export * from './modules/natural_language_processing/TokenizerModule';
22
30
 
23
31
  export * from './modules/general/ExecutorchModule';
24
32
 
25
33
  // utils
26
- export * from './utils/listDownloadedResources';
34
+ export * from './utils/ResourceFetcher';
27
35
 
28
36
  // types
29
- export * from './types/object_detection';
37
+ export * from './types/objectDetection';
30
38
  export * from './types/ocr';
39
+ export * from './types/imageSegmentation';
40
+ export * from './types/llm';
41
+ export { SpeechToTextLanguage };
31
42
 
32
43
  // constants
33
44
  export * from './constants/modelUrls';
45
+ export * from './constants/ocr/models';
46
+ export * from './constants/llmDefaults';
47
+ export { STREAMING_ACTION, MODES } from './constants/sttDefaults';
@@ -1,38 +1,28 @@
1
- import {
2
- _StyleTransferModule,
3
- _ObjectDetectionModule,
4
- _ClassificationModule,
5
- _ETModule,
6
- } from '../native/RnExecutorchModules';
7
- import { fetchResource } from '../utils/fetchResource';
8
- import { ResourceSource } from '../types/common';
1
+ import { ResourceFetcher } from '../utils/ResourceFetcher';
9
2
  import { getError } from '../Error';
3
+ import { ResourceSource } from '../types/common';
10
4
 
11
5
  export class BaseModule {
12
- static module:
13
- | _StyleTransferModule
14
- | _ObjectDetectionModule
15
- | _ClassificationModule
16
- | _ETModule;
17
-
18
- static onDownloadProgressCallback = (_downloadProgress: number) => {};
19
-
20
- static async load(modelSource: ResourceSource) {
21
- if (!modelSource) return;
6
+ protected static nativeModule: any;
7
+ static onDownloadProgressCallback: (downloadProgress: number) => void =
8
+ () => {};
22
9
 
10
+ static async load(...sources: ResourceSource[]): Promise<void> {
23
11
  try {
24
- const fileUri = await fetchResource(
25
- modelSource,
26
- this.onDownloadProgressCallback
12
+ const paths = await ResourceFetcher.fetchMultipleResources(
13
+ this.onDownloadProgressCallback,
14
+ ...sources
27
15
  );
28
- await this.module.loadModule(fileUri);
29
- } catch (e) {
30
- throw new Error(getError(e));
16
+ await this.nativeModule.loadModule(...paths);
17
+ } catch (error) {
18
+ throw new Error(getError(error));
31
19
  }
32
20
  }
33
21
 
34
- static async forward(..._: any[]): Promise<any> {
35
- throw new Error('The forward method is not implemented.');
22
+ protected static async forward(..._args: any[]): Promise<any> {
23
+ throw new Error(
24
+ 'forward method is not implemented in the BaseModule class. Please implement it in the derived class.'
25
+ );
36
26
  }
37
27
 
38
28
  static onDownloadProgress(callback: (downloadProgress: number) => void) {
@@ -1,12 +1,17 @@
1
- import { BaseCVModule } from './BaseCVModule';
2
- import { _ClassificationModule } from '../../native/RnExecutorchModules';
1
+ import { ClassificationNativeModule } from '../../native/RnExecutorchModules';
2
+ import { ResourceSource } from '../../types/common';
3
+ import { BaseModule } from '../BaseModule';
3
4
 
4
- export class ClassificationModule extends BaseCVModule {
5
- static module = new _ClassificationModule();
5
+ export class ClassificationModule extends BaseModule {
6
+ protected static override nativeModule = ClassificationNativeModule;
6
7
 
7
- static async forward(input: string) {
8
- return await (super.forward(input) as ReturnType<
9
- _ClassificationModule['forward']
10
- >);
8
+ static override async load(modelSource: ResourceSource) {
9
+ await super.load(modelSource);
10
+ }
11
+
12
+ static override async forward(
13
+ input: string
14
+ ): ReturnType<typeof ClassificationNativeModule.forward> {
15
+ return await this.nativeModule.forward(input);
11
16
  }
12
17
  }
@@ -0,0 +1,39 @@
1
+ import { BaseModule } from '../BaseModule';
2
+ import { getError } from '../../Error';
3
+ import { DeeplabLabel } from '../../types/imageSegmentation';
4
+ import { ResourceSource } from '../../types/common';
5
+ import { ImageSegmentationNativeModule } from '../../native/RnExecutorchModules';
6
+
7
+ export class ImageSegmentationModule extends BaseModule {
8
+ protected static override nativeModule = ImageSegmentationNativeModule;
9
+
10
+ static override async load(modelSource: ResourceSource) {
11
+ return await super.load(modelSource);
12
+ }
13
+
14
+ static override async forward(
15
+ input: string,
16
+ classesOfInterest?: DeeplabLabel[],
17
+ resize?: boolean
18
+ ) {
19
+ try {
20
+ const stringDict = await (this.nativeModule.forward(
21
+ input,
22
+ (classesOfInterest || []).map((label) => DeeplabLabel[label]),
23
+ resize || false
24
+ ) as ReturnType<(typeof this.nativeModule)['forward']>);
25
+
26
+ let enumDict: { [key in DeeplabLabel]?: number[] } = {};
27
+
28
+ for (const key in stringDict) {
29
+ if (key in DeeplabLabel) {
30
+ const enumKey = DeeplabLabel[key as keyof typeof DeeplabLabel];
31
+ enumDict[enumKey] = stringDict[key];
32
+ }
33
+ }
34
+ return enumDict;
35
+ } catch (e) {
36
+ throw new Error(getError(e));
37
+ }
38
+ }
39
+ }
@@ -1,12 +1,17 @@
1
- import { BaseCVModule } from './BaseCVModule';
2
- import { _ObjectDetectionModule } from '../../native/RnExecutorchModules';
1
+ import { ObjectDetectionNativeModule } from '../../native/RnExecutorchModules';
2
+ import { ResourceSource } from '../../types/common';
3
+ import { BaseModule } from '../BaseModule';
3
4
 
4
- export class ObjectDetectionModule extends BaseCVModule {
5
- static module = new _ObjectDetectionModule();
5
+ export class ObjectDetectionModule extends BaseModule {
6
+ protected static override nativeModule = ObjectDetectionNativeModule;
6
7
 
7
- static async forward(input: string) {
8
- return await (super.forward(input) as ReturnType<
9
- _ObjectDetectionModule['forward']
10
- >);
8
+ static override async load(modelSource: ResourceSource) {
9
+ return await super.load(modelSource);
10
+ }
11
+
12
+ static override async forward(
13
+ input: string
14
+ ): ReturnType<typeof this.nativeModule.forward> {
15
+ return await this.nativeModule.forward(input);
11
16
  }
12
17
  }
@@ -1,12 +1,17 @@
1
- import { BaseCVModule } from './BaseCVModule';
2
- import { _StyleTransferModule } from '../../native/RnExecutorchModules';
1
+ import { StyleTransferNativeModule } from '../../native/RnExecutorchModules';
2
+ import { ResourceSource } from '../../types/common';
3
+ import { BaseModule } from '../BaseModule';
3
4
 
4
- export class StyleTransferModule extends BaseCVModule {
5
- static module = new _StyleTransferModule();
5
+ export class StyleTransferModule extends BaseModule {
6
+ protected static override nativeModule = StyleTransferNativeModule;
6
7
 
7
- static async forward(input: string) {
8
- return await (super.forward(input) as ReturnType<
9
- _StyleTransferModule['forward']
10
- >);
8
+ static override async load(modelSource: ResourceSource) {
9
+ return await super.load(modelSource);
10
+ }
11
+
12
+ static override async forward(
13
+ input: string
14
+ ): ReturnType<typeof this.nativeModule.forward> {
15
+ return await this.nativeModule.forward(input);
11
16
  }
12
17
  }
@@ -1,13 +1,18 @@
1
- import { BaseModule } from '../BaseModule';
2
1
  import { ETError, getError } from '../../Error';
3
- import { _ETModule } from '../../native/RnExecutorchModules';
2
+ import { ETModuleNativeModule } from '../../native/RnExecutorchModules';
3
+ import { ResourceSource } from '../../types/common';
4
4
  import { ETInput } from '../../types/common';
5
5
  import { getTypeIdentifier } from '../../types/common';
6
+ import { BaseModule } from '../BaseModule';
6
7
 
7
8
  export class ExecutorchModule extends BaseModule {
8
- static module = new _ETModule();
9
+ protected static override nativeModule = ETModuleNativeModule;
10
+
11
+ static override async load(modelSource: ResourceSource) {
12
+ return await super.load(modelSource);
13
+ }
9
14
 
10
- static async forward(input: ETInput[] | ETInput, shape: number[][]) {
15
+ static override async forward(input: ETInput[] | ETInput, shape: number[][]) {
11
16
  if (!Array.isArray(input)) {
12
17
  input = [input];
13
18
  }
@@ -25,7 +30,7 @@ export class ExecutorchModule extends BaseModule {
25
30
  }
26
31
 
27
32
  try {
28
- return await this.module.forward(
33
+ return await this.nativeModule.forward(
29
34
  modelInputs,
30
35
  shape,
31
36
  inputTypeIdentifiers
@@ -37,7 +42,7 @@ export class ExecutorchModule extends BaseModule {
37
42
 
38
43
  static async loadMethod(methodName: string) {
39
44
  try {
40
- await this.module.loadMethod(methodName);
45
+ await this.nativeModule.loadMethod(methodName);
41
46
  } catch (e) {
42
47
  throw new Error(getError(e));
43
48
  }