react-native-litert-lm 0.1.1 → 0.2.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. package/README.md +149 -31
  2. package/android/src/main/java/com/margelo/nitro/dev/litert/litertlm/HybridLiteRTLM.kt +307 -61
  3. package/cpp/HybridLiteRTLM.cpp +85 -31
  4. package/cpp/HybridLiteRTLM.hpp +4 -0
  5. package/cpp/include/stb_image.h +7988 -0
  6. package/lib/hooks.d.ts +16 -0
  7. package/lib/hooks.js +114 -0
  8. package/lib/index.d.ts +27 -2
  9. package/lib/index.js +50 -6
  10. package/lib/modelFactory.d.ts +5 -0
  11. package/lib/modelFactory.js +42 -0
  12. package/lib/specs/LiteRTLM.nitro.d.ts +19 -0
  13. package/lib/templates.d.ts +51 -0
  14. package/lib/templates.js +81 -0
  15. package/nitrogen/generated/android/LiteRTLMOnLoad.cpp +2 -0
  16. package/nitrogen/generated/android/c++/JFunc_void_double.hpp +75 -0
  17. package/nitrogen/generated/android/c++/JHybridLiteRTLMSpec.cpp +33 -1
  18. package/nitrogen/generated/android/c++/JHybridLiteRTLMSpec.hpp +2 -0
  19. package/nitrogen/generated/android/c++/JLLMConfig.hpp +6 -1
  20. package/nitrogen/generated/android/kotlin/com/margelo/nitro/dev/litert/litertlm/Func_void_double.kt +80 -0
  21. package/nitrogen/generated/android/kotlin/com/margelo/nitro/dev/litert/litertlm/HybridLiteRTLMSpec.kt +13 -0
  22. package/nitrogen/generated/android/kotlin/com/margelo/nitro/dev/litert/litertlm/LLMConfig.kt +5 -2
  23. package/nitrogen/generated/shared/c++/HybridLiteRTLMSpec.cpp +2 -0
  24. package/nitrogen/generated/shared/c++/HybridLiteRTLMSpec.hpp +2 -0
  25. package/nitrogen/generated/shared/c++/LLMConfig.hpp +7 -2
  26. package/package.json +1 -1
  27. package/src/hooks.ts +152 -0
  28. package/src/index.ts +41 -3
  29. package/src/modelFactory.ts +49 -0
  30. package/src/specs/LiteRTLM.nitro.ts +26 -0
  31. package/src/templates.ts +105 -0
