react-native-litert-lm 0.3.7 → 0.4.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. package/README.md +153 -135
  2. package/android/build.gradle +12 -0
  3. package/android/src/main/AndroidManifest.xml +8 -0
  4. package/android/src/main/java/com/margelo/nitro/dev/litert/litertlm/HybridLiteRTLM.kt +276 -62
  5. package/android/src/main/java/dev/litert/litertlm/LiteRTLMPackage.kt +19 -2
  6. package/android/src/test/java/com/margelo/nitro/core/Promise.kt +46 -0
  7. package/android/src/test/java/com/margelo/nitro/dev/litert/litertlm/HybridLiteRTLMTest.kt +105 -0
  8. package/ios/HybridLiteRTLM.swift +1344 -0
  9. package/ios/Tests/HybridLiteRTLMTests.swift +113 -0
  10. package/lib/__mocks__/react-native-nitro-modules.d.ts +65 -0
  11. package/lib/__mocks__/react-native-nitro-modules.js +60 -0
  12. package/lib/__tests__/hooks.test.d.ts +1 -0
  13. package/lib/__tests__/hooks.test.js +124 -0
  14. package/lib/__tests__/memoryTracker.test.d.ts +1 -0
  15. package/lib/__tests__/memoryTracker.test.js +74 -0
  16. package/lib/__tests__/modelFactory.test.d.ts +1 -0
  17. package/lib/__tests__/modelFactory.test.js +68 -0
  18. package/lib/hooks.js +27 -3
  19. package/lib/index.d.ts +6 -2
  20. package/lib/index.js +8 -8
  21. package/lib/modelFactory.js +82 -63
  22. package/lib/specs/LiteRTLM.nitro.d.ts +87 -2
  23. package/nitrogen/generated/android/LiteRTLMOnLoad.cpp +2 -2
  24. package/nitrogen/generated/android/c++/JHybridLiteRTLMSpec.cpp +94 -9
  25. package/nitrogen/generated/android/c++/JHybridLiteRTLMSpec.hpp +5 -1
  26. package/nitrogen/generated/android/c++/JLLMConfig.hpp +40 -3
  27. package/nitrogen/generated/android/c++/JMultimodalPart.hpp +74 -0
  28. package/nitrogen/generated/android/c++/JPartType.hpp +61 -0
  29. package/nitrogen/generated/android/c++/JToolDefinition.hpp +65 -0
  30. package/nitrogen/generated/android/kotlin/com/margelo/nitro/dev/litert/litertlm/GenerationStats.kt +23 -0
  31. package/nitrogen/generated/android/kotlin/com/margelo/nitro/dev/litert/litertlm/HybridLiteRTLMSpec.kt +28 -2
  32. package/nitrogen/generated/android/kotlin/com/margelo/nitro/dev/litert/litertlm/LLMConfig.kt +46 -3
  33. package/nitrogen/generated/android/kotlin/com/margelo/nitro/dev/litert/litertlm/MemoryUsage.kt +19 -0
  34. package/nitrogen/generated/android/kotlin/com/margelo/nitro/dev/litert/litertlm/Message.kt +15 -0
  35. package/nitrogen/generated/android/kotlin/com/margelo/nitro/dev/litert/litertlm/MultimodalPart.kt +66 -0
  36. package/nitrogen/generated/android/kotlin/com/margelo/nitro/dev/litert/litertlm/PartType.kt +24 -0
  37. package/nitrogen/generated/android/kotlin/com/margelo/nitro/dev/litert/litertlm/ToolDefinition.kt +61 -0
  38. package/nitrogen/generated/ios/LiteRTLM-Swift-Cxx-Bridge.cpp +57 -1
  39. package/nitrogen/generated/ios/LiteRTLM-Swift-Cxx-Bridge.hpp +414 -3
  40. package/nitrogen/generated/ios/LiteRTLM-Swift-Cxx-Umbrella.hpp +41 -3
  41. package/nitrogen/generated/ios/LiteRTLMAutolinking.mm +4 -6
  42. package/nitrogen/generated/ios/LiteRTLMAutolinking.swift +10 -0
  43. package/nitrogen/generated/ios/c++/HybridLiteRTLMSpecSwift.cpp +11 -0
  44. package/nitrogen/generated/ios/c++/HybridLiteRTLMSpecSwift.hpp +240 -0
  45. package/nitrogen/generated/ios/swift/Backend.swift +44 -0
  46. package/nitrogen/generated/ios/swift/Func_void.swift +46 -0
  47. package/nitrogen/generated/ios/swift/Func_void_double.swift +46 -0
  48. package/nitrogen/generated/ios/swift/Func_void_std__exception_ptr.swift +46 -0
  49. package/nitrogen/generated/ios/swift/Func_void_std__string.swift +46 -0
  50. package/nitrogen/generated/ios/swift/Func_void_std__string_bool.swift +46 -0
  51. package/nitrogen/generated/ios/swift/GenerationStats.swift +54 -0
  52. package/nitrogen/generated/ios/swift/HybridLiteRTLMSpec.swift +71 -0
  53. package/nitrogen/generated/ios/swift/HybridLiteRTLMSpec_cxx.swift +431 -0
  54. package/nitrogen/generated/ios/swift/LLMConfig.swift +203 -0
  55. package/nitrogen/generated/ios/swift/MemoryUsage.swift +44 -0
  56. package/nitrogen/generated/ios/swift/Message.swift +34 -0
  57. package/nitrogen/generated/ios/swift/MultimodalPart.swift +83 -0
  58. package/nitrogen/generated/ios/swift/PartType.swift +44 -0
  59. package/nitrogen/generated/ios/swift/Role.swift +44 -0
  60. package/nitrogen/generated/ios/swift/ToolDefinition.swift +39 -0
  61. package/nitrogen/generated/shared/c++/HybridLiteRTLMSpec.cpp +4 -0
  62. package/nitrogen/generated/shared/c++/HybridLiteRTLMSpec.hpp +9 -2
  63. package/nitrogen/generated/shared/c++/LLMConfig.hpp +22 -2
  64. package/nitrogen/generated/shared/c++/MultimodalPart.hpp +99 -0
  65. package/nitrogen/generated/shared/c++/PartType.hpp +80 -0
  66. package/nitrogen/generated/shared/c++/ToolDefinition.hpp +91 -0
  67. package/package.json +22 -11
  68. package/react-native-litert-lm.podspec +17 -19
  69. package/scripts/download-ios-frameworks.sh +17 -50
  70. package/scripts/framework-source.js +46 -0
  71. package/scripts/postinstall.js +40 -18
  72. package/src/__mocks__/react-native-nitro-modules.ts +58 -0
  73. package/src/__tests__/hooks.test.ts +153 -0
  74. package/src/__tests__/memoryTracker.test.ts +87 -0
  75. package/src/__tests__/modelFactory.test.ts +96 -0
  76. package/src/hooks.ts +29 -7
  77. package/src/index.ts +7 -10
  78. package/src/modelFactory.ts +104 -80
  79. package/src/specs/LiteRTLM.nitro.ts +106 -2
  80. package/cpp/HybridLiteRTLM.cpp +0 -939
  81. package/cpp/HybridLiteRTLM.hpp +0 -169
  82. package/cpp/IOSDownloadHelper.h +0 -24
  83. package/ios/IOSDownloadHelper.mm +0 -129
  84. package/scripts/build-ios-engine.sh +0 -302
  85. package/scripts/stubs/cxx_bridge_stubs.cc +0 -224
  86. package/scripts/stubs/gemma_model_constraint_provider.cc +0 -46
  87. package/scripts/stubs/llguidance_stubs.c +0 -101
  88. package/src/templates.ts +0 -105
