react-native-litert-lm 0.2.2 → 0.3.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +269 -186
- package/android/build.gradle +1 -1
- package/android/src/main/java/com/margelo/nitro/dev/litert/litertlm/HybridLiteRTLM.kt +93 -37
- package/app.plugin.js +33 -0
- package/cpp/HybridLiteRTLM.cpp +604 -450
- package/cpp/HybridLiteRTLM.hpp +54 -23
- package/cpp/IOSDownloadHelper.h +24 -0
- package/cpp/cpp-adapter.cpp +2 -2
- package/cpp/include/litert_lm_engine.h +509 -0
- package/ios/IOSDownloadHelper.mm +129 -0
- package/ios/LiteRTLMAutolinking.mm +30 -0
- package/lib/hooks.d.ts +9 -4
- package/lib/hooks.js +34 -20
- package/lib/index.d.ts +1 -0
- package/lib/index.js +2 -5
- package/lib/memoryTracker.d.ts +1 -1
- package/lib/memoryTracker.js +1 -1
- package/lib/modelFactory.d.ts +11 -5
- package/lib/modelFactory.js +9 -4
- package/nitrogen/generated/android/LiteRTLMOnLoad.cpp +11 -4
- package/nitrogen/generated/android/c++/JHybridLiteRTLMSpec.cpp +31 -37
- package/nitrogen/generated/android/c++/JHybridLiteRTLMSpec.hpp +19 -22
- package/nitrogen/generated/android/kotlin/com/margelo/nitro/dev/litert/litertlm/HybridLiteRTLMSpec.kt +15 -18
- package/package.json +12 -5
- package/react-native-litert-lm.podspec +20 -7
- package/scripts/build-ios-engine.sh +302 -0
- package/scripts/download-ios-frameworks.sh +72 -0
- package/scripts/postinstall.js +116 -0
- package/scripts/stubs/cxx_bridge_stubs.cc +224 -0
- package/scripts/stubs/gemma_model_constraint_provider.cc +46 -0
- package/scripts/stubs/llguidance_stubs.c +101 -0
- package/src/hooks.ts +62 -39
- package/src/index.ts +4 -7
- package/src/memoryTracker.ts +1 -1
- package/src/modelFactory.ts +30 -5
|
@@ -0,0 +1,129 @@
|
|
|
1
|
+
#import <Foundation/Foundation.h>
|
|
2
|
+
#include "../cpp/IOSDownloadHelper.h"
|
|
3
|
+
#include <stdexcept>
|
|
4
|
+
|
|
5
|
+
namespace litert_lm {
|
|
6
|
+
|
|
7
|
+
std::string downloadModelFile(
|
|
8
|
+
const std::string& url,
|
|
9
|
+
const std::string& fileName,
|
|
10
|
+
const std::optional<std::function<void(double)>>& onProgress) {
|
|
11
|
+
@autoreleasepool {
|
|
12
|
+
NSString* nsUrl = [NSString stringWithUTF8String:url.c_str()];
|
|
13
|
+
NSString* nsFileName = [NSString stringWithUTF8String:fileName.c_str()];
|
|
14
|
+
|
|
15
|
+
// Use Caches directory — survives app relaunch but can be
|
|
16
|
+
// reclaimed by the system under storage pressure.
|
|
17
|
+
NSString* cachesDir = NSSearchPathForDirectoriesInDomains(
|
|
18
|
+
NSCachesDirectory, NSUserDomainMask, YES).firstObject;
|
|
19
|
+
NSString* modelsDir = [cachesDir stringByAppendingPathComponent:@"litert_models"];
|
|
20
|
+
|
|
21
|
+
// Create models subdirectory
|
|
22
|
+
NSFileManager* fm = [NSFileManager defaultManager];
|
|
23
|
+
if (![fm fileExistsAtPath:modelsDir]) {
|
|
24
|
+
[fm createDirectoryAtPath:modelsDir
|
|
25
|
+
withIntermediateDirectories:YES
|
|
26
|
+
attributes:nil
|
|
27
|
+
error:nil];
|
|
28
|
+
}
|
|
29
|
+
|
|
30
|
+
NSString* destPath = [modelsDir stringByAppendingPathComponent:nsFileName];
|
|
31
|
+
|
|
32
|
+
// If the file already exists and has content, skip download
|
|
33
|
+
if ([fm fileExistsAtPath:destPath]) {
|
|
34
|
+
NSDictionary* attrs = [fm attributesOfItemAtPath:destPath error:nil];
|
|
35
|
+
unsigned long long fileSize = [attrs fileSize];
|
|
36
|
+
if (fileSize > 0) {
|
|
37
|
+
NSLog(@"[LiteRTLM] Model already cached at %@ (%llu bytes), skipping download",
|
|
38
|
+
destPath, fileSize);
|
|
39
|
+
if (onProgress.has_value()) {
|
|
40
|
+
onProgress.value()(1.0);
|
|
41
|
+
}
|
|
42
|
+
return std::string([destPath UTF8String]);
|
|
43
|
+
}
|
|
44
|
+
}
|
|
45
|
+
|
|
46
|
+
NSLog(@"[LiteRTLM] Downloading model from %@ to %@", nsUrl, destPath);
|
|
47
|
+
|
|
48
|
+
NSURL* requestUrl = [NSURL URLWithString:nsUrl];
|
|
49
|
+
if (!requestUrl) {
|
|
50
|
+
throw std::runtime_error("Invalid download URL: " + url);
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
// Synchronous download using NSURLSession on this background thread.
|
|
54
|
+
__block NSError* downloadError = nil;
|
|
55
|
+
|
|
56
|
+
dispatch_semaphore_t semaphore = dispatch_semaphore_create(0);
|
|
57
|
+
|
|
58
|
+
NSURLSessionConfiguration* config = [NSURLSessionConfiguration defaultSessionConfiguration];
|
|
59
|
+
config.timeoutIntervalForRequest = 30;
|
|
60
|
+
config.timeoutIntervalForResource = 3600; // 1 hour for large models
|
|
61
|
+
|
|
62
|
+
NSURLSession* session = [NSURLSession sessionWithConfiguration:config];
|
|
63
|
+
NSMutableURLRequest* request = [NSMutableURLRequest requestWithURL:requestUrl];
|
|
64
|
+
[request setHTTPMethod:@"GET"];
|
|
65
|
+
|
|
66
|
+
// Use downloadTask for proper progress tracking and disk-efficient downloads
|
|
67
|
+
NSURLSessionDownloadTask* task = [session downloadTaskWithRequest:request
|
|
68
|
+
completionHandler:^(NSURL* location, NSURLResponse* response, NSError* error) {
|
|
69
|
+
if (error) {
|
|
70
|
+
downloadError = error;
|
|
71
|
+
dispatch_semaphore_signal(semaphore);
|
|
72
|
+
return;
|
|
73
|
+
}
|
|
74
|
+
|
|
75
|
+
// Check HTTP status
|
|
76
|
+
if ([response isKindOfClass:[NSHTTPURLResponse class]]) {
|
|
77
|
+
NSInteger statusCode = [(NSHTTPURLResponse*)response statusCode];
|
|
78
|
+
if (statusCode >= 400) {
|
|
79
|
+
downloadError = [NSError errorWithDomain:@"LiteRTLM"
|
|
80
|
+
code:statusCode
|
|
81
|
+
userInfo:@{NSLocalizedDescriptionKey:
|
|
82
|
+
[NSString stringWithFormat:@"HTTP %ld", (long)statusCode]}];
|
|
83
|
+
dispatch_semaphore_signal(semaphore);
|
|
84
|
+
return;
|
|
85
|
+
}
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
// Move downloaded file to destination
|
|
89
|
+
NSError* moveError = nil;
|
|
90
|
+
[fm removeItemAtPath:destPath error:nil]; // Remove any partial file
|
|
91
|
+
if (![fm moveItemAtURL:location
|
|
92
|
+
toURL:[NSURL fileURLWithPath:destPath]
|
|
93
|
+
error:&moveError]) {
|
|
94
|
+
downloadError = moveError;
|
|
95
|
+
}
|
|
96
|
+
|
|
97
|
+
dispatch_semaphore_signal(semaphore);
|
|
98
|
+
}];
|
|
99
|
+
|
|
100
|
+
[task resume];
|
|
101
|
+
|
|
102
|
+
// Poll for progress while waiting for completion
|
|
103
|
+
while (dispatch_semaphore_wait(semaphore,
|
|
104
|
+
dispatch_time(DISPATCH_TIME_NOW, 500 * NSEC_PER_MSEC)) != 0) {
|
|
105
|
+
if (onProgress.has_value() && task.countOfBytesExpectedToReceive > 0) {
|
|
106
|
+
double progress = (double)task.countOfBytesReceived /
|
|
107
|
+
(double)task.countOfBytesExpectedToReceive;
|
|
108
|
+
onProgress.value()(progress);
|
|
109
|
+
}
|
|
110
|
+
}
|
|
111
|
+
|
|
112
|
+
[session finishTasksAndInvalidate];
|
|
113
|
+
|
|
114
|
+
if (downloadError) {
|
|
115
|
+
throw std::runtime_error("Download failed: " +
|
|
116
|
+
std::string([[downloadError localizedDescription] UTF8String]));
|
|
117
|
+
}
|
|
118
|
+
|
|
119
|
+
// Final progress callback
|
|
120
|
+
if (onProgress.has_value()) {
|
|
121
|
+
onProgress.value()(1.0);
|
|
122
|
+
}
|
|
123
|
+
|
|
124
|
+
NSLog(@"[LiteRTLM] Model downloaded successfully to %@", destPath);
|
|
125
|
+
return std::string([destPath UTF8String]);
|
|
126
|
+
}
|
|
127
|
+
}
|
|
128
|
+
|
|
129
|
+
} // namespace litert_lm
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
///
|
|
2
|
+
/// LiteRTLMAutolinking.mm
|
|
3
|
+
/// Registers the C++ HybridLiteRTLM implementation with NitroModules on iOS.
|
|
4
|
+
///
|
|
5
|
+
/// On iOS, there's no JNI_OnLoad equivalent, so we use ObjC +load to register
|
|
6
|
+
/// the HybridObject factory before JS tries to create it.
|
|
7
|
+
///
|
|
8
|
+
|
|
9
|
+
#import <Foundation/Foundation.h>
|
|
10
|
+
#include <NitroModules/HybridObjectRegistry.hpp>
|
|
11
|
+
#include "HybridLiteRTLM.hpp"
|
|
12
|
+
|
|
13
|
+
@interface LiteRTLMAutolinking : NSObject
|
|
14
|
+
@end
|
|
15
|
+
|
|
16
|
+
@implementation LiteRTLMAutolinking
|
|
17
|
+
|
|
18
|
+
+ (void)load {
|
|
19
|
+
using namespace margelo::nitro;
|
|
20
|
+
using namespace margelo::nitro::litertlm;
|
|
21
|
+
|
|
22
|
+
HybridObjectRegistry::registerHybridObjectConstructor(
|
|
23
|
+
"LiteRTLM",
|
|
24
|
+
[]() -> std::shared_ptr<HybridObject> {
|
|
25
|
+
return std::make_shared<HybridLiteRTLM>();
|
|
26
|
+
}
|
|
27
|
+
);
|
|
28
|
+
}
|
|
29
|
+
|
|
30
|
+
@end
|
package/lib/hooks.d.ts
CHANGED
|
@@ -1,9 +1,10 @@
|
|
|
1
|
-
import {
|
|
1
|
+
import { LLMConfig } from "./index";
|
|
2
|
+
import type { LiteRTLMInstance } from "./modelFactory";
|
|
2
3
|
import type { MemoryTracker, MemoryTrackerSummary } from "./memoryTracker";
|
|
3
4
|
export interface UseModelConfig extends LLMConfig {
|
|
4
5
|
autoLoad?: boolean;
|
|
5
6
|
/**
|
|
6
|
-
* Enable memory tracking using native ArrayBuffers (v0.
|
|
7
|
+
* Enable memory tracking using native ArrayBuffers (v0.35+).
|
|
7
8
|
* When enabled, memory usage is tracked after each inference call
|
|
8
9
|
* using `NitroModules.createNativeArrayBuffer()` for zero-copy storage.
|
|
9
10
|
* @default false
|
|
@@ -17,14 +18,18 @@ export interface UseModelConfig extends LLMConfig {
|
|
|
17
18
|
maxMemorySnapshots?: number;
|
|
18
19
|
}
|
|
19
20
|
export interface UseModelResult {
|
|
20
|
-
model:
|
|
21
|
+
model: LiteRTLMInstance | null;
|
|
21
22
|
isReady: boolean;
|
|
22
23
|
isGenerating: boolean;
|
|
23
24
|
downloadProgress: number;
|
|
24
25
|
error: string | null;
|
|
25
26
|
generate: (prompt: string) => Promise<string>;
|
|
26
27
|
reset: () => void;
|
|
27
|
-
|
|
28
|
+
/**
|
|
29
|
+
* Delete the model file. If no fileName is provided, derives it from
|
|
30
|
+
* the URL/path passed to useModel.
|
|
31
|
+
*/
|
|
32
|
+
deleteModel: (fileName?: string) => Promise<void>;
|
|
28
33
|
load: () => Promise<void>;
|
|
29
34
|
/**
|
|
30
35
|
* Memory tracker instance (available when enableMemoryTracking is true).
|
package/lib/hooks.js
CHANGED
|
@@ -3,6 +3,12 @@ Object.defineProperty(exports, "__esModule", { value: true });
|
|
|
3
3
|
exports.useModel = useModel;
|
|
4
4
|
const react_1 = require("react");
|
|
5
5
|
const modelFactory_1 = require("./modelFactory");
|
|
6
|
+
/**
|
|
7
|
+
* Extract a filename from a URL or file path.
|
|
8
|
+
*/
|
|
9
|
+
function extractFileName(pathOrUrl) {
|
|
10
|
+
return pathOrUrl.split("/").pop() || "model.bin";
|
|
11
|
+
}
|
|
6
12
|
function useModel(pathOrUrl, config) {
|
|
7
13
|
const modelRef = (0, react_1.useRef)(null);
|
|
8
14
|
const [isReady, setIsReady] = (0, react_1.useState)(false);
|
|
@@ -10,10 +16,27 @@ function useModel(pathOrUrl, config) {
|
|
|
10
16
|
const [downloadProgress, setDownloadProgress] = (0, react_1.useState)(0);
|
|
11
17
|
const [error, setError] = (0, react_1.useState)(null);
|
|
12
18
|
const [memorySummary, setMemorySummary] = (0, react_1.useState)(null);
|
|
13
|
-
//
|
|
19
|
+
// Destructure config into primitive values for stable dependency arrays.
|
|
20
|
+
// This prevents infinite re-render loops when consumers pass inline config
|
|
21
|
+
// objects (e.g. useModel(url, { backend: 'cpu' })) without useMemo.
|
|
14
22
|
const autoLoad = config?.autoLoad ?? true;
|
|
15
23
|
const enableMemoryTracking = config?.enableMemoryTracking ?? false;
|
|
16
24
|
const maxMemorySnapshots = config?.maxMemorySnapshots ?? 256;
|
|
25
|
+
const backend = config?.backend;
|
|
26
|
+
const systemPrompt = config?.systemPrompt;
|
|
27
|
+
const maxTokens = config?.maxTokens;
|
|
28
|
+
const temperature = config?.temperature;
|
|
29
|
+
const topK = config?.topK;
|
|
30
|
+
const topP = config?.topP;
|
|
31
|
+
// Build a stable config object from the destructured primitives
|
|
32
|
+
const nativeConfig = (0, react_1.useMemo)(() => ({
|
|
33
|
+
...(backend !== undefined && { backend }),
|
|
34
|
+
...(systemPrompt !== undefined && { systemPrompt }),
|
|
35
|
+
...(maxTokens !== undefined && { maxTokens }),
|
|
36
|
+
...(temperature !== undefined && { temperature }),
|
|
37
|
+
...(topK !== undefined && { topK }),
|
|
38
|
+
...(topP !== undefined && { topP }),
|
|
39
|
+
}), [backend, systemPrompt, maxTokens, temperature, topK, topP]);
|
|
17
40
|
/**
|
|
18
41
|
* Refresh memory summary from the tracker's native buffer.
|
|
19
42
|
*/
|
|
@@ -28,10 +51,8 @@ function useModel(pathOrUrl, config) {
|
|
|
28
51
|
enableMemoryTracking,
|
|
29
52
|
maxMemorySnapshots,
|
|
30
53
|
});
|
|
31
|
-
let isMounted = true;
|
|
32
54
|
// Cleanup on unmount
|
|
33
55
|
return () => {
|
|
34
|
-
isMounted = false;
|
|
35
56
|
try {
|
|
36
57
|
modelRef.current?.close();
|
|
37
58
|
}
|
|
@@ -45,21 +66,13 @@ function useModel(pathOrUrl, config) {
|
|
|
45
66
|
setError(null);
|
|
46
67
|
setDownloadProgress(0);
|
|
47
68
|
try {
|
|
48
|
-
let modelPath = pathOrUrl;
|
|
49
|
-
// Handle URL download manually to capture progress
|
|
50
|
-
if (pathOrUrl.startsWith("http://") || pathOrUrl.startsWith("https://")) {
|
|
51
|
-
const fileName = pathOrUrl.split("/").pop() || "model.bin";
|
|
52
|
-
if (modelRef.current) {
|
|
53
|
-
modelPath = await modelRef.current.downloadModel(pathOrUrl, fileName, (progress) => {
|
|
54
|
-
setDownloadProgress(progress);
|
|
55
|
-
});
|
|
56
|
-
}
|
|
57
|
-
}
|
|
58
69
|
if (modelRef.current) {
|
|
59
|
-
//
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
await modelRef.current.loadModel(
|
|
70
|
+
// Delegate URL handling + download to the factory's loadModel,
|
|
71
|
+
// passing our progress setter as the callback (eliminates
|
|
72
|
+
// duplicate download logic that was previously in this hook).
|
|
73
|
+
await modelRef.current.loadModel(pathOrUrl, nativeConfig, (progress) => {
|
|
74
|
+
setDownloadProgress(progress);
|
|
75
|
+
});
|
|
63
76
|
setIsReady(true);
|
|
64
77
|
}
|
|
65
78
|
}
|
|
@@ -67,7 +80,7 @@ function useModel(pathOrUrl, config) {
|
|
|
67
80
|
setError(e.message || "Failed to load model");
|
|
68
81
|
console.error(e);
|
|
69
82
|
}
|
|
70
|
-
}, [pathOrUrl,
|
|
83
|
+
}, [pathOrUrl, nativeConfig]);
|
|
71
84
|
(0, react_1.useEffect)(() => {
|
|
72
85
|
if (autoLoad) {
|
|
73
86
|
load();
|
|
@@ -110,11 +123,12 @@ function useModel(pathOrUrl, config) {
|
|
|
110
123
|
}, []);
|
|
111
124
|
const deleteModel = (0, react_1.useCallback)(async (fileName) => {
|
|
112
125
|
if (modelRef.current) {
|
|
113
|
-
|
|
126
|
+
const resolvedName = fileName ?? extractFileName(pathOrUrl);
|
|
127
|
+
await modelRef.current.deleteModel(resolvedName);
|
|
114
128
|
setIsReady(false);
|
|
115
129
|
setDownloadProgress(0);
|
|
116
130
|
}
|
|
117
|
-
}, []);
|
|
131
|
+
}, [pathOrUrl]);
|
|
118
132
|
return {
|
|
119
133
|
model: modelRef.current,
|
|
120
134
|
isReady,
|
package/lib/index.d.ts
CHANGED
|
@@ -4,6 +4,7 @@ export type { ChatMessage } from "./templates";
|
|
|
4
4
|
export { applyGemmaTemplate, applyPhiTemplate, applyLlamaTemplate, } from "./templates";
|
|
5
5
|
export type { MemorySnapshot, MemoryTracker, MemoryTrackerSummary, } from "./memoryTracker";
|
|
6
6
|
export { createMemoryTracker, createNativeBuffer } from "./memoryTracker";
|
|
7
|
+
export type { LiteRTLMInstance } from "./modelFactory";
|
|
7
8
|
export * from "./hooks";
|
|
8
9
|
/**
|
|
9
10
|
* Creates a new LiteRT-LM inference engine instance.
|
package/lib/index.js
CHANGED
|
@@ -116,12 +116,9 @@ function checkBackendSupport(backend) {
|
|
|
116
116
|
return "NPU backend requires compatible hardware (Qualcomm Hexagon, MediaTek APU, etc.). Will fall back to GPU if unavailable.";
|
|
117
117
|
}
|
|
118
118
|
if (react_native_1.Platform.OS === "ios") {
|
|
119
|
-
return "NPU
|
|
119
|
+
return "NPU (Neural Engine) is not yet supported on iOS. Use 'gpu' (Metal) or 'cpu' instead.";
|
|
120
120
|
}
|
|
121
121
|
}
|
|
122
|
-
if (react_native_1.Platform.OS === "ios" && backend !== "cpu") {
|
|
123
|
-
return "LiteRT-LM iOS is not yet released. Only CPU backend may work via fallback.";
|
|
124
|
-
}
|
|
125
122
|
return undefined;
|
|
126
123
|
}
|
|
127
124
|
/**
|
|
@@ -143,7 +140,7 @@ function checkBackendSupport(backend) {
|
|
|
143
140
|
*/
|
|
144
141
|
function checkMultimodalSupport() {
|
|
145
142
|
if (react_native_1.Platform.OS === "ios") {
|
|
146
|
-
return "Multimodal (image/audio) is
|
|
143
|
+
return "Multimodal (image/audio) is experimental on iOS. Vision and audio executors may not be available in the current build.";
|
|
147
144
|
}
|
|
148
145
|
return undefined;
|
|
149
146
|
}
|
package/lib/memoryTracker.d.ts
CHANGED
|
@@ -3,7 +3,7 @@
|
|
|
3
3
|
*
|
|
4
4
|
* Records real memory usage from OS-level APIs via `getMemoryUsage()`,
|
|
5
5
|
* and stores snapshots in a native-backed ArrayBuffer allocated via
|
|
6
|
-
* `NitroModules.createNativeArrayBuffer()` (v0.
|
|
6
|
+
* `NitroModules.createNativeArrayBuffer()` (v0.35+) for zero-copy interop.
|
|
7
7
|
*
|
|
8
8
|
* @example
|
|
9
9
|
* ```typescript
|
package/lib/memoryTracker.js
CHANGED
|
@@ -4,7 +4,7 @@
|
|
|
4
4
|
*
|
|
5
5
|
* Records real memory usage from OS-level APIs via `getMemoryUsage()`,
|
|
6
6
|
* and stores snapshots in a native-backed ArrayBuffer allocated via
|
|
7
|
-
* `NitroModules.createNativeArrayBuffer()` (v0.
|
|
7
|
+
* `NitroModules.createNativeArrayBuffer()` (v0.35+) for zero-copy interop.
|
|
8
8
|
*
|
|
9
9
|
* @example
|
|
10
10
|
* ```typescript
|
package/lib/modelFactory.d.ts
CHANGED
|
@@ -1,10 +1,18 @@
|
|
|
1
|
-
import { LiteRTLM } from "./specs/LiteRTLM.nitro";
|
|
1
|
+
import { LiteRTLM, LLMConfig } from "./specs/LiteRTLM.nitro";
|
|
2
2
|
import { MemoryTracker } from "./memoryTracker";
|
|
3
|
+
/**
|
|
4
|
+
* Extended LiteRT-LM instance with optional memory tracking and
|
|
5
|
+
* augmented loadModel that accepts a download progress callback.
|
|
6
|
+
*/
|
|
7
|
+
export type LiteRTLMInstance = Omit<LiteRTLM, "loadModel"> & {
|
|
8
|
+
memoryTracker?: MemoryTracker;
|
|
9
|
+
loadModel: (pathOrUrl: string, config?: LLMConfig, onDownloadProgress?: (progress: number) => void) => Promise<void>;
|
|
10
|
+
};
|
|
3
11
|
/**
|
|
4
12
|
* Creates a new LiteRT-LM inference engine instance.
|
|
5
13
|
*
|
|
6
14
|
* Optionally creates a native-backed memory tracker using
|
|
7
|
-
* `NitroModules.createNativeArrayBuffer()` (v0.
|
|
15
|
+
* `NitroModules.createNativeArrayBuffer()` (v0.35+) for efficient
|
|
8
16
|
* zero-copy memory usage tracking.
|
|
9
17
|
*
|
|
10
18
|
* @param options.enableMemoryTracking Enable automatic memory tracking (default: false)
|
|
@@ -13,6 +21,4 @@ import { MemoryTracker } from "./memoryTracker";
|
|
|
13
21
|
export declare function createLLM(options?: {
|
|
14
22
|
enableMemoryTracking?: boolean;
|
|
15
23
|
maxMemorySnapshots?: number;
|
|
16
|
-
}):
|
|
17
|
-
memoryTracker?: MemoryTracker;
|
|
18
|
-
};
|
|
24
|
+
}): LiteRTLMInstance;
|
package/lib/modelFactory.js
CHANGED
|
@@ -7,7 +7,7 @@ const memoryTracker_1 = require("./memoryTracker");
|
|
|
7
7
|
* Creates a new LiteRT-LM inference engine instance.
|
|
8
8
|
*
|
|
9
9
|
* Optionally creates a native-backed memory tracker using
|
|
10
|
-
* `NitroModules.createNativeArrayBuffer()` (v0.
|
|
10
|
+
* `NitroModules.createNativeArrayBuffer()` (v0.35+) for efficient
|
|
11
11
|
* zero-copy memory usage tracking.
|
|
12
12
|
*
|
|
13
13
|
* @param options.enableMemoryTracking Enable automatic memory tracking (default: false)
|
|
@@ -41,10 +41,15 @@ function createLLM(options) {
|
|
|
41
41
|
return {
|
|
42
42
|
...native,
|
|
43
43
|
memoryTracker: tracker,
|
|
44
|
-
loadModel: async (pathOrUrl, config) => {
|
|
44
|
+
loadModel: async (pathOrUrl, config, onDownloadProgress) => {
|
|
45
45
|
let modelPath = pathOrUrl;
|
|
46
|
-
// Check if it's a URL
|
|
46
|
+
// Check if it's a URL — enforce HTTPS for model downloads
|
|
47
47
|
if (pathOrUrl.startsWith("http://") || pathOrUrl.startsWith("https://")) {
|
|
48
|
+
if (pathOrUrl.startsWith("http://")) {
|
|
49
|
+
throw new Error("Insecure HTTP URLs are not allowed for model downloads. " +
|
|
50
|
+
"Use HTTPS instead: " +
|
|
51
|
+
pathOrUrl.replace("http://", "https://"));
|
|
52
|
+
}
|
|
48
53
|
// Extract filename from URL
|
|
49
54
|
const fileName = pathOrUrl.split("/").pop();
|
|
50
55
|
if (!fileName) {
|
|
@@ -52,7 +57,7 @@ function createLLM(options) {
|
|
|
52
57
|
}
|
|
53
58
|
console.log(`Checking model at ${pathOrUrl}...`);
|
|
54
59
|
modelPath = await native.downloadModel(pathOrUrl, fileName, (progress) => {
|
|
55
|
-
|
|
60
|
+
onDownloadProgress?.(progress);
|
|
56
61
|
});
|
|
57
62
|
console.log(`Model downloaded to: ${modelPath}`);
|
|
58
63
|
}
|
|
@@ -28,12 +28,21 @@ int initialize(JavaVM* vm) {
|
|
|
28
28
|
});
|
|
29
29
|
}
|
|
30
30
|
|
|
31
|
+
struct JHybridLiteRTLMSpecImpl: public jni::JavaClass<JHybridLiteRTLMSpecImpl, JHybridLiteRTLMSpec::JavaPart> {
|
|
32
|
+
static auto constexpr kJavaDescriptor = "Lcom/margelo/nitro/dev/litert/litertlm/HybridLiteRTLM;";
|
|
33
|
+
static std::shared_ptr<JHybridLiteRTLMSpec> create() {
|
|
34
|
+
static auto constructorFn = javaClassStatic()->getConstructor<JHybridLiteRTLMSpecImpl::javaobject()>();
|
|
35
|
+
jni::local_ref<JHybridLiteRTLMSpec::JavaPart> javaPart = javaClassStatic()->newObject(constructorFn);
|
|
36
|
+
return javaPart->getJHybridLiteRTLMSpec();
|
|
37
|
+
}
|
|
38
|
+
};
|
|
39
|
+
|
|
31
40
|
void registerAllNatives() {
|
|
32
41
|
using namespace margelo::nitro;
|
|
33
42
|
using namespace margelo::nitro::litertlm;
|
|
34
43
|
|
|
35
44
|
// Register native JNI methods
|
|
36
|
-
margelo::nitro::litertlm::JHybridLiteRTLMSpec::registerNatives();
|
|
45
|
+
margelo::nitro::litertlm::JHybridLiteRTLMSpec::CxxPart::registerNatives();
|
|
37
46
|
margelo::nitro::litertlm::JFunc_void_double_cxx::registerNatives();
|
|
38
47
|
margelo::nitro::litertlm::JFunc_void_std__string_bool_cxx::registerNatives();
|
|
39
48
|
|
|
@@ -41,9 +50,7 @@ void registerAllNatives() {
|
|
|
41
50
|
HybridObjectRegistry::registerHybridObjectConstructor(
|
|
42
51
|
"LiteRTLM",
|
|
43
52
|
[]() -> std::shared_ptr<HybridObject> {
|
|
44
|
-
|
|
45
|
-
auto instance = object.create();
|
|
46
|
-
return instance->cthis()->shared();
|
|
53
|
+
return JHybridLiteRTLMSpecImpl::create();
|
|
47
54
|
}
|
|
48
55
|
);
|
|
49
56
|
}
|
|
@@ -45,37 +45,31 @@ namespace margelo::nitro::litertlm { enum class Backend; }
|
|
|
45
45
|
|
|
46
46
|
namespace margelo::nitro::litertlm {
|
|
47
47
|
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
});
|
|
48
|
+
std::shared_ptr<JHybridLiteRTLMSpec> JHybridLiteRTLMSpec::JavaPart::getJHybridLiteRTLMSpec() {
|
|
49
|
+
auto hybridObject = JHybridObject::JavaPart::getJHybridObject();
|
|
50
|
+
auto castHybridObject = std::dynamic_pointer_cast<JHybridLiteRTLMSpec>(hybridObject);
|
|
51
|
+
if (castHybridObject == nullptr) [[unlikely]] {
|
|
52
|
+
throw std::runtime_error("Failed to downcast JHybridObject to JHybridLiteRTLMSpec!");
|
|
53
|
+
}
|
|
54
|
+
return castHybridObject;
|
|
56
55
|
}
|
|
57
56
|
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
return method(_javaPart);
|
|
57
|
+
jni::local_ref<JHybridLiteRTLMSpec::CxxPart::jhybriddata> JHybridLiteRTLMSpec::CxxPart::initHybrid(jni::alias_ref<jhybridobject> jThis) {
|
|
58
|
+
return makeCxxInstance(jThis);
|
|
61
59
|
}
|
|
62
60
|
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
61
|
+
std::shared_ptr<JHybridObject> JHybridLiteRTLMSpec::CxxPart::createHybridObject(const jni::local_ref<JHybridObject::JavaPart>& javaPart) {
|
|
62
|
+
auto castJavaPart = jni::dynamic_ref_cast<JHybridLiteRTLMSpec::JavaPart>(javaPart);
|
|
63
|
+
if (castJavaPart == nullptr) [[unlikely]] {
|
|
64
|
+
throw std::runtime_error("Failed to cast JHybridObject::JavaPart to JHybridLiteRTLMSpec::JavaPart!");
|
|
66
65
|
}
|
|
67
|
-
return
|
|
68
|
-
}
|
|
69
|
-
|
|
70
|
-
void JHybridLiteRTLMSpec::dispose() noexcept {
|
|
71
|
-
static const auto method = javaClassStatic()->getMethod<void()>("dispose");
|
|
72
|
-
method(_javaPart);
|
|
66
|
+
return std::make_shared<JHybridLiteRTLMSpec>(castJavaPart);
|
|
73
67
|
}
|
|
74
68
|
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
69
|
+
void JHybridLiteRTLMSpec::CxxPart::registerNatives() {
|
|
70
|
+
registerHybrid({
|
|
71
|
+
makeNativeMethod("initHybrid", JHybridLiteRTLMSpec::CxxPart::initHybrid),
|
|
72
|
+
});
|
|
79
73
|
}
|
|
80
74
|
|
|
81
75
|
// Properties
|
|
@@ -83,7 +77,7 @@ namespace margelo::nitro::litertlm {
|
|
|
83
77
|
|
|
84
78
|
// Methods
|
|
85
79
|
std::shared_ptr<Promise<void>> JHybridLiteRTLMSpec::loadModel(const std::string& modelPath, const std::optional<LLMConfig>& config) {
|
|
86
|
-
static const auto method = javaClassStatic()->getMethod<jni::local_ref<JPromise::javaobject>(jni::alias_ref<jni::JString> /* modelPath */, jni::alias_ref<JLLMConfig> /* config */)>("loadModel");
|
|
80
|
+
static const auto method = _javaPart->javaClassStatic()->getMethod<jni::local_ref<JPromise::javaobject>(jni::alias_ref<jni::JString> /* modelPath */, jni::alias_ref<JLLMConfig> /* config */)>("loadModel");
|
|
87
81
|
auto __result = method(_javaPart, jni::make_jstring(modelPath), config.has_value() ? JLLMConfig::fromCpp(config.value()) : nullptr);
|
|
88
82
|
return [&]() {
|
|
89
83
|
auto __promise = Promise<void>::create();
|
|
@@ -98,7 +92,7 @@ namespace margelo::nitro::litertlm {
|
|
|
98
92
|
}();
|
|
99
93
|
}
|
|
100
94
|
std::shared_ptr<Promise<std::string>> JHybridLiteRTLMSpec::sendMessage(const std::string& message) {
|
|
101
|
-
static const auto method = javaClassStatic()->getMethod<jni::local_ref<JPromise::javaobject>(jni::alias_ref<jni::JString> /* message */)>("sendMessage");
|
|
95
|
+
static const auto method = _javaPart->javaClassStatic()->getMethod<jni::local_ref<JPromise::javaobject>(jni::alias_ref<jni::JString> /* message */)>("sendMessage");
|
|
102
96
|
auto __result = method(_javaPart, jni::make_jstring(message));
|
|
103
97
|
return [&]() {
|
|
104
98
|
auto __promise = Promise<std::string>::create();
|
|
@@ -114,7 +108,7 @@ namespace margelo::nitro::litertlm {
|
|
|
114
108
|
}();
|
|
115
109
|
}
|
|
116
110
|
std::shared_ptr<Promise<std::string>> JHybridLiteRTLMSpec::sendMessageWithImage(const std::string& message, const std::string& imagePath) {
|
|
117
|
-
static const auto method = javaClassStatic()->getMethod<jni::local_ref<JPromise::javaobject>(jni::alias_ref<jni::JString> /* message */, jni::alias_ref<jni::JString> /* imagePath */)>("sendMessageWithImage");
|
|
111
|
+
static const auto method = _javaPart->javaClassStatic()->getMethod<jni::local_ref<JPromise::javaobject>(jni::alias_ref<jni::JString> /* message */, jni::alias_ref<jni::JString> /* imagePath */)>("sendMessageWithImage");
|
|
118
112
|
auto __result = method(_javaPart, jni::make_jstring(message), jni::make_jstring(imagePath));
|
|
119
113
|
return [&]() {
|
|
120
114
|
auto __promise = Promise<std::string>::create();
|
|
@@ -130,7 +124,7 @@ namespace margelo::nitro::litertlm {
|
|
|
130
124
|
}();
|
|
131
125
|
}
|
|
132
126
|
std::shared_ptr<Promise<std::string>> JHybridLiteRTLMSpec::downloadModel(const std::string& url, const std::string& fileName, const std::optional<std::function<void(double /* progress */)>>& onProgress) {
|
|
133
|
-
static const auto method = javaClassStatic()->getMethod<jni::local_ref<JPromise::javaobject>(jni::alias_ref<jni::JString> /* url */, jni::alias_ref<jni::JString> /* fileName */, jni::alias_ref<JFunc_void_double::javaobject> /* onProgress */)>("downloadModel_cxx");
|
|
127
|
+
static const auto method = _javaPart->javaClassStatic()->getMethod<jni::local_ref<JPromise::javaobject>(jni::alias_ref<jni::JString> /* url */, jni::alias_ref<jni::JString> /* fileName */, jni::alias_ref<JFunc_void_double::javaobject> /* onProgress */)>("downloadModel_cxx");
|
|
134
128
|
auto __result = method(_javaPart, jni::make_jstring(url), jni::make_jstring(fileName), onProgress.has_value() ? JFunc_void_double_cxx::fromCpp(onProgress.value()) : nullptr);
|
|
135
129
|
return [&]() {
|
|
136
130
|
auto __promise = Promise<std::string>::create();
|
|
@@ -146,7 +140,7 @@ namespace margelo::nitro::litertlm {
|
|
|
146
140
|
}();
|
|
147
141
|
}
|
|
148
142
|
std::shared_ptr<Promise<void>> JHybridLiteRTLMSpec::deleteModel(const std::string& fileName) {
|
|
149
|
-
static const auto method = javaClassStatic()->getMethod<jni::local_ref<JPromise::javaobject>(jni::alias_ref<jni::JString> /* fileName */)>("deleteModel");
|
|
143
|
+
static const auto method = _javaPart->javaClassStatic()->getMethod<jni::local_ref<JPromise::javaobject>(jni::alias_ref<jni::JString> /* fileName */)>("deleteModel");
|
|
150
144
|
auto __result = method(_javaPart, jni::make_jstring(fileName));
|
|
151
145
|
return [&]() {
|
|
152
146
|
auto __promise = Promise<void>::create();
|
|
@@ -161,7 +155,7 @@ namespace margelo::nitro::litertlm {
|
|
|
161
155
|
}();
|
|
162
156
|
}
|
|
163
157
|
std::shared_ptr<Promise<std::string>> JHybridLiteRTLMSpec::sendMessageWithAudio(const std::string& message, const std::string& audioPath) {
|
|
164
|
-
static const auto method = javaClassStatic()->getMethod<jni::local_ref<JPromise::javaobject>(jni::alias_ref<jni::JString> /* message */, jni::alias_ref<jni::JString> /* audioPath */)>("sendMessageWithAudio");
|
|
158
|
+
static const auto method = _javaPart->javaClassStatic()->getMethod<jni::local_ref<JPromise::javaobject>(jni::alias_ref<jni::JString> /* message */, jni::alias_ref<jni::JString> /* audioPath */)>("sendMessageWithAudio");
|
|
165
159
|
auto __result = method(_javaPart, jni::make_jstring(message), jni::make_jstring(audioPath));
|
|
166
160
|
return [&]() {
|
|
167
161
|
auto __promise = Promise<std::string>::create();
|
|
@@ -177,11 +171,11 @@ namespace margelo::nitro::litertlm {
|
|
|
177
171
|
}();
|
|
178
172
|
}
|
|
179
173
|
void JHybridLiteRTLMSpec::sendMessageAsync(const std::string& message, const std::function<void(const std::string& /* token */, bool /* done */)>& onToken) {
|
|
180
|
-
static const auto method = javaClassStatic()->getMethod<void(jni::alias_ref<jni::JString> /* message */, jni::alias_ref<JFunc_void_std__string_bool::javaobject> /* onToken */)>("sendMessageAsync_cxx");
|
|
174
|
+
static const auto method = _javaPart->javaClassStatic()->getMethod<void(jni::alias_ref<jni::JString> /* message */, jni::alias_ref<JFunc_void_std__string_bool::javaobject> /* onToken */)>("sendMessageAsync_cxx");
|
|
181
175
|
method(_javaPart, jni::make_jstring(message), JFunc_void_std__string_bool_cxx::fromCpp(onToken));
|
|
182
176
|
}
|
|
183
177
|
std::vector<Message> JHybridLiteRTLMSpec::getHistory() {
|
|
184
|
-
static const auto method = javaClassStatic()->getMethod<jni::local_ref<jni::JArrayClass<JMessage>>()>("getHistory");
|
|
178
|
+
static const auto method = _javaPart->javaClassStatic()->getMethod<jni::local_ref<jni::JArrayClass<JMessage>>()>("getHistory");
|
|
185
179
|
auto __result = method(_javaPart);
|
|
186
180
|
return [&]() {
|
|
187
181
|
size_t __size = __result->size();
|
|
@@ -195,26 +189,26 @@ namespace margelo::nitro::litertlm {
|
|
|
195
189
|
}();
|
|
196
190
|
}
|
|
197
191
|
void JHybridLiteRTLMSpec::resetConversation() {
|
|
198
|
-
static const auto method = javaClassStatic()->getMethod<void()>("resetConversation");
|
|
192
|
+
static const auto method = _javaPart->javaClassStatic()->getMethod<void()>("resetConversation");
|
|
199
193
|
method(_javaPart);
|
|
200
194
|
}
|
|
201
195
|
bool JHybridLiteRTLMSpec::isReady() {
|
|
202
|
-
static const auto method = javaClassStatic()->getMethod<jboolean()>("isReady");
|
|
196
|
+
static const auto method = _javaPart->javaClassStatic()->getMethod<jboolean()>("isReady");
|
|
203
197
|
auto __result = method(_javaPart);
|
|
204
198
|
return static_cast<bool>(__result);
|
|
205
199
|
}
|
|
206
200
|
GenerationStats JHybridLiteRTLMSpec::getStats() {
|
|
207
|
-
static const auto method = javaClassStatic()->getMethod<jni::local_ref<JGenerationStats>()>("getStats");
|
|
201
|
+
static const auto method = _javaPart->javaClassStatic()->getMethod<jni::local_ref<JGenerationStats>()>("getStats");
|
|
208
202
|
auto __result = method(_javaPart);
|
|
209
203
|
return __result->toCpp();
|
|
210
204
|
}
|
|
211
205
|
MemoryUsage JHybridLiteRTLMSpec::getMemoryUsage() {
|
|
212
|
-
static const auto method = javaClassStatic()->getMethod<jni::local_ref<JMemoryUsage>()>("getMemoryUsage");
|
|
206
|
+
static const auto method = _javaPart->javaClassStatic()->getMethod<jni::local_ref<JMemoryUsage>()>("getMemoryUsage");
|
|
213
207
|
auto __result = method(_javaPart);
|
|
214
208
|
return __result->toCpp();
|
|
215
209
|
}
|
|
216
210
|
void JHybridLiteRTLMSpec::close() {
|
|
217
|
-
static const auto method = javaClassStatic()->getMethod<void()>("close");
|
|
211
|
+
static const auto method = _javaPart->javaClassStatic()->getMethod<void()>("close");
|
|
218
212
|
method(_javaPart);
|
|
219
213
|
}
|
|
220
214
|
|