package/lib/hooks.d.ts ADDED
@@ -0,0 +1,16 @@
1
+ import { LiteRTLM, LLMConfig } from "./index";
2
+ export interface UseModelConfig extends LLMConfig {
3
+ autoLoad?: boolean;
4
+ }
5
+ export interface UseModelResult {
6
+ model: LiteRTLM | null;
7
+ isReady: boolean;
8
+ isGenerating: boolean;
9
+ downloadProgress: number;
10
+ error: string | null;
11
+ generate: (prompt: string) => Promise<string>;
12
+ reset: () => void;
13
+ deleteModel: (fileName: string) => Promise<void>;
14
+ load: () => Promise<void>;
15
+ }
16
+ export declare function useModel(pathOrUrl: string, config?: UseModelConfig): UseModelResult;
package/lib/hooks.js ADDED
@@ -0,0 +1,114 @@
1
+ "use strict";
2
+ Object.defineProperty(exports, "__esModule", { value: true });
3
+ exports.useModel = useModel;
4
+ const react_1 = require("react");
5
+ const modelFactory_1 = require("./modelFactory");
6
+ function useModel(pathOrUrl, config) {
7
+ const modelRef = (0, react_1.useRef)(null);
8
+ const [isReady, setIsReady] = (0, react_1.useState)(false);
9
+ const [isGenerating, setIsGenerating] = (0, react_1.useState)(false);
10
+ const [downloadProgress, setDownloadProgress] = (0, react_1.useState)(0);
11
+ const [error, setError] = (0, react_1.useState)(null);
12
+ // Extract autoLoad (default true)
13
+ const autoLoad = config?.autoLoad ?? true;
14
+ // Initialize the model instance
15
+ (0, react_1.useEffect)(() => {
16
+ modelRef.current = (0, modelFactory_1.createLLM)();
17
+ let isMounted = true;
18
+ // Cleanup on unmount
19
+ return () => {
20
+ isMounted = false;
21
+ try {
22
+ modelRef.current?.close();
23
+ }
24
+ catch (e) {
25
+ console.warn("Failed to close model", e);
26
+ }
27
+ };
28
+ }, []);
29
+ const load = (0, react_1.useCallback)(async () => {
30
+ setIsReady(false);
31
+ setError(null);
32
+ setDownloadProgress(0);
33
+ try {
34
+ let modelPath = pathOrUrl;
35
+ // Handle URL download manually to capture progress
36
+ if (pathOrUrl.startsWith("http://") || pathOrUrl.startsWith("https://")) {
37
+ const fileName = pathOrUrl.split("/").pop() || "model.bin";
38
+ if (modelRef.current) {
39
+ modelPath = await modelRef.current.downloadModel(pathOrUrl, fileName, (progress) => {
40
+ setDownloadProgress(progress);
41
+ });
42
+ }
43
+ }
44
+ if (modelRef.current) {
45
+ // Create a clean config object for native loadModel (excluding autoLoad)
46
+ const nativeConfig = { ...config };
47
+ delete nativeConfig.autoLoad;
48
+ await modelRef.current.loadModel(modelPath, nativeConfig);
49
+ setIsReady(true);
50
+ }
51
+ }
52
+ catch (e) {
53
+ setError(e.message || "Failed to load model");
54
+ console.error(e);
55
+ }
56
+ }, [pathOrUrl, config]);
57
+ (0, react_1.useEffect)(() => {
58
+ if (autoLoad) {
59
+ load();
60
+ }
61
+ }, [autoLoad, load]);
62
+ const generate = (0, react_1.useCallback)(async (prompt) => {
63
+ if (!modelRef.current || !isReady) {
64
+ throw new Error("Model not ready");
65
+ }
66
+ setIsGenerating(true);
67
+ try {
68
+ return new Promise((resolve, reject) => {
69
+ let fullResponse = "";
70
+ try {
71
+ modelRef.current?.sendMessageAsync(prompt, (token, done) => {
72
+ fullResponse += token;
73
+ if (done) {
74
+ resolve(fullResponse);
75
+ }
76
+ });
77
+ }
78
+ catch (e) {
79
+ reject(e);
80
+ }
81
+ });
82
+ }
83
+ catch (e) {
84
+ setError(e.message || "Generation failed");
85
+ throw e;
86
+ }
87
+ finally {
88
+ setIsGenerating(false);
89
+ }
90
+ }, [isReady]);
91
+ const reset = (0, react_1.useCallback)(() => {
92
+ if (modelRef.current) {
93
+ modelRef.current.resetConversation();
94
+ }
95
+ }, []);
96
+ const deleteModel = (0, react_1.useCallback)(async (fileName) => {
97
+ if (modelRef.current) {
98
+ await modelRef.current.deleteModel(fileName);
99
+ setIsReady(false);
100
+ setDownloadProgress(0);
101
+ }
102
+ }, []);
103
+ return {
104
+ model: modelRef.current,
105
+ isReady,
106
+ isGenerating,
107
+ downloadProgress,
108
+ error,
109
+ generate,
110
+ reset,
111
+ deleteModel,
112
+ load,
113
+ };
114
+ }
package/lib/index.d.ts CHANGED
@@ -1,5 +1,8 @@
1
- import type { LiteRTLM, Backend } from "./specs/LiteRTLM.nitro";
1
+ import type { Backend } from "./specs/LiteRTLM.nitro";
2
2
  export type { LiteRTLM, LLMConfig, Message, Backend, Role, GenerationStats, } from "./specs/LiteRTLM.nitro";
3
+ export type { ChatMessage } from "./templates";
4
+ export { applyGemmaTemplate, applyPhiTemplate, applyLlamaTemplate, } from "./templates";
5
+ export * from "./hooks";
3
6
  /**
4
7
  * Creates a new LiteRT-LM inference engine instance.
5
8
  *
@@ -33,7 +36,7 @@ export type { LiteRTLM, LLMConfig, Message, Backend, Role, GenerationStats, } fr
33
36
  * llm.close();
34
37
  * ```
35
38
  */
36
- export declare function createLLM(): LiteRTLM;
39
+ export { createLLM } from "./modelFactory";
37
40
  /**
38
41
  * Pre-defined model identifiers for common models.
39
42
  * Use with model download utilities or as reference.
@@ -80,3 +83,25 @@ export declare function getRecommendedBackend(): Backend;
80
83
  * ```
81
84
  */
82
85
  export declare function checkBackendSupport(backend: Backend): string | undefined;
