react-native-executorch 0.3.3 → 0.4.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (344) hide show
  1. package/README.md +30 -13
  2. package/android/build.gradle +1 -1
  3. package/android/libs/executorch.aar +0 -0
  4. package/android/src/main/java/com/swmansion/rnexecutorch/ETModule.kt +1 -2
  5. package/android/src/main/java/com/swmansion/rnexecutorch/ImageSegmentation.kt +58 -0
  6. package/android/src/main/java/com/swmansion/rnexecutorch/LLM.kt +13 -49
  7. package/android/src/main/java/com/swmansion/rnexecutorch/RnExecutorchPackage.kt +37 -0
  8. package/android/src/main/java/com/swmansion/rnexecutorch/StyleTransfer.kt +1 -1
  9. package/android/src/main/java/com/swmansion/rnexecutorch/TextEmbeddings.kt +51 -0
  10. package/android/src/main/java/com/swmansion/rnexecutorch/Tokenizer.kt +86 -0
  11. package/android/src/main/java/com/swmansion/rnexecutorch/models/BaseModel.kt +3 -4
  12. package/android/src/main/java/com/swmansion/rnexecutorch/models/TextEmbeddings/TextEmbeddingsModel.kt +48 -0
  13. package/android/src/main/java/com/swmansion/rnexecutorch/models/TextEmbeddings/TextEmbeddingsUtils.kt +37 -0
  14. package/android/src/main/java/com/swmansion/rnexecutorch/models/classification/ClassificationModel.kt +1 -0
  15. package/android/src/main/java/com/swmansion/rnexecutorch/models/imageSegmentation/Constants.kt +26 -0
  16. package/android/src/main/java/com/swmansion/rnexecutorch/models/imageSegmentation/ImageSegmentationModel.kt +142 -0
  17. package/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/RecognizerUtils.kt +3 -0
  18. package/android/src/main/java/com/swmansion/rnexecutorch/models/{StyleTransferModel.kt → styleTransfer/StyleTransferModel.kt} +2 -1
  19. package/android/src/main/java/com/swmansion/rnexecutorch/utils/ArrayUtils.kt +0 -8
  20. package/android/src/main/java/com/swmansion/rnexecutorch/{models/classification/Utils.kt → utils/Numerical.kt} +1 -1
  21. package/ios/ExecutorchLib.xcframework/Info.plist +4 -4
  22. package/ios/ExecutorchLib.xcframework/ios-arm64/ExecutorchLib.framework/ExecutorchLib +0 -0
  23. package/ios/ExecutorchLib.xcframework/ios-arm64/ExecutorchLib.framework/Headers/HuggingFaceTokenizer.h +14 -0
  24. package/ios/ExecutorchLib.xcframework/ios-arm64/ExecutorchLib.framework/Headers/LLaMARunner.h +1 -23
  25. package/ios/ExecutorchLib.xcframework/ios-arm64/ExecutorchLib.framework/Info.plist +0 -0
  26. package/ios/ExecutorchLib.xcframework/ios-arm64-simulator/ExecutorchLib.framework/ExecutorchLib +0 -0
  27. package/ios/ExecutorchLib.xcframework/ios-arm64-simulator/ExecutorchLib.framework/Headers/HuggingFaceTokenizer.h +14 -0
  28. package/ios/ExecutorchLib.xcframework/ios-arm64-simulator/ExecutorchLib.framework/Headers/LLaMARunner.h +1 -23
  29. package/ios/ExecutorchLib.xcframework/ios-arm64-simulator/ExecutorchLib.framework/Info.plist +0 -0
  30. package/ios/RnExecutorch/Classification.mm +15 -18
  31. package/ios/RnExecutorch/ETModule.mm +6 -5
  32. package/ios/RnExecutorch/ImageSegmentation.h +5 -0
  33. package/ios/RnExecutorch/ImageSegmentation.mm +60 -0
  34. package/ios/RnExecutorch/LLM.mm +12 -53
  35. package/ios/RnExecutorch/OCR.mm +39 -43
  36. package/ios/RnExecutorch/ObjectDetection.mm +20 -20
  37. package/ios/RnExecutorch/SpeechToText.mm +6 -7
  38. package/ios/RnExecutorch/StyleTransfer.mm +16 -19
  39. package/ios/RnExecutorch/TextEmbeddings.h +5 -0
  40. package/ios/RnExecutorch/TextEmbeddings.mm +62 -0
  41. package/ios/RnExecutorch/Tokenizer.h +5 -0
  42. package/ios/RnExecutorch/Tokenizer.mm +83 -0
  43. package/ios/RnExecutorch/VerticalOCR.mm +36 -36
  44. package/ios/RnExecutorch/models/BaseModel.h +2 -5
  45. package/ios/RnExecutorch/models/BaseModel.mm +5 -15
  46. package/ios/RnExecutorch/models/classification/ClassificationModel.mm +2 -3
  47. package/ios/RnExecutorch/models/classification/Constants.mm +0 -1
  48. package/ios/RnExecutorch/models/image_segmentation/Constants.h +4 -0
  49. package/ios/RnExecutorch/models/image_segmentation/Constants.mm +8 -0
  50. package/ios/RnExecutorch/models/image_segmentation/ImageSegmentationModel.h +10 -0
  51. package/ios/RnExecutorch/models/image_segmentation/ImageSegmentationModel.mm +146 -0
  52. package/ios/RnExecutorch/models/object_detection/SSDLiteLargeModel.mm +1 -2
  53. package/ios/RnExecutorch/models/ocr/Detector.h +0 -2
  54. package/ios/RnExecutorch/models/ocr/Detector.mm +2 -1
  55. package/ios/RnExecutorch/models/ocr/RecognitionHandler.h +5 -4
  56. package/ios/RnExecutorch/models/ocr/RecognitionHandler.mm +9 -26
  57. package/ios/RnExecutorch/models/ocr/Recognizer.mm +1 -2
  58. package/ios/RnExecutorch/models/ocr/VerticalDetector.h +0 -2
  59. package/ios/RnExecutorch/models/ocr/VerticalDetector.mm +2 -1
  60. package/ios/RnExecutorch/models/ocr/utils/OCRUtils.mm +0 -1
  61. package/ios/RnExecutorch/models/stt/Moonshine.mm +1 -6
  62. package/ios/RnExecutorch/models/stt/SpeechToTextBaseModel.mm +7 -11
  63. package/ios/RnExecutorch/models/stt/Whisper.mm +0 -5
  64. package/ios/RnExecutorch/models/{StyleTransferModel.h → style_transfer/StyleTransferModel.h} +1 -1
  65. package/ios/RnExecutorch/models/{StyleTransferModel.mm → style_transfer/StyleTransferModel.mm} +2 -3
  66. package/ios/RnExecutorch/models/text_embeddings/TextEmbeddingsModel.h +15 -0
  67. package/ios/RnExecutorch/models/text_embeddings/TextEmbeddingsModel.mm +45 -0
  68. package/ios/RnExecutorch/models/text_embeddings/TextEmbeddingsUtils.h +8 -0
  69. package/ios/RnExecutorch/models/text_embeddings/TextEmbeddingsUtils.mm +49 -0
  70. package/ios/RnExecutorch/utils/Conversions.h +15 -0
  71. package/ios/RnExecutorch/utils/ImageProcessor.h +0 -1
  72. package/ios/RnExecutorch/{models/classification/Utils.h → utils/Numerical.h} +0 -2
  73. package/ios/RnExecutorch/{models/classification/Utils.mm → utils/Numerical.mm} +0 -2
  74. package/ios/RnExecutorch/utils/ObjectDetectionUtils.mm +0 -2
  75. package/ios/RnExecutorch/utils/SFFT.mm +1 -1
  76. package/ios/RnExecutorch/utils/ScalarType.h +0 -2
  77. package/lib/module/Error.js +16 -2
  78. package/lib/module/Error.js.map +1 -1
  79. package/lib/module/constants/{llamaDefaults.js → llmDefaults.js} +7 -3
  80. package/lib/module/constants/llmDefaults.js.map +1 -0
  81. package/lib/module/constants/modelUrls.js +88 -27
  82. package/lib/module/constants/modelUrls.js.map +1 -1
  83. package/lib/module/constants/ocr/models.js +290 -0
  84. package/lib/module/constants/ocr/models.js.map +1 -0
  85. package/lib/module/constants/ocr/symbols.js +137 -2
  86. package/lib/module/constants/ocr/symbols.js.map +1 -1
  87. package/lib/module/constants/sttDefaults.js +50 -25
  88. package/lib/module/constants/sttDefaults.js.map +1 -1
  89. package/lib/module/controllers/LLMController.js +205 -0
  90. package/lib/module/controllers/LLMController.js.map +1 -0
  91. package/lib/module/controllers/OCRController.js +5 -10
  92. package/lib/module/controllers/OCRController.js.map +1 -1
  93. package/lib/module/controllers/SpeechToTextController.js +225 -122
  94. package/lib/module/controllers/SpeechToTextController.js.map +1 -1
  95. package/lib/module/controllers/VerticalOCRController.js +6 -10
  96. package/lib/module/controllers/VerticalOCRController.js.map +1 -1
  97. package/lib/module/hooks/computer_vision/useClassification.js +8 -23
  98. package/lib/module/hooks/computer_vision/useClassification.js.map +1 -1
  99. package/lib/module/hooks/computer_vision/useImageSegmentation.js +13 -0
  100. package/lib/module/hooks/computer_vision/useImageSegmentation.js.map +1 -0
  101. package/lib/module/hooks/computer_vision/useOCR.js +11 -6
  102. package/lib/module/hooks/computer_vision/useOCR.js.map +1 -1
  103. package/lib/module/hooks/computer_vision/useObjectDetection.js +8 -23
  104. package/lib/module/hooks/computer_vision/useObjectDetection.js.map +1 -1
  105. package/lib/module/hooks/computer_vision/useStyleTransfer.js +8 -23
  106. package/lib/module/hooks/computer_vision/useStyleTransfer.js.map +1 -1
  107. package/lib/module/hooks/computer_vision/useVerticalOCR.js +10 -7
  108. package/lib/module/hooks/computer_vision/useVerticalOCR.js.map +1 -1
  109. package/lib/module/hooks/general/useExecutorchModule.js +8 -36
  110. package/lib/module/hooks/general/useExecutorchModule.js.map +1 -1
  111. package/lib/module/hooks/natural_language_processing/useLLM.js +54 -63
  112. package/lib/module/hooks/natural_language_processing/useLLM.js.map +1 -1
  113. package/lib/module/hooks/natural_language_processing/useSpeechToText.js +15 -11
  114. package/lib/module/hooks/natural_language_processing/useSpeechToText.js.map +1 -1
  115. package/lib/module/hooks/natural_language_processing/useTextEmbeddings.js +14 -0
  116. package/lib/module/hooks/natural_language_processing/useTextEmbeddings.js.map +1 -0
  117. package/lib/module/hooks/natural_language_processing/useTokenizer.js +54 -0
  118. package/lib/module/hooks/natural_language_processing/useTokenizer.js.map +1 -0
  119. package/lib/module/hooks/useModule.js +18 -62
  120. package/lib/module/hooks/useModule.js.map +1 -1
  121. package/lib/module/index.js +16 -2
  122. package/lib/module/index.js.map +1 -1
  123. package/lib/module/modules/BaseModule.js +9 -10
  124. package/lib/module/modules/BaseModule.js.map +1 -1
  125. package/lib/module/modules/computer_vision/ClassificationModule.js +8 -5
  126. package/lib/module/modules/computer_vision/ClassificationModule.js.map +1 -1
  127. package/lib/module/modules/computer_vision/ImageSegmentationModule.js +28 -0
  128. package/lib/module/modules/computer_vision/ImageSegmentationModule.js.map +1 -0
  129. package/lib/module/modules/computer_vision/ObjectDetectionModule.js +8 -5
  130. package/lib/module/modules/computer_vision/ObjectDetectionModule.js.map +1 -1
  131. package/lib/module/modules/computer_vision/StyleTransferModule.js +8 -5
  132. package/lib/module/modules/computer_vision/StyleTransferModule.js.map +1 -1
  133. package/lib/module/modules/general/ExecutorchModule.js +8 -5
  134. package/lib/module/modules/general/ExecutorchModule.js.map +1 -1
  135. package/lib/module/modules/natural_language_processing/LLMModule.js +46 -27
  136. package/lib/module/modules/natural_language_processing/LLMModule.js.map +1 -1
  137. package/lib/module/modules/natural_language_processing/SpeechToTextModule.js +8 -5
  138. package/lib/module/modules/natural_language_processing/SpeechToTextModule.js.map +1 -1
  139. package/lib/module/modules/natural_language_processing/TextEmbeddingsModule.js +14 -0
  140. package/lib/module/modules/natural_language_processing/TextEmbeddingsModule.js.map +1 -0
  141. package/lib/module/modules/natural_language_processing/TokenizerModule.js +26 -0
  142. package/lib/module/modules/natural_language_processing/TokenizerModule.js.map +1 -0
  143. package/lib/module/native/NativeClassification.js.map +1 -1
  144. package/lib/module/native/NativeImageSegmentation.js +5 -0
  145. package/lib/module/native/NativeImageSegmentation.js.map +1 -0
  146. package/lib/module/native/NativeLLM.js.map +1 -1
  147. package/lib/module/native/NativeTextEmbeddings.js +5 -0
  148. package/lib/module/native/NativeTextEmbeddings.js.map +1 -0
  149. package/lib/module/native/NativeTokenizer.js +5 -0
  150. package/lib/module/native/NativeTokenizer.js.map +1 -0
  151. package/lib/module/native/RnExecutorchModules.js +18 -113
  152. package/lib/module/native/RnExecutorchModules.js.map +1 -1
  153. package/lib/module/types/common.js.map +1 -1
  154. package/lib/module/types/imageSegmentation.js +29 -0
  155. package/lib/module/types/imageSegmentation.js.map +1 -0
  156. package/lib/module/types/llm.js +7 -0
  157. package/lib/module/types/llm.js.map +1 -0
  158. package/lib/module/types/{object_detection.js → objectDetection.js} +1 -1
  159. package/lib/module/types/objectDetection.js.map +1 -0
  160. package/lib/module/types/ocr.js +2 -0
  161. package/lib/module/types/stt.js +82 -0
  162. package/lib/module/types/stt.js.map +1 -0
  163. package/lib/module/utils/ResourceFetcher.js +156 -0
  164. package/lib/module/utils/ResourceFetcher.js.map +1 -0
  165. package/lib/module/utils/llm.js +25 -0
  166. package/lib/module/utils/llm.js.map +1 -0
  167. package/lib/module/utils/stt.js +22 -0
  168. package/lib/module/utils/stt.js.map +1 -0
  169. package/lib/typescript/Error.d.ts +4 -1
  170. package/lib/typescript/Error.d.ts.map +1 -1
  171. package/lib/typescript/constants/{llamaDefaults.d.ts → llmDefaults.d.ts} +5 -5
  172. package/lib/typescript/constants/llmDefaults.d.ts.map +1 -0
  173. package/lib/typescript/constants/modelUrls.d.ts +74 -28
  174. package/lib/typescript/constants/modelUrls.d.ts.map +1 -1
  175. package/lib/typescript/constants/ocr/models.d.ts +285 -0
  176. package/lib/typescript/constants/ocr/models.d.ts.map +1 -0
  177. package/lib/typescript/constants/ocr/symbols.d.ts +73 -1
  178. package/lib/typescript/constants/ocr/symbols.d.ts.map +1 -1
  179. package/lib/typescript/constants/sttDefaults.d.ts +8 -13
  180. package/lib/typescript/constants/sttDefaults.d.ts.map +1 -1
  181. package/lib/typescript/controllers/LLMController.d.ts +46 -0
  182. package/lib/typescript/controllers/LLMController.d.ts.map +1 -0
  183. package/lib/typescript/controllers/OCRController.d.ts.map +1 -1
  184. package/lib/typescript/controllers/SpeechToTextController.d.ts +30 -16
  185. package/lib/typescript/controllers/SpeechToTextController.d.ts.map +1 -1
  186. package/lib/typescript/controllers/VerticalOCRController.d.ts +1 -1
  187. package/lib/typescript/controllers/VerticalOCRController.d.ts.map +1 -1
  188. package/lib/typescript/hooks/computer_vision/useClassification.d.ts +5 -5
  189. package/lib/typescript/hooks/computer_vision/useClassification.d.ts.map +1 -1
  190. package/lib/typescript/hooks/computer_vision/useImageSegmentation.d.ts +37 -0
  191. package/lib/typescript/hooks/computer_vision/useImageSegmentation.d.ts.map +1 -0
  192. package/lib/typescript/hooks/computer_vision/useOCR.d.ts +2 -1
  193. package/lib/typescript/hooks/computer_vision/useOCR.d.ts.map +1 -1
  194. package/lib/typescript/hooks/computer_vision/useObjectDetection.d.ts +5 -4
  195. package/lib/typescript/hooks/computer_vision/useObjectDetection.d.ts.map +1 -1
  196. package/lib/typescript/hooks/computer_vision/useStyleTransfer.d.ts +4 -2
  197. package/lib/typescript/hooks/computer_vision/useStyleTransfer.d.ts.map +1 -1
  198. package/lib/typescript/hooks/computer_vision/useVerticalOCR.d.ts +2 -1
  199. package/lib/typescript/hooks/computer_vision/useVerticalOCR.d.ts.map +1 -1
  200. package/lib/typescript/hooks/general/useExecutorchModule.d.ts +5 -6
  201. package/lib/typescript/hooks/general/useExecutorchModule.d.ts.map +1 -1
  202. package/lib/typescript/hooks/natural_language_processing/useLLM.d.ts +6 -6
  203. package/lib/typescript/hooks/natural_language_processing/useLLM.d.ts.map +1 -1
  204. package/lib/typescript/hooks/natural_language_processing/useSpeechToText.d.ts +7 -3
  205. package/lib/typescript/hooks/natural_language_processing/useSpeechToText.d.ts.map +1 -1
  206. package/lib/typescript/hooks/natural_language_processing/useTextEmbeddings.d.ts +13 -0
  207. package/lib/typescript/hooks/natural_language_processing/useTextEmbeddings.d.ts.map +1 -0
  208. package/lib/typescript/hooks/natural_language_processing/useTokenizer.d.ts +16 -0
  209. package/lib/typescript/hooks/natural_language_processing/useTokenizer.d.ts.map +1 -0
  210. package/lib/typescript/hooks/useModule.d.ts +11 -10
  211. package/lib/typescript/hooks/useModule.d.ts.map +1 -1
  212. package/lib/typescript/index.d.ts +15 -2
  213. package/lib/typescript/index.d.ts.map +1 -1
  214. package/lib/typescript/modules/BaseModule.d.ts +4 -5
  215. package/lib/typescript/modules/BaseModule.d.ts.map +1 -1
  216. package/lib/typescript/modules/computer_vision/ClassificationModule.d.ts +7 -7
  217. package/lib/typescript/modules/computer_vision/ClassificationModule.d.ts.map +1 -1
  218. package/lib/typescript/modules/computer_vision/ImageSegmentationModule.d.ts +32 -0
  219. package/lib/typescript/modules/computer_vision/ImageSegmentationModule.d.ts.map +1 -0
  220. package/lib/typescript/modules/computer_vision/ObjectDetectionModule.d.ts +6 -5
  221. package/lib/typescript/modules/computer_vision/ObjectDetectionModule.d.ts.map +1 -1
  222. package/lib/typescript/modules/computer_vision/StyleTransferModule.d.ts +6 -5
  223. package/lib/typescript/modules/computer_vision/StyleTransferModule.d.ts.map +1 -1
  224. package/lib/typescript/modules/general/ExecutorchModule.d.ts +4 -3
  225. package/lib/typescript/modules/general/ExecutorchModule.d.ts.map +1 -1
  226. package/lib/typescript/modules/natural_language_processing/LLMModule.d.ts +19 -5
  227. package/lib/typescript/modules/natural_language_processing/LLMModule.d.ts.map +1 -1
  228. package/lib/typescript/modules/natural_language_processing/SpeechToTextModule.d.ts +7 -4
  229. package/lib/typescript/modules/natural_language_processing/SpeechToTextModule.d.ts.map +1 -1
  230. package/lib/typescript/modules/natural_language_processing/TextEmbeddingsModule.d.ts +8 -0
  231. package/lib/typescript/modules/natural_language_processing/TextEmbeddingsModule.d.ts.map +1 -0
  232. package/lib/typescript/modules/natural_language_processing/TokenizerModule.d.ts +12 -0
  233. package/lib/typescript/modules/natural_language_processing/TokenizerModule.d.ts.map +1 -0
  234. package/lib/typescript/native/NativeClassification.d.ts.map +1 -1
  235. package/lib/typescript/native/NativeImageSegmentation.d.ts +10 -0
  236. package/lib/typescript/native/NativeImageSegmentation.d.ts.map +1 -0
  237. package/lib/typescript/native/NativeLLM.d.ts +3 -4
  238. package/lib/typescript/native/NativeLLM.d.ts.map +1 -1
  239. package/lib/typescript/native/NativeObjectDetection.d.ts +1 -1
  240. package/lib/typescript/native/NativeObjectDetection.d.ts.map +1 -1
  241. package/lib/typescript/native/NativeSpeechToText.d.ts +2 -2
  242. package/lib/typescript/native/NativeSpeechToText.d.ts.map +1 -1
  243. package/lib/typescript/native/NativeTextEmbeddings.d.ts +8 -0
  244. package/lib/typescript/native/NativeTextEmbeddings.d.ts.map +1 -0
  245. package/lib/typescript/native/NativeTokenizer.d.ts +12 -0
  246. package/lib/typescript/native/NativeTokenizer.d.ts.map +1 -0
  247. package/lib/typescript/native/RnExecutorchModules.d.ts +18 -41
  248. package/lib/typescript/native/RnExecutorchModules.d.ts.map +1 -1
  249. package/lib/typescript/types/common.d.ts +1 -26
  250. package/lib/typescript/types/common.d.ts.map +1 -1
  251. package/lib/typescript/types/imageSegmentation.d.ts +25 -0
  252. package/lib/typescript/types/imageSegmentation.d.ts.map +1 -0
  253. package/lib/typescript/types/llm.d.ts +38 -0
  254. package/lib/typescript/types/llm.d.ts.map +1 -0
  255. package/lib/typescript/types/{object_detection.d.ts → objectDetection.d.ts} +1 -1
  256. package/lib/typescript/types/objectDetection.d.ts.map +1 -0
  257. package/lib/typescript/types/ocr.d.ts +2 -1
  258. package/lib/typescript/types/ocr.d.ts.map +1 -1
  259. package/lib/typescript/types/stt.d.ts +91 -0
  260. package/lib/typescript/types/stt.d.ts.map +1 -0
  261. package/lib/typescript/utils/ResourceFetcher.d.ts +17 -0
  262. package/lib/typescript/utils/ResourceFetcher.d.ts.map +1 -0
  263. package/lib/typescript/utils/llm.d.ts +3 -0
  264. package/lib/typescript/utils/llm.d.ts.map +1 -0
  265. package/lib/typescript/utils/stt.d.ts +2 -0
  266. package/lib/typescript/utils/stt.d.ts.map +1 -0
  267. package/package.json +13 -49
  268. package/react-native-executorch.podspec +1 -1
  269. package/src/Error.ts +16 -3
  270. package/src/constants/llmDefaults.ts +14 -0
  271. package/src/constants/modelUrls.ts +146 -39
  272. package/src/constants/ocr/models.ts +453 -0
  273. package/src/constants/ocr/symbols.ts +147 -3
  274. package/src/constants/sttDefaults.ts +55 -37
  275. package/src/controllers/LLMController.ts +286 -0
  276. package/src/controllers/OCRController.ts +14 -28
  277. package/src/controllers/SpeechToTextController.ts +318 -180
  278. package/src/controllers/VerticalOCRController.ts +17 -32
  279. package/src/hooks/computer_vision/useClassification.ts +11 -26
  280. package/src/hooks/computer_vision/useImageSegmentation.ts +18 -0
  281. package/src/hooks/computer_vision/useOCR.ts +17 -5
  282. package/src/hooks/computer_vision/useObjectDetection.ts +10 -24
  283. package/src/hooks/computer_vision/useStyleTransfer.ts +9 -25
  284. package/src/hooks/computer_vision/useVerticalOCR.ts +11 -4
  285. package/src/hooks/general/useExecutorchModule.ts +10 -50
  286. package/src/hooks/natural_language_processing/useLLM.ts +80 -97
  287. package/src/hooks/natural_language_processing/useSpeechToText.ts +39 -12
  288. package/src/hooks/natural_language_processing/useTextEmbeddings.ts +18 -0
  289. package/src/hooks/natural_language_processing/useTokenizer.ts +61 -0
  290. package/src/hooks/useModule.ts +32 -92
  291. package/src/index.tsx +16 -2
  292. package/src/modules/BaseModule.ts +16 -26
  293. package/src/modules/computer_vision/ClassificationModule.ts +13 -8
  294. package/src/modules/computer_vision/ImageSegmentationModule.ts +39 -0
  295. package/src/modules/computer_vision/ObjectDetectionModule.ts +13 -8
  296. package/src/modules/computer_vision/StyleTransferModule.ts +13 -8
  297. package/src/modules/general/ExecutorchModule.ts +11 -6
  298. package/src/modules/natural_language_processing/LLMModule.ts +64 -51
  299. package/src/modules/natural_language_processing/SpeechToTextModule.ts +25 -10
  300. package/src/modules/natural_language_processing/TextEmbeddingsModule.ts +18 -0
  301. package/src/modules/natural_language_processing/TokenizerModule.ts +34 -0
  302. package/src/native/NativeClassification.ts +0 -1
  303. package/src/native/NativeImageSegmentation.ts +14 -0
  304. package/src/native/NativeLLM.ts +3 -10
  305. package/src/native/NativeObjectDetection.ts +1 -1
  306. package/src/native/NativeSpeechToText.ts +2 -2
  307. package/src/native/NativeTextEmbeddings.ts +9 -0
  308. package/src/native/NativeTokenizer.ts +13 -0
  309. package/src/native/RnExecutorchModules.ts +54 -234
  310. package/src/types/common.ts +1 -44
  311. package/src/types/imageSegmentation.ts +25 -0
  312. package/src/types/llm.ts +57 -0
  313. package/src/types/ocr.ts +3 -1
  314. package/src/types/stt.ts +93 -0
  315. package/src/utils/ResourceFetcher.ts +196 -0
  316. package/src/utils/llm.ts +34 -0
  317. package/src/utils/stt.ts +28 -0
  318. package/android/src/main/java/com/swmansion/rnexecutorch/utils/llms/ConversationManager.kt +0 -68
  319. package/ios/ExecutorchLib.xcframework/ios-arm64-simulator/ExecutorchLib.framework/_CodeSignature/CodeResources +0 -124
  320. package/ios/RnExecutorch/utils/llms/Constants.h +0 -6
  321. package/ios/RnExecutorch/utils/llms/Constants.mm +0 -23
  322. package/ios/RnExecutorch/utils/llms/ConversationManager.h +0 -26
  323. package/ios/RnExecutorch/utils/llms/ConversationManager.mm +0 -71
  324. package/lib/module/constants/llamaDefaults.js.map +0 -1
  325. package/lib/module/modules/computer_vision/BaseCVModule.js +0 -14
  326. package/lib/module/modules/computer_vision/BaseCVModule.js.map +0 -1
  327. package/lib/module/types/object_detection.js.map +0 -1
  328. package/lib/module/utils/fetchResource.js +0 -93
  329. package/lib/module/utils/fetchResource.js.map +0 -1
  330. package/lib/module/utils/listDownloadedResources.js +0 -13
  331. package/lib/module/utils/listDownloadedResources.js.map +0 -1
  332. package/lib/typescript/constants/llamaDefaults.d.ts.map +0 -1
  333. package/lib/typescript/modules/computer_vision/BaseCVModule.d.ts +0 -9
  334. package/lib/typescript/modules/computer_vision/BaseCVModule.d.ts.map +0 -1
  335. package/lib/typescript/types/object_detection.d.ts.map +0 -1
  336. package/lib/typescript/utils/fetchResource.d.ts +0 -3
  337. package/lib/typescript/utils/fetchResource.d.ts.map +0 -1
  338. package/lib/typescript/utils/listDownloadedResources.d.ts +0 -3
  339. package/lib/typescript/utils/listDownloadedResources.d.ts.map +0 -1
  340. package/src/constants/llamaDefaults.ts +0 -9
  341. package/src/modules/computer_vision/BaseCVModule.ts +0 -22
  342. package/src/utils/fetchResource.ts +0 -106
  343. package/src/utils/listDownloadedResources.ts +0 -12
  344. /package/src/types/{object_detection.ts → objectDetection.ts} +0 -0