@@ -38,72 +38,91 @@ function createLLM(options) {
38
38
  // Ignore errors during memory tracking - it's non-critical
39
39
  }
40
40
  };
41
- return {
42
- ...native,
43
- memoryTracker: tracker,
44
- loadModel: async (pathOrUrl, config, onDownloadProgress) => {
45
- let modelPath = pathOrUrl;
46
- // Check if it's a URL enforce HTTPS for model downloads
47
- if (pathOrUrl.startsWith("http://") || pathOrUrl.startsWith("https://")) {
48
- if (pathOrUrl.startsWith("http://")) {
49
- throw new Error("Insecure HTTP URLs are not allowed for model downloads. " +
50
- "Use HTTPS instead: " +
51
- pathOrUrl.replace("http://", "https://"));
52
- }
53
- // Extract filename from URL
54
- const fileName = pathOrUrl.split("/").pop();
55
- if (!fileName) {
56
- throw new Error(`Invalid model URL: ${pathOrUrl}`);
57
- }
58
- console.log(`Checking model at ${pathOrUrl}...`);
59
- modelPath = await native.downloadModel(pathOrUrl, fileName, (progress) => {
60
- onDownloadProgress?.(progress);
61
- });
62
- console.log(`Model downloaded to: ${modelPath}`);
41
+ const augmentedLoadModel = async (pathOrUrl, config, onDownloadProgress) => {
42
+ let modelPath = pathOrUrl;
43
+ // Check if it's a URL — enforce HTTPS for model downloads
44
+ if (pathOrUrl.startsWith("http://") || pathOrUrl.startsWith("https://")) {
45
+ if (pathOrUrl.startsWith("http://")) {
46
+ throw new Error("Insecure HTTP URLs are not allowed for model downloads. " +
47
+ "Use HTTPS instead: " +
48
+ pathOrUrl.replace("http://", "https://"));
63
49
  }
64
- const result = await native.loadModel(modelPath, config);
65
- // Record initial memory snapshot after model load
66
- if (tracker) {
67
- tracker.reset();
68
- recordMemorySnapshot();
50
+ // Extract filename from URL, stripping query parameters
51
+ const urlWithoutQuery = pathOrUrl.split("?")[0];
52
+ const fileName = urlWithoutQuery.split("/").pop();
53
+ if (!fileName) {
54
+ throw new Error(`Invalid model URL: ${pathOrUrl}`);
69
55
  }
70
- return result;
71
- },
72
- sendMessage: async (...args) => {
73
- const result = await native.sendMessage(...args);
74
- recordMemorySnapshot();
75
- return result;
76
- },
77
- sendMessageAsync: (...args) => {
78
- const [message, onToken] = args;
79
- native.sendMessageAsync(message, (token, done) => {
80
- onToken(token, done);
81
- if (done) {
82
- recordMemorySnapshot();
83
- }
56
+ console.log(`Checking model at ${pathOrUrl}...`);
57
+ modelPath = await native.downloadModel(pathOrUrl, fileName, (progress) => {
58
+ onDownloadProgress?.(progress);
84
59
  });
85
- },
86
- sendMessageWithImage: async (...args) => {
87
- const result = await native.sendMessageWithImage(...args);
88
- recordMemorySnapshot();
89
- return result;
90
- },
91
- sendMessageWithAudio: async (...args) => {
92
- const result = await native.sendMessageWithAudio(...args);
93
- recordMemorySnapshot();
94
- return result;
95
- },
96
- getHistory: native.getHistory.bind(native),
97
- resetConversation: () => {
98
- native.resetConversation();
99
- // KV cache is cleared on reset, record the drop
60
+ console.log(`Model downloaded to: ${modelPath}`);
61
+ }
62
+ const result = await native.loadModel(modelPath, config);
63
+ // Record initial memory snapshot after model load
64
+ if (tracker) {
65
+ tracker.reset();
100
66
  recordMemorySnapshot();
101
- },
102
- isReady: native.isReady.bind(native),
103
- getStats: native.getStats.bind(native),
104
- getMemoryUsage: native.getMemoryUsage.bind(native),
105
- close: native.close.bind(native),
106
- downloadModel: native.downloadModel.bind(native),
107
- deleteModel: native.deleteModel.bind(native),
67
+ }
68
+ return result;
108
69
  };
70
+ const SNAPSHOT_TRIGGERS = new Set([
71
+ "sendMessage",
72
+ "sendMessageWithImage",
73
+ "sendMessageWithAudio",
74
+ "resetConversation",
75
+ ]);
76
+ return new Proxy(native, {
77
+ get(target, prop, receiver) {
78
+ if (prop === "memoryTracker") {
79
+ return tracker;
80
+ }
81
+ if (prop === "loadModel") {
82
+ return augmentedLoadModel;
83
+ }
84
+ const original = Reflect.get(target, prop, receiver);
85
+ if (typeof original !== "function") {
86
+ return original;
87
+ }
88
+ if (prop === "sendMessageAsync") {
89
+ return (message, onToken) => {
90
+ return original.call(target, message, (token, done) => {
91
+ onToken(token, done);
92
+ if (done) {
93
+ recordMemorySnapshot();
94
+ }
95
+ });
96
+ };
97
+ }
98
+ if (prop === "sendMessageWithImageAsync") {
99
+ return (message, imagePath, onToken) => {
100
+ return original.call(target, message, imagePath, (token, done) => {
101
+ onToken(token, done);
102
+ if (done) {
103
+ recordMemorySnapshot();
104
+ }
105
+ });
106
+ };
107
+ }
108
+ if (prop === "sendMessageWithAudioAsync") {
109
+ return (message, audioPath, onToken) => {
110
+ return original.call(target, message, audioPath, (token, done) => {
111
+ onToken(token, done);
112
+ if (done) {
113
+ recordMemorySnapshot();
114
+ }
115
+ });
116
+ };
117
+ }
118
+ if (SNAPSHOT_TRIGGERS.has(prop)) {
119
+ return async (...args) => {
120
+ const result = await original.apply(target, args);
121
+ recordMemorySnapshot();
122
+ return result;
123
+ };
124
+ }
125
+ return original.bind(target);
126
+ },
127
+ });
109
128
  }
@@ -14,6 +14,34 @@ export type Backend = "cpu" | "gpu" | "npu";
14
14
  * Message roles for conversation.
15
15
  */
16
16
  export type Role = "user" | "model" | "system";
17
+ /**
18
+ * Definition for a function/tool that the model can request to execute.
19
+ */
20
+ export interface ToolDefinition {
21
+ /** Name of the function/tool */
22
+ name: string;
23
+ /** Human-readable description of what the function/tool does */
24
+ description: string;
25
+ /** JSON schema defining parameter names and types (stringified) */
26
+ parametersJson: string;
27
+ }
28
+ /**
29
+ * The part type for a multimodal message content part.
30
+ */
31
+ export type PartType = "text" | "image" | "audio";
32
+ /**
33
+ * A part of a unified multimodal message payload.
34
+ */
35
+ export interface MultimodalPart {
36
+ /** The part type: 'text', 'image', or 'audio' */
37
+ type: PartType;
38
+ /** The plain text content, if type is 'text' */
39
+ text?: string;
40
+ /** Raw image binary data, if type is 'image' (zero-copy ArrayBuffer mapping) */
41
+ imageBuffer?: ArrayBuffer;
42
+ /** Raw audio binary data, if type is 'audio' (zero-copy ArrayBuffer mapping) */
43
+ audioBuffer?: ArrayBuffer;
44
+ }
17
45
  /**
18
46
  * Configuration options for loading an LLM.
19
47
  */
@@ -60,6 +88,37 @@ export interface LLMConfig {
60
88
  * @default 0.95
61
89
  */
62
90
  topP?: number;
91
+ /**
92
+ * Whether to run engine validation after loading the model.
93
+ * When enabled, sends a quick test inference ("Hi") and waits up to 30s
94
+ * for a response to confirm the backend works. This is useful for GPU/NPU
95
+ * backends that may silently fail during inference (they can initialize
96
+ * without error but produce no tokens).
97
+ *
98
+ * Validation is **always a no-op on CPU** — the CPU backend is inherently
99
+ * reliable and never needs validation.
100
+ *
101
+ * Disabled by default because it adds significant latency (5-30s) to model loading.
102
+ * Enable only to catch GPU/NPU silent failure issues during development.
103
+ *
104
+ * @default false
105
+ */
106
+ validate?: boolean;
107
+ /**
108
+ * Whether this is a multimodal model.
109
+ * When enabled, the engine handles image/audio tokens properly.
110
+ * If not specified, the system will fall back to filename sniffing.
111
+ */
112
+ multimodal?: boolean;
113
+ /**
114
+ * List of tools/functions that the model can call.
115
+ */
116
+ tools?: ToolDefinition[];
117
+ /**
118
+ * Whether to enable speculative decoding (multi-token prediction) if supported by the model.
119
+ * @default false
120
+ */
121
+ enableSpeculativeDecoding?: boolean;
63
122
  }
64
123
  /**
65
124
  * A simple message in the conversation.
@@ -123,7 +182,7 @@ export interface MemoryUsage {
123
182
  * ```
124
183
  */
125
184
  export interface LiteRTLM extends HybridObject<{
126
- ios: "c++";
185
+ ios: "swift";
127
186
  android: "kotlin";
128
187
  }> {
129
188
  /**
@@ -145,6 +204,14 @@ export interface LiteRTLM extends HybridObject<{
145
204
  * @returns The model's response text.
146
205
  */
147
206
  sendMessageWithImage(message: string, imagePath: string): Promise<string>;
207
+ /**
208
+ * Send a text message with an image and get a streaming response.
209
+ * Tokens are delivered via callback as they are generated.
210
+ * @param message User message text.
211
+ * @param imagePath Absolute path to an image file.
212
+ * @param onToken Callback invoked for each token (token, isDone).
213
+ */
214
+ sendMessageWithImageAsync(message: string, imagePath: string, onToken: (token: string, done: boolean) => void): Promise<void>;
148
215
  /**
149
216
  * Download a model file from a URL.
150
217
  * @param url URL to download from.
@@ -165,13 +232,27 @@ export interface LiteRTLM extends HybridObject<{
165
232
  * @returns The model's response text.
166
233
  */
167
234
  sendMessageWithAudio(message: string, audioPath: string): Promise<string>;
235
+ /**
236
+ * Send a text message with audio and get a streaming response.
237
+ * Tokens are delivered via callback as they are generated.
238
+ * @param message User message text.
239
+ * @param audioPath Absolute path to an audio file (WAV).
240
+ * @param onToken Callback invoked for each token (token, isDone).
241
+ */
242
+ sendMessageWithAudioAsync(message: string, audioPath: string, onToken: (token: string, done: boolean) => void): Promise<void>;
243
+ /**
244
+ * Send a unified multimodal message containing text and/or zero-copy binary buffers.
245
+ * @param parts The message content parts (text, image, and/or audio).
246
+ * @returns The model's response text.
247
+ */
248
+ sendMultimodalMessage(parts: MultimodalPart[]): Promise<string>;
168
249
  /**
169
250
  * Send a message with streaming response.
170
251
  * Tokens are delivered via callback as they are generated.
171
252
  * @param message User message text.
172
253
  * @param onToken Callback invoked for each token (token, isDone).
173
254
  */
174
- sendMessageAsync(message: string, onToken: (token: string, done: boolean) => void): void;
255
+ sendMessageAsync(message: string, onToken: (token: string, done: boolean) => void): Promise<void>;
175
256
  /**
176
257
  * Get the current conversation history.
177
258
  * @returns Array of messages in the conversation.
@@ -189,6 +270,10 @@ export interface LiteRTLM extends HybridObject<{
189
270
  * Get the last generation statistics.
190
271
  */
191
272
  getStats(): GenerationStats;
273
+ /**
274
+ * Count tokens in a text string. Returns -1 if unavailable.
275
+ */
276
+ countTokens(text: string): number;
192
277
  /**
193
278
  * Get real memory usage from the native runtime.
194
279
  * Uses OS-level APIs to report actual memory consumption.
@@ -16,8 +16,8 @@
16
16
  #include <NitroModules/HybridObjectRegistry.hpp>
17
17
 
18
18
  #include "JHybridLiteRTLMSpec.hpp"
19
- #include "JFunc_void_double.hpp"
20
19
  #include "JFunc_void_std__string_bool.hpp"
20
+ #include "JFunc_void_double.hpp"
21
21
  #include <NitroModules/DefaultConstructableObject.hpp>
22
22
 
23
23
  namespace margelo::nitro::litertlm {
@@ -43,8 +43,8 @@ void registerAllNatives() {
43
43
 
44
44
  // Register native JNI methods
45
45
  margelo::nitro::litertlm::JHybridLiteRTLMSpec::CxxPart::registerNatives();
46
- margelo::nitro::litertlm::JFunc_void_double_cxx::registerNatives();
47
46
  margelo::nitro::litertlm::JFunc_void_std__string_bool_cxx::registerNatives();
47
+ margelo::nitro::litertlm::JFunc_void_double_cxx::registerNatives();
48
48
 
49
49
  // Register Nitro Hybrid Objects
50
50
  HybridObjectRegistry::registerHybridObjectConstructor(
@@ -19,6 +19,12 @@ namespace margelo::nitro::litertlm { struct MemoryUsage; }
19
19
  namespace margelo::nitro::litertlm { struct LLMConfig; }
20
20
  // Forward declaration of `Backend` to properly resolve imports.
21
21
  namespace margelo::nitro::litertlm { enum class Backend; }
22
+ // Forward declaration of `ToolDefinition` to properly resolve imports.
23
+ namespace margelo::nitro::litertlm { struct ToolDefinition; }
24
+ // Forward declaration of `MultimodalPart` to properly resolve imports.
25
+ namespace margelo::nitro::litertlm { struct MultimodalPart; }
26
+ // Forward declaration of `PartType` to properly resolve imports.
27
+ namespace margelo::nitro::litertlm { enum class PartType; }
22
28
 
23
29
  #include <NitroModules/Promise.hpp>
24
30
  #include <NitroModules/JPromise.hpp>
@@ -38,10 +44,18 @@ namespace margelo::nitro::litertlm { enum class Backend; }
38
44
  #include "JLLMConfig.hpp"
39
45
  #include "Backend.hpp"
40
46
  #include "JBackend.hpp"
47
+ #include "ToolDefinition.hpp"
48
+ #include "JToolDefinition.hpp"
41
49
  #include <functional>
42
- #include "JFunc_void_double.hpp"
43
- #include <NitroModules/JNICallable.hpp>
44
50
  #include "JFunc_void_std__string_bool.hpp"
51
+ #include <NitroModules/JNICallable.hpp>
52
+ #include "JFunc_void_double.hpp"
53
+ #include "MultimodalPart.hpp"
54
+ #include "JMultimodalPart.hpp"
55
+ #include "PartType.hpp"
56
+ #include "JPartType.hpp"
57
+ #include <NitroModules/ArrayBuffer.hpp>
58
+ #include <NitroModules/JArrayBuffer.hpp>
45
59
 
46
60
  namespace margelo::nitro::litertlm {
47
61
 
@@ -123,6 +137,21 @@ namespace margelo::nitro::litertlm {
123
137
  return __promise;
124
138
  }();
125
139
  }
140
+ std::shared_ptr<Promise<void>> JHybridLiteRTLMSpec::sendMessageWithImageAsync(const std::string& message, const std::string& imagePath, const std::function<void(const std::string& /* token */, bool /* done */)>& onToken) {
141
+ static const auto method = _javaPart->javaClassStatic()->getMethod<jni::local_ref<JPromise::javaobject>(jni::alias_ref<jni::JString> /* message */, jni::alias_ref<jni::JString> /* imagePath */, jni::alias_ref<JFunc_void_std__string_bool::javaobject> /* onToken */)>("sendMessageWithImageAsync_cxx");
142
+ auto __result = method(_javaPart, jni::make_jstring(message), jni::make_jstring(imagePath), JFunc_void_std__string_bool_cxx::fromCpp(onToken));
143
+ return [&]() {
144
+ auto __promise = Promise<void>::create();
145
+ __result->cthis()->addOnResolvedListener([=](const jni::alias_ref<jni::JObject>& /* unit */) {
146
+ __promise->resolve();
147
+ });
148
+ __result->cthis()->addOnRejectedListener([=](const jni::alias_ref<jni::JThrowable>& __throwable) {
149
+ jni::JniException __jniError(__throwable);
150
+ __promise->reject(std::make_exception_ptr(__jniError));
151
+ });
152
+ return __promise;
153
+ }();
154
+ }
126
155
  std::shared_ptr<Promise<std::string>> JHybridLiteRTLMSpec::downloadModel(const std::string& url, const std::string& fileName, const std::optional<std::function<void(double /* progress */)>>& onProgress) {
127
156
  static const auto method = _javaPart->javaClassStatic()->getMethod<jni::local_ref<JPromise::javaobject>(jni::alias_ref<jni::JString> /* url */, jni::alias_ref<jni::JString> /* fileName */, jni::alias_ref<JFunc_void_double::javaobject> /* onProgress */)>("downloadModel_cxx");
128
157
  auto __result = method(_javaPart, jni::make_jstring(url), jni::make_jstring(fileName), onProgress.has_value() ? JFunc_void_double_cxx::fromCpp(onProgress.value()) : nullptr);
@@ -170,23 +199,74 @@ namespace margelo::nitro::litertlm {
170
199
  return __promise;
171
200
  }();
172
201
  }
173
- void JHybridLiteRTLMSpec::sendMessageAsync(const std::string& message, const std::function<void(const std::string& /* token */, bool /* done */)>& onToken) {
174
- static const auto method = _javaPart->javaClassStatic()->getMethod<void(jni::alias_ref<jni::JString> /* message */, jni::alias_ref<JFunc_void_std__string_bool::javaobject> /* onToken */)>("sendMessageAsync_cxx");
175
- method(_javaPart, jni::make_jstring(message), JFunc_void_std__string_bool_cxx::fromCpp(onToken));
202
+ std::shared_ptr<Promise<void>> JHybridLiteRTLMSpec::sendMessageWithAudioAsync(const std::string& message, const std::string& audioPath, const std::function<void(const std::string& /* token */, bool /* done */)>& onToken) {
203
+ static const auto method = _javaPart->javaClassStatic()->getMethod<jni::local_ref<JPromise::javaobject>(jni::alias_ref<jni::JString> /* message */, jni::alias_ref<jni::JString> /* audioPath */, jni::alias_ref<JFunc_void_std__string_bool::javaobject> /* onToken */)>("sendMessageWithAudioAsync_cxx");
204
+ auto __result = method(_javaPart, jni::make_jstring(message), jni::make_jstring(audioPath), JFunc_void_std__string_bool_cxx::fromCpp(onToken));
205
+ return [&]() {
206
+ auto __promise = Promise<void>::create();
207
+ __result->cthis()->addOnResolvedListener([=](const jni::alias_ref<jni::JObject>& /* unit */) {
208
+ __promise->resolve();
209
+ });
210
+ __result->cthis()->addOnRejectedListener([=](const jni::alias_ref<jni::JThrowable>& __throwable) {
211
+ jni::JniException __jniError(__throwable);
212
+ __promise->reject(std::make_exception_ptr(__jniError));
213
+ });
214
+ return __promise;
215
+ }();
216
+ }
217
+ std::shared_ptr<Promise<std::string>> JHybridLiteRTLMSpec::sendMultimodalMessage(const std::vector<MultimodalPart>& parts) {
218
+ static const auto method = _javaPart->javaClassStatic()->getMethod<jni::local_ref<JPromise::javaobject>(jni::alias_ref<jni::JArrayClass<JMultimodalPart>> /* parts */)>("sendMultimodalMessage");
219
+ auto __result = method(_javaPart, [&](auto&& __input) {
220
+ size_t __size = __input.size();
221
+ jni::local_ref<jni::JArrayClass<JMultimodalPart>> __array = jni::JArrayClass<JMultimodalPart>::newArray(__size);
222
+ for (size_t __i = 0; __i < __size; __i++) {
223
+ const auto& __element = __input[__i];
224
+ auto __elementJni = JMultimodalPart::fromCpp(__element);
225
+ __array->setElement(__i, *__elementJni);
226
+ }
227
+ return __array;
228
+ }(parts));
229
+ return [&]() {
230
+ auto __promise = Promise<std::string>::create();
231
+ __result->cthis()->addOnResolvedListener([=](const jni::alias_ref<jni::JObject>& __boxedResult) {
232
+ auto __result = jni::static_ref_cast<jni::JString>(__boxedResult);
233
+ __promise->resolve(__result->toStdString());
234
+ });
235
+ __result->cthis()->addOnRejectedListener([=](const jni::alias_ref<jni::JThrowable>& __throwable) {
236
+ jni::JniException __jniError(__throwable);
237
+ __promise->reject(std::make_exception_ptr(__jniError));
238
+ });
239
+ return __promise;
240
+ }();
241
+ }
242
+ std::shared_ptr<Promise<void>> JHybridLiteRTLMSpec::sendMessageAsync(const std::string& message, const std::function<void(const std::string& /* token */, bool /* done */)>& onToken) {
243
+ static const auto method = _javaPart->javaClassStatic()->getMethod<jni::local_ref<JPromise::javaobject>(jni::alias_ref<jni::JString> /* message */, jni::alias_ref<JFunc_void_std__string_bool::javaobject> /* onToken */)>("sendMessageAsync_cxx");
244
+ auto __result = method(_javaPart, jni::make_jstring(message), JFunc_void_std__string_bool_cxx::fromCpp(onToken));
245
+ return [&]() {
246
+ auto __promise = Promise<void>::create();
247
+ __result->cthis()->addOnResolvedListener([=](const jni::alias_ref<jni::JObject>& /* unit */) {
248
+ __promise->resolve();
249
+ });
250
+ __result->cthis()->addOnRejectedListener([=](const jni::alias_ref<jni::JThrowable>& __throwable) {
251
+ jni::JniException __jniError(__throwable);
252
+ __promise->reject(std::make_exception_ptr(__jniError));
253
+ });
254
+ return __promise;
255
+ }();
176
256
  }
177
257
  std::vector<Message> JHybridLiteRTLMSpec::getHistory() {
178
258
  static const auto method = _javaPart->javaClassStatic()->getMethod<jni::local_ref<jni::JArrayClass<JMessage>>()>("getHistory");
179
259
  auto __result = method(_javaPart);
180
- return [&]() {
181
- size_t __size = __result->size();
260
+ return [&](auto&& __input) {
261
+ size_t __size = __input->size();
182
262
  std::vector<Message> __vector;
183
263
  __vector.reserve(__size);
184
264
  for (size_t __i = 0; __i < __size; __i++) {
185
- auto __element = __result->getElement(__i);
265
+ auto __element = __input->getElement(__i);
186
266
  __vector.push_back(__element->toCpp());
187
267
  }
188
268
  return __vector;
189
- }();
269
+ }(__result);
190
270
  }
191
271
  void JHybridLiteRTLMSpec::resetConversation() {
192
272
  static const auto method = _javaPart->javaClassStatic()->getMethod<void()>("resetConversation");
@@ -202,6 +282,11 @@ namespace margelo::nitro::litertlm {
202
282
  auto __result = method(_javaPart);
203
283
  return __result->toCpp();
204
284
  }
285
+ double JHybridLiteRTLMSpec::countTokens(const std::string& text) {
286
+ static const auto method = _javaPart->javaClassStatic()->getMethod<double(jni::alias_ref<jni::JString> /* text */)>("countTokens");
287
+ auto __result = method(_javaPart, jni::make_jstring(text));
288
+ return __result;
289
+ }
205
290
  MemoryUsage JHybridLiteRTLMSpec::getMemoryUsage() {
206
291
  static const auto method = _javaPart->javaClassStatic()->getMethod<jni::local_ref<JMemoryUsage>()>("getMemoryUsage");
207
292
  auto __result = method(_javaPart);
@@ -57,14 +57,18 @@ namespace margelo::nitro::litertlm {
57
57
  std::shared_ptr<Promise<void>> loadModel(const std::string& modelPath, const std::optional<LLMConfig>& config) override;
58
58
  std::shared_ptr<Promise<std::string>> sendMessage(const std::string& message) override;
59
59
  std::shared_ptr<Promise<std::string>> sendMessageWithImage(const std::string& message, const std::string& imagePath) override;
60
+ std::shared_ptr<Promise<void>> sendMessageWithImageAsync(const std::string& message, const std::string& imagePath, const std::function<void(const std::string& /* token */, bool /* done */)>& onToken) override;
60
61
  std::shared_ptr<Promise<std::string>> downloadModel(const std::string& url, const std::string& fileName, const std::optional<std::function<void(double /* progress */)>>& onProgress) override;
61
62
  std::shared_ptr<Promise<void>> deleteModel(const std::string& fileName) override;
62
63
  std::shared_ptr<Promise<std::string>> sendMessageWithAudio(const std::string& message, const std::string& audioPath) override;
63
- void sendMessageAsync(const std::string& message, const std::function<void(const std::string& /* token */, bool /* done */)>& onToken) override;
64
+ std::shared_ptr<Promise<void>> sendMessageWithAudioAsync(const std::string& message, const std::string& audioPath, const std::function<void(const std::string& /* token */, bool /* done */)>& onToken) override;
65
+ std::shared_ptr<Promise<std::string>> sendMultimodalMessage(const std::vector<MultimodalPart>& parts) override;
66
+ std::shared_ptr<Promise<void>> sendMessageAsync(const std::string& message, const std::function<void(const std::string& /* token */, bool /* done */)>& onToken) override;
64
67
  std::vector<Message> getHistory() override;
65
68
  void resetConversation() override;
66
69
  bool isReady() override;
67
70
  GenerationStats getStats() override;
71
+ double countTokens(const std::string& text) override;
68
72
  MemoryUsage getMemoryUsage() override;
69
73
  void close() override;
70
74
 
@@ -12,8 +12,11 @@
12
12
 
13
13
  #include "Backend.hpp"
14
14
  #include "JBackend.hpp"
15
+ #include "JToolDefinition.hpp"
16
+ #include "ToolDefinition.hpp"
15
17
  #include <optional>
16
18
  #include <string>
19
+ #include <vector>
17
20
 
18
21
  namespace margelo::nitro::litertlm {
19
22
 
@@ -46,13 +49,34 @@ namespace margelo::nitro::litertlm {
46
49
  jni::local_ref<jni::JDouble> topK = this->getFieldValue(fieldTopK);
47
50
  static const auto fieldTopP = clazz->getField<jni::JDouble>("topP");
48
51
  jni::local_ref<jni::JDouble> topP = this->getFieldValue(fieldTopP);
52
+ static const auto fieldValidate = clazz->getField<jni::JBoolean>("validate");
53
+ jni::local_ref<jni::JBoolean> validate = this->getFieldValue(fieldValidate);
54
+ static const auto fieldMultimodal = clazz->getField<jni::JBoolean>("multimodal");
55
+ jni::local_ref<jni::JBoolean> multimodal = this->getFieldValue(fieldMultimodal);
56
+ static const auto fieldTools = clazz->getField<jni::JArrayClass<JToolDefinition>>("tools");
57
+ jni::local_ref<jni::JArrayClass<JToolDefinition>> tools = this->getFieldValue(fieldTools);
58
+ static const auto fieldEnableSpeculativeDecoding = clazz->getField<jni::JBoolean>("enableSpeculativeDecoding");
59
+ jni::local_ref<jni::JBoolean> enableSpeculativeDecoding = this->getFieldValue(fieldEnableSpeculativeDecoding);
49
60
  return LLMConfig(
50
61
  systemPrompt != nullptr ? std::make_optional(systemPrompt->toStdString()) : std::nullopt,
51
62
  backend != nullptr ? std::make_optional(backend->toCpp()) : std::nullopt,
52
63
  maxTokens != nullptr ? std::make_optional(maxTokens->value()) : std::nullopt,
53
64
  temperature != nullptr ? std::make_optional(temperature->value()) : std::nullopt,
54
65
  topK != nullptr ? std::make_optional(topK->value()) : std::nullopt,
55
- topP != nullptr ? std::make_optional(topP->value()) : std::nullopt
66
+ topP != nullptr ? std::make_optional(topP->value()) : std::nullopt,
67
+ validate != nullptr ? std::make_optional(static_cast<bool>(validate->value())) : std::nullopt,
68
+ multimodal != nullptr ? std::make_optional(static_cast<bool>(multimodal->value())) : std::nullopt,
69
+ tools != nullptr ? std::make_optional([&](auto&& __input) {
70
+ size_t __size = __input->size();
71
+ std::vector<ToolDefinition> __vector;
72
+ __vector.reserve(__size);
73
+ for (size_t __i = 0; __i < __size; __i++) {
74
+ auto __element = __input->getElement(__i);
75
+ __vector.push_back(__element->toCpp());
76
+ }
77
+ return __vector;
78
+ }(tools)) : std::nullopt,
79
+ enableSpeculativeDecoding != nullptr ? std::make_optional(static_cast<bool>(enableSpeculativeDecoding->value())) : std::nullopt
56
80
  );
57
81
  }
58
82
 
@@ -62,7 +86,7 @@ namespace margelo::nitro::litertlm {
62
86
  */
63
87
  [[maybe_unused]]
64
88
  static jni::local_ref<JLLMConfig::javaobject> fromCpp(const LLMConfig& value) {
65
- using JSignature = JLLMConfig(jni::alias_ref<jni::JString>, jni::alias_ref<JBackend>, jni::alias_ref<jni::JDouble>, jni::alias_ref<jni::JDouble>, jni::alias_ref<jni::JDouble>, jni::alias_ref<jni::JDouble>);
89
+ using JSignature = JLLMConfig(jni::alias_ref<jni::JString>, jni::alias_ref<JBackend>, jni::alias_ref<jni::JDouble>, jni::alias_ref<jni::JDouble>, jni::alias_ref<jni::JDouble>, jni::alias_ref<jni::JDouble>, jni::alias_ref<jni::JBoolean>, jni::alias_ref<jni::JBoolean>, jni::alias_ref<jni::JArrayClass<JToolDefinition>>, jni::alias_ref<jni::JBoolean>);
66
90
  static const auto clazz = javaClassStatic();
67
91
  static const auto create = clazz->getStaticMethod<JSignature>("fromCpp");
68
92
  return create(
@@ -72,7 +96,20 @@ namespace margelo::nitro::litertlm {
72
96
  value.maxTokens.has_value() ? jni::JDouble::valueOf(value.maxTokens.value()) : nullptr,
73
97
  value.temperature.has_value() ? jni::JDouble::valueOf(value.temperature.value()) : nullptr,
74
98
  value.topK.has_value() ? jni::JDouble::valueOf(value.topK.value()) : nullptr,
75
- value.topP.has_value() ? jni::JDouble::valueOf(value.topP.value()) : nullptr
99
+ value.topP.has_value() ? jni::JDouble::valueOf(value.topP.value()) : nullptr,
100
+ value.validate.has_value() ? jni::JBoolean::valueOf(value.validate.value()) : nullptr,
101
+ value.multimodal.has_value() ? jni::JBoolean::valueOf(value.multimodal.value()) : nullptr,
102
+ value.tools.has_value() ? [&](auto&& __input) {
103
+ size_t __size = __input.size();
104
+ jni::local_ref<jni::JArrayClass<JToolDefinition>> __array = jni::JArrayClass<JToolDefinition>::newArray(__size);
105
+ for (size_t __i = 0; __i < __size; __i++) {
106
+ const auto& __element = __input[__i];
107
+ auto __elementJni = JToolDefinition::fromCpp(__element);
108
+ __array->setElement(__i, *__elementJni);
109
+ }
110
+ return __array;
111
+ }(value.tools.value()) : nullptr,
112
+ value.enableSpeculativeDecoding.has_value() ? jni::JBoolean::valueOf(value.enableSpeculativeDecoding.value()) : nullptr
76
113
  );
77
114
  }
78
115
  };
@@ -0,0 +1,74 @@
1
+ ///
2
+ /// JMultimodalPart.hpp
3
+ /// This file was generated by nitrogen. DO NOT MODIFY THIS FILE.
4
+ /// https://github.com/mrousavy/nitro
5
+ /// Copyright © Marc Rousavy @ Margelo
6
+ ///
7
+
8
+ #pragma once
9
+
10
+ #include <fbjni/fbjni.h>
11
+ #include "MultimodalPart.hpp"
12
+
13
+ #include "JPartType.hpp"
14
+ #include "PartType.hpp"
15
+ #include <NitroModules/ArrayBuffer.hpp>
16
+ #include <NitroModules/JArrayBuffer.hpp>
17
+ #include <optional>
18
+ #include <string>
19
+
20
+ namespace margelo::nitro::litertlm {
21
+
22
+ using namespace facebook;
23
+
24
+ /**
25
+ * The C++ JNI bridge between the C++ struct "MultimodalPart" and the the Kotlin data class "MultimodalPart".
26
+ */
27
+ struct JMultimodalPart final: public jni::JavaClass<JMultimodalPart> {
28
+ public:
29
+ static constexpr auto kJavaDescriptor = "Lcom/margelo/nitro/dev/litert/litertlm/MultimodalPart;";
30
+
31
+ public:
32
+ /**
33
+ * Convert this Java/Kotlin-based struct to the C++ struct MultimodalPart by copying all values to C++.
34
+ */
35
+ [[maybe_unused]]
36
+ [[nodiscard]]
37
+ MultimodalPart toCpp() const {
38
+ static const auto clazz = javaClassStatic();
39
+ static const auto fieldType = clazz->getField<JPartType>("type");
40
+ jni::local_ref<JPartType> type = this->getFieldValue(fieldType);
41
+ static const auto fieldText = clazz->getField<jni::JString>("text");
42
+ jni::local_ref<jni::JString> text = this->getFieldValue(fieldText);
43
+ static const auto fieldImageBuffer = clazz->getField<JArrayBuffer::javaobject>("imageBuffer");
44
+ jni::local_ref<JArrayBuffer::javaobject> imageBuffer = this->getFieldValue(fieldImageBuffer);
45
+ static const auto fieldAudioBuffer = clazz->getField<JArrayBuffer::javaobject>("audioBuffer");
46
+ jni::local_ref<JArrayBuffer::javaobject> audioBuffer = this->getFieldValue(fieldAudioBuffer);
47
+ return MultimodalPart(
48
+ type->toCpp(),
49
+ text != nullptr ? std::make_optional(text->toStdString()) : std::nullopt,
50
+ imageBuffer != nullptr ? std::make_optional(imageBuffer->cthis()->getArrayBuffer()) : std::nullopt,
51
+ audioBuffer != nullptr ? std::make_optional(audioBuffer->cthis()->getArrayBuffer()) : std::nullopt
52
+ );
53
+ }
54
+
55
+ public:
56
+ /**
57
+ * Create a Java/Kotlin-based struct by copying all values from the given C++ struct to Java.
58
+ */
59
+ [[maybe_unused]]
60
+ static jni::local_ref<JMultimodalPart::javaobject> fromCpp(const MultimodalPart& value) {
61
+ using JSignature = JMultimodalPart(jni::alias_ref<JPartType>, jni::alias_ref<jni::JString>, jni::alias_ref<JArrayBuffer::javaobject>, jni::alias_ref<JArrayBuffer::javaobject>);
62
+ static const auto clazz = javaClassStatic();
63
+ static const auto create = clazz->getStaticMethod<JSignature>("fromCpp");
64
+ return create(
65
+ clazz,
66
+ JPartType::fromCpp(value.type),
67
+ value.text.has_value() ? jni::make_jstring(value.text.value()) : nullptr,
68
+ value.imageBuffer.has_value() ? JArrayBuffer::wrap(value.imageBuffer.value()) : nullptr,
69
+ value.audioBuffer.has_value() ? JArrayBuffer::wrap(value.audioBuffer.value()) : nullptr
70
+ );
71
+ }
72
+ };
73
+
74
+ } // namespace margelo::nitro::litertlm