react-native-executorch 0.5.15 → 0.6.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (277) hide show
  1. package/README.md +42 -36
  2. package/android/CMakeLists.txt +13 -25
  3. package/android/build.gradle +2 -3
  4. package/android/libs/classes.jar +0 -0
  5. package/android/src/main/cpp/CMakeLists.txt +2 -1
  6. package/common/rnexecutorch/RnExecutorchInstaller.cpp +18 -0
  7. package/common/rnexecutorch/TokenizerModule.cpp +3 -3
  8. package/common/rnexecutorch/data_processing/Numerical.cpp +31 -23
  9. package/common/rnexecutorch/data_processing/Numerical.h +6 -1
  10. package/common/rnexecutorch/data_processing/dsp.cpp +0 -46
  11. package/common/rnexecutorch/host_objects/JsiConversions.h +16 -0
  12. package/common/rnexecutorch/host_objects/ModelHostObject.h +26 -11
  13. package/common/rnexecutorch/jsi/OwningArrayBuffer.h +19 -2
  14. package/common/rnexecutorch/metaprogramming/TypeConcepts.h +0 -20
  15. package/common/rnexecutorch/models/BaseModel.cpp +12 -11
  16. package/common/rnexecutorch/models/BaseModel.h +18 -10
  17. package/common/rnexecutorch/models/embeddings/BaseEmbeddings.cpp +3 -11
  18. package/common/rnexecutorch/models/embeddings/text/TextEmbeddings.cpp +0 -1
  19. package/common/rnexecutorch/models/image_segmentation/ImageSegmentation.cpp +6 -12
  20. package/common/rnexecutorch/models/llm/LLM.cpp +25 -8
  21. package/common/rnexecutorch/models/llm/LLM.h +4 -4
  22. package/common/rnexecutorch/models/ocr/CTCLabelConverter.h +1 -1
  23. package/common/rnexecutorch/models/ocr/utils/RecognitionHandlerUtils.cpp +7 -4
  24. package/common/rnexecutorch/models/speech_to_text/SpeechToText.cpp +8 -13
  25. package/common/rnexecutorch/models/speech_to_text/SpeechToText.h +1 -3
  26. package/common/rnexecutorch/models/speech_to_text/asr/ASR.cpp +12 -19
  27. package/common/rnexecutorch/models/speech_to_text/asr/ASR.h +4 -5
  28. package/common/rnexecutorch/models/text_to_image/Constants.h +9 -0
  29. package/common/rnexecutorch/models/text_to_image/Decoder.cpp +32 -0
  30. package/common/rnexecutorch/models/text_to_image/Decoder.h +24 -0
  31. package/common/rnexecutorch/models/text_to_image/Encoder.cpp +44 -0
  32. package/common/rnexecutorch/models/text_to_image/Encoder.h +32 -0
  33. package/common/rnexecutorch/models/text_to_image/Scheduler.cpp +152 -0
  34. package/common/rnexecutorch/models/text_to_image/Scheduler.h +41 -0
  35. package/common/rnexecutorch/models/text_to_image/TextToImage.cpp +141 -0
  36. package/common/rnexecutorch/models/text_to_image/TextToImage.h +64 -0
  37. package/common/rnexecutorch/models/text_to_image/UNet.cpp +38 -0
  38. package/common/rnexecutorch/models/text_to_image/UNet.h +28 -0
  39. package/common/rnexecutorch/models/voice_activity_detection/Constants.h +27 -0
  40. package/common/rnexecutorch/models/voice_activity_detection/Types.h +12 -0
  41. package/common/rnexecutorch/models/voice_activity_detection/Utils.cpp +15 -0
  42. package/common/rnexecutorch/models/voice_activity_detection/Utils.h +13 -0
  43. package/common/rnexecutorch/models/voice_activity_detection/VoiceActivityDetection.cpp +160 -0
  44. package/common/rnexecutorch/models/voice_activity_detection/VoiceActivityDetection.h +36 -0
  45. package/common/rnexecutorch/tests/CMakeLists.txt +30 -0
  46. package/common/rnexecutorch/tests/NumericalTest.cpp +110 -0
  47. package/common/rnexecutorch/tests/README.md +30 -13
  48. package/common/rnexecutorch/threads/GlobalThreadPool.h +4 -0
  49. package/common/runner/arange_util.cpp +44 -0
  50. package/common/runner/arange_util.h +37 -0
  51. package/common/runner/constants.h +28 -0
  52. package/common/runner/io_manager.h +240 -0
  53. package/common/runner/irunner.h +87 -16
  54. package/common/runner/kernel_includes.h +23 -0
  55. package/common/runner/runner.cpp +151 -66
  56. package/common/runner/runner.h +39 -22
  57. package/common/runner/sampler.cpp +8 -1
  58. package/common/runner/sampler.h +4 -2
  59. package/common/runner/stats.h +1 -4
  60. package/common/runner/text_decoder_runner.cpp +26 -12
  61. package/common/runner/text_decoder_runner.h +52 -31
  62. package/common/runner/text_prefiller.cpp +46 -12
  63. package/common/runner/text_prefiller.h +38 -4
  64. package/common/runner/text_token_generator.h +51 -26
  65. package/common/runner/util.h +53 -8
  66. package/ios/RnExecutorch.xcodeproj/project.pbxproj +0 -23
  67. package/lib/module/Error.js +1 -0
  68. package/lib/module/Error.js.map +1 -1
  69. package/lib/module/constants/directories.js +1 -1
  70. package/lib/module/constants/directories.js.map +1 -1
  71. package/lib/module/constants/modelUrls.js +32 -1
  72. package/lib/module/constants/modelUrls.js.map +1 -1
  73. package/lib/module/constants/ocr/models.js +7 -7
  74. package/lib/module/constants/ocr/models.js.map +1 -1
  75. package/lib/module/constants/ocr/symbols.js +3 -2
  76. package/lib/module/constants/ocr/symbols.js.map +1 -1
  77. package/lib/module/controllers/LLMController.js +10 -1
  78. package/lib/module/controllers/LLMController.js.map +1 -1
  79. package/lib/module/controllers/OCRController.js +3 -3
  80. package/lib/module/controllers/OCRController.js.map +1 -1
  81. package/lib/module/controllers/VerticalOCRController.js +2 -2
  82. package/lib/module/controllers/VerticalOCRController.js.map +1 -1
  83. package/lib/module/hooks/computer_vision/useOCR.js +3 -3
  84. package/lib/module/hooks/computer_vision/useOCR.js.map +1 -1
  85. package/lib/module/hooks/{useNonStaticModule.js → computer_vision/useTextToImage.js} +21 -16
  86. package/lib/module/hooks/computer_vision/useTextToImage.js.map +1 -0
  87. package/lib/module/hooks/computer_vision/useVerticalOCR.js +3 -3
  88. package/lib/module/hooks/computer_vision/useVerticalOCR.js.map +1 -1
  89. package/lib/module/hooks/natural_language_processing/useLLM.js +3 -3
  90. package/lib/module/hooks/natural_language_processing/useLLM.js.map +1 -1
  91. package/lib/module/hooks/natural_language_processing/useTokenizer.js +5 -5
  92. package/lib/module/hooks/natural_language_processing/useTokenizer.js.map +1 -1
  93. package/lib/module/hooks/natural_language_processing/useVAD.js +13 -0
  94. package/lib/module/hooks/natural_language_processing/useVAD.js.map +1 -0
  95. package/lib/module/index.js +7 -2
  96. package/lib/module/index.js.map +1 -1
  97. package/lib/module/modules/computer_vision/OCRModule.js +2 -2
  98. package/lib/module/modules/computer_vision/OCRModule.js.map +1 -1
  99. package/lib/module/modules/computer_vision/TextToImageModule.js +48 -0
  100. package/lib/module/modules/computer_vision/TextToImageModule.js.map +1 -0
  101. package/lib/module/modules/computer_vision/VerticalOCRModule.js +2 -2
  102. package/lib/module/modules/computer_vision/VerticalOCRModule.js.map +1 -1
  103. package/lib/module/modules/natural_language_processing/SpeechToTextModule.js +7 -4
  104. package/lib/module/modules/natural_language_processing/SpeechToTextModule.js.map +1 -1
  105. package/lib/module/modules/natural_language_processing/VADModule.js +19 -0
  106. package/lib/module/modules/natural_language_processing/VADModule.js.map +1 -0
  107. package/lib/module/types/llm.js.map +1 -1
  108. package/lib/module/types/vad.js +2 -0
  109. package/lib/module/types/vad.js.map +1 -0
  110. package/lib/module/utils/ResourceFetcher.js +2 -1
  111. package/lib/module/utils/ResourceFetcher.js.map +1 -1
  112. package/lib/module/utils/ResourceFetcherUtils.js +6 -6
  113. package/lib/module/utils/ResourceFetcherUtils.js.map +1 -1
  114. package/lib/typescript/Error.d.ts +1 -0
  115. package/lib/typescript/Error.d.ts.map +1 -1
  116. package/lib/typescript/constants/modelUrls.d.ts +23 -0
  117. package/lib/typescript/constants/modelUrls.d.ts.map +1 -1
  118. package/lib/typescript/constants/ocr/symbols.d.ts +1 -1
  119. package/lib/typescript/constants/ocr/symbols.d.ts.map +1 -1
  120. package/lib/typescript/controllers/LLMController.d.ts.map +1 -1
  121. package/lib/typescript/controllers/OCRController.d.ts +1 -1
  122. package/lib/typescript/controllers/OCRController.d.ts.map +1 -1
  123. package/lib/typescript/controllers/VerticalOCRController.d.ts +1 -1
  124. package/lib/typescript/controllers/VerticalOCRController.d.ts.map +1 -1
  125. package/lib/typescript/hooks/computer_vision/useOCR.d.ts +1 -1
  126. package/lib/typescript/hooks/computer_vision/useOCR.d.ts.map +1 -1
  127. package/lib/typescript/hooks/computer_vision/useTextToImage.d.ts +22 -0
  128. package/lib/typescript/hooks/computer_vision/useTextToImage.d.ts.map +1 -0
  129. package/lib/typescript/hooks/computer_vision/useVerticalOCR.d.ts +1 -1
  130. package/lib/typescript/hooks/computer_vision/useVerticalOCR.d.ts.map +1 -1
  131. package/lib/typescript/hooks/natural_language_processing/useLLM.d.ts.map +1 -1
  132. package/lib/typescript/hooks/natural_language_processing/useSpeechToText.d.ts +2 -2
  133. package/lib/typescript/hooks/natural_language_processing/useVAD.d.ts +16 -0
  134. package/lib/typescript/hooks/natural_language_processing/useVAD.d.ts.map +1 -0
  135. package/lib/typescript/index.d.ts +8 -1
  136. package/lib/typescript/index.d.ts.map +1 -1
  137. package/lib/typescript/modules/computer_vision/OCRModule.d.ts +1 -1
  138. package/lib/typescript/modules/computer_vision/OCRModule.d.ts.map +1 -1
  139. package/lib/typescript/modules/computer_vision/TextToImageModule.d.ts +16 -0
  140. package/lib/typescript/modules/computer_vision/TextToImageModule.d.ts.map +1 -0
  141. package/lib/typescript/modules/computer_vision/VerticalOCRModule.d.ts +1 -1
  142. package/lib/typescript/modules/computer_vision/VerticalOCRModule.d.ts.map +1 -1
  143. package/lib/typescript/modules/natural_language_processing/SpeechToTextModule.d.ts +3 -2
  144. package/lib/typescript/modules/natural_language_processing/SpeechToTextModule.d.ts.map +1 -1
  145. package/lib/typescript/modules/natural_language_processing/VADModule.d.ts +10 -0
  146. package/lib/typescript/modules/natural_language_processing/VADModule.d.ts.map +1 -0
  147. package/lib/typescript/types/llm.d.ts +2 -0
  148. package/lib/typescript/types/llm.d.ts.map +1 -1
  149. package/lib/typescript/types/vad.d.ts +5 -0
  150. package/lib/typescript/types/vad.d.ts.map +1 -0
  151. package/lib/typescript/utils/ResourceFetcher.d.ts +29 -0
  152. package/lib/typescript/utils/ResourceFetcher.d.ts.map +1 -1
  153. package/lib/typescript/utils/ResourceFetcherUtils.d.ts +2 -2
  154. package/lib/typescript/utils/ResourceFetcherUtils.d.ts.map +1 -1
  155. package/package.json +11 -8
  156. package/react-native-executorch.podspec +9 -9
  157. package/src/Error.ts +1 -0
  158. package/src/constants/directories.ts +1 -1
  159. package/src/constants/modelUrls.ts +36 -1
  160. package/src/constants/ocr/models.ts +7 -7
  161. package/src/constants/ocr/symbols.ts +3 -2
  162. package/src/controllers/LLMController.ts +12 -1
  163. package/src/controllers/OCRController.ts +3 -3
  164. package/src/controllers/VerticalOCRController.ts +2 -2
  165. package/src/hooks/computer_vision/useOCR.ts +4 -5
  166. package/src/hooks/computer_vision/useTextToImage.ts +92 -0
  167. package/src/hooks/computer_vision/useVerticalOCR.ts +4 -5
  168. package/src/hooks/natural_language_processing/useLLM.ts +3 -4
  169. package/src/hooks/natural_language_processing/useTokenizer.ts +5 -5
  170. package/src/hooks/natural_language_processing/useVAD.ts +15 -0
  171. package/src/index.ts +20 -1
  172. package/src/modules/computer_vision/OCRModule.ts +2 -2
  173. package/src/modules/computer_vision/TextToImageModule.ts +93 -0
  174. package/src/modules/computer_vision/VerticalOCRModule.ts +2 -2
  175. package/src/modules/natural_language_processing/SpeechToTextModule.ts +8 -4
  176. package/src/modules/natural_language_processing/VADModule.ts +27 -0
  177. package/src/types/llm.ts +2 -0
  178. package/src/types/vad.ts +4 -0
  179. package/src/utils/ResourceFetcher.ts +2 -1
  180. package/src/utils/ResourceFetcherUtils.ts +8 -8
  181. package/third-party/android/libs/cpuinfo/arm64-v8a/libcpuinfo.so +0 -0
  182. package/third-party/android/libs/executorch/arm64-v8a/libexecutorch.so +0 -0
  183. package/third-party/android/libs/executorch/x86_64/libexecutorch.so +0 -0
  184. package/third-party/android/libs/pthreadpool/arm64-v8a/libpthreadpool.so +0 -0
  185. package/third-party/include/c10/macros/Export.h +0 -78
  186. package/third-party/include/c10/macros/Macros.h +1 -520
  187. package/third-party/include/c10/util/BFloat16-inl.h +1 -339
  188. package/third-party/include/c10/util/BFloat16.h +1 -122
  189. package/third-party/include/c10/util/Half-inl.h +1 -347
  190. package/third-party/include/c10/util/Half.h +6 -419
  191. package/third-party/include/c10/util/TypeSafeSignMath.h +1 -133
  192. package/third-party/include/c10/util/bit_cast.h +1 -43
  193. package/third-party/include/c10/util/complex.h +1 -568
  194. package/third-party/include/c10/util/floating_point_utils.h +1 -33
  195. package/third-party/include/c10/util/irange.h +1 -1
  196. package/third-party/include/c10/util/llvmMathExtras.h +866 -0
  197. package/third-party/include/c10/util/safe_numerics.h +97 -0
  198. package/third-party/include/executorch/ExecuTorchError.h +6 -7
  199. package/third-party/include/executorch/ExecuTorchLLM/ExecuTorchLLM.h +12 -0
  200. package/third-party/include/executorch/ExecuTorchLLM/ExecuTorchLLMConfig.h +56 -0
  201. package/third-party/include/executorch/ExecuTorchLLM/ExecuTorchLLMError.h +16 -0
  202. package/third-party/include/executorch/ExecuTorchLLM/ExecuTorchLLMMultimodalRunner.h +227 -0
  203. package/third-party/include/executorch/ExecuTorchLLM/ExecuTorchLLMTextRunner.h +97 -0
  204. package/third-party/include/executorch/ExecuTorchLLM/module.modulemap +4 -0
  205. package/third-party/include/executorch/ExecuTorchLog.h +1 -0
  206. package/third-party/include/executorch/ExecuTorchModule.h +177 -4
  207. package/third-party/include/executorch/ExecuTorchTensor.h +3 -4
  208. package/third-party/include/executorch/ExecuTorchValue.h +1 -7
  209. package/third-party/include/executorch/extension/module/module.h +139 -8
  210. package/third-party/include/executorch/extension/tensor/tensor.h +1 -0
  211. package/third-party/include/executorch/extension/tensor/tensor_ptr.h +88 -26
  212. package/third-party/include/executorch/extension/threadpool/threadpool.h +4 -1
  213. package/third-party/include/executorch/runtime/backend/backend_init_context.h +6 -0
  214. package/third-party/include/executorch/runtime/backend/interface.h +1 -1
  215. package/third-party/include/executorch/runtime/core/error.h +76 -49
  216. package/third-party/include/executorch/runtime/core/exec_aten/util/scalar_type_util.h +18 -4
  217. package/third-party/include/executorch/runtime/core/memory_allocator.h +12 -2
  218. package/third-party/include/executorch/runtime/core/named_data_map.h +1 -11
  219. package/third-party/include/executorch/runtime/core/portable_type/c10/c10/macros/Export.h +0 -78
  220. package/third-party/include/executorch/runtime/core/portable_type/c10/c10/macros/Macros.h +1 -520
  221. package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/BFloat16-inl.h +1 -339
  222. package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/BFloat16.h +1 -122
  223. package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/Half-inl.h +1 -347
  224. package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/Half.h +6 -419
  225. package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/TypeSafeSignMath.h +1 -133
  226. package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/bit_cast.h +1 -43
  227. package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/complex.h +1 -568
  228. package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/floating_point_utils.h +1 -33
  229. package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/irange.h +1 -1
  230. package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/llvmMathExtras.h +866 -0
  231. package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/safe_numerics.h +97 -0
  232. package/third-party/include/executorch/runtime/core/portable_type/c10/torch/headeronly/macros/Export.h +66 -0
  233. package/third-party/include/executorch/runtime/core/portable_type/c10/torch/headeronly/macros/Macros.h +553 -0
  234. package/third-party/include/executorch/runtime/core/portable_type/c10/torch/headeronly/util/BFloat16.h +477 -0
  235. package/third-party/include/executorch/runtime/core/portable_type/c10/torch/headeronly/util/Half.h +781 -0
  236. package/third-party/include/executorch/runtime/core/portable_type/c10/torch/headeronly/util/TypeSafeSignMath.h +141 -0
  237. package/third-party/include/executorch/runtime/core/portable_type/c10/torch/headeronly/util/bit_cast.h +49 -0
  238. package/third-party/include/executorch/runtime/core/portable_type/c10/torch/headeronly/util/complex.h +593 -0
  239. package/third-party/include/executorch/runtime/core/portable_type/c10/torch/headeronly/util/floating_point_utils.h +38 -0
  240. package/third-party/include/executorch/runtime/core/tensor_layout.h +1 -1
  241. package/third-party/include/executorch/runtime/executor/merged_data_map.h +142 -0
  242. package/third-party/include/executorch/runtime/executor/method.h +21 -8
  243. package/third-party/include/executorch/runtime/executor/method_meta.h +20 -2
  244. package/third-party/include/executorch/runtime/executor/program.h +0 -10
  245. package/third-party/include/executorch/runtime/kernel/operator_registry.h +1 -1
  246. package/third-party/include/executorch/runtime/platform/compiler.h +2 -0
  247. package/third-party/include/executorch/schema/extended_header.h +10 -1
  248. package/third-party/include/torch/headeronly/macros/Export.h +66 -0
  249. package/third-party/include/torch/headeronly/macros/Macros.h +553 -0
  250. package/third-party/include/torch/headeronly/util/BFloat16.h +477 -0
  251. package/third-party/include/torch/headeronly/util/Half.h +781 -0
  252. package/third-party/include/torch/headeronly/util/TypeSafeSignMath.h +141 -0
  253. package/third-party/include/torch/headeronly/util/bit_cast.h +49 -0
  254. package/third-party/include/torch/headeronly/util/complex.h +593 -0
  255. package/third-party/include/torch/headeronly/util/floating_point_utils.h +38 -0
  256. package/third-party/ios/ExecutorchLib.xcframework/ios-arm64/ExecutorchLib.framework/ExecutorchLib +0 -0
  257. package/third-party/ios/ExecutorchLib.xcframework/ios-arm64/ExecutorchLib.framework/Info.plist +0 -0
  258. package/third-party/ios/ExecutorchLib.xcframework/ios-arm64-simulator/ExecutorchLib.framework/ExecutorchLib +0 -0
  259. package/third-party/ios/ExecutorchLib.xcframework/ios-arm64-simulator/ExecutorchLib.framework/Info.plist +0 -0
  260. package/common/rnexecutorch/tests/run_all_tests.sh +0 -14
  261. package/common/rnexecutorch/tests/run_test.sh +0 -18
  262. package/ios/RnExecutorch/utils/Conversions.h +0 -14
  263. package/ios/RnExecutorch/utils/ETError.h +0 -26
  264. package/ios/RnExecutorch/utils/ImageProcessor.h +0 -15
  265. package/ios/RnExecutorch/utils/ImageProcessor.mm +0 -147
  266. package/ios/RnExecutorch/utils/Numerical.h +0 -3
  267. package/ios/RnExecutorch/utils/Numerical.mm +0 -18
  268. package/ios/RnExecutorch/utils/ScalarType.h +0 -14
  269. package/ios/RnExecutorch/utils/ScalarType.mm +0 -21
  270. package/lib/module/hooks/useNonStaticModule.js.map +0 -1
  271. package/lib/typescript/hooks/useNonStaticModule.d.ts +0 -21
  272. package/lib/typescript/hooks/useNonStaticModule.d.ts.map +0 -1
  273. package/src/hooks/useNonStaticModule.ts +0 -74
  274. package/third-party/include/executorch/extension/kernel_util/make_boxed_from_unboxed_functor.h +0 -181
  275. package/third-party/include/executorch/extension/kernel_util/meta_programming.h +0 -108
  276. package/third-party/include/executorch/extension/kernel_util/type_list.h +0 -137
  277. package/third-party/include/executorch/extension/threadpool/threadpool_guard.h +0 -35