86
+ /**
87
+ * Check if multimodal features (image/audio) are supported on the current platform.
88
+ * Returns an error message if not supported, undefined if OK.
89
+ *
90
+ * @returns Error message if multimodal is not supported, undefined if OK
91
+ *
92
+ * @example
93
+ * ```typescript
94
+ * const error = checkMultimodalSupport();
95
+ * if (error) {
96
+ * console.warn(error);
97
+ * // Fall back to text-only
98
+ * } else {
99
+ * llm.sendMessageWithImage('Describe this', imagePath);
100
+ * }
101
+ * ```
102
+ */
103
+ export declare function checkMultimodalSupport(): string | undefined;
104
+ /**
105
+ * Download URL for the Gemma 3n E2B IT INT4 model.
106
+ */
107
+ export declare const GEMMA_3N_E2B_IT_INT4 = "https://litert.dev/gemma-3n-E2B-it-int4.litertlm";
package/lib/index.js CHANGED
@@ -1,11 +1,29 @@
1
1
  "use strict";
2
+ var __createBinding = (this && this.__createBinding) || (Object.create ? (function(o, m, k, k2) {
3
+ if (k2 === undefined) k2 = k;
4
+ var desc = Object.getOwnPropertyDescriptor(m, k);
5
+ if (!desc || ("get" in desc ? !m.__esModule : desc.writable || desc.configurable)) {
6
+ desc = { enumerable: true, get: function() { return m[k]; } };
7
+ }
8
+ Object.defineProperty(o, k2, desc);
9
+ }) : (function(o, m, k, k2) {
10
+ if (k2 === undefined) k2 = k;
11
+ o[k2] = m[k];
12
+ }));
13
+ var __exportStar = (this && this.__exportStar) || function(m, exports) {
14
+ for (var p in m) if (p !== "default" && !Object.prototype.hasOwnProperty.call(exports, p)) __createBinding(exports, m, p);
15
+ };
2
16
  Object.defineProperty(exports, "__esModule", { value: true });
3
- exports.Models = void 0;
4
- exports.createLLM = createLLM;
17
+ exports.GEMMA_3N_E2B_IT_INT4 = exports.Models = exports.createLLM = exports.applyLlamaTemplate = exports.applyPhiTemplate = exports.applyGemmaTemplate = void 0;
5
18
  exports.getRecommendedBackend = getRecommendedBackend;
6
19
  exports.checkBackendSupport = checkBackendSupport;
7
- const react_native_nitro_modules_1 = require("react-native-nitro-modules");
20
+ exports.checkMultimodalSupport = checkMultimodalSupport;
8
21
  const react_native_1 = require("react-native");
22
+ var templates_1 = require("./templates");
23
+ Object.defineProperty(exports, "applyGemmaTemplate", { enumerable: true, get: function () { return templates_1.applyGemmaTemplate; } });
24
+ Object.defineProperty(exports, "applyPhiTemplate", { enumerable: true, get: function () { return templates_1.applyPhiTemplate; } });
25
+ Object.defineProperty(exports, "applyLlamaTemplate", { enumerable: true, get: function () { return templates_1.applyLlamaTemplate; } });
26
+ __exportStar(require("./hooks"), exports);
9
27
  /**
10
28
  * Creates a new LiteRT-LM inference engine instance.
11
29
  *
@@ -39,9 +57,8 @@ const react_native_1 = require("react-native");
39
57
  * llm.close();
40
58
  * ```
41
59
  */
