react-native-executorch 0.7.0 → 0.7.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. package/common/rnexecutorch/TokenizerModule.cpp +3 -2
  2. package/common/rnexecutorch/TokenizerModule.h +1 -1
  3. package/package.json +2 -1
  4. package/third-party/android/libs/executorch/arm64-v8a/libexecutorch.so +0 -0
  5. package/third-party/android/libs/executorch/x86_64/libexecutorch.so +0 -0
  6. package/third-party/include/executorch/extension/llm/tokenizers/include/pytorch/tokenizers/bpe_model.h +84 -0
  7. package/third-party/include/executorch/extension/llm/tokenizers/include/pytorch/tokenizers/bpe_tokenizer_base.h +6 -87
  8. package/third-party/include/executorch/extension/llm/tokenizers/include/pytorch/tokenizers/hf_tokenizer.h +28 -176
  9. package/third-party/include/executorch/extension/llm/tokenizers/include/pytorch/tokenizers/map_utils.h +174 -0
  10. package/third-party/include/executorch/extension/llm/tokenizers/include/pytorch/tokenizers/model.h +151 -0
  11. package/third-party/include/executorch/extension/llm/tokenizers/include/pytorch/tokenizers/normalizer.h +55 -1
  12. package/third-party/include/executorch/extension/llm/tokenizers/include/pytorch/tokenizers/padding.h +112 -0
  13. package/third-party/include/executorch/extension/llm/tokenizers/include/pytorch/tokenizers/post_processor.h +101 -42
  14. package/third-party/include/executorch/extension/llm/tokenizers/include/pytorch/tokenizers/pre_tokenizer.h +25 -9
  15. package/third-party/include/executorch/extension/llm/tokenizers/include/pytorch/tokenizers/token_decoder.h +33 -6
  16. package/third-party/include/executorch/extension/llm/tokenizers/include/pytorch/tokenizers/tokenizer.h +2 -2
  17. package/third-party/include/executorch/extension/llm/tokenizers/include/pytorch/tokenizers/truncation.h +92 -0
  18. package/third-party/include/executorch/extension/llm/tokenizers/include/pytorch/tokenizers/wordpiece_model.h +74 -0
  19. package/third-party/ios/ExecutorchLib.xcframework/ios-arm64/ExecutorchLib.framework/ExecutorchLib +0 -0
  20. package/third-party/ios/ExecutorchLib.xcframework/ios-arm64-simulator/ExecutorchLib.framework/ExecutorchLib +0 -0
  21. package/common/rnexecutorch/tests/CMakeLists.txt +0 -253
  22. package/common/rnexecutorch/tests/README.md +0 -73
  23. package/common/rnexecutorch/tests/integration/BaseModelTest.cpp +0 -207
  24. package/common/rnexecutorch/tests/integration/BaseModelTests.h +0 -120
  25. package/common/rnexecutorch/tests/integration/ClassificationTest.cpp +0 -117
  26. package/common/rnexecutorch/tests/integration/ImageEmbeddingsTest.cpp +0 -122
  27. package/common/rnexecutorch/tests/integration/ImageSegmentationTest.cpp +0 -152
  28. package/common/rnexecutorch/tests/integration/LLMTest.cpp +0 -155
  29. package/common/rnexecutorch/tests/integration/OCRTest.cpp +0 -128
  30. package/common/rnexecutorch/tests/integration/ObjectDetectionTest.cpp +0 -135
  31. package/common/rnexecutorch/tests/integration/SpeechToTextTest.cpp +0 -97
  32. package/common/rnexecutorch/tests/integration/StyleTransferTest.cpp +0 -112
  33. package/common/rnexecutorch/tests/integration/TextEmbeddingsTest.cpp +0 -164
  34. package/common/rnexecutorch/tests/integration/TextToImageTest.cpp +0 -149
  35. package/common/rnexecutorch/tests/integration/TokenizerModuleTest.cpp +0 -98
  36. package/common/rnexecutorch/tests/integration/VerticalOCRTest.cpp +0 -238
  37. package/common/rnexecutorch/tests/integration/VoiceActivityDetectionTest.cpp +0 -99
  38. package/common/rnexecutorch/tests/integration/assets/test_audio_float.raw +0 -0
  39. package/common/rnexecutorch/tests/integration/assets/we_are_software_mansion.jpg +0 -0
  40. package/common/rnexecutorch/tests/integration/libs/libfbjni.so +0 -0
  41. package/common/rnexecutorch/tests/integration/stubs/jsi_stubs.cpp +0 -45
  42. package/common/rnexecutorch/tests/integration/utils/TestUtils.h +0 -36
  43. package/common/rnexecutorch/tests/run_tests.sh +0 -333
  44. package/common/rnexecutorch/tests/unit/FileUtilsTest.cpp +0 -32
  45. package/common/rnexecutorch/tests/unit/LogTest.cpp +0 -529
  46. package/common/rnexecutorch/tests/unit/NumericalTest.cpp +0 -107