package/README.md CHANGED
@@ -8,13 +8,25 @@
8
8
 
9
9
  **ExecuTorch** is a novel framework created by Meta that enables running AI models on devices such as mobile phones or microcontrollers. React Native ExecuTorch bridges the gap between React Native and native platform capabilities, allowing developers to run AI models locally on mobile devices with state-of-the-art performance, without requiring deep knowledge of native code or machine learning internals.
10
10
 
11
+ **Table of contents:**
12
+
13
+ - [Compatibility](#compatibility)
14
+ - [Ready-made models 🤖](#readymade-models-)
15
+ - [Documentation 📚](#documentation-)
16
+ - [🦙 Quickstart - Running Llama](#-quickstart---running-llama)
17
+ - [Minimal supported versions](#minimal-supported-versions)
18
+ - [Examples 📲](#examples-)
19
+ - [Warning](#warning)
20
+ - [License](#license)
21
+ - [What's next?](#whats-next)
22
+
11
23
  ## Compatibility
12
24
 
13
25
  React Native Executorch supports only the [New React Native architecture](https://reactnative.dev/architecture/landing-page).
14
26
 
15
27
  If your app still runs on the old architecture, please consider upgrading to the New Architecture.
16
28
 
17
- ## Readymade models 🤖
29
+ ## Ready-made models 🤖
18
30
 
19
31
  To run any AI model in ExecuTorch, you need to export it to a `.pte` format. If you're interested in experimenting with your own models, we highly encourage you to check out the [Python API](https://pypi.org/project/executorch/). If you prefer focusing on developing your React Native app, we will cover several common use cases. For more details, please refer to the documentation.
20
32
 
@@ -43,16 +55,17 @@ Add this to your component file:
43
55
 
44
56
  ```tsx
45
57
  import {
46
- LLAMA3_2_3B_QLORA,
47
- LLAMA3_2_3B_TOKENIZER,
48
58
  useLLM,
59
+ LLAMA3_2_1B,
60
+ LLAMA3_2_TOKENIZER_CONFIG,
49
61
  } from 'react-native-executorch';
50
62
 
51
63
  function MyComponent() {
52
64
  // Initialize the model 🚀
53
65
  const llama = useLLM({
54
- modelSource: LLAMA3_2_3B_QLORA,
55
- tokenizerSource: LLAMA3_2_3B_TOKENIZER,
66
+ modelSource: LLAMA3_2_1B,
67
+ tokenizerSource: LLAMA3_2_TOKENIZER,
68
+ tokenizerConfigSource: LLAMA3_2_TOKENIZER_CONFIG,
56
69
  });
57
70
  // ... rest of your component
58
71
  }
@@ -64,11 +77,14 @@ function MyComponent() {
64
77
 
65
78
  ```tsx
66
79
  const handleGenerate = async () => {
67
- const prompt = 'The meaning of life is';
68
-
69
- // Generate text based on your desired prompt
70
- const response = await llama.generate(prompt);
71
- console.log('Llama says:', response);
80
+ const chat = [
81
+ { role: 'system' content: 'You are a helpful assistant' }
82
+ { role: 'user', content: 'What is the meaning of life?' }
83
+ ];
84
+
85
+ // Chat completion
86
+ await llm.generate(chat);
87
+ console.log('Llama says:', llm.response);
72
88
  };
73
89
  ```
74
90
 
@@ -80,13 +96,14 @@ The minimal supported version is 17.0 for iOS and Android 13.
80
96
 
81
97
  https://github.com/user-attachments/assets/27ab3406-c7f1-4618-a981-6c86b53547ee
82
98
 
83
- We currently host two example apps demonstrating use cases of our library:
99
+ We currently host a few example apps demonstrating use cases of our library:
84
100
 
101
+ - examples/llm - chat application showcasing use of LLMs
85
102
  - examples/speech-to-text - Whisper and Moonshine models ready for transcription tasks
86
103
  - examples/computer-vision - computer vision related tasks
87
- - examples/llama - chat applications showcasing use of LLMs
104
+ - examples/text-embeddings - computing text representations for semantic search
88
105
 
89
- If you would like to run it, navigate to it's project directory, for example `examples/llama` from the repository root and install dependencies with:
106
+ If you would like to run it, navigate to it's project directory, for example `examples/llm` from the repository root and install dependencies with:
90
107
 
91
108
  ```bash
92
109
  yarn
@@ -103,7 +103,7 @@ dependencies {
103
103
  implementation "com.facebook.react:react-android:+"
104
104
  implementation 'org.opencv:opencv:4.10.0'
105
105
  implementation "org.jetbrains.kotlin:kotlin-stdlib:$kotlin_version"
106
- implementation 'com.github.software-mansion:react-native-executorch:main-SNAPSHOT'
106
+ implementation(files("libs/executorch.aar"))
107
107
  implementation 'org.opencv:opencv:4.10.0'
108
108
  implementation("com.squareup.okhttp3:okhttp:4.9.2")
109
109
  }
Binary file
@@ -9,7 +9,6 @@ import com.swmansion.rnexecutorch.utils.ETError
9
9
  import com.swmansion.rnexecutorch.utils.TensorUtils
10
10
  import org.pytorch.executorch.EValue
11
11
  import org.pytorch.executorch.Module
12
- import java.net.URL
13
12
 
14
13
  class ETModule(
15
14
  reactContext: ReactApplicationContext,
@@ -23,7 +22,7 @@ class ETModule(
23
22
  modelSource: String,
24
23
  promise: Promise,
25
24
  ) {
26
- module = Module.load(URL(modelSource).path)
25
+ module = Module.load(modelSource)
27
26
  promise.resolve(0)
28
27
  }
29
28
 
@@ -0,0 +1,58 @@
1
+ package com.swmansion.rnexecutorch
2
+
3
+ import android.util.Log
4
+ import com.facebook.react.bridge.Promise
5
+ import com.facebook.react.bridge.ReactApplicationContext
6
+ import com.facebook.react.bridge.ReadableArray
7
+ import com.swmansion.rnexecutorch.models.imagesegmentation.ImageSegmentationModel
8
+ import com.swmansion.rnexecutorch.utils.ETError
9
+ import com.swmansion.rnexecutorch.utils.ImageProcessor
10
+ import org.opencv.android.OpenCVLoader
11
+
12
+ class ImageSegmentation(
13
+ reactContext: ReactApplicationContext,
14
+ ) : NativeImageSegmentationSpec(reactContext) {
15
+ private lateinit var model: ImageSegmentationModel
16
+
17
+ companion object {
18
+ const val NAME = "ImageSegmentation"
19
+
20
+ init {
21
+ if (!OpenCVLoader.initLocal()) {
22
+ Log.d("rn_executorch", "OpenCV not loaded")
23
+ } else {
24
+ Log.d("rn_executorch", "OpenCV loaded")
25
+ }
26
+ }
27
+ }
28
+
29
+ override fun loadModule(
30
+ modelSource: String,
31
+ promise: Promise,
32
+ ) {
33
+ try {
34
+ model = ImageSegmentationModel(reactApplicationContext)
35
+ model.loadModel(modelSource)
36
+ promise.resolve(0)
37
+ } catch (e: Exception) {
38
+ promise.reject(e.message!!, ETError.InvalidModelSource.toString())
39
+ }
40
+ }
41
+
42
+ override fun forward(
43
+ input: String,
44
+ classesOfInterest: ReadableArray,
45
+ resize: Boolean,
46
+ promise: Promise,
47
+ ) {
48
+ try {
49
+ val output =
50
+ model.runModel(Triple(ImageProcessor.readImage(input), classesOfInterest, resize))
51
+ promise.resolve(output)
52
+ } catch (e: Exception) {
53
+ promise.reject(e.message!!, e.message)
54
+ }
55
+ }
56
+
57
+ override fun getName(): String = NAME
58
+ }
@@ -3,22 +3,13 @@ package com.swmansion.rnexecutorch
3
3
  import android.util.Log
4
4
  import com.facebook.react.bridge.Promise
5
5
  import com.facebook.react.bridge.ReactApplicationContext
6
- import com.facebook.react.bridge.ReadableArray
7
- import com.swmansion.rnexecutorch.utils.ArrayUtils
8
- import com.swmansion.rnexecutorch.utils.llms.ChatRole
9
- import com.swmansion.rnexecutorch.utils.llms.ConversationManager
10
- import com.swmansion.rnexecutorch.utils.llms.END_OF_TEXT_TOKEN
11
- import org.pytorch.executorch.LlamaCallback
12
- import org.pytorch.executorch.LlamaModule
13
- import java.net.URL
6
+ import org.pytorch.executorch.extension.llm.LlmCallback
7
+ import org.pytorch.executorch.extension.llm.LlmModule
14
8
 
15
9
  class LLM(
16
10
  reactContext: ReactApplicationContext,
17
- ) : NativeLLMSpec(reactContext),
18
- LlamaCallback {
19
- private var llamaModule: LlamaModule? = null
20
- private var tempLlamaResponse = StringBuilder()
21
- private lateinit var conversationManager: ConversationManager
11
+ ) : NativeLLMSpec(reactContext), LlmCallback {
12
+ private var llmModule: LlmModule? = null
22
13
 
23
14
  override fun getName(): String = NAME
24
15
 
@@ -28,7 +19,6 @@ class LLM(
28
19
 
29
20
  override fun onResult(result: String) {
30
21
  emitOnToken(result)
31
- this.tempLlamaResponse.append(result)
32
22
  }
33
23
 
34
24
  override fun onStats(tps: Float) {
@@ -38,59 +28,33 @@ class LLM(
38
28
  override fun loadLLM(
39
29
  modelSource: String,
40
30
  tokenizerSource: String,
41
- systemPrompt: String,
42
- messageHistory: ReadableArray,
43
- contextWindowLength: Double,
44
31
  promise: Promise,
45
32
  ) {
46
33
  try {
47
- this.conversationManager =
48
- ConversationManager(
49
- contextWindowLength.toInt(),
50
- systemPrompt,
51
- ArrayUtils.createMapArray<String>(messageHistory),
52
- )
53
- llamaModule = LlamaModule(1, URL(modelSource).path, URL(tokenizerSource).path, 0.7f)
54
- this.tempLlamaResponse.clear()
34
+ llmModule = LlmModule(modelSource, tokenizerSource, 0.7f)
55
35
  promise.resolve("Model loaded successfully")
56
36
  } catch (e: Exception) {
57
37
  promise.reject("Model loading failed", e.message)
58
38
  }
59
39
  }
60
40
 
61
- override fun runInference(
41
+ override fun forward(
62
42
  input: String,
63
43
  promise: Promise,
64
44
  ) {
65
- this.conversationManager.addResponse(input, ChatRole.USER)
66
- val conversation = this.conversationManager.getConversation()
67
-
68
45
  Thread {
69
- llamaModule!!.generate(conversation, (conversation.length * 0.75).toInt() + 64, this, false)
70
-
71
- // When we call .interrupt(), the LLM doesn't produce EOT token, that also could happen when the
72
- // generated sequence length is larger than specified in the JNI callback, hence we check if EOT
73
- // is there and if not, we append it to the output and emit the EOT token to the JS side.
74
- if (!this.tempLlamaResponse.endsWith(END_OF_TEXT_TOKEN)) {
75
- this.onResult(END_OF_TEXT_TOKEN)
76
- }
77
-
78
- // We want to add the LLM response to the conversation once all the tokens are generated.
79
- // Each token is appended to the tempLlamaResponse StringBuilder in onResult callback.
80
- this.conversationManager.addResponse(this.tempLlamaResponse.toString(), ChatRole.ASSISTANT)
81
- this.tempLlamaResponse.clear()
82
- Log.d("ExecutorchLib", this.conversationManager.getConversation())
83
- }.start()
84
-
85
- promise.resolve("Inference completed successfully")
46
+ llmModule!!.generate(input, this)
47
+ promise.resolve("Inference completed successfully")
48
+ }
49
+ .start()
86
50
  }
87
51
 
88
52
  override fun interrupt() {
89
- llamaModule!!.stop()
53
+ llmModule!!.stop()
90
54
  }
91
55
 
92
- override fun deleteModule() {
93
- llamaModule = null
56
+ override fun releaseResources() {
57
+ llmModule = null
94
58
  }
95
59
 
96
60
  companion object {
@@ -30,6 +30,12 @@ class RnExecutorchPackage : TurboReactPackage() {
30
30
  OCR(reactContext)
31
31
  } else if (name == VerticalOCR.NAME) {
32
32
  VerticalOCR(reactContext)
33
+ } else if (name == ImageSegmentation.NAME) {
34
+ ImageSegmentation(reactContext)
35
+ } else if (name == Tokenizer.NAME) {
36
+ Tokenizer(reactContext)
37
+ } else if (name == TextEmbeddings.NAME) {
38
+ TextEmbeddings(reactContext)
33
39
  } else {
34
40
  null
35
41
  }
@@ -115,6 +121,37 @@ class RnExecutorchPackage : TurboReactPackage() {
115
121
  false, // isCxxModule
116
122
  true,
117
123
  )
124
+
125
+ moduleInfos[ImageSegmentation.NAME] =
126
+ ReactModuleInfo(
127
+ ImageSegmentation.NAME,
128
+ ImageSegmentation.NAME,
129
+ false, // canOverrideExistingModule
130
+ false, // needsEagerInit
131
+ false, // isCxxModule
132
+ true,
133
+ )
134
+
135
+ moduleInfos[Tokenizer.NAME] =
136
+ ReactModuleInfo(
137
+ Tokenizer.NAME,
138
+ Tokenizer.NAME,
139
+ false, // canOverrideExistingModule
140
+ false, // needsEagerInit
141
+ false, // isCxxModule
142
+ true,
143
+ )
144
+
145
+ moduleInfos[TextEmbeddings.NAME] =
146
+ ReactModuleInfo(
147
+ TextEmbeddings.NAME,
148
+ TextEmbeddings.NAME,
149
+ false, // canOverrideExistingModule
150
+ false, // needsEagerInit
151
+ false, // isCxxModule
152
+ true,
153
+ )
154
+
118
155
  moduleInfos
119
156
  }
120
157
  }
@@ -3,7 +3,7 @@ package com.swmansion.rnexecutorch
3
3
  import android.util.Log
4
4
  import com.facebook.react.bridge.Promise
5
5
  import com.facebook.react.bridge.ReactApplicationContext
6
- import com.swmansion.rnexecutorch.models.StyleTransferModel
6
+ import com.swmansion.rnexecutorch.models.styletransfer.StyleTransferModel
7
7
  import com.swmansion.rnexecutorch.utils.ETError
8
8
  import com.swmansion.rnexecutorch.utils.ImageProcessor
9
9
  import org.opencv.android.OpenCVLoader
@@ -0,0 +1,51 @@
1
+ package com.swmansion.rnexecutorch
2
+
3
+ import com.facebook.react.bridge.Promise
4
+ import com.facebook.react.bridge.ReactApplicationContext
5
+ import com.facebook.react.bridge.WritableNativeArray
6
+ import com.swmansion.rnexecutorch.models.textEmbeddings.TextEmbeddingsModel
7
+ import com.swmansion.rnexecutorch.utils.ETError
8
+
9
+ class TextEmbeddings(
10
+ reactContext: ReactApplicationContext,
11
+ ) : NativeTextEmbeddingsSpec(reactContext) {
12
+ private lateinit var textEmbeddingsModel: TextEmbeddingsModel
13
+
14
+ companion object {
15
+ const val NAME = "TextEmbeddings"
16
+ }
17
+
18
+ override fun loadModule(
19
+ modelSource: String,
20
+ tokenizerSource: String,
21
+ promise: Promise,
22
+ ) {
23
+ try {
24
+ textEmbeddingsModel = TextEmbeddingsModel(reactApplicationContext)
25
+
26
+ textEmbeddingsModel.loadModel(modelSource)
27
+ textEmbeddingsModel.loadTokenizer(tokenizerSource)
28
+
29
+ promise.resolve(0)
30
+ } catch (e: Exception) {
31
+ promise.reject(e.message!!, ETError.InvalidModelSource.toString())
32
+ }
33
+ }
34
+
35
+ override fun forward(
36
+ input: String,
37
+ promise: Promise,
38
+ ) {
39
+ try {
40
+ val output = textEmbeddingsModel.runModel(input)
41
+ val writableArray = WritableNativeArray()
42
+ output.forEach { writableArray.pushDouble(it) }
43
+
44
+ promise.resolve(writableArray)
45
+ } catch (e: Exception) {
46
+ promise.reject(e.message!!, e.message)
47
+ }
48
+ }
49
+
50
+ override fun getName(): String = NAME
51
+ }
@@ -0,0 +1,86 @@
1
+ package com.swmansion.rnexecutorch
2
+
3
+ import com.facebook.react.bridge.Promise
4
+ import com.facebook.react.bridge.ReactApplicationContext
5
+ import com.facebook.react.bridge.ReadableArray
6
+ import com.swmansion.rnexecutorch.utils.ArrayUtils.Companion.createIntArray
7
+ import com.swmansion.rnexecutorch.utils.ArrayUtils.Companion.createReadableArrayFromIntArray
8
+ import com.swmansion.rnexecutorch.utils.ETError
9
+ import org.pytorch.executorch.HuggingFaceTokenizer
10
+
11
+ class Tokenizer(
12
+ reactContext: ReactApplicationContext,
13
+ ) : NativeTokenizerSpec(reactContext) {
14
+ private lateinit var tokenizer: HuggingFaceTokenizer
15
+
16
+ companion object {
17
+ const val NAME = "Tokenizer"
18
+ }
19
+
20
+ override fun loadModule(
21
+ tokenizerSource: String,
22
+ promise: Promise,
23
+ ) {
24
+ try {
25
+ tokenizer = HuggingFaceTokenizer(tokenizerSource)
26
+ promise.resolve(0)
27
+ } catch (e: Exception) {
28
+ promise.reject(e.message!!, ETError.InvalidModelSource.toString())
29
+ }
30
+ }
31
+
32
+ override fun decode(
33
+ input: ReadableArray,
34
+ skipSpecialTokens: Boolean,
35
+ promise: Promise,
36
+ ) {
37
+ try {
38
+ promise.resolve(tokenizer.decode(createIntArray(input), skipSpecialTokens))
39
+ } catch (e: Exception) {
40
+ promise.reject(e.message!!, ETError.UndefinedError.toString())
41
+ }
42
+ }
43
+
44
+ override fun encode(
45
+ input: String,
46
+ promise: Promise,
47
+ ) {
48
+ try {
49
+ promise.resolve(createReadableArrayFromIntArray(tokenizer.encode(input)))
50
+ } catch (e: Exception) {
51
+ promise.reject(e.message!!, ETError.UndefinedError.toString())
52
+ }
53
+ }
54
+
55
+ override fun getVocabSize(promise: Promise) {
56
+ try {
57
+ promise.resolve(tokenizer.vocabSize)
58
+ } catch (e: Exception) {
59
+ promise.reject(e.message!!, ETError.UndefinedError.toString())
60
+ }
61
+ }
62
+
63
+ override fun idToToken(
64
+ id: Double,
65
+ promise: Promise,
66
+ ) {
67
+ try {
68
+ promise.resolve(tokenizer.idToToken(id.toInt()))
69
+ } catch (e: Exception) {
70
+ promise.reject(e.message!!, ETError.UndefinedError.toString())
71
+ }
72
+ }
73
+
74
+ override fun tokenToId(
75
+ token: String,
76
+ promise: Promise,
77
+ ) {
78
+ try {
79
+ promise.resolve(tokenizer.tokenToId(token))
80
+ } catch (e: Exception) {
81
+ promise.reject(e.message!!, ETError.UndefinedError.toString())
82
+ }
83
+ }
84
+
85
+ override fun getName(): String = NAME
86
+ }
@@ -5,7 +5,6 @@ import com.swmansion.rnexecutorch.utils.ETError
5
5
  import org.pytorch.executorch.EValue
6
6
  import org.pytorch.executorch.Module
7
7
  import org.pytorch.executorch.Tensor
8
- import java.net.URL
9
8
 
10
9
  abstract class BaseModel<Input, Output>(
11
10
  val context: Context,
@@ -13,12 +12,12 @@ abstract class BaseModel<Input, Output>(
13
12
  protected lateinit var module: Module
14
13
 
15
14
  fun loadModel(modelSource: String) {
16
- module = Module.load(URL(modelSource).path)
15
+ module = Module.load(modelSource)
17
16
  }
18
17
 
19
- protected fun forward(input: EValue): Array<EValue> {
18
+ protected fun forward(vararg inputs: EValue): Array<EValue> {
20
19
  try {
21
- val result = module.forward(input)
20
+ val result = module.forward(*inputs)
22
21
  return result
23
22
  } catch (e: IllegalArgumentException) {
24
23
  // The error is thrown when transformation to Tensor fails
@@ -0,0 +1,48 @@
1
+ package com.swmansion.rnexecutorch.models.textEmbeddings
2
+
3
+ import com.facebook.react.bridge.ReactApplicationContext
4
+ import com.swmansion.rnexecutorch.models.BaseModel
5
+ import org.pytorch.executorch.EValue
6
+ import org.pytorch.executorch.HuggingFaceTokenizer
7
+ import org.pytorch.executorch.Tensor
8
+
9
+ class TextEmbeddingsModel(
10
+ reactApplicationContext: ReactApplicationContext,
11
+ ) : BaseModel<String, DoubleArray>(reactApplicationContext) {
12
+ private lateinit var tokenizer: HuggingFaceTokenizer
13
+
14
+ fun loadTokenizer(tokenizerSource: String) {
15
+ tokenizer = HuggingFaceTokenizer(tokenizerSource)
16
+ }
17
+
18
+ fun preprocess(input: String): Array<LongArray> {
19
+ val inputIds = tokenizer.encode(input).map { it.toLong() }.toLongArray()
20
+ val attentionMask = inputIds.map { if (it != 0L) 1L else 0L }.toLongArray()
21
+ return arrayOf(inputIds, attentionMask) // Shape: [2, tokens]
22
+ }
23
+
24
+ fun postprocess(
25
+ modelOutput: FloatArray, // [tokens * embedding_dim]
26
+ attentionMask: LongArray, // [tokens]
27
+ ): DoubleArray {
28
+ val modelOutputDouble = modelOutput.map { it.toDouble() }.toDoubleArray()
29
+ val embeddings = TextEmbeddingsUtils.meanPooling(modelOutputDouble, attentionMask)
30
+ return TextEmbeddingsUtils.normalize(embeddings)
31
+ }
32
+
33
+ override fun runModel(input: String): DoubleArray {
34
+ val modelInput = preprocess(input)
35
+ val inputsIds = modelInput[0]
36
+ val attentionMask = modelInput[1]
37
+
38
+ val inputsIdsShape = longArrayOf(1, inputsIds.size.toLong())
39
+ val attentionMaskShape = longArrayOf(1, attentionMask.size.toLong())
40
+
41
+ val inputIdsEValue = EValue.from(Tensor.fromBlob(inputsIds, inputsIdsShape))
42
+ val attentionMaskEValue = EValue.from(Tensor.fromBlob(attentionMask, attentionMaskShape))
43
+
44
+ val modelOutput = forward(inputIdsEValue, attentionMaskEValue)[0].toTensor().dataAsFloatArray
45
+
46
+ return postprocess(modelOutput, attentionMask)
47
+ }
48
+ }
@@ -0,0 +1,37 @@
1
+ package com.swmansion.rnexecutorch.models.textEmbeddings
2
+
3
+ import kotlin.math.sqrt
4
+
5
+ class TextEmbeddingsUtils {
6
+ companion object {
7
+ fun meanPooling(
8
+ modelOutput: DoubleArray,
9
+ attentionMask: LongArray,
10
+ ): DoubleArray {
11
+ val attentionMaskLength = attentionMask.size
12
+ val modelOutputLength = modelOutput.size
13
+ val embeddingDim = modelOutputLength / attentionMaskLength
14
+
15
+ val result = DoubleArray(embeddingDim)
16
+ var sumMask = attentionMask.sum().toDouble()
17
+ sumMask = maxOf(sumMask, 1e-9)
18
+
19
+ for (i in 0 until embeddingDim) {
20
+ var sum = 0.0
21
+ for (j in 0 until attentionMaskLength) {
22
+ sum += modelOutput[j * embeddingDim + i] * attentionMask[j]
23
+ }
24
+ result[i] = sum / sumMask
25
+ }
26
+
27
+ return result
28
+ }
29
+
30
+ fun normalize(embeddings: DoubleArray): DoubleArray {
31
+ var sum = embeddings.sumOf { it * it }
32
+ sum = maxOf(sqrt(sum), 1e-9)
33
+
34
+ return embeddings.map { it / sum }.toDoubleArray()
35
+ }
36
+ }
37
+ }
@@ -3,6 +3,7 @@ package com.swmansion.rnexecutorch.models.classification
3
3
  import com.facebook.react.bridge.ReactApplicationContext
4
4
  import com.swmansion.rnexecutorch.models.BaseModel
5
5
  import com.swmansion.rnexecutorch.utils.ImageProcessor
6
+ import com.swmansion.rnexecutorch.utils.softmax
6
7
  import org.opencv.core.Mat
7
8
  import org.opencv.core.Size
8
9
  import org.opencv.imgproc.Imgproc
@@ -0,0 +1,26 @@
1
+ package com.swmansion.rnexecutorch.models.imagesegmentation
2
+
3
+ val deeplabv3_resnet50_labels: Array<String> =
4
+ arrayOf(
5
+ "BACKGROUND",
6
+ "AEROPLANE",
7
+ "BICYCLE",
8
+ "BIRD",
9
+ "BOAT",
10
+ "BOTTLE",
11
+ "BUS",
12
+ "CAR",
13
+ "CAT",
14
+ "CHAIR",
15
+ "COW",
16
+ "DININGTABLE",
17
+ "DOG",
18
+ "HORSE",
19
+ "MOTORBIKE",
20
+ "PERSON",
21
+ "POTTEDPLANT",
22
+ "SHEEP",
23
+ "SOFA",
24
+ "TRAIN",
25
+ "TVMONITOR",
26
+ )