@@ -8,13 +8,14 @@ namespace rnexecutorch::models {
8
8
 
9
9
  using namespace facebook;
10
10
  using namespace executorch::extension;
11
+ using ::executorch::extension::module::Module;
11
12
  using ::executorch::runtime::Error;
12
13
 
13
14
  BaseModel::BaseModel(const std::string &modelSource,
14
- std::shared_ptr<react::CallInvoker> callInvoker)
15
+ std::shared_ptr<react::CallInvoker> callInvoker,
16
+ Module::LoadMode loadMode)
15
17
  : callInvoker(callInvoker),
16
- module_(std::make_unique<Module>(
17
- modelSource, Module::LoadMode::MmapUseMlockIgnoreErrors)) {
18
+ module_(std::make_unique<Module>(modelSource, loadMode)) {
18
19
  Error loadError = module_->load();
19
20
  if (loadError != Error::Ok) {
20
21
  throw std::runtime_error("Failed to load model: Error " +
@@ -29,7 +30,7 @@ BaseModel::BaseModel(const std::string &modelSource,
29
30
  }
30
31
 
31
32
  std::vector<int32_t> BaseModel::getInputShape(std::string method_name,
32
- int32_t index) {
33
+ int32_t index) const {
33
34
  if (!module_) {
34
35
  throw std::runtime_error("Model not loaded: Cannot get input shape");
35
36
  }
@@ -55,7 +56,7 @@ std::vector<int32_t> BaseModel::getInputShape(std::string method_name,
55
56
  }
56
57
 
57
58
  std::vector<std::vector<int32_t>>
58
- BaseModel::getAllInputShapes(std::string methodName) {
59
+ BaseModel::getAllInputShapes(std::string methodName) const {
59
60
  if (!module_) {
60
61
  throw std::runtime_error("Model not loaded: Cannot get all input shapes");
61
62
  }
@@ -87,7 +88,7 @@ BaseModel::getAllInputShapes(std::string methodName) {
87
88
  /// to JS. It is not meant to be used within C++. If you want to call forward
88
89
  /// from C++ on a BaseModel, please use BaseModel::forward.
89
90
  std::vector<JSTensorViewOut>
90
- BaseModel::forwardJS(std::vector<JSTensorViewIn> tensorViewVec) {
91
+ BaseModel::forwardJS(std::vector<JSTensorViewIn> tensorViewVec) const {
91
92
  if (!module_) {
92
93
  throw std::runtime_error("Model not loaded: Cannot perform forward pass");
93
94
  }
@@ -126,8 +127,8 @@ BaseModel::forwardJS(std::vector<JSTensorViewIn> tensorViewVec) {
126
127
  auto &outputTensor = outputs[i].toTensor();
127
128
  std::vector<int32_t> sizes = getTensorShape(outputTensor);
128
129
  size_t bufferSize = outputTensor.numel() * outputTensor.element_size();
129
- auto buffer = std::make_shared<OwningArrayBuffer>(bufferSize);
130
- std::memcpy(buffer->data(), outputTensor.const_data_ptr(), bufferSize);
130
+ auto buffer = std::make_shared<OwningArrayBuffer>(
131
+ outputTensor.const_data_ptr(), bufferSize);
131
132
  auto jsTensor = JSTensorViewOut(sizes, outputTensor.scalar_type(), buffer);
132
133
  output.emplace_back(jsTensor);
133
134
  }
@@ -135,7 +136,7 @@ BaseModel::forwardJS(std::vector<JSTensorViewIn> tensorViewVec) {
135
136
  }
136
137
 
137
138
  Result<executorch::runtime::MethodMeta>
138
- BaseModel::getMethodMeta(const std::string &methodName) {
139
+ BaseModel::getMethodMeta(const std::string &methodName) const {
139
140
  if (!module_) {
140
141
  throw std::runtime_error("Model not loaded: Cannot get method meta!");
141
142
  }
@@ -160,7 +161,7 @@ BaseModel::forward(const std::vector<EValue> &input_evalues) const {
160
161
 
161
162
  Result<std::vector<EValue>>
162
163
  BaseModel::execute(const std::string &methodName,
163
- const std::vector<EValue> &input_value) {
164
+ const std::vector<EValue> &input_value) const {
164
165
  if (!module_) {
165
166
  throw std::runtime_error("Model not loaded, cannot run execute.");
166
167
  }
@@ -174,7 +175,7 @@ std::size_t BaseModel::getMemoryLowerBound() const noexcept {
174
175
  void BaseModel::unload() noexcept { module_.reset(nullptr); }
175
176
 
176
177
  std::vector<int32_t>
177
- BaseModel::getTensorShape(const executorch::aten::Tensor &tensor) {
178
+ BaseModel::getTensorShape(const executorch::aten::Tensor &tensor) const {
178
179
  auto sizes = tensor.sizes();
179
180
  return std::vector<int32_t>(sizes.begin(), sizes.end());
180
181
  }
@@ -13,26 +13,32 @@
13
13
  namespace rnexecutorch {
14
14
  namespace models {
15
15
  using namespace facebook;
16
+ using executorch::extension::module::Module;
16
17
  using executorch::runtime::EValue;
17
18
  using executorch::runtime::Result;
19
+
18
20
  class BaseModel {
19
21
  public:
20
- BaseModel(const std::string &modelSource,
21
- std::shared_ptr<react::CallInvoker> callInvoker);
22
+ BaseModel(
23
+ const std::string &modelSource,
24
+ std::shared_ptr<react::CallInvoker> callInvoker,
25
+ Module::LoadMode loadMode = Module::LoadMode::MmapUseMlockIgnoreErrors);
22
26
  std::size_t getMemoryLowerBound() const noexcept;
23
27
  void unload() noexcept;
24
- std::vector<int32_t> getInputShape(std::string method_name, int32_t index);
28
+ std::vector<int32_t> getInputShape(std::string method_name,
29
+ int32_t index) const;
25
30
  std::vector<std::vector<int32_t>>
26
- getAllInputShapes(std::string methodName = "forward");
31
+ getAllInputShapes(std::string methodName = "forward") const;
27
32
  std::vector<JSTensorViewOut>
28
- forwardJS(std::vector<JSTensorViewIn> tensorViewVec);
33
+ forwardJS(std::vector<JSTensorViewIn> tensorViewVec) const;
29
34
  Result<std::vector<EValue>> forward(const EValue &input_value) const;
30
35
  Result<std::vector<EValue>>
31
36
  forward(const std::vector<EValue> &input_value) const;
32
- Result<std::vector<EValue>> execute(const std::string &methodName,
33
- const std::vector<EValue> &input_value);
37
+ Result<std::vector<EValue>>
38
+ execute(const std::string &methodName,
39
+ const std::vector<EValue> &input_value) const;
34
40
  Result<executorch::runtime::MethodMeta>
35
- getMethodMeta(const std::string &methodName);
41
+ getMethodMeta(const std::string &methodName) const;
36
42
 
37
43
  protected:
38
44
  // If possible, models should not use the JS runtime to keep JSI internals
@@ -42,9 +48,11 @@ protected:
42
48
  std::shared_ptr<react::CallInvoker> callInvoker;
43
49
  std::unique_ptr<executorch::extension::Module> module_;
44
50
 
45
- private:
46
51
  std::size_t memorySizeLowerBound{0};
47
- std::vector<int32_t> getTensorShape(const executorch::aten::Tensor &tensor);
52
+
53
+ private:
54
+ std::vector<int32_t>
55
+ getTensorShape(const executorch::aten::Tensor &tensor) const;
48
56
  };
49
57
  } // namespace models
50
58
 
@@ -11,17 +11,9 @@ BaseEmbeddings::BaseEmbeddings(const std::string &modelSource,
11
11
  std::shared_ptr<OwningArrayBuffer>
12
12
  BaseEmbeddings::postprocess(const Result<std::vector<EValue>> &forwardResult) {
13
13
  auto forwardResultTensor = forwardResult->at(0).toTensor();
14
- auto dataPtr = forwardResultTensor.mutable_data_ptr();
15
- auto outputNumel = forwardResultTensor.numel();
16
-
17
- std::span<float> modelOutput(static_cast<float *>(dataPtr), outputNumel);
18
-
19
- auto createBuffer = [](const auto &data, size_t size) {
20
- auto buffer = std::make_shared<OwningArrayBuffer>(size);
21
- std::memcpy(buffer->data(), data, size);
22
- return buffer;
23
- };
24
- return createBuffer(modelOutput.data(), modelOutput.size_bytes());
14
+ auto buffer = std::make_shared<OwningArrayBuffer>(
15
+ forwardResultTensor.const_data_ptr(), forwardResultTensor.nbytes());
16
+ return buffer;
25
17
  }
26
18
 
27
19
  } // namespace rnexecutorch::models::embeddings
@@ -48,7 +48,6 @@ TextEmbeddings::generate(const std::string input) {
48
48
  attnMaskShape, preprocessed.attentionMask.data(), ScalarType::Long);
49
49
 
50
50
  auto forwardResult = BaseModel::forward({tokenIds, attnMask});
51
-
52
51
  if (!forwardResult.ok()) {
53
52
  throw std::runtime_error(
54
53
  "Function forward in TextEmbeddings failed with error code: " +
@@ -62,11 +62,9 @@ std::shared_ptr<jsi::Object> ImageSegmentation::postprocess(
62
62
  std::vector<std::shared_ptr<OwningArrayBuffer>> resultClasses;
63
63
  resultClasses.reserve(numClasses);
64
64
  for (std::size_t cl = 0; cl < numClasses; ++cl) {
65
- auto classBuffer =
66
- std::make_shared<OwningArrayBuffer>(numModelPixels * sizeof(float));
65
+ auto classBuffer = std::make_shared<OwningArrayBuffer>(
66
+ &resultData[cl * numModelPixels], numModelPixels * sizeof(float));
67
67
  resultClasses.push_back(classBuffer);
68
- std::memcpy(classBuffer->data(), &resultData[cl * numModelPixels],
69
- numModelPixels * sizeof(float));
70
68
  }
71
69
 
72
70
  // Apply softmax per each pixel across all classes
@@ -112,18 +110,14 @@ std::shared_ptr<jsi::Object> ImageSegmentation::postprocess(
112
110
  cv::Mat argmaxMat(modelImageSize, CV_32SC1, argmax->data());
113
111
  cv::resize(argmaxMat, argmaxMat, originalSize, 0, 0,
114
112
  cv::InterpolationFlags::INTER_NEAREST);
115
- argmax = std::make_shared<OwningArrayBuffer>(originalSize.area() *
116
- sizeof(int32_t));
117
- std::memcpy(argmax->data(), argmaxMat.data,
118
- originalSize.area() * sizeof(int32_t));
113
+ argmax = std::make_shared<OwningArrayBuffer>(
114
+ argmaxMat.data, originalSize.area() * sizeof(int32_t));
119
115
 
120
116
  for (auto &[label, arrayBuffer] : *buffersToReturn) {
121
117
  cv::Mat classMat(modelImageSize, CV_32FC1, arrayBuffer->data());
122
118
  cv::resize(classMat, classMat, originalSize);
123
- arrayBuffer = std::make_shared<OwningArrayBuffer>(originalSize.area() *
124
- sizeof(float));
125
- std::memcpy(arrayBuffer->data(), classMat.data,
126
- originalSize.area() * sizeof(float));
119
+ arrayBuffer = std::make_shared<OwningArrayBuffer>(
120
+ classMat.data, originalSize.area() * sizeof(float));
127
121
  }
128
122
  }
129
123
  return populateDictionary(argmax, buffersToReturn);
@@ -1,30 +1,33 @@
1
1
  #include "LLM.h"
2
2
 
3
- #include <atomic>
4
3
  #include <executorch/extension/tensor/tensor.h>
5
4
  #include <filesystem>
6
5
  #include <rnexecutorch/threads/GlobalThreadPool.h>
7
6
 
8
7
  namespace rnexecutorch::models::llm {
8
+ namespace llm = ::executorch::extension::llm;
9
+ namespace fs = std::filesystem;
9
10
  using namespace facebook;
10
11
  using executorch::extension::TensorPtr;
12
+ using executorch::extension::module::Module;
11
13
  using executorch::runtime::Error;
12
14
 
13
15
  LLM::LLM(const std::string &modelSource, const std::string &tokenizerSource,
14
16
  std::shared_ptr<react::CallInvoker> callInvoker)
15
- : runner(std::make_unique<example::Runner>(modelSource, tokenizerSource)),
16
- callInvoker(callInvoker) {
17
-
17
+ : BaseModel(modelSource, callInvoker, Module::LoadMode::File),
18
+ runner(
19
+ std::make_unique<example::Runner>(module_.get(), tokenizerSource)) {
18
20
  auto loadResult = runner->load();
19
21
  if (loadResult != Error::Ok) {
20
22
  throw std::runtime_error("Failed to load LLM runner, error code: " +
21
23
  std::to_string(static_cast<int>(loadResult)));
22
24
  }
23
- memorySizeLowerBound =
24
- std::filesystem::file_size(std::filesystem::path(modelSource)) +
25
- std::filesystem::file_size(std::filesystem::path(tokenizerSource));
25
+
26
+ memorySizeLowerBound = fs::file_size(fs::path(modelSource)) +
27
+ fs::file_size(fs::path(tokenizerSource));
26
28
  }
27
29
 
30
+ // TODO: add a way to manipulate the generation config with params
28
31
  void LLM::generate(std::string input, std::shared_ptr<jsi::Function> callback) {
29
32
  if (!runner || !runner->is_loaded()) {
30
33
  throw std::runtime_error("Runner is not loaded");
@@ -37,7 +40,8 @@ void LLM::generate(std::string input, std::shared_ptr<jsi::Function> callback) {
37
40
  });
38
41
  };
39
42
 
40
- auto error = runner->generate(input, nativeCallback, {}, false);
43
+ auto config = llm::GenerationConfig{.echo = false, .warming = false};
44
+ auto error = runner->generate(input, config, nativeCallback, {});
41
45
  if (error != executorch::runtime::Error::Ok) {
42
46
  throw std::runtime_error("Failed to generate text, error code: " +
43
47
  std::to_string(static_cast<int>(error)));
@@ -76,6 +80,19 @@ void LLM::setTimeInterval(size_t timeInterval) {
76
80
  runner->set_time_interval(timeInterval);
77
81
  }
78
82
 
83
+ void LLM::setTemperature(float temperature) {
84
+ if (!runner || !runner->is_loaded()) {
85
+ throw std::runtime_error("Can't configure a model that's not loaded!");
86
+ }
87
+ runner->set_temperature(temperature);
88
+ };
89
+
90
+ void LLM::setTopp(float topp) {
91
+ if (!runner || !runner->is_loaded()) {
92
+ throw std::runtime_error("Can't configure a model that's not loaded!");
93
+ }
94
+ runner->set_topp(topp);
95
+ }
79
96
  void LLM::unload() noexcept { runner.reset(nullptr); }
80
97
 
81
98
  } // namespace rnexecutorch::models::llm
@@ -3,16 +3,16 @@
3
3
  #include <memory>
4
4
  #include <string>
5
5
 
6
- #include "rnexecutorch/metaprogramming/ConstructorHelpers.h"
7
6
  #include <ReactCommon/CallInvoker.h>
8
7
  #include <jsi/jsi.h>
8
+ #include <rnexecutorch/models/BaseModel.h>
9
9
  #include <runner/runner.h>
10
10
 
11
11
  namespace rnexecutorch {
12
12
  namespace models::llm {
13
13
  using namespace facebook;
14
14
 
15
- class LLM {
15
+ class LLM : public BaseModel {
16
16
  public:
17
17
  explicit LLM(const std::string &modelSource,
18
18
  const std::string &tokenizerSource,
@@ -24,12 +24,12 @@ public:
24
24
  size_t getGeneratedTokenCount() const noexcept;
25
25
  size_t getMemoryLowerBound() const noexcept;
26
26
  void setCountInterval(size_t countInterval);
27
+ void setTemperature(float temperature);
28
+ void setTopp(float topp);
27
29
  void setTimeInterval(size_t timeInterval);
28
30
 
29
31
  private:
30
- size_t memorySizeLowerBound;
31
32
  std::unique_ptr<example::Runner> runner;
32
- std::shared_ptr<react::CallInvoker> callInvoker;
33
33
  };
34
34
  } // namespace models::llm
35
35
 
@@ -23,7 +23,7 @@ public:
23
23
  size_t length);
24
24
 
25
25
  private:
26
- std::vector<std::string> character;
27
26
  int32_t ignoreIdx;
27
+ std::vector<std::string> character;
28
28
  };
29
29
  } // namespace rnexecutorch::models::ocr
@@ -60,16 +60,19 @@ cv::Mat cropImage(types::DetectorBBox box, cv::Mat &image,
60
60
  cv::warpAffine(image, rotatedImage, rotationMatrix, image.size(),
61
61
  cv::INTER_LINEAR);
62
62
 
63
- cv::Mat rectMat(4, 2, CV_32FC2);
63
+ constexpr int32_t rows = 4;
64
+ constexpr int32_t cols = 2;
65
+ cv::Mat rectMat(rows, cols, CV_32FC2);
64
66
  #pragma unroll
65
- for (int32_t i = 0; i < rectMat.rows; ++i) {
67
+ for (int32_t i = 0; i < rows; ++i) {
66
68
  rectMat.at<cv::Vec2f>(i, 0) = cv::Vec2f(rectPoints[i].x, rectPoints[i].y);
67
69
  }
68
70
  cv::transform(rectMat, rectMat, rotationMatrix);
69
71
 
70
- std::vector<cv::Point2f> transformedPoints(4);
72
+ constexpr size_t transformedPointsSize = 4;
73
+ std::vector<cv::Point2f> transformedPoints(transformedPointsSize);
71
74
  #pragma unroll
72
- for (std::size_t i = 0; i < transformedPoints.size(); ++i) {
75
+ for (std::size_t i = 0; i < transformedPointsSize; ++i) {
73
76
  cv::Vec2f point = rectMat.at<cv::Vec2f>(i, 0);
74
77
  transformedPoints[i] = cv::Point2f(point[0], point[1]);
75
78
  }
@@ -23,17 +23,22 @@ SpeechToText::SpeechToText(const std::string &encoderSource,
23
23
  processor(std::make_unique<OnlineASRProcessor>(this->asr.get())),
24
24
  isStreaming(false), readyToProcess(false) {}
25
25
 
26
+ void SpeechToText::unload() noexcept {
27
+ this->encoder->unload();
28
+ this->decoder->unload();
29
+ }
30
+
26
31
  std::shared_ptr<OwningArrayBuffer>
27
32
  SpeechToText::encode(std::span<float> waveform) const {
28
33
  std::vector<float> encoderOutput = this->asr->encode(waveform);
29
- return this->makeOwningBuffer(encoderOutput);
34
+ return std::make_shared<OwningArrayBuffer>(encoderOutput);
30
35
  }
31
36
 
32
37
  std::shared_ptr<OwningArrayBuffer>
33
38
  SpeechToText::decode(std::span<int32_t> tokens,
34
39
  std::span<float> encoderOutput) const {
35
40
  std::vector<float> decoderOutput = this->asr->decode(tokens, encoderOutput);
36
- return this->makeOwningBuffer(decoderOutput);
41
+ return std::make_shared<OwningArrayBuffer>(decoderOutput);
37
42
  }
38
43
 
39
44
  std::vector<char> SpeechToText::transcribe(std::span<float> waveform,
@@ -61,17 +66,7 @@ std::vector<char> SpeechToText::transcribe(std::span<float> waveform,
61
66
 
62
67
  size_t SpeechToText::getMemoryLowerBound() const noexcept {
63
68
  return this->encoder->getMemoryLowerBound() +
64
- this->decoder->getMemoryLowerBound() +
65
- this->tokenizer->getMemoryLowerBound();
66
- }
67
-
68
- std::shared_ptr<OwningArrayBuffer>
69
- SpeechToText::makeOwningBuffer(std::span<const float> vectorView) const {
70
- auto owningArrayBuffer =
71
- std::make_shared<OwningArrayBuffer>(vectorView.size_bytes());
72
- std::memcpy(owningArrayBuffer->data(), vectorView.data(),
73
- vectorView.size_bytes());
74
- return owningArrayBuffer;
69
+ this->decoder->getMemoryLowerBound();
75
70
  }
76
71
 
77
72
  void SpeechToText::stream(std::shared_ptr<jsi::Function> callback,
@@ -16,6 +16,7 @@ public:
16
16
  const std::string &tokenizerSource,
17
17
  std::shared_ptr<react::CallInvoker> callInvoker);
18
18
 
19
+ void unload() noexcept;
19
20
  std::shared_ptr<OwningArrayBuffer> encode(std::span<float> waveform) const;
20
21
  std::shared_ptr<OwningArrayBuffer>
21
22
  decode(std::span<int32_t> tokens, std::span<float> encoderOutput) const;
@@ -37,9 +38,6 @@ private:
37
38
  std::unique_ptr<TokenizerModule> tokenizer;
38
39
  std::unique_ptr<asr::ASR> asr;
39
40
 
40
- std::shared_ptr<OwningArrayBuffer>
41
- makeOwningBuffer(std::span<const float> vectorView) const;
42
-
43
41
  // Stream
44
42
  std::unique_ptr<stream::OnlineASRProcessor> processor;
45
43
  bool isStreaming;
@@ -4,7 +4,6 @@
4
4
  #include "ASR.h"
5
5
  #include "executorch/extension/tensor/tensor_ptr.h"
6
6
  #include "rnexecutorch/data_processing/Numerical.h"
7
- #include "rnexecutorch/data_processing/dsp.h"
8
7
  #include "rnexecutorch/data_processing/gzip.h"
9
8
 
10
9
  namespace rnexecutorch::models::speech_to_text::asr {
@@ -37,8 +36,7 @@ ASR::getInitialSequence(const DecodingOptions &options) const {
37
36
  return seq;
38
37
  }
39
38
 
40
- GenerationResult ASR::generate(std::span<const float> waveform,
41
- float temperature,
39
+ GenerationResult ASR::generate(std::span<float> waveform, float temperature,
42
40
  const DecodingOptions &options) const {
43
41
  std::vector<float> encoderOutput = this->encode(waveform);
44
42
 
@@ -94,7 +92,7 @@ float ASR::getCompressionRatio(const std::string &text) const {
94
92
  }
95
93
 
96
94
  std::vector<Segment>
97
- ASR::generateWithFallback(std::span<const float> waveform,
95
+ ASR::generateWithFallback(std::span<float> waveform,
98
96
  const DecodingOptions &options) const {
99
97
  std::vector<float> temperatures = {0.0f, 0.2f, 0.4f, 0.6f, 0.8f, 1.0f};
100
98
  std::vector<int32_t> bestTokens;
@@ -209,7 +207,7 @@ ASR::estimateWordLevelTimestampsLinear(std::span<const int32_t> tokens,
209
207
  return wordObjs;
210
208
  }
211
209
 
212
- std::vector<Segment> ASR::transcribe(std::span<const float> waveform,
210
+ std::vector<Segment> ASR::transcribe(std::span<float> waveform,
213
211
  const DecodingOptions &options) const {
214
212
  int32_t seek = 0;
215
213
  std::vector<Segment> results;
@@ -218,7 +216,7 @@ std::vector<Segment> ASR::transcribe(std::span<const float> waveform,
218
216
  int32_t start = seek * ASR::kSamplingRate;
219
217
  const auto end = std::min<int32_t>(
220
218
  (seek + ASR::kChunkSize) * ASR::kSamplingRate, waveform.size());
221
- std::span<const float> chunk = waveform.subspan(start, end - start);
219
+ auto chunk = waveform.subspan(start, end - start);
222
220
 
223
221
  if (std::cmp_less(chunk.size(), ASR::kMinChunkSamples)) {
224
222
  break;
@@ -246,19 +244,12 @@ std::vector<Segment> ASR::transcribe(std::span<const float> waveform,
246
244
  return results;
247
245
  }
248
246
 
249
- std::vector<float> ASR::encode(std::span<const float> waveform) const {
250
- constexpr int32_t fftWindowSize = 512;
251
- constexpr int32_t stftHopLength = 160;
252
- constexpr int32_t innerDim = 256;
253
-
254
- std::vector<float> preprocessedData =
255
- dsp::stftFromWaveform(waveform, fftWindowSize, stftHopLength);
256
- const auto numFrames =
257
- static_cast<int32_t>(preprocessedData.size()) / innerDim;
258
- std::vector<int32_t> inputShape = {numFrames, innerDim};
247
+ std::vector<float> ASR::encode(std::span<float> waveform) const {
248
+ auto inputShape = {static_cast<int32_t>(waveform.size())};
259
249
 
260
250
  const auto modelInputTensor = executorch::extension::make_tensor_ptr(
261
- std::move(inputShape), std::move(preprocessedData));
251
+ std::move(inputShape), waveform.data(),
252
+ executorch::runtime::etensor::ScalarType::Float);
262
253
  const auto encoderResult = this->encoder->forward(modelInputTensor);
263
254
 
264
255
  if (!encoderResult.ok()) {
@@ -268,7 +259,7 @@ std::vector<float> ASR::encode(std::span<const float> waveform) const {
268
259
  }
269
260
 
270
261
  const auto decoderOutputTensor = encoderResult.get().at(0).toTensor();
271
- const int32_t outputNumel = decoderOutputTensor.numel();
262
+ const auto outputNumel = decoderOutputTensor.numel();
272
263
 
273
264
  const float *const dataPtr = decoderOutputTensor.const_data_ptr<float>();
274
265
  return {dataPtr, dataPtr + outputNumel};
@@ -277,8 +268,10 @@ std::vector<float> ASR::encode(std::span<const float> waveform) const {
277
268
  std::vector<float> ASR::decode(std::span<int32_t> tokens,
278
269
  std::span<float> encoderOutput) const {
279
270
  std::vector<int32_t> tokenShape = {1, static_cast<int32_t>(tokens.size())};
271
+ auto tokensLong = std::vector<int64_t>(tokens.begin(), tokens.end());
272
+
280
273
  auto tokenTensor = executorch::extension::make_tensor_ptr(
281
- std::move(tokenShape), tokens.data(), ScalarType::Int);
274
+ tokenShape, tokensLong.data(), ScalarType::Long);
282
275
 
283
276
  const auto encoderOutputSize = static_cast<int32_t>(encoderOutput.size());
284
277
  std::vector<int32_t> encShape = {1, ASR::kNumFrames,
@@ -14,9 +14,9 @@ public:
14
14
  const models::BaseModel *decoder,
15
15
  const TokenizerModule *tokenizer);
16
16
  std::vector<types::Segment>
17
- transcribe(std::span<const float> waveform,
17
+ transcribe(std::span<float> waveform,
18
18
  const types::DecodingOptions &options) const;
19
- std::vector<float> encode(std::span<const float> waveform) const;
19
+ std::vector<float> encode(std::span<float> waveform) const;
20
20
  std::vector<float> decode(std::span<int32_t> tokens,
21
21
  std::span<float> encoderOutput) const;
22
22
 
@@ -44,11 +44,10 @@ private:
44
44
 
45
45
  std::vector<int32_t>
46
46
  getInitialSequence(const types::DecodingOptions &options) const;
47
- types::GenerationResult generate(std::span<const float> waveform,
48
- float temperature,
47
+ types::GenerationResult generate(std::span<float> waveform, float temperature,
49
48
  const types::DecodingOptions &options) const;
50
49
  std::vector<types::Segment>
51
- generateWithFallback(std::span<const float> waveform,
50
+ generateWithFallback(std::span<float> waveform,
52
51
  const types::DecodingOptions &options) const;
53
52
  std::vector<types::Segment>
54
53
  calculateWordLevelTimestamps(std::span<const int32_t> tokens,
@@ -0,0 +1,9 @@
1
+ #pragma once
2
+
3
+ #include <string_view>
4
+
5
+ namespace rnexecutorch::models::text_to_image::constants {
6
+
7
+ inline constexpr std::string_view kBosToken = "<|startoftext|>";
8
+
9
+ } // namespace rnexecutorch::models::text_to_image::constants
@@ -0,0 +1,32 @@
1
+ #include "Decoder.h"
2
+
3
+ #include <cmath>
4
+
5
+ #include <executorch/extension/tensor/tensor_ptr_maker.h>
6
+
7
+ namespace rnexecutorch::models::text_to_image {
8
+
9
+ using namespace executorch::extension;
10
+
11
+ Decoder::Decoder(const std::string &modelSource,
12
+ std::shared_ptr<react::CallInvoker> callInvoker)
13
+ : BaseModel(modelSource, callInvoker) {}
14
+
15
+ std::vector<float> Decoder::generate(std::vector<float> &input) const {
16
+ std::vector<int32_t> inputShape = {1, numChannels, latentImageSize,
17
+ latentImageSize};
18
+ auto inputTensor =
19
+ make_tensor_ptr(inputShape, input.data(), ScalarType::Float);
20
+
21
+ auto forwardResult = BaseModel::forward(inputTensor);
22
+ if (!forwardResult.ok()) {
23
+ throw std::runtime_error(
24
+ "Function forward in decoder failed with error code: " +
25
+ std::to_string(static_cast<uint32_t>(forwardResult.error())));
26
+ }
27
+
28
+ auto forwardResultTensor = forwardResult->at(0).toTensor();
29
+ const auto *dataPtr = forwardResultTensor.const_data_ptr<float>();
30
+ return {dataPtr, dataPtr + forwardResultTensor.numel()};
31
+ }
32
+ } // namespace rnexecutorch::models::text_to_image
@@ -0,0 +1,24 @@
1
+ #pragma once
2
+
3
+ #include <memory>
4
+ #include <string>
5
+ #include <vector>
6
+
7
+ #include <ReactCommon/CallInvoker.h>
8
+
9
+ #include <rnexecutorch/models/BaseModel.h>
10
+
11
+ namespace rnexecutorch::models::text_to_image {
12
+
13
+ class Decoder final : public BaseModel {
14
+ public:
15
+ explicit Decoder(const std::string &modelSource,
16
+ std::shared_ptr<react::CallInvoker> callInvoker);
17
+ std::vector<float> generate(std::vector<float> &input) const;
18
+
19
+ int32_t latentImageSize;
20
+
21
+ private:
22
+ static constexpr int32_t numChannels = 4;
23
+ };
24
+ } // namespace rnexecutorch::models::text_to_image
@@ -0,0 +1,44 @@
1
+ #include "Encoder.h"
2
+
3
+ #include <cmath>
4
+ #include <random>
5
+ #include <span>
6
+
7
+ #include <rnexecutorch/models/text_to_image/Constants.h>
8
+
9
+ namespace rnexecutorch::models::text_to_image {
10
+
11
+ Encoder::Encoder(const std::string &tokenizerSource,
12
+ const std::string &encoderSource,
13
+ std::shared_ptr<react::CallInvoker> callInvoker)
14
+ : callInvoker(callInvoker),
15
+ encoder(std::make_unique<embeddings::TextEmbeddings>(
16
+ encoderSource, tokenizerSource, callInvoker)) {}
17
+
18
+ std::vector<float> Encoder::generate(std::string input) {
19
+ std::shared_ptr<OwningArrayBuffer> embeddingsText = encoder->generate(input);
20
+ std::shared_ptr<OwningArrayBuffer> embeddingsUncond =
21
+ encoder->generate(std::string(constants::kBosToken));
22
+
23
+ assert(embeddingsText->size() == embeddingsUncond->size());
24
+ size_t embeddingsSize = embeddingsText->size() / sizeof(float);
25
+ auto *embeddingsTextPtr = reinterpret_cast<float *>(embeddingsText->data());
26
+ auto *embeddingsUncondPtr =
27
+ reinterpret_cast<float *>(embeddingsUncond->data());
28
+
29
+ std::vector<float> embeddingsConcat;
30
+ embeddingsConcat.reserve(embeddingsSize * 2);
31
+ embeddingsConcat.insert(embeddingsConcat.end(), embeddingsUncondPtr,
32
+ embeddingsUncondPtr + embeddingsSize);
33
+ embeddingsConcat.insert(embeddingsConcat.end(), embeddingsTextPtr,
34
+ embeddingsTextPtr + embeddingsSize);
35
+ return embeddingsConcat;
36
+ }
37
+
38
+ size_t Encoder::getMemoryLowerBound() const noexcept {
39
+ return encoder->getMemoryLowerBound();
40
+ }
41
+
42
+ void Encoder::unload() noexcept { encoder->unload(); }
43
+
44
+ } // namespace rnexecutorch::models::text_to_image