react-native-executorch 0.8.0-nightly-48610bf-20260324 → 0.8.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. package/android/src/main/cpp/ETInstallerModule.h +1 -2
  2. package/common/rnexecutorch/models/object_detection/ObjectDetection.cpp +115 -43
  3. package/common/rnexecutorch/models/object_detection/ObjectDetection.h +57 -8
  4. package/lib/module/constants/modelUrls.js +227 -55
  5. package/lib/module/constants/modelUrls.js.map +1 -1
  6. package/lib/module/constants/resourceFetcher.js +4 -0
  7. package/lib/module/constants/resourceFetcher.js.map +1 -0
  8. package/lib/module/constants/versions.js +2 -2
  9. package/lib/module/hooks/computer_vision/useInstanceSegmentation.js +2 -2
  10. package/lib/module/hooks/computer_vision/useObjectDetection.js +6 -3
  11. package/lib/module/hooks/computer_vision/useObjectDetection.js.map +1 -1
  12. package/lib/module/modules/computer_vision/InstanceSegmentationModule.js +7 -4
  13. package/lib/module/modules/computer_vision/InstanceSegmentationModule.js.map +1 -1
  14. package/lib/module/modules/computer_vision/ObjectDetectionModule.js +127 -11
  15. package/lib/module/modules/computer_vision/ObjectDetectionModule.js.map +1 -1
  16. package/lib/module/modules/computer_vision/VisionModule.js +3 -2
  17. package/lib/module/modules/computer_vision/VisionModule.js.map +1 -1
  18. package/lib/module/types/objectDetection.js +21 -4
  19. package/lib/module/types/objectDetection.js.map +1 -1
  20. package/lib/module/utils/ResourceFetcher.js +14 -8
  21. package/lib/module/utils/ResourceFetcher.js.map +1 -1
  22. package/lib/module/utils/ResourceFetcherUtils.js +42 -0
  23. package/lib/module/utils/ResourceFetcherUtils.js.map +1 -1
  24. package/lib/typescript/constants/modelUrls.d.ts +626 -122
  25. package/lib/typescript/constants/modelUrls.d.ts.map +1 -1
  26. package/lib/typescript/constants/resourceFetcher.d.ts +2 -0
  27. package/lib/typescript/constants/resourceFetcher.d.ts.map +1 -0
  28. package/lib/typescript/constants/versions.d.ts +2 -2
  29. package/lib/typescript/hooks/computer_vision/useInstanceSegmentation.d.ts +2 -2
  30. package/lib/typescript/hooks/computer_vision/useObjectDetection.d.ts.map +1 -1
  31. package/lib/typescript/modules/computer_vision/InstanceSegmentationModule.d.ts +5 -4
  32. package/lib/typescript/modules/computer_vision/InstanceSegmentationModule.d.ts.map +1 -1
  33. package/lib/typescript/modules/computer_vision/ObjectDetectionModule.d.ts +82 -5
  34. package/lib/typescript/modules/computer_vision/ObjectDetectionModule.d.ts.map +1 -1
  35. package/lib/typescript/modules/computer_vision/VisionModule.d.ts +4 -3
  36. package/lib/typescript/modules/computer_vision/VisionModule.d.ts.map +1 -1
  37. package/lib/typescript/types/instanceSegmentation.d.ts +3 -1
  38. package/lib/typescript/types/instanceSegmentation.d.ts.map +1 -1
  39. package/lib/typescript/types/objectDetection.d.ts +71 -12
  40. package/lib/typescript/types/objectDetection.d.ts.map +1 -1
  41. package/lib/typescript/utils/ResourceFetcher.d.ts +1 -0
  42. package/lib/typescript/utils/ResourceFetcher.d.ts.map +1 -1
  43. package/lib/typescript/utils/ResourceFetcherUtils.d.ts +5 -0
  44. package/lib/typescript/utils/ResourceFetcherUtils.d.ts.map +1 -1
  45. package/package.json +1 -1
  46. package/src/constants/modelUrls.ts +239 -66
  47. package/src/constants/resourceFetcher.ts +2 -0
  48. package/src/constants/versions.ts +2 -2
  49. package/src/hooks/computer_vision/useInstanceSegmentation.ts +2 -2
  50. package/src/hooks/computer_vision/useObjectDetection.ts +10 -2
  51. package/src/modules/computer_vision/InstanceSegmentationModule.ts +15 -11
  52. package/src/modules/computer_vision/ObjectDetectionModule.ts +208 -6
  53. package/src/modules/computer_vision/VisionModule.ts +4 -3
  54. package/src/types/instanceSegmentation.ts +3 -1
  55. package/src/types/objectDetection.ts +67 -13
  56. package/src/utils/ResourceFetcher.ts +14 -8
  57. package/src/utils/ResourceFetcherUtils.ts +40 -0
  58. package/third-party/android/libs/executorch/arm64-v8a/libexecutorch.so +0 -0
  59. package/third-party/android/libs/executorch/x86_64/libexecutorch.so +0 -0
