react-native-litert-lm 0.3.7 → 0.4.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (86) hide show
  1. package/README.md +153 -135
  2. package/android/build.gradle +12 -0
  3. package/android/src/main/AndroidManifest.xml +5 -0
  4. package/android/src/main/java/com/margelo/nitro/dev/litert/litertlm/HybridLiteRTLM.kt +159 -62
  5. package/android/src/main/java/dev/litert/litertlm/LiteRTLMPackage.kt +19 -2
  6. package/android/src/test/java/com/margelo/nitro/core/Promise.kt +46 -0
  7. package/android/src/test/java/com/margelo/nitro/dev/litert/litertlm/HybridLiteRTLMTest.kt +83 -0
  8. package/ios/HybridLiteRTLM.swift +1058 -0
  9. package/ios/Tests/HybridLiteRTLMTests.swift +67 -0
  10. package/lib/__mocks__/react-native-nitro-modules.d.ts +61 -0
  11. package/lib/__mocks__/react-native-nitro-modules.js +50 -0
  12. package/lib/__tests__/hooks.test.d.ts +1 -0
  13. package/lib/__tests__/hooks.test.js +124 -0
  14. package/lib/__tests__/memoryTracker.test.d.ts +1 -0
  15. package/lib/__tests__/memoryTracker.test.js +74 -0
  16. package/lib/__tests__/modelFactory.test.d.ts +1 -0
  17. package/lib/__tests__/modelFactory.test.js +52 -0
  18. package/lib/hooks.js +1 -1
  19. package/lib/index.d.ts +0 -2
  20. package/lib/index.js +1 -5
  21. package/lib/modelFactory.js +62 -63
  22. package/lib/specs/LiteRTLM.nitro.d.ts +71 -2
  23. package/nitrogen/generated/android/c++/JHybridLiteRTLMSpec.cpp +62 -7
  24. package/nitrogen/generated/android/c++/JHybridLiteRTLMSpec.hpp +3 -1
  25. package/nitrogen/generated/android/c++/JLLMConfig.hpp +40 -3
  26. package/nitrogen/generated/android/c++/JMultimodalPart.hpp +74 -0
  27. package/nitrogen/generated/android/c++/JPartType.hpp +61 -0
  28. package/nitrogen/generated/android/c++/JToolDefinition.hpp +65 -0
  29. package/nitrogen/generated/android/kotlin/com/margelo/nitro/dev/litert/litertlm/GenerationStats.kt +23 -0
  30. package/nitrogen/generated/android/kotlin/com/margelo/nitro/dev/litert/litertlm/HybridLiteRTLMSpec.kt +10 -2
  31. package/nitrogen/generated/android/kotlin/com/margelo/nitro/dev/litert/litertlm/LLMConfig.kt +46 -3
  32. package/nitrogen/generated/android/kotlin/com/margelo/nitro/dev/litert/litertlm/MemoryUsage.kt +19 -0
  33. package/nitrogen/generated/android/kotlin/com/margelo/nitro/dev/litert/litertlm/Message.kt +15 -0
  34. package/nitrogen/generated/android/kotlin/com/margelo/nitro/dev/litert/litertlm/MultimodalPart.kt +66 -0
  35. package/nitrogen/generated/android/kotlin/com/margelo/nitro/dev/litert/litertlm/PartType.kt +24 -0
  36. package/nitrogen/generated/android/kotlin/com/margelo/nitro/dev/litert/litertlm/ToolDefinition.kt +61 -0
  37. package/nitrogen/generated/ios/LiteRTLM-Swift-Cxx-Bridge.cpp +57 -1
  38. package/nitrogen/generated/ios/LiteRTLM-Swift-Cxx-Bridge.hpp +414 -3
  39. package/nitrogen/generated/ios/LiteRTLM-Swift-Cxx-Umbrella.hpp +41 -3
  40. package/nitrogen/generated/ios/LiteRTLMAutolinking.mm +4 -6
  41. package/nitrogen/generated/ios/LiteRTLMAutolinking.swift +10 -0
  42. package/nitrogen/generated/ios/c++/HybridLiteRTLMSpecSwift.cpp +11 -0
  43. package/nitrogen/generated/ios/c++/HybridLiteRTLMSpecSwift.hpp +224 -0
  44. package/nitrogen/generated/ios/swift/Backend.swift +44 -0
  45. package/nitrogen/generated/ios/swift/Func_void.swift +46 -0
  46. package/nitrogen/generated/ios/swift/Func_void_double.swift +46 -0
  47. package/nitrogen/generated/ios/swift/Func_void_std__exception_ptr.swift +46 -0
  48. package/nitrogen/generated/ios/swift/Func_void_std__string.swift +46 -0
  49. package/nitrogen/generated/ios/swift/Func_void_std__string_bool.swift +46 -0
  50. package/nitrogen/generated/ios/swift/GenerationStats.swift +54 -0
  51. package/nitrogen/generated/ios/swift/HybridLiteRTLMSpec.swift +69 -0
  52. package/nitrogen/generated/ios/swift/HybridLiteRTLMSpec_cxx.swift +383 -0
  53. package/nitrogen/generated/ios/swift/LLMConfig.swift +203 -0
  54. package/nitrogen/generated/ios/swift/MemoryUsage.swift +44 -0
  55. package/nitrogen/generated/ios/swift/Message.swift +34 -0
  56. package/nitrogen/generated/ios/swift/MultimodalPart.swift +83 -0
  57. package/nitrogen/generated/ios/swift/PartType.swift +44 -0
  58. package/nitrogen/generated/ios/swift/Role.swift +44 -0
  59. package/nitrogen/generated/ios/swift/ToolDefinition.swift +39 -0
  60. package/nitrogen/generated/shared/c++/HybridLiteRTLMSpec.cpp +2 -0
  61. package/nitrogen/generated/shared/c++/HybridLiteRTLMSpec.hpp +7 -2
  62. package/nitrogen/generated/shared/c++/LLMConfig.hpp +22 -2
  63. package/nitrogen/generated/shared/c++/MultimodalPart.hpp +99 -0
  64. package/nitrogen/generated/shared/c++/PartType.hpp +80 -0
  65. package/nitrogen/generated/shared/c++/ToolDefinition.hpp +91 -0
  66. package/package.json +16 -8
  67. package/react-native-litert-lm.podspec +15 -19
  68. package/scripts/download-ios-frameworks.sh +14 -48
  69. package/scripts/postinstall.js +1 -2
  70. package/src/__mocks__/react-native-nitro-modules.ts +48 -0
  71. package/src/__tests__/hooks.test.ts +153 -0
  72. package/src/__tests__/memoryTracker.test.ts +87 -0
  73. package/src/__tests__/modelFactory.test.ts +68 -0
  74. package/src/hooks.ts +1 -1
  75. package/src/index.ts +0 -7
  76. package/src/modelFactory.ts +82 -80
  77. package/src/specs/LiteRTLM.nitro.ts +80 -2
  78. package/cpp/HybridLiteRTLM.cpp +0 -939
  79. package/cpp/HybridLiteRTLM.hpp +0 -169
  80. package/cpp/IOSDownloadHelper.h +0 -24
  81. package/ios/IOSDownloadHelper.mm +0 -129
  82. package/scripts/build-ios-engine.sh +0 -302
  83. package/scripts/stubs/cxx_bridge_stubs.cc +0 -224
  84. package/scripts/stubs/gemma_model_constraint_provider.cc +0 -46
  85. package/scripts/stubs/llguidance_stubs.c +0 -101
  86. package/src/templates.ts +0 -105