@@ -1,529 +0,0 @@
1
- #include "../Log.h"
2
- #include <gtest/gtest.h>
3
-
4
- #include <array>
5
- #include <cmath>
6
- #include <complex>
7
- #include <deque>
8
- #include <forward_list>
9
- #include <fstream>
10
- #include <list>
11
- #include <map>
12
- #include <queue>
13
- #include <regex>
14
- #include <set>
15
- #include <stack>
16
- #include <stdexcept>
17
- #include <string_view>
18
- #include <unordered_map>
19
- #include <unordered_set>
20
- #include <vector>
21
-
22
- namespace low_level_log_implementation {
23
-
24
- class TestValue : public ::testing::Test {
25
- protected:
26
- TestValue() { oss << std::boolalpha; }
27
-
28
- template <typename T>
29
- void testValueViaComparison(const T &value,
30
- const std::string &expectedOutput) {
31
- printElement(oss, value);
32
- EXPECT_EQ(oss.str(), expectedOutput);
33
- clearOutputStream(oss);
34
- }
35
-
36
- template <typename T>
37
- void testValueViaRegex(const T &value, const std::string &expectedPattern) {
38
- printElement(oss, value);
39
- const std::regex pattern(expectedPattern);
40
- EXPECT_TRUE(std::regex_search(oss.str(), pattern))
41
- << "Expected pattern not found: " << expectedPattern;
42
- clearOutputStream(oss);
43
- }
44
-
45
- void setOutputStreamPresicion(int precision) noexcept {
46
- oss << std::fixed << std::setprecision(precision);
47
- }
48
-
49
- private:
50
- std::ostringstream oss;
51
- void clearOutputStream(std::ostringstream &os) noexcept {
52
- oss.str("");
53
- oss.clear();
54
- }
55
- };
56
-
57
- class DirectStreamableElementsPrintTest : public TestValue {};
58
-
59
- class ContainerPrintTest : public TestValue {};
60
-
61
- class NestedContainerPrintTest : public TestValue {};
62
-
63
- class EgdeCasesPrintTest : public TestValue {};
64
-
65
- class SmartPointerPrintTest : public TestValue {};
66
-
67
- class OptionalPrintTest : public TestValue {};
68
-
69
- class VariantPrintTest : public TestValue {};
70
-
71
- class ErrorHandlingPrintTest : public TestValue {};
72
-
73
- class FileSystemPrintTest : public TestValue {};
74
-
75
- class UnsupportedLoggingTest : public ::testing::Test {};
76
-
77
- class Point final {
78
- public:
79
- explicit constexpr Point(int x, int y) noexcept : x(x), y(y) {}
80
-
81
- // Overloading the << operator to make Point directly streamable
82
- friend std::ostream &operator<<(std::ostream &os, const Point &pt) noexcept {
83
- os << "Point(" << pt.x << ", " << pt.y << ")";
84
- return os;
85
- }
86
-
87
- private:
88
- int x, y;
89
- };
90
-
91
- TEST_F(DirectStreamableElementsPrintTest, HandlesIntegers) {
92
- testValueViaComparison(123, "123");
93
- }
94
-
95
- TEST_F(DirectStreamableElementsPrintTest, HandlesStrings) {
96
- testValueViaComparison(std::string("Hello World"), "Hello World");
97
- }
98
-
99
- TEST_F(DirectStreamableElementsPrintTest, HandlesStringViews) {
100
- testValueViaComparison(std::string_view("Hello World"), "Hello World");
101
- }
102
-
103
- TEST_F(DirectStreamableElementsPrintTest, HandlesDoubles) {
104
- constexpr double roughlyPi = 3.14159;
105
- testValueViaComparison(roughlyPi, "3.14159");
106
- }
107
-
108
- TEST_F(DirectStreamableElementsPrintTest, HandlesBooleans) {
109
- testValueViaComparison(true, "true");
110
- testValueViaComparison(false, "false");
111
- }
112
-
113
- TEST_F(DirectStreamableElementsPrintTest, HandlesChar) {
114
- testValueViaComparison('a', "a");
115
- }
116
-
117
- TEST_F(DirectStreamableElementsPrintTest, HandlesCharPointer) {
118
- const char *word = "Hello World";
119
- testValueViaComparison(word, "Hello World");
120
- }
121
-
122
- TEST_F(DirectStreamableElementsPrintTest, HandlesComplexNumbers) {
123
- using namespace std::complex_literals;
124
- constexpr int presision = 1;
125
- setOutputStreamPresicion(presision);
126
- const std::complex<double> complexNumber = std::pow(1i, 2);
127
- testValueViaComparison(complexNumber, "(-1.0,0.0)");
128
- }
129
-
130
- TEST_F(DirectStreamableElementsPrintTest, HandlesPoint) {
131
- constexpr Point point(3, 4);
132
- testValueViaComparison(point, "Point(3, 4)");
133
- }
134
-
135
- // log handles operator<<(&ostream) for std::pair
136
- TEST_F(DirectStreamableElementsPrintTest, HandlesStdPair) {
137
- std::pair<int, std::string> pairOfIntAndString = {42, "Hello"};
138
- testValueViaComparison(pairOfIntAndString, "(42, Hello)");
139
-
140
- // Testing edge cases with pairs
141
- const std::pair<std::string, std::string> emptyPair = {"", ""};
142
- testValueViaComparison(emptyPair, "(, )");
143
- }
144
-
145
- TEST_F(DirectStreamableElementsPrintTest, handlesStaticArrayOfChars) {
146
- constexpr char staticCharArray[] = "prompt tokens:";
147
- testValueViaComparison(staticCharArray, "prompt tokens:");
148
- }
149
-
150
- // log handles operator<<(&ostream) for std::tuple
151
- TEST_F(DirectStreamableElementsPrintTest, HandlesStdTuple) {
152
- const std::tuple<int, std::string, double> tupleOfDifferentTypes = {
153
- 42, "Tuple", 3.14};
154
- testValueViaComparison(tupleOfDifferentTypes, "<42, Tuple, 3.14>");
155
-
156
- // All empty or zero-initialized elements of tuple
157
- const std::tuple<std::string, int, float> zeroInitializedTuple = {"", 0,
158
- 0.0f};
159
- testValueViaComparison(zeroInitializedTuple, "<, 0, 0>");
160
-
161
- // Nested tuple
162
- const std::tuple<int, std::pair<std::string, bool>, float> nestedTuple = {
163
- 1, {"nested", true}, 2.5};
164
- testValueViaComparison(nestedTuple, "<1, (nested, true), 2.5>");
165
- }
166
-
167
- TEST_F(ContainerPrintTest, VectorIntTest) {
168
- const std::vector<int> vectorOfInts = {1, 2, 3, 4};
169
- testValueViaComparison(vectorOfInts, "[1, 2, 3, 4]");
170
- }
171
-
172
- TEST_F(ContainerPrintTest, ListDoubleTest) {
173
- const std::list<double> listOfDoubles = {1.1, 2.2, 3.3};
174
- testValueViaComparison(listOfDoubles, "[1.1, 2.2, 3.3]");
175
- }
176
-
177
- TEST_F(ContainerPrintTest, DequeStringTest) {
178
- const std::deque<std::string> dequeOfStrings = {"hello", "world"};
179
- testValueViaComparison(dequeOfStrings, "[hello, world]");
180
- }
181
-
182
- TEST_F(ContainerPrintTest, SetTest) {
183
- const std::set<std::string> setOfStrings = {"apple", "banana", "cherry"};
184
- testValueViaComparison(setOfStrings,
185
- "[apple, banana, cherry]"); // Note: Sets are sorted
186
- }
187
-
188
- TEST_F(ContainerPrintTest, MapTest) {
189
- const std::map<std::string, int> mapStringToInt = {{"one", 1}, {"two", 2}};
190
- testValueViaComparison(mapStringToInt, "[(one, 1), (two, 2)]");
191
- }
192
-
193
- TEST_F(ContainerPrintTest, HandlesUnorderedSet) {
194
- const std::unordered_set<int> unorderedSetOfInts = {4, 3, 2, 1};
195
- // Pattern expects to find each element at least once in any order
196
- testValueViaRegex(unorderedSetOfInts, R"(.*1.*2.*3.*4.*)");
197
- }
198
-
199
- TEST_F(ContainerPrintTest, HandlesUnorderedMultimap) {
200
- const std::unordered_multimap<std::string, int> unorderedMultimapStringToInt =
201
- {{"one", 1}, {"one", 2}, {"two", 2}};
202
- std::string pattern = R"(\[\s*)";
203
- // construct regex by adding each permutation
204
- pattern += R"((?:\(one, 1\),\s*\(one, 2\),\s*\(two, 2\)|)";
205
- pattern += R"(\(one, 1\),\s*\(two, 2\),\s*\(one, 2\)|)";
206
- pattern += R"(\(one, 2\),\s*\(one, 1\),\s*\(two, 2\)|)";
207
- pattern += R"(\(one, 2\),\s*\(two, 2\),\s*\(one, 1\)|)";
208
- pattern += R"(\(two, 2\),\s*\(one, 1\),\s*\(one, 2\)|)";
209
- pattern += R"(\(two, 2\),\s*\(one, 2\),\s*\(one, 1\))\s*\])";
210
-
211
- testValueViaRegex(unorderedMultimapStringToInt, pattern);
212
- }
213
-
214
- TEST_F(ContainerPrintTest, StackTest) {
215
- std::stack<int> stackOfInts;
216
- stackOfInts.push(1);
217
- stackOfInts.push(2);
218
- stackOfInts.push(3);
219
- testValueViaComparison(stackOfInts, "[3, 2, 1]"); // LIFO order
220
- }
221
-
222
- TEST_F(ContainerPrintTest, QueueTest) {
223
- std::queue<int> queueOfInts;
224
- queueOfInts.push(1);
225
- queueOfInts.push(2);
226
- queueOfInts.push(3);
227
- testValueViaComparison(queueOfInts, "[1, 2, 3]"); // FIFO order
228
- }
229
-
230
- TEST_F(ContainerPrintTest, PriorityQueueTest) {
231
- std::priority_queue<int> priorityQueueOfInts;
232
- priorityQueueOfInts.push(3);
233
- priorityQueueOfInts.push(1);
234
- priorityQueueOfInts.push(2);
235
- testValueViaComparison(priorityQueueOfInts,
236
- "[3, 2, 1]"); // Output based on internal max-heap
237
- }
238
-
239
- TEST_F(ContainerPrintTest, HandlesArray) {
240
- constexpr std::array<int, 3> arrayOfInts = {1, 2, 3};
241
- testValueViaComparison(arrayOfInts, "[1, 2, 3]");
242
- }
243
-
244
- TEST_F(ContainerPrintTest, HandlesForwardList) {
245
- const std::forward_list<int> forwardListOfInts = {1, 2, 3};
246
- testValueViaComparison(forwardListOfInts, "[1, 2, 3]");
247
- }
248
-
249
- TEST_F(ContainerPrintTest, HandlesMultiset) {
250
- const std::multiset<int> multisetOfInts = {3, 2, 1, 2};
251
- testValueViaComparison(multisetOfInts,
252
- "[1, 2, 2, 3]"); // Multiset elements are sorted
253
- }
254
-
255
- TEST_F(ContainerPrintTest, HandlesMultimap) {
256
- const std::multimap<std::string, int> multimapStringToInt = {
257
- {"one", 1}, {"one", 2}, {"two", 2}};
258
- testValueViaComparison(multimapStringToInt, "[(one, 1), (one, 2), (two, 2)]");
259
- }
260
-
261
- TEST_F(ContainerPrintTest, HandlesSpan) {
262
- std::vector<int> vectorOfInts = {1, 2, 3, 4};
263
- const std::span<int> spanOnVector(
264
- vectorOfInts.begin(), vectorOfInts.end()); // Create a span from a vector
265
- testValueViaComparison(spanOnVector, "[1, 2, 3, 4]");
266
- }
267
-
268
- TEST_F(ContainerPrintTest, HandlesStaticArray) {
269
- constexpr int staticArray[] = {1, 2, 3, 4, 5};
270
- testValueViaComparison(staticArray, "[1, 2, 3, 4, 5]");
271
- }
272
-
273
- TEST_F(NestedContainerPrintTest, HandlesListOfQueuesOfPoints) {
274
- std::list<std::queue<Point>> listOfQueues = {std::queue<Point>()};
275
- listOfQueues.front().push(Point(1, 1));
276
- listOfQueues.front().push(Point(2, 2));
277
- listOfQueues.front().push(Point(3, 3));
278
- testValueViaComparison(listOfQueues,
279
- "[[Point(1, 1), Point(2, 2), Point(3, 3)]]");
280
- }
281
-
282
- TEST_F(NestedContainerPrintTest, HandlesNestedVectors) {
283
- const std::vector<std::vector<int>> nestedVector = {{1, 2}, {3, 4, 5}};
284
- testValueViaComparison(nestedVector, "[[1, 2], [3, 4, 5]]");
285
- }
286
-
287
- TEST_F(NestedContainerPrintTest, HandlesMapOfVectorOfPoints) {
288
- const std::map<std::string, std::vector<Point>> mapOfVectors = {
289
- {"first", {Point(1, 2)}}, {"second", {Point(3, 4), Point(5, 6)}}};
290
- testValueViaComparison(
291
- mapOfVectors,
292
- "[(first, [Point(1, 2)]), (second, [Point(3, 4), Point(5, 6)])]");
293
- }
294
-
295
- TEST_F(NestedContainerPrintTest, HandlesVectorOfMaps) {
296
- const std::vector<std::map<std::string, int>> vectorOfMaps = {
297
- {{"one", 1}, {"two", 2}}, {{"three", 3}, {"four", 4}}};
298
- // word "three" is lexicographically smaller than "four"
299
- testValueViaComparison(vectorOfMaps,
300
- "[[(one, 1), (two, 2)], [(four, 4), (three, 3)]]");
301
- }
302
-
303
- TEST_F(NestedContainerPrintTest, HandlesComplexNestedStructures) {
304
- const std::vector<std::map<std::string, std::list<std::set<int>>>>
305
- complexNested = {{{"first", {{1, 2}, {3}}}, {"second", {{4}}}}};
306
- testValueViaComparison(complexNested,
307
- "[[(first, [[1, 2], [3]]), (second, [[4]])]]");
308
- }
309
-
310
- TEST_F(EgdeCasesPrintTest, HandleEmptyContainer) {
311
- const std::vector<int> emptyVector{};
312
- testValueViaComparison(emptyVector, "[]");
313
- }
314
-
315
- TEST_F(SmartPointerPrintTest, HandlesSharedPtr) {
316
- const auto sharedPointer = std::make_shared<int>(10);
317
- testValueViaComparison(sharedPointer, "10");
318
- }
319
-
320
- TEST_F(SmartPointerPrintTest, HandlesWeakPtr) {
321
- auto sharedPointer = std::make_shared<int>(20);
322
- std::weak_ptr<int> weakPointer = sharedPointer;
323
- testValueViaComparison(weakPointer, "20");
324
-
325
- sharedPointer.reset(); // Reset shared_ptr to make the weak_ptr expire
326
- testValueViaComparison(weakPointer,
327
- "expired"); // Test after the weak pointer has expired
328
- }
329
-
330
- TEST_F(SmartPointerPrintTest, HandlesUniquePtr) {
331
- const auto uniquePointer = std::make_unique<int>(30);
332
- testValueViaComparison(uniquePointer, "30");
333
- }
334
-
335
- TEST_F(OptionalPrintTest, HandlesOptional) {
336
- std::optional<int> optionalInt{40};
337
- testValueViaComparison(optionalInt, "Optional(40)");
338
- optionalInt.reset();
339
- testValueViaComparison(optionalInt, "nullopt");
340
- }
341
-
342
- TEST_F(VariantPrintTest, HandlesVariant) {
343
- std::variant<int, std::string> variantIntOrString = 10;
344
- testValueViaComparison(variantIntOrString, "Variant(10)");
345
- variantIntOrString = "Hello";
346
- testValueViaComparison(variantIntOrString, "Variant(Hello)");
347
- }
348
-
349
- TEST_F(ErrorHandlingPrintTest, HandlesErrorCode) {
350
- const auto errorCodeValue =
351
- std::make_error_code(std::errc::function_not_supported).value();
352
- const std::error_code errorCode =
353
- make_error_code(std::errc::function_not_supported);
354
- testValueViaComparison(
355
- errorCode, "ErrorCode(" + std::to_string(errorCodeValue) + ", generic)");
356
- }
357
-
358
- TEST_F(ErrorHandlingPrintTest, HandlesExceptionPtr) {
359
- try {
360
- throw std::runtime_error("test error");
361
- } catch (...) {
362
- const std::exception_ptr exceptionPointer = std::current_exception();
363
- testValueViaComparison(exceptionPointer, "ExceptionPtr(\"test error\")");
364
- }
365
- }
366
-
367
- TEST_F(FileSystemPrintTest, HandlesPath) {
368
- const std::filesystem::path filePath = "/path/to/some/file.txt";
369
- testValueViaComparison(filePath, "Path(\"/path/to/some/file.txt\")");
370
- }
371
-
372
- TEST_F(FileSystemPrintTest, HandlesDirectoryIterator) {
373
- // Setup a temporary directory and files within
374
- std::filesystem::path directory =
375
- std::filesystem::temp_directory_path() / "test_dir";
376
- std::filesystem::create_directory(directory);
377
-
378
- std::ofstream(directory / "file1.txt");
379
- std::ofstream(directory / "file2.txt");
380
-
381
- std::filesystem::directory_iterator begin(directory);
382
-
383
- testValueViaRegex(
384
- begin,
385
- R"(Directory\["file1.txt", "file2.txt"\]|Directory\["file2.txt", "file1.txt"\])");
386
-
387
- // Cleanup
388
- std::filesystem::remove_all(directory);
389
- }
390
-
391
- TEST_F(UnsupportedLoggingTest, TestLoggingUnsupportedType) {
392
- std::ostringstream oss;
393
- class UnsupportedClass {};
394
- const auto x = UnsupportedClass();
395
-
396
- ASSERT_THROW({ printElement(oss, x); }, std::runtime_error);
397
- }
398
-
399
- } // namespace low_level_log_implementation
400
-
401
- namespace rnexecutorch {
402
-
403
- namespace high_level_log_implementation {
404
-
405
- class BufferTest : public ::testing::Test {
406
- protected:
407
- // Helper to validate the final output
408
- void validateBuffer(const std::string &result, const std::string &expected,
409
- std::size_t expectedSize) {
410
- EXPECT_EQ(result, expected);
411
- EXPECT_EQ(result.size(), expectedSize);
412
- if (result.size() > expected.size()) {
413
- EXPECT_EQ(result.substr(expected.size()), "...");
414
- }
415
- }
416
- };
417
-
418
- TEST_F(BufferTest, MessageShorterThanLimit) {
419
- constexpr std::size_t smallLogLimit = 20;
420
- const std::string message = "Short message";
421
- auto result = getBuffer(message, smallLogLimit);
422
- validateBuffer(result, message, message.size());
423
- }
424
-
425
- TEST_F(BufferTest, MessageExactlyAtLimit) {
426
- // Creating a string with 1024 'a' characters
427
- constexpr std::size_t defaultLogLimit = 1024;
428
- const std::string message(defaultLogLimit, 'a');
429
- auto result = getBuffer(message, defaultLogLimit);
430
- validateBuffer(result, message, message.size());
431
- }
432
-
433
- TEST_F(BufferTest, MessageLongerThanLimit) {
434
- constexpr std::size_t defaultLogLimit = 1024;
435
- constexpr std::size_t sizeAboveLimit = 1050;
436
- // Creating a string longer than the limit
437
- const std::string message(sizeAboveLimit, 'a');
438
- const auto expected = std::string(defaultLogLimit, 'a') + "...";
439
- const auto result = getBuffer(message, defaultLogLimit);
440
- validateBuffer(result, expected, expected.size());
441
- }
442
-
443
- } // namespace high_level_log_implementation
444
-
445
- class LoggingTest : public ::testing::Test {
446
- protected:
447
- template <typename T>
448
- void testLoggingDoesNotChangeContainer(const T &original) {
449
- const auto copy = original; // Make a copy of the container
450
- log(LOG_LEVEL::Info, original);
451
- ASSERT_TRUE(check_if_same_content(original, copy))
452
- << "Logging modified the content of the container.";
453
- }
454
-
455
- private:
456
- // == op for smart pointers compare addresses, check content maunally
457
- template <typename T>
458
- bool check_if_same_content(const std::shared_ptr<T> &a,
459
- const std::shared_ptr<T> &b) const noexcept {
460
- if (!a || !b) {
461
- return a == b;
462
- }
463
- return *a == *b;
464
- }
465
-
466
- template <typename T>
467
- bool check_if_same_content(const T &original, const T &after) const noexcept {
468
- // Requires that T has an equality operator (operator==)
469
- return original == after;
470
- }
471
- };
472
-
473
- TEST_F(LoggingTest, LoggingDoesNotChangeSharedPtr) {
474
- const auto original = std::make_shared<int>(42);
475
- testLoggingDoesNotChangeContainer(original);
476
- }
477
-
478
- TEST_F(LoggingTest, LoggingDoesNotChangeQueue) {
479
- std::queue<int> original;
480
- original.push(1);
481
- original.push(2);
482
- original.push(3);
483
- testLoggingDoesNotChangeContainer(original);
484
- }
485
-
486
- TEST_F(LoggingTest, LoggingDoesNotChangeVector) {
487
- const std::vector<int> original = {1, 2, 3, 4, 5};
488
- testLoggingDoesNotChangeContainer(original);
489
- }
490
-
491
- TEST(LogFunctionTest, LoggingBasic) {
492
- EXPECT_NO_THROW(log(LOG_LEVEL::Debug, "Test123"));
493
- }
494
-
495
- TEST(LogFunctionTest, LoggingWithNonDefaultLogSize) {
496
- constexpr std::size_t sizeBiggerThanDefault = 2048;
497
- const auto testString = std::string(sizeBiggerThanDefault, 'a');
498
- EXPECT_NO_THROW(log<sizeBiggerThanDefault>(LOG_LEVEL::Info, testString));
499
- }
500
-
501
- TEST(LogFunctionTest, LoggingMoreThanOneElement) {
502
- constexpr auto testStringLiteral = "Test123";
503
- const auto testVector = std::vector<int>{1, 2, 3, 4};
504
- const auto testPair = std::pair<int, double>(1, 2.0);
505
- EXPECT_NO_THROW(
506
- log(LOG_LEVEL::Debug, testStringLiteral, testVector, testPair));
507
- }
508
-
509
- TEST(MovingSequencable, MovingSequencableTest) {
510
- std::priority_queue<int> q;
511
- q.push(1);
512
- q.push(2);
513
- q.push(3);
514
-
515
- log(LOG_LEVEL::Debug, q);
516
- ASSERT_EQ(q.size(), 3);
517
- const auto &cq = q;
518
- log(LOG_LEVEL::Debug, cq);
519
- ASSERT_EQ(cq.size(), 3);
520
- log(LOG_LEVEL::Debug, std::move(q));
521
- ASSERT_EQ(q.size(), 0);
522
- }
523
-
524
- } // namespace rnexecutorch
525
-
526
- int main(int argc, char **argv) {
527
- ::testing::InitGoogleTest(&argc, argv);
528
- return RUN_ALL_TESTS();
529
- }
@@ -1,107 +0,0 @@
1
- #include "../data_processing/Numerical.h"
2
- #include <gtest/gtest.h>
3
- #include <limits>
4
- #include <rnexecutorch/Error.h>
5
- #include <span>
6
- #include <vector>
7
-
8
- namespace rnexecutorch::numerical {
9
-
10
- // Helper function to check if two float vectors are approximately equal
11
- void expect_vectors_eq(const std::vector<float> &vector1,
12
- const std::vector<float> &vector2,
13
- float atol = 1.0e-6F) {
14
- ASSERT_EQ(vector1.size(), vector2.size());
15
- for (size_t i = 0; i < vector1.size(); i++) {
16
- EXPECT_NEAR(vector1[i], vector2[i], atol);
17
- }
18
- }
19
-
20
- TEST(SoftmaxTests, SoftmaxBasic) {
21
- std::vector<float> input = {1.0F, 2.0F, 3.0F};
22
- softmax(input);
23
- const std::vector<float> expected = {0.09003057F, 0.24472847F, 0.66524095F};
24
- expect_vectors_eq(input, expected);
25
- }
26
-
27
- TEST(SoftmaxTests, SoftmaxWithBigValues) {
28
- std::vector<float> input = {100000.0F, 100000.0F, 100000.0F};
29
- softmax(input);
30
- const std::vector<float> expected = {0.3333333F, 0.3333333F, 0.3333333F};
31
- expect_vectors_eq(input, expected);
32
- }
33
-
34
- TEST(SoftmaxTests, SoftmaxOfEmptyVector) {
35
- std::vector<float> emptyVector{};
36
- EXPECT_NO_THROW(softmax(emptyVector));
37
- }
38
-
39
- TEST(NormalizeTests, NormalizeBasic) {
40
- std::vector<float> input = {1.0F, 2.0F, 3.0F};
41
- normalize(input);
42
- const auto normOfInput = std::sqrtf(14.0F);
43
- const std::vector<float> expected = {1.0F / normOfInput, 2.0F / normOfInput,
44
- 3.0F / normOfInput};
45
- expect_vectors_eq(input, expected);
46
- }
47
-
48
- TEST(NormalizeTests, NormalizationOfExtremelySmallValues) {
49
- constexpr auto epsilon = std::numeric_limits<float>::epsilon();
50
- std::vector<float> input(3, epsilon);
51
- const auto normOfInput = std::sqrtf(3.0F);
52
- const std::vector<float> expected(3, 1.0F / normOfInput);
53
- normalize(input);
54
- expect_vectors_eq(input, expected);
55
- }
56
-
57
- TEST(NormalizeTests, NormalizationOfZeroVector) {
58
- std::vector<float> zeroVector(3, 0.0F);
59
- EXPECT_NO_THROW(normalize(zeroVector));
60
- }
61
-
62
- TEST(NormalizeTests, NormalizationOfEmptyVector) {
63
- std::vector<float> emptyVector{};
64
- EXPECT_NO_THROW(normalize(emptyVector));
65
- }
66
-
67
- TEST(MeanPoolingTests, MeanPoolingBasic) {
68
- const std::vector<float> modelOutputVec = {1.0F, 2.0F, 3.0F,
69
- 4.0F, 5.0F, 6.0F};
70
- const std::vector<int64_t> attnMaskVec = {1, 1, 0};
71
-
72
- std::span<const float> modelOutput(modelOutputVec);
73
- std::span<const int64_t> attnMask(attnMaskVec);
74
-
75
- const auto result = meanPooling(modelOutput, attnMask);
76
- const std::vector<float> expected = {2.0F, 3.0F};
77
- expect_vectors_eq(result, expected);
78
- }
79
-
80
- TEST(MeanPoolingTests, MeanPoolingWithZeroAttentionMask) {
81
- const std::vector<float> modelOutputVec = {1.0F, 2.0F, 3.0F,
82
- 4.0F, 5.0F, 6.0F};
83
- const std::vector<int64_t> attnMaskVec = {0, 0, 0};
84
-
85
- std::span<const float> modelOutput(modelOutputVec);
86
- std::span<const int64_t> attnMask(attnMaskVec);
87
-
88
- const auto result = meanPooling(modelOutput, attnMask);
89
- const std::vector<float> expected = {0.0F, 0.0F};
90
- expect_vectors_eq(result, expected);
91
- }
92
-
93
- TEST(MeanPoolingTests, InvalidDimensionSize) {
94
- const std::vector<float> modelOutput = {1.0F, 2.0F, 3.0F, 4.0F};
95
- const std::vector<int64_t> attnMask = {1, 1, 1};
96
-
97
- EXPECT_THROW({ meanPooling(modelOutput, attnMask); }, RnExecutorchError);
98
- }
99
-
100
- TEST(MeanPoolingTests, EmptyAttentionMask) {
101
- const std::vector<float> modelOutput = {1.0F, 2.0F, 3.0F, 4.0F};
102
- const std::vector<int64_t> attnMask = {};
103
-
104
- EXPECT_THROW({ meanPooling(modelOutput, attnMask); }, RnExecutorchError);
105
- }
106
-
107
- } // namespace rnexecutorch::numerical