@@ -2,7 +2,6 @@
2
2
 
3
3
  #include <ReactCommon/CallInvokerHolder.h>
4
4
  #include <fbjni/fbjni.h>
5
- #include <react/jni/CxxModuleWrapper.h>
6
5
  #include <react/jni/JMessageQueueThread.h>
7
6
 
8
7
  #include <memory>
@@ -40,4 +39,4 @@ private:
40
39
  const std::shared_ptr<facebook::react::CallInvoker> &jsCallInvoker);
41
40
  };
42
41
 
43
- } // namespace rnexecutorch
42
+ } // namespace rnexecutorch
@@ -1,6 +1,8 @@
1
1
  #include "ObjectDetection.h"
2
2
  #include "Constants.h"
3
3
 
4
+ #include <set>
5
+
4
6
  #include <rnexecutorch/Error.h>
5
7
  #include <rnexecutorch/ErrorCodes.h>
6
8
  #include <rnexecutorch/Log.h>
@@ -18,21 +20,6 @@ ObjectDetection::ObjectDetection(
18
20
  std::shared_ptr<react::CallInvoker> callInvoker)
19
21
  : VisionModel(modelSource, callInvoker),
20
22
  labelNames_(std::move(labelNames)) {
21
- auto inputTensors = getAllInputShapes();
22
- if (inputTensors.empty()) {
23
- throw RnExecutorchError(RnExecutorchErrorCode::UnexpectedNumInputs,
24
- "Model seems to not take any input tensors.");
25
- }
26
- modelInputShape_ = inputTensors[0];
27
- if (modelInputShape_.size() < 2) {
28
- char errorMessage[100];
29
- std::snprintf(errorMessage, sizeof(errorMessage),
30
- "Unexpected model input size, expected at least 2 dimensions "
31
- "but got: %zu.",
32
- modelInputShape_.size());
33
- throw RnExecutorchError(RnExecutorchErrorCode::UnexpectedNumInputs,
34
- errorMessage);
35
- }
36
23
  if (normMean.size() == 3) {
37
24
  normMean_ = cv::Scalar(normMean[0], normMean[1], normMean[2]);
38
25
  } else if (!normMean.empty()) {
@@ -47,14 +34,67 @@ ObjectDetection::ObjectDetection(
47
34
  }
48
35
  }
49
36
 
37
+ cv::Size ObjectDetection::modelInputSize() const {
38
+ if (currentlyLoadedMethod_.empty()) {
39
+ return VisionModel::modelInputSize();
40
+ }
41
+ auto inputShapes = getAllInputShapes(currentlyLoadedMethod_);
42
+ if (inputShapes.empty() || inputShapes[0].size() < 2) {
43
+ throw RnExecutorchError(RnExecutorchErrorCode::UnexpectedNumInputs,
44
+ "Could not determine input shape for method: " +
45
+ currentlyLoadedMethod_);
46
+ }
47
+ const auto &shape = inputShapes[0];
48
+ return {static_cast<int32_t>(shape[shape.size() - 2]),
49
+ static_cast<int32_t>(shape[shape.size() - 1])};
50
+ }
51
+
52
+ void ObjectDetection::ensureMethodLoaded(const std::string &methodName) {
53
+ if (methodName.empty()) {
54
+ throw RnExecutorchError(RnExecutorchErrorCode::InvalidUserInput,
55
+ "methodName cannot be empty");
56
+ }
57
+ if (currentlyLoadedMethod_ == methodName) {
58
+ return;
59
+ }
60
+ if (!module_) {
61
+ throw RnExecutorchError(RnExecutorchErrorCode::ModuleNotLoaded,
62
+ "Model module is not loaded");
63
+ }
64
+ if (!currentlyLoadedMethod_.empty()) {
65
+ module_->unload_method(currentlyLoadedMethod_);
66
+ }
67
+ auto loadResult = module_->load_method(methodName);
68
+ if (loadResult != executorch::runtime::Error::Ok) {
69
+ throw RnExecutorchError(
70
+ loadResult, "Failed to load method '" + methodName +
71
+ "'. Ensure the method exists in the exported model.");
72
+ }
73
+ currentlyLoadedMethod_ = methodName;
74
+ }
75
+
76
+ std::set<int32_t> ObjectDetection::prepareAllowedClasses(
77
+ const std::vector<int32_t> &classIndices) const {
78
+ std::set<int32_t> allowedClasses;
79
+ if (!classIndices.empty()) {
80
+ allowedClasses.insert(classIndices.begin(), classIndices.end());
81
+ }
82
+ return allowedClasses;
83
+ }
84
+
50
85
  std::vector<types::Detection>
51
86
  ObjectDetection::postprocess(const std::vector<EValue> &tensors,
52
- cv::Size originalSize, double detectionThreshold) {
87
+ cv::Size originalSize, double detectionThreshold,
88
+ double iouThreshold,
89
+ const std::vector<int32_t> &classIndices) {
53
90
  const cv::Size inputSize = modelInputSize();
54
91
  float widthRatio = static_cast<float>(originalSize.width) / inputSize.width;
55
92
  float heightRatio =
56
93
  static_cast<float>(originalSize.height) / inputSize.height;
57
94
 
95
+ // Prepare allowed classes set for filtering
96
+ auto allowedClasses = prepareAllowedClasses(classIndices);
97
+
58
98
  std::vector<types::Detection> detections;
59
99
  auto bboxTensor = tensors.at(0).toTensor();
60
100
  std::span<const float> bboxes(
@@ -75,12 +115,21 @@ ObjectDetection::postprocess(const std::vector<EValue> &tensors,
75
115
  if (scores[i] < detectionThreshold) {
76
116
  continue;
77
117
  }
118
+
119
+ auto labelIdx = static_cast<int32_t>(labels[i]);
120
+
121
+ // Filter by class if classesOfInterest is specified
122
+ if (!allowedClasses.empty() &&
123
+ allowedClasses.find(labelIdx) == allowedClasses.end()) {
124
+ continue;
125
+ }
126
+
78
127
  float x1 = bboxes[i * 4] * widthRatio;
79
128
  float y1 = bboxes[i * 4 + 1] * heightRatio;
80
129
  float x2 = bboxes[i * 4 + 2] * widthRatio;
81
130
  float y2 = bboxes[i * 4 + 3] * heightRatio;
82
- auto labelIdx = static_cast<std::size_t>(labels[i]);
83
- if (labelIdx >= labelNames_.size()) {
131
+
132
+ if (std::cmp_greater_equal(labelIdx, labelNames_.size())) {
84
133
  throw RnExecutorchError(
85
134
  RnExecutorchErrorCode::InvalidConfig,
86
135
  "Model output class index " + std::to_string(labelIdx) +
@@ -88,23 +137,40 @@ ObjectDetection::postprocess(const std::vector<EValue> &tensors,
88
137
  ". Ensure the labelMap covers all model output classes.");
89
138
  }
90
139
  detections.emplace_back(utils::computer_vision::BBox{x1, y1, x2, y2},
91
- labelNames_[labelIdx],
92
- static_cast<int32_t>(labelIdx), scores[i]);
140
+ labelNames_[labelIdx], labelIdx, scores[i]);
93
141
  }
94
142
 
95
- return utils::computer_vision::nonMaxSuppression(detections,
96
- constants::IOU_THRESHOLD);
143
+ return utils::computer_vision::nonMaxSuppression(detections, iouThreshold);
97
144
  }
98
145
 
99
- std::vector<types::Detection>
100
- ObjectDetection::runInference(cv::Mat image, double detectionThreshold) {
146
+ std::vector<types::Detection> ObjectDetection::runInference(
147
+ cv::Mat image, double detectionThreshold, double iouThreshold,
148
+ const std::vector<int32_t> &classIndices, const std::string &methodName) {
101
149
  if (detectionThreshold < 0.0 || detectionThreshold > 1.0) {
102
150
  throw RnExecutorchError(RnExecutorchErrorCode::InvalidUserInput,
103
151
  "detectionThreshold must be in range [0, 1]");
104
152
  }
153
+ if (iouThreshold < 0.0 || iouThreshold > 1.0) {
154
+ throw RnExecutorchError(RnExecutorchErrorCode::InvalidUserInput,
155
+ "iouThreshold must be in range [0, 1]");
156
+ }
157
+
105
158
  std::scoped_lock lock(inference_mutex_);
106
159
 
160
+ // Ensure the correct method is loaded
161
+ ensureMethodLoaded(methodName);
162
+
107
163
  cv::Size originalSize = image.size();
164
+
165
+ // Query input shapes for the currently loaded method
166
+ auto inputShapes = getAllInputShapes(methodName);
167
+ if (inputShapes.empty() || inputShapes[0].size() < 2) {
168
+ throw RnExecutorchError(RnExecutorchErrorCode::UnexpectedNumInputs,
169
+ "Could not determine input shape for method: " +
170
+ methodName);
171
+ }
172
+ modelInputShape_ = inputShapes[0];
173
+
108
174
  cv::Mat preprocessed = preprocess(image);
109
175
 
110
176
  auto inputTensor =
@@ -114,46 +180,52 @@ ObjectDetection::runInference(cv::Mat image, double detectionThreshold) {
114
180
  : image_processing::getTensorFromMatrix(modelInputShape_,
115
181
  preprocessed);
116
182
 
117
- auto forwardResult = BaseModel::forward(inputTensor);
118
- if (!forwardResult.ok()) {
119
- throw RnExecutorchError(forwardResult.error(),
120
- "The model's forward function did not succeed. "
121
- "Ensure the model input is correct.");
183
+ auto executeResult = execute(methodName, {inputTensor});
184
+ if (!executeResult.ok()) {
185
+ throw RnExecutorchError(executeResult.error(),
186
+ "The model's " + methodName +
187
+ " method did not succeed. "
188
+ "Ensure the model input is correct.");
122
189
  }
123
190
 
124
- return postprocess(forwardResult.get(), originalSize, detectionThreshold);
191
+ return postprocess(executeResult.get(), originalSize, detectionThreshold,
192
+ iouThreshold, classIndices);
125
193
  }
126
194
 
127
- std::vector<types::Detection>
128
- ObjectDetection::generateFromString(std::string imageSource,
129
- double detectionThreshold) {
195
+ std::vector<types::Detection> ObjectDetection::generateFromString(
196
+ std::string imageSource, double detectionThreshold, double iouThreshold,
197
+ std::vector<int32_t> classIndices, std::string methodName) {
130
198
  cv::Mat imageBGR = image_processing::readImage(imageSource);
131
199
 
132
200
  cv::Mat imageRGB;
133
201
  cv::cvtColor(imageBGR, imageRGB, cv::COLOR_BGR2RGB);
134
202
 
135
- return runInference(imageRGB, detectionThreshold);
203
+ return runInference(imageRGB, detectionThreshold, iouThreshold, classIndices,
204
+ methodName);
136
205
  }
137
206
 
138
- std::vector<types::Detection>
139
- ObjectDetection::generateFromFrame(jsi::Runtime &runtime,
140
- const jsi::Value &frameData,
141
- double detectionThreshold) {
207
+ std::vector<types::Detection> ObjectDetection::generateFromFrame(
208
+ jsi::Runtime &runtime, const jsi::Value &frameData,
209
+ double detectionThreshold, double iouThreshold,
210
+ std::vector<int32_t> classIndices, std::string methodName) {
142
211
  auto orient = ::rnexecutorch::utils::readFrameOrientation(runtime, frameData);
143
212
  cv::Mat frame = extractFromFrame(runtime, frameData);
144
213
  cv::Mat rotated = ::rnexecutorch::utils::rotateFrameForModel(frame, orient);
145
- auto detections = runInference(rotated, detectionThreshold);
214
+ auto detections = runInference(rotated, detectionThreshold, iouThreshold,
215
+ classIndices, methodName);
216
+
146
217
  for (auto &det : detections) {
147
218
  ::rnexecutorch::utils::inverseRotateBbox(det.bbox, orient, rotated.size());
148
219
  }
149
220
  return detections;
150
221
  }
151
222
 
152
- std::vector<types::Detection>
153
- ObjectDetection::generateFromPixels(JSTensorViewIn pixelData,
154
- double detectionThreshold) {
223
+ std::vector<types::Detection> ObjectDetection::generateFromPixels(
224
+ JSTensorViewIn pixelData, double detectionThreshold, double iouThreshold,
225
+ std::vector<int32_t> classIndices, std::string methodName) {
155
226
  cv::Mat image = extractFromPixels(pixelData);
156
227
 
157
- return runInference(image, detectionThreshold);
228
+ return runInference(image, detectionThreshold, iouThreshold, classIndices,
229
+ methodName);
158
230
  }
159
231
  } // namespace rnexecutorch::models::object_detection
@@ -57,6 +57,13 @@ public:
57
57
  * @param imageSource URI or file path of the input image.
58
58
  * @param detectionThreshold Minimum confidence score in (0, 1] for a
59
59
  * detection to be included in the output.
60
+ * @param iouThreshold IoU threshold for non-maximum suppression.
61
+ * @param classIndices Optional list of class indices to filter results.
62
+ * Only detections matching these classes will be
63
+ * returned. Pass empty vector to include all
64
+ * classes.
65
+ * @param methodName Name of the method to execute (e.g., "forward",
66
+ * "forward_384", "forward_512", "forward_640").
60
67
  *
61
68
  * @return A vector of @ref types::Detection objects with bounding boxes,
62
69
  * label strings (resolved via the label names passed to the
@@ -66,16 +73,33 @@ public:
66
73
  * fails.
67
74
  */
68
75
  [[nodiscard("Registered non-void function")]] std::vector<types::Detection>
69
- generateFromString(std::string imageSource, double detectionThreshold);
76
+ generateFromString(std::string imageSource, double detectionThreshold,
77
+ double iouThreshold, std::vector<int32_t> classIndices,
78
+ std::string methodName);
70
79
  [[nodiscard("Registered non-void function")]] std::vector<types::Detection>
71
80
  generateFromFrame(jsi::Runtime &runtime, const jsi::Value &frameData,
72
- double detectionThreshold);
81
+ double detectionThreshold, double iouThreshold,
82
+ std::vector<int32_t> classIndices, std::string methodName);
73
83
  [[nodiscard("Registered non-void function")]] std::vector<types::Detection>
74
- generateFromPixels(JSTensorViewIn pixelData, double detectionThreshold);
84
+ generateFromPixels(JSTensorViewIn pixelData, double detectionThreshold,
85
+ double iouThreshold, std::vector<int32_t> classIndices,
86
+ std::string methodName);
75
87
 
76
88
  protected:
77
- std::vector<types::Detection> runInference(cv::Mat image,
78
- double detectionThreshold);
89
+ /**
90
+ * @brief Returns the model input size based on the currently loaded method.
91
+ *
92
+ * Overrides VisionModel::modelInputSize() to support multi-method models
93
+ * where each method may have different input dimensions.
94
+ *
95
+ * @return The expected input size for the currently loaded method.
96
+ */
97
+ cv::Size modelInputSize() const override;
98
+
99
+ std::vector<types::Detection>
100
+ runInference(cv::Mat image, double detectionThreshold, double iouThreshold,
101
+ const std::vector<int32_t> &classIndices,
102
+ const std::string &methodName);
79
103
 
80
104
  private:
81
105
  /**
@@ -88,15 +112,37 @@ private:
88
112
  * bounding boxes back to input coordinates.
89
113
  * @param detectionThreshold Confidence threshold below which detections
90
114
  * are discarded.
115
+ * @param iouThreshold IoU threshold for non-maximum suppression.
116
+ * @param classIndices Optional list of class indices to filter results.
91
117
  *
92
118
  * @return Non-max-suppressed detections above the threshold.
93
119
  *
94
120
  * @throws RnExecutorchError if the model outputs a class index that exceeds
95
121
  * the size of @ref labelNames_.
96
122
  */
97
- std::vector<types::Detection> postprocess(const std::vector<EValue> &tensors,
98
- cv::Size originalSize,
99
- double detectionThreshold);
123
+ std::vector<types::Detection>
124
+ postprocess(const std::vector<EValue> &tensors, cv::Size originalSize,
125
+ double detectionThreshold, double iouThreshold,
126
+ const std::vector<int32_t> &classIndices);
127
+
128
+ /**
129
+ * @brief Ensures the specified method is loaded, unloading any previous
130
+ * method if necessary.
131
+ *
132
+ * @param methodName Name of the method to load (e.g., "forward",
133
+ * "forward_384").
134
+ * @throws RnExecutorchError if the method cannot be loaded.
135
+ */
136
+ void ensureMethodLoaded(const std::string &methodName);
137
+
138
+ /**
139
+ * @brief Prepares a set of allowed class indices for filtering detections.
140
+ *
141
+ * @param classIndices Vector of class indices to allow.
142
+ * @return A set containing the allowed class indices.
143
+ */
144
+ std::set<int32_t>
145
+ prepareAllowedClasses(const std::vector<int32_t> &classIndices) const;
100
146
 
101
147
  /// Optional per-channel mean for input normalisation (set in constructor).
102
148
  std::optional<cv::Scalar> normMean_;
@@ -106,6 +152,9 @@ private:
106
152
 
107
153
  /// Ordered label strings mapping class indices to human-readable names.
108
154
  std::vector<std::string> labelNames_;
155
+
156
+ /// Name of the currently loaded method (for multi-method models).
157
+ std::string currentlyLoadedMethod_;
109
158
  };
110
159
  } // namespace models::object_detection
111
160