@@ -14,6 +14,34 @@ export type Backend = "cpu" | "gpu" | "npu";
14
14
  * Message roles for conversation.
15
15
  */
16
16
  export type Role = "user" | "model" | "system";
17
+ /**
18
+ * Definition for a function/tool that the model can request to execute.
19
+ */
20
+ export interface ToolDefinition {
21
+ /** Name of the function/tool */
22
+ name: string;
23
+ /** Human-readable description of what the function/tool does */
24
+ description: string;
25
+ /** JSON schema defining parameter names and types (stringified) */
26
+ parametersJson: string;
27
+ }
28
+ /**
29
+ * The part type for a multimodal message content part.
30
+ */
31
+ export type PartType = "text" | "image" | "audio";
32
+ /**
33
+ * A part of a unified multimodal message payload.
34
+ */
35
+ export interface MultimodalPart {
36
+ /** The part type: 'text', 'image', or 'audio' */
37
+ type: PartType;
38
+ /** The plain text content, if type is 'text' */
39
+ text?: string;
40
+ /** Raw image binary data, if type is 'image' (zero-copy ArrayBuffer mapping) */
41
+ imageBuffer?: ArrayBuffer;
42
+ /** Raw audio binary data, if type is 'audio' (zero-copy ArrayBuffer mapping) */
43
+ audioBuffer?: ArrayBuffer;
44
+ }
17
45
  /**
18
46
  * Configuration options for loading an LLM.
19
47
  */
@@ -60,6 +88,37 @@ export interface LLMConfig {
60
88
  * @default 0.95
61
89
  */
62
90
  topP?: number;
91
+ /**
92
+ * Whether to run engine validation after loading the model.
93
+ * When enabled, sends a quick test inference ("Hi") and waits up to 30s
94
+ * for a response to confirm the backend works. This is useful for GPU/NPU
95
+ * backends that may silently fail during inference (they can initialize
96
+ * without error but produce no tokens).
97
+ *
98
+ * Validation is **always a no-op on CPU** — the CPU backend is inherently
99
+ * reliable and never needs validation.
100
+ *
101
+ * Disabled by default because it adds significant latency (5-30s) to model loading.
102
+ * Enable only to catch GPU/NPU silent failure issues during development.
103
+ *
104
+ * @default false
105
+ */
106
+ validate?: boolean;
107
+ /**
108
+ * Whether this is a multimodal model.
109
+ * When enabled, the engine handles image/audio tokens properly.
110
+ * If not specified, the system will fall back to filename sniffing.
111
+ */
112
+ multimodal?: boolean;
113
+ /**
114
+ * List of tools/functions that the model can call.
115
+ */
116
+ tools?: ToolDefinition[];
117
+ /**
118
+ * Whether to enable speculative decoding (multi-token prediction) if supported by the model.
119
+ * @default false
120
+ */
121
+ enableSpeculativeDecoding?: boolean;
63
122
  }
64
123
  /**
65
124
  * A simple message in the conversation.
@@ -123,7 +182,7 @@ export interface MemoryUsage {
123
182
  * ```
124
183
  */
