react-native-executorch 0.7.0 → 0.7.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. package/common/rnexecutorch/TokenizerModule.cpp +3 -2
  2. package/common/rnexecutorch/TokenizerModule.h +1 -1
  3. package/lib/module/modules/computer_vision/TextToImageModule.js +8 -4
  4. package/lib/module/modules/computer_vision/TextToImageModule.js.map +1 -1
  5. package/lib/typescript/modules/computer_vision/TextToImageModule.d.ts.map +1 -1
  6. package/package.json +4 -3
  7. package/src/modules/computer_vision/TextToImageModule.ts +9 -4
  8. package/third-party/android/libs/executorch/arm64-v8a/libexecutorch.so +0 -0
  9. package/third-party/android/libs/executorch/x86_64/libexecutorch.so +0 -0
  10. package/third-party/include/executorch/extension/llm/tokenizers/include/pytorch/tokenizers/bpe_model.h +84 -0
  11. package/third-party/include/executorch/extension/llm/tokenizers/include/pytorch/tokenizers/bpe_tokenizer_base.h +6 -87
  12. package/third-party/include/executorch/extension/llm/tokenizers/include/pytorch/tokenizers/hf_tokenizer.h +28 -176
  13. package/third-party/include/executorch/extension/llm/tokenizers/include/pytorch/tokenizers/map_utils.h +174 -0
  14. package/third-party/include/executorch/extension/llm/tokenizers/include/pytorch/tokenizers/model.h +151 -0
  15. package/third-party/include/executorch/extension/llm/tokenizers/include/pytorch/tokenizers/normalizer.h +55 -1
  16. package/third-party/include/executorch/extension/llm/tokenizers/include/pytorch/tokenizers/padding.h +112 -0
  17. package/third-party/include/executorch/extension/llm/tokenizers/include/pytorch/tokenizers/post_processor.h +101 -42
  18. package/third-party/include/executorch/extension/llm/tokenizers/include/pytorch/tokenizers/pre_tokenizer.h +25 -9
  19. package/third-party/include/executorch/extension/llm/tokenizers/include/pytorch/tokenizers/token_decoder.h +33 -6
  20. package/third-party/include/executorch/extension/llm/tokenizers/include/pytorch/tokenizers/tokenizer.h +2 -2
  21. package/third-party/include/executorch/extension/llm/tokenizers/include/pytorch/tokenizers/truncation.h +92 -0
  22. package/third-party/include/executorch/extension/llm/tokenizers/include/pytorch/tokenizers/wordpiece_model.h +74 -0
  23. package/third-party/ios/ExecutorchLib.xcframework/ios-arm64/ExecutorchLib.framework/ExecutorchLib +0 -0
  24. package/third-party/ios/ExecutorchLib.xcframework/ios-arm64-simulator/ExecutorchLib.framework/ExecutorchLib +0 -0
  25. package/common/rnexecutorch/tests/CMakeLists.txt +0 -253
  26. package/common/rnexecutorch/tests/README.md +0 -73
  27. package/common/rnexecutorch/tests/integration/BaseModelTest.cpp +0 -207
  28. package/common/rnexecutorch/tests/integration/BaseModelTests.h +0 -120
  29. package/common/rnexecutorch/tests/integration/ClassificationTest.cpp +0 -117
  30. package/common/rnexecutorch/tests/integration/ImageEmbeddingsTest.cpp +0 -122
  31. package/common/rnexecutorch/tests/integration/ImageSegmentationTest.cpp +0 -152
  32. package/common/rnexecutorch/tests/integration/LLMTest.cpp +0 -155
  33. package/common/rnexecutorch/tests/integration/OCRTest.cpp +0 -128
  34. package/common/rnexecutorch/tests/integration/ObjectDetectionTest.cpp +0 -135
  35. package/common/rnexecutorch/tests/integration/SpeechToTextTest.cpp +0 -97
  36. package/common/rnexecutorch/tests/integration/StyleTransferTest.cpp +0 -112
  37. package/common/rnexecutorch/tests/integration/TextEmbeddingsTest.cpp +0 -164
  38. package/common/rnexecutorch/tests/integration/TextToImageTest.cpp +0 -149
  39. package/common/rnexecutorch/tests/integration/TokenizerModuleTest.cpp +0 -98
  40. package/common/rnexecutorch/tests/integration/VerticalOCRTest.cpp +0 -238
  41. package/common/rnexecutorch/tests/integration/VoiceActivityDetectionTest.cpp +0 -99
  42. package/common/rnexecutorch/tests/integration/assets/test_audio_float.raw +0 -0
  43. package/common/rnexecutorch/tests/integration/assets/we_are_software_mansion.jpg +0 -0
  44. package/common/rnexecutorch/tests/integration/libs/libfbjni.so +0 -0
  45. package/common/rnexecutorch/tests/integration/stubs/jsi_stubs.cpp +0 -45
  46. package/common/rnexecutorch/tests/integration/utils/TestUtils.h +0 -36
  47. package/common/rnexecutorch/tests/run_tests.sh +0 -333
  48. package/common/rnexecutorch/tests/unit/FileUtilsTest.cpp +0 -32
  49. package/common/rnexecutorch/tests/unit/LogTest.cpp +0 -529
  50. package/common/rnexecutorch/tests/unit/NumericalTest.cpp +0 -107