42
- function createLLM() {
43
- return react_native_nitro_modules_1.NitroModules.createHybridObject("LiteRTLM");
44
- }
60
+ var modelFactory_1 = require("./modelFactory");
61
+ Object.defineProperty(exports, "createLLM", { enumerable: true, get: function () { return modelFactory_1.createLLM; } });
45
62
  /**
46
63
  * Pre-defined model identifiers for common models.
47
64
  * Use with model download utilities or as reference.
@@ -104,3 +121,30 @@ function checkBackendSupport(backend) {
104
121
  }
105
122
  return undefined;
106
123
  }
124
+ /**
125
+ * Check if multimodal features (image/audio) are supported on the current platform.
126
+ * Returns an error message if not supported, undefined if OK.
127
+ *
128
+ * @returns Error message if multimodal is not supported, undefined if OK
129
+ *
130
+ * @example
131
+ * ```typescript
132
+ * const error = checkMultimodalSupport();
133
+ * if (error) {
134
+ * console.warn(error);
135
+ * // Fall back to text-only
136
+ * } else {
137
+ * llm.sendMessageWithImage('Describe this', imagePath);
138
+ * }
139
+ * ```
140
+ */
141
+ function checkMultimodalSupport() {
142
+ if (react_native_1.Platform.OS === "ios") {
143
+ return "Multimodal (image/audio) is not yet supported on iOS. LiteRT-LM iOS SDK is pending.";
144
+ }
145
+ return undefined;
146
+ }
147
+ /**
148
+ * Download URL for the Gemma 3n E2B IT INT4 model.
149
+ */
150
+ exports.GEMMA_3N_E2B_IT_INT4 = "https://litert.dev/gemma-3n-E2B-it-int4.litertlm";
@@ -0,0 +1,5 @@
1
+ import { LiteRTLM } from "./specs/LiteRTLM.nitro";
2
+ /**
3
+ * Creates a new LiteRT-LM inference engine instance.
4
+ */
5
+ export declare function createLLM(): LiteRTLM;
@@ -0,0 +1,42 @@
1
+ "use strict";
2
+ Object.defineProperty(exports, "__esModule", { value: true });
3
+ exports.createLLM = createLLM;
4
+ const react_native_nitro_modules_1 = require("react-native-nitro-modules");
5
+ /**
6
+ * Creates a new LiteRT-LM inference engine instance.
7
+ */
8
+ function createLLM() {
9
+ const native = react_native_nitro_modules_1.NitroModules.createHybridObject("LiteRTLM");
10
+ return {
11
+ ...native,
12
+ loadModel: async (pathOrUrl, config) => {
13
+ let modelPath = pathOrUrl;
14
+ // Check if it's a URL
15
+ if (pathOrUrl.startsWith("http://") || pathOrUrl.startsWith("https://")) {
16
+ // Extract filename from URL
17
+ const fileName = pathOrUrl.split("/").pop();
18
+ if (!fileName) {
19
+ throw new Error(`Invalid model URL: ${pathOrUrl}`);
20
+ }
21
+ console.log(`Checking model at ${pathOrUrl}...`);
22
+ modelPath = await native.downloadModel(pathOrUrl, fileName, (progress) => {
23
+ console.log(`Download progress: ${progress}`);
24
+ });
25
+ console.log(`Model downloaded to: ${modelPath}`);
26
+ }
27
+ return native.loadModel(modelPath, config);
28
+ },
29
+ // Bind valid methods to native instance
30
+ sendMessage: native.sendMessage.bind(native),
31
+ sendMessageAsync: native.sendMessageAsync.bind(native),
32
+ sendMessageWithImage: native.sendMessageWithImage.bind(native),
33
+ sendMessageWithAudio: native.sendMessageWithAudio.bind(native),
34
+ getHistory: native.getHistory.bind(native),
35
+ resetConversation: native.resetConversation.bind(native),
36
+ isReady: native.isReady.bind(native),
37
+ getStats: native.getStats.bind(native),
38
+ close: native.close.bind(native),
39
+ downloadModel: native.downloadModel.bind(native),
40
+ deleteModel: native.deleteModel.bind(native),
41
+ };
42
+ }
@@ -18,6 +18,12 @@ export type Role = "user" | "model" | "system";
18
18
  * Configuration options for loading an LLM.
19
19
  */