125
184
  export interface LiteRTLM extends HybridObject<{
126
- ios: "c++";
185
+ ios: "swift";
127
186
  android: "kotlin";
128
187
  }> {
129
188
  /**
@@ -165,13 +224,19 @@ export interface LiteRTLM extends HybridObject<{
165
224
  * @returns The model's response text.
166
225
  */
167
226
  sendMessageWithAudio(message: string, audioPath: string): Promise<string>;
227
+ /**
228
+ * Send a unified multimodal message containing text and/or zero-copy binary buffers.
229
+ * @param parts The message content parts (text, image, and/or audio).
230
+ * @returns The model's response text.
231
+ */
232
+ sendMultimodalMessage(parts: MultimodalPart[]): Promise<string>;
168
233
  /**
169
234
  * Send a message with streaming response.
170
235
  * Tokens are delivered via callback as they are generated.
171
236
  * @param message User message text.
172
237
  * @param onToken Callback invoked for each token (token, isDone).
173
238
  */
174
- sendMessageAsync(message: string, onToken: (token: string, done: boolean) => void): void;
239
+ sendMessageAsync(message: string, onToken: (token: string, done: boolean) => void): Promise<void>;
175
240
  /**
176
241
  * Get the current conversation history.
177
242
  * @returns Array of messages in the conversation.
@@ -189,6 +254,10 @@ export interface LiteRTLM extends HybridObject<{
189
254
  * Get the last generation statistics.
190
255
  */
191
256
  getStats(): GenerationStats;
257
+ /**
258
+ * Count tokens in a text string. Returns -1 if unavailable.
259
+ */
260
+ countTokens(text: string): number;
192
261
  /**
193
262
  * Get real memory usage from the native runtime.
194
263
  * Uses OS-level APIs to report actual memory consumption.
@@ -19,6 +19,12 @@ namespace margelo::nitro::litertlm { struct MemoryUsage; }
19
19
  namespace margelo::nitro::litertlm { struct LLMConfig; }
20
20
  // Forward declaration of `Backend` to properly resolve imports.
21
21
  namespace margelo::nitro::litertlm { enum class Backend; }
22
+ // Forward declaration of `ToolDefinition` to properly resolve imports.
23
+ namespace margelo::nitro::litertlm { struct ToolDefinition; }
24
+ // Forward declaration of `MultimodalPart` to properly resolve imports.
25
+ namespace margelo::nitro::litertlm { struct MultimodalPart; }
26
+ // Forward declaration of `PartType` to properly resolve imports.
27
+ namespace margelo::nitro::litertlm { enum class PartType; }
22
28
 
23
29
  #include <NitroModules/Promise.hpp>
24
30
  #include <NitroModules/JPromise.hpp>
@@ -38,9 +44,17 @@ namespace margelo::nitro::litertlm { enum class Backend; }
38
44
  #include "JLLMConfig.hpp"
39
45
  #include "Backend.hpp"
40
46
  #include "JBackend.hpp"
47
+ #include "ToolDefinition.hpp"
48
+ #include "JToolDefinition.hpp"
41
49
  #include <functional>
42
50
  #include "JFunc_void_double.hpp"
43
51
  #include <NitroModules/JNICallable.hpp>
52
+ #include "MultimodalPart.hpp"
53
+ #include "JMultimodalPart.hpp"
54
+ #include "PartType.hpp"
55
+ #include "JPartType.hpp"
56
+ #include <NitroModules/ArrayBuffer.hpp>
57
+ #include <NitroModules/JArrayBuffer.hpp>
44
58
  #include "JFunc_void_std__string_bool.hpp"
45
59
 
46
60
  namespace margelo::nitro::litertlm {
@@ -170,23 +184,59 @@ namespace margelo::nitro::litertlm {
170
184
  return __promise;
171
185
  }();
172
186
  }
173
- void JHybridLiteRTLMSpec::sendMessageAsync(const std::string& message, const std::function<void(const std::string& /* token */, bool /* done */)>& onToken) {
174
- static const auto method = _javaPart->javaClassStatic()->getMethod<void(jni::alias_ref<jni::JString> /* message */, jni::alias_ref<JFunc_void_std__string_bool::javaobject> /* onToken */)>("sendMessageAsync_cxx");
175
- method(_javaPart, jni::make_jstring(message), JFunc_void_std__string_bool_cxx::fromCpp(onToken));
187
+ std::shared_ptr<Promise<std::string>> JHybridLiteRTLMSpec::sendMultimodalMessage(const std::vector<MultimodalPart>& parts) {
188
+ static const auto method = _javaPart->javaClassStatic()->getMethod<jni::local_ref<JPromise::javaobject>(jni::alias_ref<jni::JArrayClass<JMultimodalPart>> /* parts */)>("sendMultimodalMessage");
189
+ auto __result = method(_javaPart, [&](auto&& __input) {
190
+ size_t __size = __input.size();
191
+ jni::local_ref<jni::JArrayClass<JMultimodalPart>> __array = jni::JArrayClass<JMultimodalPart>::newArray(__size);
192
+ for (size_t __i = 0; __i < __size; __i++) {
193
+ const auto& __element = __input[__i];
194
+ auto __elementJni = JMultimodalPart::fromCpp(__element);
195
+ __array->setElement(__i, *__elementJni);
196
+ }
197
+ return __array;
198
+ }(parts));
199
+ return [&]() {
200
+ auto __promise = Promise<std::string>::create();
201
+ __result->cthis()->addOnResolvedListener([=](const jni::alias_ref<jni::JObject>& __boxedResult) {
202
+ auto __result = jni::static_ref_cast<jni::JString>(__boxedResult);
203
+ __promise->resolve(__result->toStdString());
204
+ });
205
+ __result->cthis()->addOnRejectedListener([=](const jni::alias_ref<jni::JThrowable>& __throwable) {
206
+ jni::JniException __jniError(__throwable);
207
+ __promise->reject(std::make_exception_ptr(__jniError));
208
+ });
209
+ return __promise;
210
+ }();
211
+ }
212
+ std::shared_ptr<Promise<void>> JHybridLiteRTLMSpec::sendMessageAsync(const std::string& message, const std::function<void(const std::string& /* token */, bool /* done */)>& onToken) {
213
+ static const auto method = _javaPart->javaClassStatic()->getMethod<jni::local_ref<JPromise::javaobject>(jni::alias_ref<jni::JString> /* message */, jni::alias_ref<JFunc_void_std__string_bool::javaobject> /* onToken */)>("sendMessageAsync_cxx");
214
+ auto __result = method(_javaPart, jni::make_jstring(message), JFunc_void_std__string_bool_cxx::fromCpp(onToken));
215
+ return [&]() {
216
+ auto __promise = Promise<void>::create();
217
+ __result->cthis()->addOnResolvedListener([=](const jni::alias_ref<jni::JObject>& /* unit */) {
218
+ __promise->resolve();
219
+ });
220
+ __result->cthis()->addOnRejectedListener([=](const jni::alias_ref<jni::JThrowable>& __throwable) {
221
+ jni::JniException __jniError(__throwable);
222
+ __promise->reject(std::make_exception_ptr(__jniError));
223
+ });
224
+ return __promise;
225
+ }();
176
226
  }
177
227
  std::vector<Message> JHybridLiteRTLMSpec::getHistory() {
178
228
  static const auto method = _javaPart->javaClassStatic()->getMethod<jni::local_ref<jni::JArrayClass<JMessage>>()>("getHistory");
179
229
  auto __result = method(_javaPart);
180
- return [&]() {
181
- size_t __size = __result->size();
230
+ return [&](auto&& __input) {
231
+ size_t __size = __input->size();
182
232
  std::vector<Message> __vector;
183
233
  __vector.reserve(__size);
184
234
  for (size_t __i = 0; __i < __size; __i++) {
185
- auto __element = __result->getElement(__i);
235
+ auto __element = __input->getElement(__i);
186
236
  __vector.push_back(__element->toCpp());
187
237
  }
188
238
  return __vector;
189
- }();
239
+ }(__result);
190
240
  }
191
241
  void JHybridLiteRTLMSpec::resetConversation() {
192
242
  static const auto method = _javaPart->javaClassStatic()->getMethod<void()>("resetConversation");
@@ -202,6 +252,11 @@ namespace margelo::nitro::litertlm {
202
252
  auto __result = method(_javaPart);
203
253
  return __result->toCpp();
204
254
  }
255
+ double JHybridLiteRTLMSpec::countTokens(const std::string& text) {
256
+ static const auto method = _javaPart->javaClassStatic()->getMethod<double(jni::alias_ref<jni::JString> /* text */)>("countTokens");
257
+ auto __result = method(_javaPart, jni::make_jstring(text));
258
+ return __result;
259
+ }
205
260
  MemoryUsage JHybridLiteRTLMSpec::getMemoryUsage() {
206
261
  static const auto method = _javaPart->javaClassStatic()->getMethod<jni::local_ref<JMemoryUsage>()>("getMemoryUsage");
207
262
  auto __result = method(_javaPart);
@@ -60,11 +60,13 @@ namespace margelo::nitro::litertlm {
60
60
  std::shared_ptr<Promise<std::string>> downloadModel(const std::string& url, const std::string& fileName, const std::optional<std::function<void(double /* progress */)>>& onProgress) override;
61
61
  std::shared_ptr<Promise<void>> deleteModel(const std::string& fileName) override;
62
62
  std::shared_ptr<Promise<std::string>> sendMessageWithAudio(const std::string& message, const std::string& audioPath) override;
63
- void sendMessageAsync(const std::string& message, const std::function<void(const std::string& /* token */, bool /* done */)>& onToken) override;
63
+ std::shared_ptr<Promise<std::string>> sendMultimodalMessage(const std::vector<MultimodalPart>& parts) override;
64
+ std::shared_ptr<Promise<void>> sendMessageAsync(const std::string& message, const std::function<void(const std::string& /* token */, bool /* done */)>& onToken) override;
64
65
  std::vector<Message> getHistory() override;
65
66
  void resetConversation() override;
66
67
  bool isReady() override;
67
68
  GenerationStats getStats() override;
69
+ double countTokens(const std::string& text) override;
68
70
  MemoryUsage getMemoryUsage() override;
69
71
  void close() override;
70
72
 
@@ -12,8 +12,11 @@
12
12
 
13
13
  #include "Backend.hpp"
14
14
  #include "JBackend.hpp"
15
+ #include "JToolDefinition.hpp"
16
+ #include "ToolDefinition.hpp"
15
17
  #include <optional>
16
18
  #include <string>
19
+ #include <vector>
17
20
 
18
21
  namespace margelo::nitro::litertlm {
19
22
 
@@ -46,13 +49,34 @@ namespace margelo::nitro::litertlm {
46
49
  jni::local_ref<jni::JDouble> topK = this->getFieldValue(fieldTopK);
47
50
  static const auto fieldTopP = clazz->getField<jni::JDouble>("topP");
48
51
  jni::local_ref<jni::JDouble> topP = this->getFieldValue(fieldTopP);
52
+ static const auto fieldValidate = clazz->getField<jni::JBoolean>("validate");
53
+ jni::local_ref<jni::JBoolean> validate = this->getFieldValue(fieldValidate);
54
+ static const auto fieldMultimodal = clazz->getField<jni::JBoolean>("multimodal");
55
+ jni::local_ref<jni::JBoolean> multimodal = this->getFieldValue(fieldMultimodal);
56
+ static const auto fieldTools = clazz->getField<jni::JArrayClass<JToolDefinition>>("tools");
57
+ jni::local_ref<jni::JArrayClass<JToolDefinition>> tools = this->getFieldValue(fieldTools);
58
+ static const auto fieldEnableSpeculativeDecoding = clazz->getField<jni::JBoolean>("enableSpeculativeDecoding");
59
+ jni::local_ref<jni::JBoolean> enableSpeculativeDecoding = this->getFieldValue(fieldEnableSpeculativeDecoding);
49
60
  return LLMConfig(
50
61
  systemPrompt != nullptr ? std::make_optional(systemPrompt->toStdString()) : std::nullopt,
51
62
  backend != nullptr ? std::make_optional(backend->toCpp()) : std::nullopt,
52
63
  maxTokens != nullptr ? std::make_optional(maxTokens->value()) : std::nullopt,
53
64
  temperature != nullptr ? std::make_optional(temperature->value()) : std::nullopt,
54
65
  topK != nullptr ? std::make_optional(topK->value()) : std::nullopt,
55
- topP != nullptr ? std::make_optional(topP->value()) : std::nullopt
66
+ topP != nullptr ? std::make_optional(topP->value()) : std::nullopt,
67
+ validate != nullptr ? std::make_optional(static_cast<bool>(validate->value())) : std::nullopt,
68
+ multimodal != nullptr ? std::make_optional(static_cast<bool>(multimodal->value())) : std::nullopt,
69
+ tools != nullptr ? std::make_optional([&](auto&& __input) {
70
+ size_t __size = __input->size();
71
+ std::vector<ToolDefinition> __vector;
72
+ __vector.reserve(__size);
73
+ for (size_t __i = 0; __i < __size; __i++) {
74
+ auto __element = __input->getElement(__i);
75
+ __vector.push_back(__element->toCpp());
76
+ }
77
+ return __vector;
78
+ }(tools)) : std::nullopt,
79
+ enableSpeculativeDecoding != nullptr ? std::make_optional(static_cast<bool>(enableSpeculativeDecoding->value())) : std::nullopt
56
80
  );
57
81
  }
58
82
 
@@ -62,7 +86,7 @@ namespace margelo::nitro::litertlm {
62
86
  */
63
87
  [[maybe_unused]]
64
88
  static jni::local_ref<JLLMConfig::javaobject> fromCpp(const LLMConfig& value) {
65
- using JSignature = JLLMConfig(jni::alias_ref<jni::JString>, jni::alias_ref<JBackend>, jni::alias_ref<jni::JDouble>, jni::alias_ref<jni::JDouble>, jni::alias_ref<jni::JDouble>, jni::alias_ref<jni::JDouble>);
89
+ using JSignature = JLLMConfig(jni::alias_ref<jni::JString>, jni::alias_ref<JBackend>, jni::alias_ref<jni::JDouble>, jni::alias_ref<jni::JDouble>, jni::alias_ref<jni::JDouble>, jni::alias_ref<jni::JDouble>, jni::alias_ref<jni::JBoolean>, jni::alias_ref<jni::JBoolean>, jni::alias_ref<jni::JArrayClass<JToolDefinition>>, jni::alias_ref<jni::JBoolean>);
66
90
  static const auto clazz = javaClassStatic();
67
91
  static const auto create = clazz->getStaticMethod<JSignature>("fromCpp");
68
92
  return create(
@@ -72,7 +96,20 @@ namespace margelo::nitro::litertlm {
72
96
  value.maxTokens.has_value() ? jni::JDouble::valueOf(value.maxTokens.value()) : nullptr,
73
97
  value.temperature.has_value() ? jni::JDouble::valueOf(value.temperature.value()) : nullptr,
74
98
  value.topK.has_value() ? jni::JDouble::valueOf(value.topK.value()) : nullptr,
75
- value.topP.has_value() ? jni::JDouble::valueOf(value.topP.value()) : nullptr
99
+ value.topP.has_value() ? jni::JDouble::valueOf(value.topP.value()) : nullptr,
100
+ value.validate.has_value() ? jni::JBoolean::valueOf(value.validate.value()) : nullptr,
101
+ value.multimodal.has_value() ? jni::JBoolean::valueOf(value.multimodal.value()) : nullptr,
102
+ value.tools.has_value() ? [&](auto&& __input) {
103
+ size_t __size = __input.size();
104
+ jni::local_ref<jni::JArrayClass<JToolDefinition>> __array = jni::JArrayClass<JToolDefinition>::newArray(__size);
105
+ for (size_t __i = 0; __i < __size; __i++) {
106
+ const auto& __element = __input[__i];
107
+ auto __elementJni = JToolDefinition::fromCpp(__element);
108
+ __array->setElement(__i, *__elementJni);
109
+ }
110
+ return __array;
111
+ }(value.tools.value()) : nullptr,
112
+ value.enableSpeculativeDecoding.has_value() ? jni::JBoolean::valueOf(value.enableSpeculativeDecoding.value()) : nullptr
76
113
  );
77
114
  }
78
115
  };
@@ -0,0 +1,74 @@
1
+ ///
2
+ /// JMultimodalPart.hpp
3
+ /// This file was generated by nitrogen. DO NOT MODIFY THIS FILE.
4
+ /// https://github.com/mrousavy/nitro
5
+ /// Copyright © Marc Rousavy @ Margelo
6
+ ///
7
+
8
+ #pragma once
9
+
10
+ #include <fbjni/fbjni.h>
11
+ #include "MultimodalPart.hpp"
12
+
13
+ #include "JPartType.hpp"
14
+ #include "PartType.hpp"
15
+ #include <NitroModules/ArrayBuffer.hpp>
16
+ #include <NitroModules/JArrayBuffer.hpp>
17
+ #include <optional>
18
+ #include <string>
19
+
20
+ namespace margelo::nitro::litertlm {
21
+
22
+ using namespace facebook;
23
+
24
+ /**
25
+ * The C++ JNI bridge between the C++ struct "MultimodalPart" and the the Kotlin data class "MultimodalPart".
26
+ */
27
+ struct JMultimodalPart final: public jni::JavaClass<JMultimodalPart> {
28
+ public:
29
+ static constexpr auto kJavaDescriptor = "Lcom/margelo/nitro/dev/litert/litertlm/MultimodalPart;";
30
+
31
+ public:
32
+ /**
33
+ * Convert this Java/Kotlin-based struct to the C++ struct MultimodalPart by copying all values to C++.
34
+ */
35
+ [[maybe_unused]]
36
+ [[nodiscard]]
37
+ MultimodalPart toCpp() const {
38
+ static const auto clazz = javaClassStatic();
39
+ static const auto fieldType = clazz->getField<JPartType>("type");
40
+ jni::local_ref<JPartType> type = this->getFieldValue(fieldType);
41
+ static const auto fieldText = clazz->getField<jni::JString>("text");
42
+ jni::local_ref<jni::JString> text = this->getFieldValue(fieldText);
43
+ static const auto fieldImageBuffer = clazz->getField<JArrayBuffer::javaobject>("imageBuffer");
44
+ jni::local_ref<JArrayBuffer::javaobject> imageBuffer = this->getFieldValue(fieldImageBuffer);
45
+ static const auto fieldAudioBuffer = clazz->getField<JArrayBuffer::javaobject>("audioBuffer");
46
+ jni::local_ref<JArrayBuffer::javaobject> audioBuffer = this->getFieldValue(fieldAudioBuffer);
47
+ return MultimodalPart(
48
+ type->toCpp(),
49
+ text != nullptr ? std::make_optional(text->toStdString()) : std::nullopt,
50
+ imageBuffer != nullptr ? std::make_optional(imageBuffer->cthis()->getArrayBuffer()) : std::nullopt,
51
+ audioBuffer != nullptr ? std::make_optional(audioBuffer->cthis()->getArrayBuffer()) : std::nullopt
52
+ );
53
+ }
54
+
55
+ public:
56
+ /**
57
+ * Create a Java/Kotlin-based struct by copying all values from the given C++ struct to Java.
58
+ */
59
+ [[maybe_unused]]
60
+ static jni::local_ref<JMultimodalPart::javaobject> fromCpp(const MultimodalPart& value) {
61
+ using JSignature = JMultimodalPart(jni::alias_ref<JPartType>, jni::alias_ref<jni::JString>, jni::alias_ref<JArrayBuffer::javaobject>, jni::alias_ref<JArrayBuffer::javaobject>);
62
+ static const auto clazz = javaClassStatic();
63
+ static const auto create = clazz->getStaticMethod<JSignature>("fromCpp");
64
+ return create(
65
+ clazz,
66
+ JPartType::fromCpp(value.type),
67
+ value.text.has_value() ? jni::make_jstring(value.text.value()) : nullptr,
68
+ value.imageBuffer.has_value() ? JArrayBuffer::wrap(value.imageBuffer.value()) : nullptr,
69
+ value.audioBuffer.has_value() ? JArrayBuffer::wrap(value.audioBuffer.value()) : nullptr
70
+ );
71
+ }
72
+ };
73
+
74
+ } // namespace margelo::nitro::litertlm
@@ -0,0 +1,61 @@
1
+ ///
2
+ /// JPartType.hpp
3
+ /// This file was generated by nitrogen. DO NOT MODIFY THIS FILE.
4
+ /// https://github.com/mrousavy/nitro
5
+ /// Copyright © Marc Rousavy @ Margelo
6
+ ///
7
+
8
+ #pragma once
9
+
10
+ #include <fbjni/fbjni.h>
11
+ #include "PartType.hpp"
12
+
13
+ namespace margelo::nitro::litertlm {
14
+
15
+ using namespace facebook;
16
+
17
+ /**
18
+ * The C++ JNI bridge between the C++ enum "PartType" and the the Kotlin enum "PartType".
19
+ */
20
+ struct JPartType final: public jni::JavaClass<JPartType> {
21
+ public:
22
+ static constexpr auto kJavaDescriptor = "Lcom/margelo/nitro/dev/litert/litertlm/PartType;";
23
+
24
+ public:
25
+ /**
26
+ * Convert this Java/Kotlin-based enum to the C++ enum PartType.
27
+ */
28
+ [[maybe_unused]]
29
+ [[nodiscard]]
30
+ PartType toCpp() const {
31
+ static const auto clazz = javaClassStatic();
32
+ static const auto fieldOrdinal = clazz->getField<int>("value");
33
+ int ordinal = this->getFieldValue(fieldOrdinal);
34
+ return static_cast<PartType>(ordinal);
35
+ }
36
+
37
+ public:
38
+ /**
39
+ * Create a Java/Kotlin-based enum with the given C++ enum's value.
40
+ */
41
+ [[maybe_unused]]
42
+ static jni::alias_ref<JPartType> fromCpp(PartType value) {
43
+ static const auto clazz = javaClassStatic();
44
+ switch (value) {
45
+ case PartType::TEXT:
46
+ static const auto fieldTEXT = clazz->getStaticField<JPartType>("TEXT");
47
+ return clazz->getStaticFieldValue(fieldTEXT);
48
+ case PartType::IMAGE:
49
+ static const auto fieldIMAGE = clazz->getStaticField<JPartType>("IMAGE");
50
+ return clazz->getStaticFieldValue(fieldIMAGE);
51
+ case PartType::AUDIO:
52
+ static const auto fieldAUDIO = clazz->getStaticField<JPartType>("AUDIO");
53
+ return clazz->getStaticFieldValue(fieldAUDIO);
54
+ default:
55
+ std::string stringValue = std::to_string(static_cast<int>(value));
56
+ throw std::invalid_argument("Invalid enum value (" + stringValue + "!");
57
+ }
58
+ }
59
+ };
60
+
61
+ } // namespace margelo::nitro::litertlm
@@ -0,0 +1,65 @@
1
+ ///
2
+ /// JToolDefinition.hpp
3
+ /// This file was generated by nitrogen. DO NOT MODIFY THIS FILE.
4
+ /// https://github.com/mrousavy/nitro
5
+ /// Copyright © Marc Rousavy @ Margelo
6
+ ///
7
+
8
+ #pragma once
9
+
10
+ #include <fbjni/fbjni.h>
11
+ #include "ToolDefinition.hpp"
12
+
13
+ #include <string>
14
+
15
+ namespace margelo::nitro::litertlm {
16
+
17
+ using namespace facebook;
18
+
19
+ /**
20
+ * The C++ JNI bridge between the C++ struct "ToolDefinition" and the the Kotlin data class "ToolDefinition".
21
+ */
22
+ struct JToolDefinition final: public jni::JavaClass<JToolDefinition> {
23
+ public:
24
+ static constexpr auto kJavaDescriptor = "Lcom/margelo/nitro/dev/litert/litertlm/ToolDefinition;";
25
+
26
+ public:
27
+ /**
28
+ * Convert this Java/Kotlin-based struct to the C++ struct ToolDefinition by copying all values to C++.
29
+ */
30
+ [[maybe_unused]]
31
+ [[nodiscard]]
32
+ ToolDefinition toCpp() const {
33
+ static const auto clazz = javaClassStatic();
34
+ static const auto fieldName = clazz->getField<jni::JString>("name");
35
+ jni::local_ref<jni::JString> name = this->getFieldValue(fieldName);
36
+ static const auto fieldDescription = clazz->getField<jni::JString>("description");
37
+ jni::local_ref<jni::JString> description = this->getFieldValue(fieldDescription);
38
+ static const auto fieldParametersJson = clazz->getField<jni::JString>("parametersJson");
39
+ jni::local_ref<jni::JString> parametersJson = this->getFieldValue(fieldParametersJson);
40
+ return ToolDefinition(
41
+ name->toStdString(),
42
+ description->toStdString(),
43
+ parametersJson->toStdString()
44
+ );
45
+ }
46
+
47
+ public:
48
+ /**
49
+ * Create a Java/Kotlin-based struct by copying all values from the given C++ struct to Java.
50
+ */
51
+ [[maybe_unused]]
52
+ static jni::local_ref<JToolDefinition::javaobject> fromCpp(const ToolDefinition& value) {
53
+ using JSignature = JToolDefinition(jni::alias_ref<jni::JString>, jni::alias_ref<jni::JString>, jni::alias_ref<jni::JString>);
54
+ static const auto clazz = javaClassStatic();
55
+ static const auto create = clazz->getStaticMethod<JSignature>("fromCpp");
56
+ return create(
57
+ clazz,
58
+ jni::make_jstring(value.name),
59
+ jni::make_jstring(value.description),
60
+ jni::make_jstring(value.parametersJson)
61
+ );
62
+ }
63
+ };
64
+
65
+ } // namespace margelo::nitro::litertlm
@@ -9,6 +9,7 @@ package com.margelo.nitro.dev.litert.litertlm
9
9
 
10
10
  import androidx.annotation.Keep
11
11
  import com.facebook.proguard.annotations.DoNotStrip
12
+ import java.util.Objects
12
13
 
13
14
 
14
15
  /**
@@ -38,6 +39,28 @@ data class GenerationStats(
38
39
  ) {
39
40
  /* primary constructor */
40
41
 
42
+ override fun equals(other: Any?): Boolean {
43
+ if (this === other) return true
44
+ if (other !is GenerationStats) return false
45
+ return Objects.deepEquals(this.promptTokens, other.promptTokens)
46
+ && Objects.deepEquals(this.completionTokens, other.completionTokens)
47
+ && Objects.deepEquals(this.totalTokens, other.totalTokens)
48
+ && Objects.deepEquals(this.timeToFirstToken, other.timeToFirstToken)
49
+ && Objects.deepEquals(this.totalTime, other.totalTime)
50
+ && Objects.deepEquals(this.tokensPerSecond, other.tokensPerSecond)
51
+ }
52
+
53
+ override fun hashCode(): Int {
54
+ return arrayOf<Any?>(
55
+ promptTokens,
56
+ completionTokens,
57
+ totalTokens,
58
+ timeToFirstToken,
59
+ totalTime,
60
+ tokensPerSecond
61
+ ).contentDeepHashCode()
62
+ }
63
+
41
64
  companion object {
42
65
  /**
43
66
  * Constructor called from C++
@@ -58,11 +58,15 @@ abstract class HybridLiteRTLMSpec: HybridObject() {
58
58
  @Keep
59
59
  abstract fun sendMessageWithAudio(message: String, audioPath: String): Promise<String>
60
60
 
61
- abstract fun sendMessageAsync(message: String, onToken: (token: String, done: Boolean) -> Unit): Unit
61
+ @DoNotStrip
62
+ @Keep
63
+ abstract fun sendMultimodalMessage(parts: Array<MultimodalPart>): Promise<String>
64
+
65
+ abstract fun sendMessageAsync(message: String, onToken: (token: String, done: Boolean) -> Unit): Promise<Unit>
62
66
 
63
67
  @DoNotStrip
64
68
  @Keep
65
- private fun sendMessageAsync_cxx(message: String, onToken: Func_void_std__string_bool): Unit {
69
+ private fun sendMessageAsync_cxx(message: String, onToken: Func_void_std__string_bool): Promise<Unit> {
66
70
  val __result = sendMessageAsync(message, onToken)
67
71
  return __result
68
72
  }
@@ -83,6 +87,10 @@ abstract class HybridLiteRTLMSpec: HybridObject() {
83
87
  @Keep
84
88
  abstract fun getStats(): GenerationStats
85
89
 
90
+ @DoNotStrip
91
+ @Keep
92
+ abstract fun countTokens(text: String): Double
93
+
86
94
  @DoNotStrip
87
95
  @Keep
88
96
  abstract fun getMemoryUsage(): MemoryUsage
@@ -9,6 +9,7 @@ package com.margelo.nitro.dev.litert.litertlm
9
9
 
10
10
  import androidx.annotation.Keep
11
11
  import com.facebook.proguard.annotations.DoNotStrip
12
+ import java.util.Objects
12
13
 
13
14
 
14
15
  /**
@@ -34,10 +35,52 @@ data class LLMConfig(
34
35
  val topK: Double?,
35
36
  @DoNotStrip
36
37
  @Keep
37
- val topP: Double?
38
+ val topP: Double?,
39
+ @DoNotStrip
40
+ @Keep
41
+ val validate: Boolean?,
42
+ @DoNotStrip
43
+ @Keep
44
+ val multimodal: Boolean?,
45
+ @DoNotStrip
46
+ @Keep
47
+ val tools: Array<ToolDefinition>?,
48
+ @DoNotStrip
49
+ @Keep
50
+ val enableSpeculativeDecoding: Boolean?
38
51
  ) {
39
52
  /* primary constructor */
40
53
 
54
+ override fun equals(other: Any?): Boolean {
55
+ if (this === other) return true
56
+ if (other !is LLMConfig) return false
57
+ return Objects.deepEquals(this.systemPrompt, other.systemPrompt)
58
+ && Objects.deepEquals(this.backend, other.backend)
59
+ && Objects.deepEquals(this.maxTokens, other.maxTokens)
60
+ && Objects.deepEquals(this.temperature, other.temperature)
61
+ && Objects.deepEquals(this.topK, other.topK)
62
+ && Objects.deepEquals(this.topP, other.topP)
63
+ && Objects.deepEquals(this.validate, other.validate)
64
+ && Objects.deepEquals(this.multimodal, other.multimodal)
65
+ && Objects.deepEquals(this.tools, other.tools)
66
+ && Objects.deepEquals(this.enableSpeculativeDecoding, other.enableSpeculativeDecoding)
67
+ }
68
+
69
+ override fun hashCode(): Int {
70
+ return arrayOf<Any?>(
71
+ systemPrompt,
72
+ backend,
73
+ maxTokens,
74
+ temperature,
75
+ topK,
76
+ topP,
77
+ validate,
78
+ multimodal,
79
+ tools,
80
+ enableSpeculativeDecoding
81
+ ).contentDeepHashCode()
82
+ }
83
+
41
84
  companion object {
42
85
  /**
43
86
  * Constructor called from C++
@@ -46,8 +89,8 @@ data class LLMConfig(
46
89
  @Keep
47
90
  @Suppress("unused")
48
91
  @JvmStatic
49
- private fun fromCpp(systemPrompt: String?, backend: Backend?, maxTokens: Double?, temperature: Double?, topK: Double?, topP: Double?): LLMConfig {
50
- return LLMConfig(systemPrompt, backend, maxTokens, temperature, topK, topP)
92
+ private fun fromCpp(systemPrompt: String?, backend: Backend?, maxTokens: Double?, temperature: Double?, topK: Double?, topP: Double?, validate: Boolean?, multimodal: Boolean?, tools: Array<ToolDefinition>?, enableSpeculativeDecoding: Boolean?): LLMConfig {
93
+ return LLMConfig(systemPrompt, backend, maxTokens, temperature, topK, topP, validate, multimodal, tools, enableSpeculativeDecoding)
51
94
  }
52
95
  }
53
96
  }