@@ -1,253 +0,0 @@
1
- if(NOT ANDROID_ABI)
2
- message(FATAL_ERROR "Tests can be only built for Android simulator")
3
- endif()
4
-
5
- cmake_minimum_required(VERSION 3.13)
6
- project(RNExecutorchTests)
7
-
8
- set(CMAKE_CXX_STANDARD 20)
9
- set(CMAKE_CXX_STANDARD_REQUIRED TRUE)
10
-
11
- # tests/ <- CMAKE_SOURCE_DIR (this file's location)
12
- # rnexecutorch/ <- RNEXECUTORCH_DIR (parent of tests)
13
- # common/ <- COMMON_DIR
14
- # react-native-executorch/ <- PACKAGE_ROOT
15
- # <monorepo-root>/ <- MONOREPO_ROOT
16
- # <monorepo-root>/third-party/ <- THIRD_PARTY_DIR
17
- set(RNEXECUTORCH_DIR "${CMAKE_SOURCE_DIR}/..")
18
- set(COMMON_DIR "${RNEXECUTORCH_DIR}/..")
19
- set(PACKAGE_ROOT "${COMMON_DIR}/..")
20
- set(MONOREPO_ROOT "${PACKAGE_ROOT}/../..")
21
- set(THIRD_PARTY_DIR "${MONOREPO_ROOT}/third-party")
22
- set(REACT_NATIVE_DIR "${MONOREPO_ROOT}/node_modules/react-native")
23
- set(ANDROID_THIRD_PARTY "${PACKAGE_ROOT}/third-party/android/libs/")
24
- set(TOKENIZERS_DIR "${PACKAGE_ROOT}/third-party/include/executorch/extension/llm/tokenizers/include")
25
-
26
- # Add Gtest as a subdirectory
27
- add_subdirectory(${THIRD_PARTY_DIR}/googletest ${PROJECT_BINARY_DIR}/googletest)
28
-
29
- # ExecuTorch Prebuilt binaries
30
- add_library(executorch_prebuilt SHARED IMPORTED)
31
- set_target_properties(executorch_prebuilt PROPERTIES
32
- IMPORTED_LOCATION "${ANDROID_THIRD_PARTY}/executorch/${ANDROID_ABI}/libexecutorch.so"
33
- )
34
-
35
- # pthreadpool and cpuinfo (needed for OpenMP/OpenCV)
36
- if(ANDROID_ABI STREQUAL "arm64-v8a")
37
- add_library(pthreadpool SHARED IMPORTED)
38
- set_target_properties(pthreadpool PROPERTIES
39
- IMPORTED_LOCATION "${ANDROID_THIRD_PARTY}/pthreadpool/${ANDROID_ABI}/libpthreadpool.so"
40
- )
41
-
42
- add_library(cpuinfo SHARED IMPORTED)
43
- set_target_properties(cpuinfo PROPERTIES
44
- IMPORTED_LOCATION "${ANDROID_THIRD_PARTY}/cpuinfo/${ANDROID_ABI}/libcpuinfo.so"
45
- )
46
-
47
- set(EXECUTORCH_LIBS pthreadpool cpuinfo)
48
- else()
49
- set(EXECUTORCH_LIBS "")
50
- endif()
51
-
52
- # OpenCV (Interface Library)
53
- set(OPENCV_LIBS_DIR "${ANDROID_THIRD_PARTY}/opencv/${ANDROID_ABI}")
54
- set(OPENCV_THIRD_PARTY_DIR "${ANDROID_THIRD_PARTY}/opencv-third-party/${ANDROID_ABI}")
55
-
56
- if(ANDROID_ABI STREQUAL "arm64-v8a")
57
- set(OPENCV_THIRD_PARTY_LIBS
58
- "${OPENCV_THIRD_PARTY_DIR}/libkleidicv_hal.a"
59
- "${OPENCV_THIRD_PARTY_DIR}/libkleidicv_thread.a"
60
- "${OPENCV_THIRD_PARTY_DIR}/libkleidicv.a"
61
- )
62
- elseif(ANDROID_ABI STREQUAL "x86_64")
63
- set(OPENCV_THIRD_PARTY_LIBS "")
64
- endif()
65
-
66
-
67
- add_library(opencv_deps INTERFACE)
68
- target_link_libraries(opencv_deps INTERFACE
69
- ${OPENCV_LIBS_DIR}/libopencv_core.a
70
- ${OPENCV_LIBS_DIR}/libopencv_features2d.a
71
- ${OPENCV_LIBS_DIR}/libopencv_highgui.a
72
- ${OPENCV_LIBS_DIR}/libopencv_imgproc.a
73
- ${OPENCV_LIBS_DIR}/libopencv_photo.a
74
- ${OPENCV_LIBS_DIR}/libopencv_video.a
75
- ${OPENCV_THIRD_PARTY_LIBS}
76
- ${EXECUTORCH_LIBS}
77
- z
78
- dl
79
- m
80
- log
81
- )
82
- target_link_options(opencv_deps INTERFACE -fopenmp -static-openmp)
83
-
84
- add_library(tokenizers_deps INTERFACE)
85
- target_include_directories(tokenizers_deps INTERFACE "${TOKENIZERS_DIR}")
86
-
87
- # Source Definitions
88
- set(CORE_SOURCES
89
- ${RNEXECUTORCH_DIR}/models/BaseModel.cpp
90
- ${RNEXECUTORCH_DIR}/data_processing/Numerical.cpp
91
- ${CMAKE_SOURCE_DIR}/integration/stubs/jsi_stubs.cpp
92
- )
93
-
94
- set(IMAGE_UTILS_SOURCES
95
- ${RNEXECUTORCH_DIR}/data_processing/ImageProcessing.cpp
96
- ${RNEXECUTORCH_DIR}/data_processing/base64.cpp
97
- ${COMMON_DIR}/ada/ada.cpp
98
- )
99
-
100
- set(TOKENIZER_SOURCES ${RNEXECUTORCH_DIR}/TokenizerModule.cpp)
101
- set(DSP_SOURCES ${RNEXECUTORCH_DIR}/data_processing/dsp.cpp)
102
-
103
- # Core Library
104
- add_library(rntests_core STATIC ${CORE_SOURCES})
105
-
106
- target_include_directories(rntests_core PUBLIC
107
- ${RNEXECUTORCH_DIR}/data_processing
108
- ${TOKENIZERS_DIR}
109
- ${RNEXECUTORCH_DIR}
110
- ${COMMON_DIR}
111
- ${PACKAGE_ROOT}/third-party/include
112
- ${REACT_NATIVE_DIR}/ReactCommon
113
- ${REACT_NATIVE_DIR}/ReactCommon/jsi
114
- ${REACT_NATIVE_DIR}/ReactCommon/callinvoker
115
- ${COMMON_DIR}/ada
116
- )
117
-
118
- target_link_libraries(rntests_core PUBLIC
119
- executorch_prebuilt
120
- gtest
121
- log
122
- )
123
-
124
- enable_testing()
125
- function(add_rn_test TEST_TARGET TEST_FILENAME)
126
- cmake_parse_arguments(ARG "" "" "SOURCES;LIBS" ${ARGN})
127
- # Create executable using the explicit filename provided
128
- add_executable(${TEST_TARGET} ${TEST_FILENAME} ${ARG_SOURCES})
129
-
130
- target_compile_definitions(${TEST_TARGET} PRIVATE TEST_BUILD)
131
- target_link_libraries(${TEST_TARGET} PRIVATE rntests_core gtest_main ${ARG_LIBS})
132
- target_link_options(${TEST_TARGET} PRIVATE "LINKER:-z,max-page-size=16384")
133
-
134
- add_test(NAME ${TEST_TARGET} COMMAND ${TEST_TARGET})
135
- endfunction()
136
-
137
- add_rn_test(NumericalTests unit/NumericalTest.cpp)
138
- add_rn_test(LogTests unit/LogTest.cpp)
139
- add_rn_test(BaseModelTests integration/BaseModelTest.cpp)
140
-
141
- add_rn_test(ClassificationTests integration/ClassificationTest.cpp
142
- SOURCES
143
- ${RNEXECUTORCH_DIR}/models/classification/Classification.cpp
144
- ${IMAGE_UTILS_SOURCES}
145
- LIBS opencv_deps
146
- )
147
-
148
- add_rn_test(ObjectDetectionTests integration/ObjectDetectionTest.cpp
149
- SOURCES
150
- ${RNEXECUTORCH_DIR}/models/object_detection/ObjectDetection.cpp
151
- ${RNEXECUTORCH_DIR}/models/object_detection/Utils.cpp
152
- ${IMAGE_UTILS_SOURCES}
153
- LIBS opencv_deps
154
- )
155
-
156
- add_rn_test(ImageEmbeddingsTests integration/ImageEmbeddingsTest.cpp
157
- SOURCES
158
- ${RNEXECUTORCH_DIR}/models/embeddings/image/ImageEmbeddings.cpp
159
- ${RNEXECUTORCH_DIR}/models/embeddings/BaseEmbeddings.cpp
160
- ${IMAGE_UTILS_SOURCES}
161
- LIBS opencv_deps
162
- )
163
-
164
- add_rn_test(TextEmbeddingsTests integration/TextEmbeddingsTest.cpp
165
- SOURCES
166
- ${RNEXECUTORCH_DIR}/models/embeddings/text/TextEmbeddings.cpp
167
- ${RNEXECUTORCH_DIR}/models/embeddings/BaseEmbeddings.cpp
168
- ${TOKENIZER_SOURCES}
169
- LIBS tokenizers_deps
170
- )
171
-
172
- add_rn_test(StyleTransferTests integration/StyleTransferTest.cpp
173
- SOURCES
174
- ${RNEXECUTORCH_DIR}/models/style_transfer/StyleTransfer.cpp
175
- ${IMAGE_UTILS_SOURCES}
176
- LIBS opencv_deps
177
- )
178
-
179
- add_rn_test(VADTests integration/VoiceActivityDetectionTest.cpp
180
- SOURCES
181
- ${RNEXECUTORCH_DIR}/models/voice_activity_detection/VoiceActivityDetection.cpp
182
- ${RNEXECUTORCH_DIR}/models/voice_activity_detection/Utils.cpp
183
- ${DSP_SOURCES}
184
- )
185
-
186
- add_rn_test(TokenizerModuleTests integration/TokenizerModuleTest.cpp
187
- SOURCES ${TOKENIZER_SOURCES}
188
- LIBS tokenizers_deps
189
- )
190
-
191
- add_rn_test(SpeechToTextTests integration/SpeechToTextTest.cpp
192
- SOURCES
193
- ${RNEXECUTORCH_DIR}/models/speech_to_text/SpeechToText.cpp
194
- ${RNEXECUTORCH_DIR}/models/speech_to_text/asr/ASR.cpp
195
- ${RNEXECUTORCH_DIR}/models/speech_to_text/stream/HypothesisBuffer.cpp
196
- ${RNEXECUTORCH_DIR}/models/speech_to_text/stream/OnlineASRProcessor.cpp
197
- ${RNEXECUTORCH_DIR}/data_processing/gzip.cpp
198
- ${TOKENIZER_SOURCES}
199
- ${DSP_SOURCES}
200
- LIBS tokenizers_deps z
201
- )
202
-
203
- add_rn_test(LLMTests integration/LLMTest.cpp
204
- SOURCES
205
- ${RNEXECUTORCH_DIR}/models/llm/LLM.cpp
206
- ${COMMON_DIR}/runner/runner.cpp
207
- ${COMMON_DIR}/runner/text_prefiller.cpp
208
- ${COMMON_DIR}/runner/text_decoder_runner.cpp
209
- ${COMMON_DIR}/runner/sampler.cpp
210
- ${COMMON_DIR}/runner/arange_util.cpp
211
- LIBS tokenizers_deps
212
- )
213
-
214
- add_rn_test(TextToImageTests integration/TextToImageTest.cpp
215
- SOURCES
216
- ${RNEXECUTORCH_DIR}/models/text_to_image/TextToImage.cpp
217
- ${RNEXECUTORCH_DIR}/models/text_to_image/Encoder.cpp
218
- ${RNEXECUTORCH_DIR}/models/text_to_image/UNet.cpp
219
- ${RNEXECUTORCH_DIR}/models/text_to_image/Decoder.cpp
220
- ${RNEXECUTORCH_DIR}/models/text_to_image/Scheduler.cpp
221
- ${RNEXECUTORCH_DIR}/models/embeddings/text/TextEmbeddings.cpp
222
- ${RNEXECUTORCH_DIR}/models/embeddings/BaseEmbeddings.cpp
223
- ${TOKENIZER_SOURCES}
224
- LIBS tokenizers_deps
225
- )
226
-
227
- add_rn_test(OCRTests integration/OCRTest.cpp
228
- SOURCES
229
- ${RNEXECUTORCH_DIR}/models/ocr/OCR.cpp
230
- ${RNEXECUTORCH_DIR}/models/ocr/CTCLabelConverter.cpp
231
- ${RNEXECUTORCH_DIR}/models/ocr/Detector.cpp
232
- ${RNEXECUTORCH_DIR}/models/ocr/RecognitionHandler.cpp
233
- ${RNEXECUTORCH_DIR}/models/ocr/Recognizer.cpp
234
- ${RNEXECUTORCH_DIR}/models/ocr/utils/DetectorUtils.cpp
235
- ${RNEXECUTORCH_DIR}/models/ocr/utils/RecognitionHandlerUtils.cpp
236
- ${RNEXECUTORCH_DIR}/models/ocr/utils/RecognizerUtils.cpp
237
- ${IMAGE_UTILS_SOURCES}
238
- LIBS opencv_deps
239
- )
240
-
241
- add_rn_test(VerticalOCRTests integration/VerticalOCRTest.cpp
242
- SOURCES
243
- ${RNEXECUTORCH_DIR}/models/vertical_ocr/VerticalOCR.cpp
244
- ${RNEXECUTORCH_DIR}/models/vertical_ocr/VerticalDetector.cpp
245
- ${RNEXECUTORCH_DIR}/models/ocr/Detector.cpp
246
- ${RNEXECUTORCH_DIR}/models/ocr/CTCLabelConverter.cpp
247
- ${RNEXECUTORCH_DIR}/models/ocr/Recognizer.cpp
248
- ${RNEXECUTORCH_DIR}/models/ocr/utils/DetectorUtils.cpp
249
- ${RNEXECUTORCH_DIR}/models/ocr/utils/RecognitionHandlerUtils.cpp
250
- ${RNEXECUTORCH_DIR}/models/ocr/utils/RecognizerUtils.cpp
251
- ${IMAGE_UTILS_SOURCES}
252
- LIBS opencv_deps
253
- )
@@ -1,73 +0,0 @@
1
- ## Native Test
2
- This guide provide information on how functions are tested, how to install all needed dependencies and how to run tests.
3
-
4
- ### Used Tools
5
- To test the native code we use [`googletest`](https://github.com/google/googletest). It's a flexible tool for creating unit tests.
6
-
7
- ### Installation
8
- The googletest is already in repo in `react-native-executorch/third-party/googletest`. Firstly, you need to fetch googletest locally, run from root directory of project:
9
- * `git submodule update --init --recursive third-party/googletest`
10
-
11
- ### Running tests
12
-
13
- #### Prerequisites
14
-
15
- - **Android NDK**: The `ANDROID_NDK` environment variable must be set
16
- - **wget**: Must be in your PATH
17
- - **Android emulator**: Must be running before executing tests
18
- - **Device requirements**:
19
- - 16GB disk storage (minimum)
20
- - 8GB RAM (minimum)
21
-
22
- #### First-time setup
23
-
24
- Before running tests, you need to build an app to generate required native libraries (`libfbjni.so` and `libc++_shared.so`). The test script automatically searches for these in the monorepo.
25
-
26
- If the script reports missing libraries, build any example app:
27
- ```bash
28
- cd apps/computer-vision/android
29
- ./gradlew assembleDebug
30
- # or
31
- ./gradlew assembleRelease
32
- ```
33
-
34
- #### Running the tests
35
-
36
- Navigate to the tests directory:
37
- ```bash
38
- cd packages/react-native-executorch/common/rnexecutorch/tests
39
- ```
40
-
41
- Run the test script:
42
- ```bash
43
- bash ./run_tests.sh
44
- ```
45
-
46
- This script:
47
- - Downloads all needed models
48
- - Pushes executables, models, assets, and shared libraries via ADB to the running emulator
49
- - Runs the pre-compiled test executables
50
-
51
- #### Available flags
52
-
53
- * `--refresh-models` - Forcefully downloads all the models. By default, models are not downloaded unless they are missing from the specified directory.
54
- * `--skip-build` - Skips the cmake build step.
55
-
56
- ### How to add a new test
57
- To add new test you need to:
58
- * Add a new .cpp file to either integration/ or unit/, depending on the type of the test.
59
- * In `CMakeLists.txt`, add all executables and link all the needed libraries against the executable, for example you can use the `add_rn_test`, which is a helper function that links core libs. Example:
60
- ```cmake
61
- # unit
62
- add_rn_test(BaseModelTests integration/BaseModelTest.cpp)
63
-
64
- # integration
65
- add_rn_test(ClassificationTests integration/ClassificationTest.cpp
66
- SOURCES
67
- ${RNEXECUTORCH_DIR}/models/classification/Classification.cpp
68
- ${IMAGE_UTILS_SOURCES}
69
- LIBS opencv_deps
70
- )
71
- ```
72
- * Lastly, add the test executable name to the run_tests script along with all the needed URL and assets.
73
-
@@ -1,207 +0,0 @@
1
- #include "BaseModelTests.h"
2
- #include <executorch/extension/tensor/tensor.h>
3
- #include <gtest/gtest.h>
4
- #include <limits>
5
- #include <rnexecutorch/Error.h>
6
- #include <rnexecutorch/models/BaseModel.h>
7
- #include <vector>
8
-
9
- using namespace rnexecutorch;
10
- using namespace rnexecutorch::models;
11
- using namespace executorch::extension;
12
- using namespace model_tests;
13
- using executorch::runtime::EValue;
14
-
15
- constexpr auto kValidStyleTransferModelPath =
16
- "style_transfer_candy_xnnpack.pte";
17
-
18
- // ============================================================================
19
- // Common tests via typed test suite
20
- // ============================================================================
21
- namespace model_tests {
22
- template <> struct ModelTraits<BaseModel> {
23
- using ModelType = BaseModel;
24
-
25
- static ModelType createValid() {
26
- return ModelType(kValidStyleTransferModelPath, nullptr);
27
- }
28
-
29
- static ModelType createInvalid() {
30
- return ModelType("nonexistent.pte", nullptr);
31
- }
32
-
33
- static void callGenerate(ModelType &model) {
34
- std::vector<int32_t> shape = {1, 3, 640, 640};
35
- size_t numElements = 1 * 3 * 640 * 640;
36
- std::vector<float> inputData(numElements, 0.5f);
37
- auto tensorPtr = make_tensor_ptr(shape, inputData.data());
38
- EValue input(*tensorPtr);
39
- (void)model.forward(input);
40
- }
41
- };
42
- } // namespace model_tests
43
-
44
- using BaseModelTypes = ::testing::Types<BaseModel>;
45
- INSTANTIATE_TYPED_TEST_SUITE_P(BaseModel, CommonModelTest, BaseModelTypes);
46
-
47
- // ============================================================================
48
- // BaseModel-specific tests (methods not in all models)
49
- // ============================================================================
50
-
51
- TEST(BaseModelGetInputShapeTests, ValidMethodCorrectShape) {
52
- BaseModel model(kValidStyleTransferModelPath, nullptr);
53
- auto forwardShape = model.getInputShape("forward", 0);
54
- std::vector<int32_t> expectedShape = {1, 3, 640, 640};
55
- EXPECT_EQ(forwardShape, expectedShape);
56
- }
57
-
58
- TEST(BaseModelGetInputShapeTests, InvalidMethodThrows) {
59
- BaseModel model(kValidStyleTransferModelPath, nullptr);
60
- EXPECT_THROW((void)model.getInputShape("this_method_does_not_exist", 0),
61
- RnExecutorchError);
62
- }
63
-
64
- TEST(BaseModelGetInputShapeTests, ValidMethodInvalidIndexThrows) {
65
- BaseModel model(kValidStyleTransferModelPath, nullptr);
66
- EXPECT_THROW(
67
- (void)model.getInputShape("forward", std::numeric_limits<int32_t>::min()),
68
- RnExecutorchError);
69
- }
70
-
71
- TEST(BaseModelGetAllInputShapesTests, ValidMethodReturnsShapes) {
72
- BaseModel model(kValidStyleTransferModelPath, nullptr);
73
- auto allShapes = model.getAllInputShapes("forward");
74
- EXPECT_FALSE(allShapes.empty());
75
- std::vector<int32_t> expectedFirstShape = {1, 3, 640, 640};
76
- EXPECT_EQ(allShapes[0], expectedFirstShape);
77
- }
78
-
79
- TEST(BaseModelGetAllInputShapesTests, InvalidMethodThrows) {
80
- BaseModel model(kValidStyleTransferModelPath, nullptr);
81
- EXPECT_THROW(model.getAllInputShapes("non_existent_method"),
82
- RnExecutorchError);
83
- }
84
-
85
- TEST(BaseModelGetMethodMetaTests, ValidMethodReturnsOk) {
86
- BaseModel model(kValidStyleTransferModelPath, nullptr);
87
- auto result = model.getMethodMeta("forward");
88
- EXPECT_TRUE(result.ok());
89
- }
90
-
91
- TEST(BaseModelGetMethodMetaTests, InvalidMethodReturnsError) {
92
- BaseModel model(kValidStyleTransferModelPath, nullptr);
93
- auto result = model.getMethodMeta("non_existent_method");
94
- EXPECT_FALSE(result.ok());
95
- }
96
-
97
- TEST(BaseModelForwardTests, ForwardWithValidInputReturnsOk) {
98
- BaseModel model(kValidStyleTransferModelPath, nullptr);
99
- std::vector<int32_t> shape = {1, 3, 640, 640};
100
- size_t numElements = 1 * 3 * 640 * 640;
101
- std::vector<float> inputData(numElements, 0.5f);
102
- auto tensorPtr = make_tensor_ptr(shape, inputData.data());
103
- EValue input(*tensorPtr);
104
-
105
- auto result = model.forward(input);
106
- EXPECT_TRUE(result.ok());
107
- }
108
-
109
- TEST(BaseModelForwardTests, ForwardWithVectorInputReturnsOk) {
110
- BaseModel model(kValidStyleTransferModelPath, nullptr);
111
- std::vector<int32_t> shape = {1, 3, 640, 640};
112
- size_t numElements = 1 * 3 * 640 * 640;
113
- std::vector<float> inputData(numElements, 0.5f);
114
- auto tensorPtr = make_tensor_ptr(shape, inputData.data());
115
- std::vector<EValue> inputs;
116
- inputs.emplace_back(*tensorPtr);
117
-
118
- auto result = model.forward(inputs);
119
- EXPECT_TRUE(result.ok());
120
- }
121
-
122
- TEST(BaseModelForwardTests, ForwardReturnsCorrectOutputShape) {
123
- BaseModel model(kValidStyleTransferModelPath, nullptr);
124
- std::vector<int32_t> shape = {1, 3, 640, 640};
125
- size_t numElements = 1 * 3 * 640 * 640;
126
- std::vector<float> inputData(numElements, 0.5f);
127
- auto tensorPtr = make_tensor_ptr(shape, inputData.data());
128
- EValue input(*tensorPtr);
129
-
130
- auto result = model.forward(input);
131
- ASSERT_TRUE(result.ok());
132
- ASSERT_FALSE(result->empty());
133
-
134
- auto &outputTensor = result->at(0).toTensor();
135
- auto outputSizes = outputTensor.sizes();
136
- EXPECT_EQ(outputSizes.size(), 4);
137
- EXPECT_EQ(outputSizes[0], 1);
138
- EXPECT_EQ(outputSizes[1], 3);
139
- EXPECT_EQ(outputSizes[2], 640);
140
- EXPECT_EQ(outputSizes[3], 640);
141
- }
142
-
143
- TEST(BaseModelForwardTests, ForwardAfterUnloadThrows) {
144
- BaseModel model(kValidStyleTransferModelPath, nullptr);
145
- model.unload();
146
-
147
- std::vector<int32_t> shape = {1, 3, 640, 640};
148
- size_t numElements = 1 * 3 * 640 * 640;
149
- std::vector<float> inputData(numElements, 0.5f);
150
- auto tensorPtr = make_tensor_ptr(shape, inputData.data());
151
- EValue input(*tensorPtr);
152
-
153
- EXPECT_THROW(model.forward(input), RnExecutorchError);
154
- }
155
-
156
- TEST(BaseModelForwardJSTests, ForwardJSWithValidInputReturnsOutput) {
157
- BaseModel model(kValidStyleTransferModelPath, nullptr);
158
- std::vector<int32_t> shape = {1, 3, 640, 640};
159
- size_t numElements = 1 * 3 * 640 * 640;
160
- std::vector<float> inputData(numElements, 0.5f);
161
-
162
- JSTensorViewIn tensorView;
163
- tensorView.dataPtr = inputData.data();
164
- tensorView.sizes = shape;
165
- tensorView.scalarType = executorch::aten::ScalarType::Float;
166
-
167
- std::vector<JSTensorViewIn> inputs = {tensorView};
168
- auto outputs = model.forwardJS(inputs);
169
-
170
- EXPECT_FALSE(outputs.empty());
171
- }
172
-
173
- TEST(BaseModelForwardJSTests, ForwardJSReturnsCorrectOutputShape) {
174
- BaseModel model(kValidStyleTransferModelPath, nullptr);
175
- std::vector<int32_t> shape = {1, 3, 640, 640};
176
- size_t numElements = 1 * 3 * 640 * 640;
177
- std::vector<float> inputData(numElements, 0.5f);
178
-
179
- JSTensorViewIn tensorView;
180
- tensorView.dataPtr = inputData.data();
181
- tensorView.sizes = shape;
182
- tensorView.scalarType = executorch::aten::ScalarType::Float;
183
-
184
- std::vector<JSTensorViewIn> inputs = {tensorView};
185
- auto outputs = model.forwardJS(inputs);
186
-
187
- ASSERT_EQ(outputs.size(), 1);
188
- std::vector<int32_t> expectedShape = {1, 3, 640, 640};
189
- EXPECT_EQ(outputs[0].sizes, expectedShape);
190
- }
191
-
192
- TEST(BaseModelForwardJSTests, ForwardJSAfterUnloadThrows) {
193
- BaseModel model(kValidStyleTransferModelPath, nullptr);
194
- model.unload();
195
-
196
- std::vector<int32_t> shape = {1, 3, 640, 640};
197
- size_t numElements = 1 * 3 * 640 * 640;
198
- std::vector<float> inputData(numElements, 0.5f);
199
-
200
- JSTensorViewIn tensorView;
201
- tensorView.dataPtr = inputData.data();
202
- tensorView.sizes = shape;
203
- tensorView.scalarType = executorch::aten::ScalarType::Float;
204
-
205
- std::vector<JSTensorViewIn> inputs = {tensorView};
206
- EXPECT_THROW((void)model.forwardJS(inputs), RnExecutorchError);
207
- }
@@ -1,120 +0,0 @@
1
- #pragma once
2
-
3
- #include "gtest/gtest.h"
4
- #include <rnexecutorch/Error.h>
5
-
6
- namespace facebook::react {
7
- class CallInvoker;
8
- }
9
-
10
- namespace rnexecutorch {
11
- std::shared_ptr<facebook::react::CallInvoker> createMockCallInvoker();
12
- }
13
-
14
- namespace model_tests {
15
-
16
- inline auto getMockInvoker() { return rnexecutorch::createMockCallInvoker(); }
17
-
18
- /// Helper macro to access Traits in typed tests
19
- #define SETUP_TRAITS() using Traits = typename TestFixture::Traits
20
-
21
- /// Trait struct that each model must specialize
22
- /// This defines how to construct and test each model type
23
- template <typename T> struct ModelTraits;
24
-
25
- /// Example of what a specialization looks like:
26
- ///
27
- /// template<>
28
- /// struct ModelTraits<Classification> {
29
- /// using ModelType = Classification;
30
- ///
31
- /// // Create valid model instance
32
- /// static ModelType createValid() {
33
- /// return ModelType("valid_model.pte", nullptr);
34
- /// }
35
- ///
36
- /// // Create invalid model instance (should throw in constructor)
37
- /// static ModelType createInvalid() {
38
- /// return ModelType("nonexistent.pte", nullptr);
39
- /// }
40
- ///
41
- /// // Call the model's generate/forward function with valid input
42
- /// // Used to test that generate throws after unload
43
- /// static void callGenerate(ModelType& model) {
44
- /// (void)model.generate("valid_input.jpg");
45
- /// }
46
- /// };
47
- // Typed test fixture for common model tests
48
- template <typename T> class CommonModelTest : public ::testing::Test {
49
- protected:
50
- using Traits = ModelTraits<T>;
51
- using ModelType = typename Traits::ModelType;
52
- };
53
-
54
- // Define the test suite
55
- TYPED_TEST_SUITE_P(CommonModelTest);
56
-
57
- // Constructor tests
58
- TYPED_TEST_P(CommonModelTest, InvalidPathThrows) {
59
- SETUP_TRAITS();
60
- EXPECT_THROW(Traits::createInvalid(), rnexecutorch::RnExecutorchError);
61
- }
62
-
63
- TYPED_TEST_P(CommonModelTest, ValidPathDoesntThrow) {
64
- SETUP_TRAITS();
65
- EXPECT_NO_THROW(Traits::createValid());
66
- }
67
-
68
- // Memory tests
69
- TYPED_TEST_P(CommonModelTest, GetMemoryLowerBoundValue) {
70
- SETUP_TRAITS();
71
- auto model = Traits::createValid();
72
- EXPECT_GT(model.getMemoryLowerBound(), 0u);
73
- }
74
-
75
- TYPED_TEST_P(CommonModelTest, GetMemoryLowerBoundConsistent) {
76
- SETUP_TRAITS();
77
- auto model = Traits::createValid();
78
- auto bound1 = model.getMemoryLowerBound();
79
- auto bound2 = model.getMemoryLowerBound();
80
- EXPECT_EQ(bound1, bound2);
81
- }
82
-
83
- // Unload tests
84
- TYPED_TEST_P(CommonModelTest, UnloadDoesntThrow) {
85
- SETUP_TRAITS();
86
- auto model = Traits::createValid();
87
- EXPECT_NO_THROW(model.unload());
88
- }
89
-
90
- TYPED_TEST_P(CommonModelTest, MultipleUnloadsSafe) {
91
- SETUP_TRAITS();
92
- auto model = Traits::createValid();
93
- EXPECT_NO_THROW(model.unload());
94
- EXPECT_NO_THROW(model.unload());
95
- EXPECT_NO_THROW(model.unload());
96
- }
97
-
98
- TYPED_TEST_P(CommonModelTest, GenerateAfterUnloadThrows) {
99
- SETUP_TRAITS();
100
- auto model = Traits::createValid();
101
- model.unload();
102
- EXPECT_THROW(Traits::callGenerate(model), rnexecutorch::RnExecutorchError);
103
- }
104
-
105
- TYPED_TEST_P(CommonModelTest, MultipleGeneratesWork) {
106
- SETUP_TRAITS();
107
- auto model = Traits::createValid();
108
- EXPECT_NO_THROW(Traits::callGenerate(model));
109
- EXPECT_NO_THROW(Traits::callGenerate(model));
110
- EXPECT_NO_THROW(Traits::callGenerate(model));
111
- }
112
-
113
- // Register all tests in the suite
114
- REGISTER_TYPED_TEST_SUITE_P(CommonModelTest, InvalidPathThrows,
115
- ValidPathDoesntThrow, GetMemoryLowerBoundValue,
116
- GetMemoryLowerBoundConsistent, UnloadDoesntThrow,
117
- MultipleUnloadsSafe, GenerateAfterUnloadThrows,
118
- MultipleGeneratesWork);
119
-
120
- } // namespace model_tests