20
20
  export interface LLMConfig {
21
+ /**
22
+ * System prompt to set the model's behavior.
23
+ * This is prepended to the conversation to guide model responses.
24
+ * @example "You are a helpful coding assistant."
25
+ */
26
+ systemPrompt?: string;
21
27
  /**
22
28
  * Primary compute backend for text generation.
23
29
  * - 'cpu': CPU inference (slower but always available)
@@ -125,6 +131,19 @@ export interface LiteRTLM extends HybridObject<{
125
131
  * @returns The model's response text.
126
132
  */
127
133
  sendMessageWithImage(message: string, imagePath: string): Promise<string>;
134
+ /**
135
+ * Download a model file from a URL.
136
+ * @param url URL to download from.
137
+ * @param fileName Filename to save as (in app's files directory).
138
+ * @param onProgress Callback for download progress (0.0 - 1.0).
139
+ * @returns Absolute path to the downloaded file.
140
+ */
141
+ downloadModel(url: string, fileName: string, onProgress?: (progress: number) => void): Promise<string>;
142
+ /**
143
+ * Delete a downloaded model file.
144
+ * @param fileName Filename to delete (in app's files directory).
145
+ */
146
+ deleteModel(fileName: string): Promise<void>;
128
147
  /**
129
148
  * Send a text message with audio (multimodal).
130
149
  * @param message User message text.
@@ -0,0 +1,51 @@
1
+ /**
2
+ * Prompt template utilities for different LLM families.
3
+ *
4
+ * LiteRT-LM's Conversation API may handle templates internally for some models,
5
+ * but these utilities give developers explicit control for custom workflows
6
+ * or when using models with different template formats.
7
+ *
8
+ * @example
9
+ * ```typescript
10
+ * import { applyGemmaTemplate, ChatMessage } from 'react-native-litert-lm';
11
+ *
12
+ * const history: ChatMessage[] = [
13
+ * { role: 'user', content: 'What is React Native?' },
14
+ * { role: 'model', content: 'React Native is a framework for building...' },
15
+ * { role: 'user', content: 'How do I use hooks?' }
16
+ * ];
17
+ *
18
+ * const prompt = applyGemmaTemplate(history, 'You are a helpful coding assistant.');
19
+ * ```
20
+ */
21
+ /**
22
+ * A message in a conversation.
23
+ */
24
+ export type ChatMessage = {
25
+ role: "user" | "model" | "system";
26
+ content: string;
27
+ };
28
+ /**
29
+ * Apply Gemma chat template (Gemma 2, Gemma 3, Gemma 3n).
30
+ *
31
+ * @param history Array of previous messages
32
+ * @param systemPrompt Optional system prompt
33
+ * @returns Formatted prompt string
34
+ */
35
+ export declare function applyGemmaTemplate(history: ChatMessage[], systemPrompt?: string): string;
36
+ /**
37
+ * Apply Phi chat template (Phi-3, Phi-4).
38
+ *
39
+ * @param history Array of previous messages
40
+ * @param systemPrompt Optional system prompt
41
+ * @returns Formatted prompt string
42
+ */
43
+ export declare function applyPhiTemplate(history: ChatMessage[], systemPrompt?: string): string;
44
+ /**
45
+ * Apply Llama 3 chat template.
46
+ *
47
+ * @param history Array of previous messages
48
+ * @param systemPrompt Optional system prompt
49
+ * @returns Formatted prompt string
50
+ */
51
+ export declare function applyLlamaTemplate(history: ChatMessage[], systemPrompt?: string): string;
@@ -0,0 +1,81 @@
1
+ "use strict";
2
+ /**
3
+ * Prompt template utilities for different LLM families.
4
+ *
5
+ * LiteRT-LM's Conversation API may handle templates internally for some models,
6
+ * but these utilities give developers explicit control for custom workflows
7
+ * or when using models with different template formats.
8
+ *
9
+ * @example
10
+ * ```typescript
11
+ * import { applyGemmaTemplate, ChatMessage } from 'react-native-litert-lm';
12
+ *
13
+ * const history: ChatMessage[] = [
14
+ * { role: 'user', content: 'What is React Native?' },
15
+ * { role: 'model', content: 'React Native is a framework for building...' },
16
+ * { role: 'user', content: 'How do I use hooks?' }
17
+ * ];
18
+ *
19
+ * const prompt = applyGemmaTemplate(history, 'You are a helpful coding assistant.');
20
+ * ```
21
+ */
22
+ Object.defineProperty(exports, "__esModule", { value: true });
23
+ exports.applyGemmaTemplate = applyGemmaTemplate;
24
+ exports.applyPhiTemplate = applyPhiTemplate;
25
+ exports.applyLlamaTemplate = applyLlamaTemplate;
26
+ /**
27
+ * Apply Gemma chat template (Gemma 2, Gemma 3, Gemma 3n).
28
+ *
29
+ * @param history Array of previous messages
30
+ * @param systemPrompt Optional system prompt
31
+ * @returns Formatted prompt string
32
+ */
33
+ function applyGemmaTemplate(history, systemPrompt) {
34
+ let result = "";
35
+ if (systemPrompt) {
36
+ result += `<start_of_turn>system\n${systemPrompt}<end_of_turn>\n`;
37
+ }
38
+ for (const m of history) {
39
+ result += `<start_of_turn>${m.role}\n${m.content}<end_of_turn>\n`;
40
+ }
41
+ result += "<start_of_turn>model\n";
42
+ return result;
43
+ }
44
+ /**
45
+ * Apply Phi chat template (Phi-3, Phi-4).
46
+ *
47
+ * @param history Array of previous messages
48
+ * @param systemPrompt Optional system prompt
49
+ * @returns Formatted prompt string
50
+ */
51
+ function applyPhiTemplate(history, systemPrompt) {
52
+ let result = "";
53
+ if (systemPrompt) {
54
+ result += `<|system|>\n${systemPrompt}<|end|>\n`;
55
+ }
56
+ for (const m of history) {
57
+ const role = m.role === "model" ? "assistant" : m.role;
58
+ result += `<|${role}|>\n${m.content}<|end|>\n`;
59
+ }
60
+ result += "<|assistant|>\n";
61
+ return result;
62
+ }
63
+ /**
64
+ * Apply Llama 3 chat template.
65
+ *
66
+ * @param history Array of previous messages
67
+ * @param systemPrompt Optional system prompt
68
+ * @returns Formatted prompt string
69
+ */
70
+ function applyLlamaTemplate(history, systemPrompt) {
71
+ let result = "<|begin_of_text|>";
72
+ if (systemPrompt) {
73
+ result += `<|start_header_id|>system<|end_header_id|>\n\n${systemPrompt}<|eot_id|>`;
74
+ }
75
+ for (const m of history) {
76
+ const role = m.role === "model" ? "assistant" : m.role;
77
+ result += `<|start_header_id|>${role}<|end_header_id|>\n\n${m.content}<|eot_id|>`;
78
+ }
79
+ result += "<|start_header_id|>assistant<|end_header_id|>\n\n";
80
+ return result;
81
+ }
@@ -16,6 +16,7 @@
16
16
  #include <NitroModules/HybridObjectRegistry.hpp>
17
17
 
18
18
  #include "JHybridLiteRTLMSpec.hpp"
19
+ #include "JFunc_void_double.hpp"
19
20
  #include "JFunc_void_std__string_bool.hpp"
20
21
  #include <NitroModules/DefaultConstructableObject.hpp>
21
22
 
@@ -29,6 +30,7 @@ int initialize(JavaVM* vm) {
29
30
  return facebook::jni::initialize(vm, [] {
30
31
  // Register native JNI methods
31
32
  margelo::nitro::litertlm::JHybridLiteRTLMSpec::registerNatives();
33
+ margelo::nitro::litertlm::JFunc_void_double_cxx::registerNatives();
32
34
  margelo::nitro::litertlm::JFunc_void_std__string_bool_cxx::registerNatives();
33
35
 
34
36
  // Register Nitro Hybrid Objects
@@ -0,0 +1,75 @@
1
+ ///
2
+ /// JFunc_void_double.hpp
3
+ /// This file was generated by nitrogen. DO NOT MODIFY THIS FILE.
4
+ /// https://github.com/mrousavy/nitro
5
+ /// Copyright © Marc Rousavy @ Margelo
6
+ ///
7
+
8
+ #pragma once
9
+
10
+ #include <fbjni/fbjni.h>
11
+ #include <functional>
12
+
13
+ #include <functional>
14
+ #include <NitroModules/JNICallable.hpp>
15
+
16
+ namespace margelo::nitro::litertlm {
17
+
18
+ using namespace facebook;
19
+
20
+ /**
21
+ * Represents the Java/Kotlin callback `(progress: Double) -> Unit`.
22
+ * This can be passed around between C++ and Java/Kotlin.
23
+ */
24
+ struct JFunc_void_double: public jni::JavaClass<JFunc_void_double> {
25
+ public:
26
+ static auto constexpr kJavaDescriptor = "Lcom/margelo/nitro/dev/litert/litertlm/Func_void_double;";
27
+
28
+ public:
29
+ /**
30
+ * Invokes the function this `JFunc_void_double` instance holds through JNI.
31
+ */
32
+ void invoke(double progress) const {
33
+ static const auto method = javaClassStatic()->getMethod<void(double /* progress */)>("invoke");
34
+ method(self(), progress);
35
+ }
36
+ };
37
+
38
+ /**
39
+ * An implementation of Func_void_double that is backed by a C++ implementation (using `std::function<...>`)
40
+ */
41
+ class JFunc_void_double_cxx final: public jni::HybridClass<JFunc_void_double_cxx, JFunc_void_double> {
42
+ public:
43
+ static jni::local_ref<JFunc_void_double::javaobject> fromCpp(const std::function<void(double /* progress */)>& func) {
44
+ return JFunc_void_double_cxx::newObjectCxxArgs(func);
45
+ }
46
+
47
+ public:
48
+ /**
49
+ * Invokes the C++ `std::function<...>` this `JFunc_void_double_cxx` instance holds.
50
+ */
51
+ void invoke_cxx(double progress) {
52
+ _func(progress);
53
+ }
54
+
55
+ public:
56
+ [[nodiscard]]
57
+ inline const std::function<void(double /* progress */)>& getFunction() const {
58
+ return _func;
59
+ }
60
+
61
+ public:
62
+ static auto constexpr kJavaDescriptor = "Lcom/margelo/nitro/dev/litert/litertlm/Func_void_double_cxx;";
63
+ static void registerNatives() {
64
+ registerHybrid({makeNativeMethod("invoke_cxx", JFunc_void_double_cxx::invoke_cxx)});
65
+ }
66
+
67
+ private:
68
+ explicit JFunc_void_double_cxx(const std::function<void(double /* progress */)>& func): _func(func) { }
69
+
70
+ private:
71
+ friend HybridBase;
72
+ std::function<void(double /* progress */)> _func;
73
+ };
74
+
75
+ } // namespace margelo::nitro::litertlm
@@ -35,8 +35,9 @@ namespace margelo::nitro::litertlm { enum class Backend; }
35
35
  #include "Backend.hpp"
36
36
  #include "JBackend.hpp"
37
37
  #include <functional>
38
- #include "JFunc_void_std__string_bool.hpp"
38
+ #include "JFunc_void_double.hpp"
39
39
  #include <NitroModules/JNICallable.hpp>
40
+ #include "JFunc_void_std__string_bool.hpp"
40
41
 
41
42
  namespace margelo::nitro::litertlm {
42
43
 
@@ -124,6 +125,37 @@ namespace margelo::nitro::litertlm {
124
125
  return __promise;
125
126
  }();
126
127
  }
128
+ std::shared_ptr<Promise<std::string>> JHybridLiteRTLMSpec::downloadModel(const std::string& url, const std::string& fileName, const std::optional<std::function<void(double /* progress */)>>& onProgress) {
129
+ static const auto method = javaClassStatic()->getMethod<jni::local_ref<JPromise::javaobject>(jni::alias_ref<jni::JString> /* url */, jni::alias_ref<jni::JString> /* fileName */, jni::alias_ref<JFunc_void_double::javaobject> /* onProgress */)>("downloadModel_cxx");
130
+ auto __result = method(_javaPart, jni::make_jstring(url), jni::make_jstring(fileName), onProgress.has_value() ? JFunc_void_double_cxx::fromCpp(onProgress.value()) : nullptr);
131
+ return [&]() {
132
+ auto __promise = Promise<std::string>::create();
133
+ __result->cthis()->addOnResolvedListener([=](const jni::alias_ref<jni::JObject>& __boxedResult) {
134
+ auto __result = jni::static_ref_cast<jni::JString>(__boxedResult);
135
+ __promise->resolve(__result->toStdString());
136
+ });
137
+ __result->cthis()->addOnRejectedListener([=](const jni::alias_ref<jni::JThrowable>& __throwable) {
138
+ jni::JniException __jniError(__throwable);
139
+ __promise->reject(std::make_exception_ptr(__jniError));
140
+ });
141
+ return __promise;
142
+ }();
143
+ }
144
+ std::shared_ptr<Promise<void>> JHybridLiteRTLMSpec::deleteModel(const std::string& fileName) {
145
+ static const auto method = javaClassStatic()->getMethod<jni::local_ref<JPromise::javaobject>(jni::alias_ref<jni::JString> /* fileName */)>("deleteModel");
146
+ auto __result = method(_javaPart, jni::make_jstring(fileName));
147
+ return [&]() {
148
+ auto __promise = Promise<void>::create();
149
+ __result->cthis()->addOnResolvedListener([=](const jni::alias_ref<jni::JObject>& /* unit */) {
150
+ __promise->resolve();
151
+ });
152
+ __result->cthis()->addOnRejectedListener([=](const jni::alias_ref<jni::JThrowable>& __throwable) {
153
+ jni::JniException __jniError(__throwable);
154
+ __promise->reject(std::make_exception_ptr(__jniError));
155
+ });
156
+ return __promise;
157
+ }();
158
+ }
127
159
  std::shared_ptr<Promise<std::string>> JHybridLiteRTLMSpec::sendMessageWithAudio(const std::string& message, const std::string& audioPath) {
128
160
  static const auto method = javaClassStatic()->getMethod<jni::local_ref<JPromise::javaobject>(jni::alias_ref<jni::JString> /* message */, jni::alias_ref<jni::JString> /* audioPath */)>("sendMessageWithAudio");
129
161
  auto __result = method(_javaPart, jni::make_jstring(message), jni::make_jstring(audioPath));
@@ -58,6 +58,8 @@ namespace margelo::nitro::litertlm {
58
58
  std::shared_ptr<Promise<void>> loadModel(const std::string& modelPath, const std::optional<LLMConfig>& config) override;
59
59
  std::shared_ptr<Promise<std::string>> sendMessage(const std::string& message) override;
60
60
  std::shared_ptr<Promise<std::string>> sendMessageWithImage(const std::string& message, const std::string& imagePath) override;
61
+ std::shared_ptr<Promise<std::string>> downloadModel(const std::string& url, const std::string& fileName, const std::optional<std::function<void(double /* progress */)>>& onProgress) override;
62
+ std::shared_ptr<Promise<void>> deleteModel(const std::string& fileName) override;
61
63
  std::shared_ptr<Promise<std::string>> sendMessageWithAudio(const std::string& message, const std::string& audioPath) override;
62
64
  void sendMessageAsync(const std::string& message, const std::function<void(const std::string& /* token */, bool /* done */)>& onToken) override;
63
65
  std::vector<Message> getHistory() override;
@@ -13,6 +13,7 @@
13
13
  #include "Backend.hpp"
14
14
  #include "JBackend.hpp"
15
15
  #include <optional>
16
+ #include <string>
16
17
 
17
18
  namespace margelo::nitro::litertlm {
18
19
 
@@ -33,6 +34,8 @@ namespace margelo::nitro::litertlm {
33
34
  [[nodiscard]]
34
35
  LLMConfig toCpp() const {
35
36
  static const auto clazz = javaClassStatic();
37
+ static const auto fieldSystemPrompt = clazz->getField<jni::JString>("systemPrompt");
38
+ jni::local_ref<jni::JString> systemPrompt = this->getFieldValue(fieldSystemPrompt);
36
39
  static const auto fieldBackend = clazz->getField<JBackend>("backend");
37
40
  jni::local_ref<JBackend> backend = this->getFieldValue(fieldBackend);
38
41
  static const auto fieldMaxTokens = clazz->getField<jni::JDouble>("maxTokens");
@@ -44,6 +47,7 @@ namespace margelo::nitro::litertlm {
44
47
  static const auto fieldTopP = clazz->getField<jni::JDouble>("topP");
45
48
  jni::local_ref<jni::JDouble> topP = this->getFieldValue(fieldTopP);
46
49
  return LLMConfig(
50
+ systemPrompt != nullptr ? std::make_optional(systemPrompt->toStdString()) : std::nullopt,
47
51
  backend != nullptr ? std::make_optional(backend->toCpp()) : std::nullopt,
48
52
  maxTokens != nullptr ? std::make_optional(maxTokens->value()) : std::nullopt,
49
53
  temperature != nullptr ? std::make_optional(temperature->value()) : std::nullopt,
@@ -58,11 +62,12 @@ namespace margelo::nitro::litertlm {
58
62
  */
59
63
  [[maybe_unused]]
60
64
  static jni::local_ref<JLLMConfig::javaobject> fromCpp(const LLMConfig& value) {
61
- using JSignature = JLLMConfig(jni::alias_ref<JBackend>, jni::alias_ref<jni::JDouble>, jni::alias_ref<jni::JDouble>, jni::alias_ref<jni::JDouble>, jni::alias_ref<jni::JDouble>);
65
+ using JSignature = JLLMConfig(jni::alias_ref<jni::JString>, jni::alias_ref<JBackend>, jni::alias_ref<jni::JDouble>, jni::alias_ref<jni::JDouble>, jni::alias_ref<jni::JDouble>, jni::alias_ref<jni::JDouble>);
62
66
  static const auto clazz = javaClassStatic();
63
67
  static const auto create = clazz->getStaticMethod<JSignature>("fromCpp");
64
68
  return create(
65
69
  clazz,
70
+ value.systemPrompt.has_value() ? jni::make_jstring(value.systemPrompt.value()) : nullptr,
66
71
  value.backend.has_value() ? JBackend::fromCpp(value.backend.value()) : nullptr,
67
72
  value.maxTokens.has_value() ? jni::JDouble::valueOf(value.maxTokens.value()) : nullptr,
68
73
  value.temperature.has_value() ? jni::JDouble::valueOf(value.temperature.value()) : nullptr,