react-native-litert-lm 0.3.6 → 0.4.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +207 -158
- package/android/build.gradle +12 -0
- package/android/src/main/AndroidManifest.xml +5 -0
- package/android/src/main/java/com/margelo/nitro/dev/litert/litertlm/HybridLiteRTLM.kt +316 -63
- package/android/src/main/java/dev/litert/litertlm/LiteRTLMPackage.kt +19 -2
- package/android/src/test/java/com/margelo/nitro/core/Promise.kt +46 -0
- package/android/src/test/java/com/margelo/nitro/dev/litert/litertlm/HybridLiteRTLMTest.kt +83 -0
- package/cpp/include/README.md +9 -11
- package/ios/HybridLiteRTLM.swift +1058 -0
- package/ios/Tests/HybridLiteRTLMTests.swift +67 -0
- package/lib/__mocks__/react-native-nitro-modules.d.ts +61 -0
- package/lib/__mocks__/react-native-nitro-modules.js +50 -0
- package/lib/__tests__/hooks.test.d.ts +1 -0
- package/lib/__tests__/hooks.test.js +124 -0
- package/lib/__tests__/memoryTracker.test.d.ts +1 -0
- package/lib/__tests__/memoryTracker.test.js +74 -0
- package/lib/__tests__/modelFactory.test.d.ts +1 -0
- package/lib/__tests__/modelFactory.test.js +52 -0
- package/lib/hooks.js +1 -1
- package/lib/index.d.ts +2 -4
- package/lib/index.js +12 -7
- package/lib/modelFactory.js +62 -63
- package/lib/specs/LiteRTLM.nitro.d.ts +71 -2
- package/nitrogen/generated/android/c++/JHybridLiteRTLMSpec.cpp +62 -7
- package/nitrogen/generated/android/c++/JHybridLiteRTLMSpec.hpp +3 -1
- package/nitrogen/generated/android/c++/JLLMConfig.hpp +40 -3
- package/nitrogen/generated/android/c++/JMultimodalPart.hpp +74 -0
- package/nitrogen/generated/android/c++/JPartType.hpp +61 -0
- package/nitrogen/generated/android/c++/JToolDefinition.hpp +65 -0
- package/nitrogen/generated/android/kotlin/com/margelo/nitro/dev/litert/litertlm/GenerationStats.kt +23 -0
- package/nitrogen/generated/android/kotlin/com/margelo/nitro/dev/litert/litertlm/HybridLiteRTLMSpec.kt +10 -2
- package/nitrogen/generated/android/kotlin/com/margelo/nitro/dev/litert/litertlm/LLMConfig.kt +46 -3
- package/nitrogen/generated/android/kotlin/com/margelo/nitro/dev/litert/litertlm/MemoryUsage.kt +19 -0
- package/nitrogen/generated/android/kotlin/com/margelo/nitro/dev/litert/litertlm/Message.kt +15 -0
- package/nitrogen/generated/android/kotlin/com/margelo/nitro/dev/litert/litertlm/MultimodalPart.kt +66 -0
- package/nitrogen/generated/android/kotlin/com/margelo/nitro/dev/litert/litertlm/PartType.kt +24 -0
- package/nitrogen/generated/android/kotlin/com/margelo/nitro/dev/litert/litertlm/ToolDefinition.kt +61 -0
- package/nitrogen/generated/ios/LiteRTLM-Swift-Cxx-Bridge.cpp +57 -1
- package/nitrogen/generated/ios/LiteRTLM-Swift-Cxx-Bridge.hpp +414 -3
- package/nitrogen/generated/ios/LiteRTLM-Swift-Cxx-Umbrella.hpp +41 -3
- package/nitrogen/generated/ios/LiteRTLMAutolinking.mm +4 -6
- package/nitrogen/generated/ios/LiteRTLMAutolinking.swift +10 -0
- package/nitrogen/generated/ios/c++/HybridLiteRTLMSpecSwift.cpp +11 -0
- package/nitrogen/generated/ios/c++/HybridLiteRTLMSpecSwift.hpp +224 -0
- package/nitrogen/generated/ios/swift/Backend.swift +44 -0
- package/nitrogen/generated/ios/swift/Func_void.swift +46 -0
- package/nitrogen/generated/ios/swift/Func_void_double.swift +46 -0
- package/nitrogen/generated/ios/swift/Func_void_std__exception_ptr.swift +46 -0
- package/nitrogen/generated/ios/swift/Func_void_std__string.swift +46 -0
- package/nitrogen/generated/ios/swift/Func_void_std__string_bool.swift +46 -0
- package/nitrogen/generated/ios/swift/GenerationStats.swift +54 -0
- package/nitrogen/generated/ios/swift/HybridLiteRTLMSpec.swift +69 -0
- package/nitrogen/generated/ios/swift/HybridLiteRTLMSpec_cxx.swift +383 -0
- package/nitrogen/generated/ios/swift/LLMConfig.swift +203 -0
- package/nitrogen/generated/ios/swift/MemoryUsage.swift +44 -0
- package/nitrogen/generated/ios/swift/Message.swift +34 -0
- package/nitrogen/generated/ios/swift/MultimodalPart.swift +83 -0
- package/nitrogen/generated/ios/swift/PartType.swift +44 -0
- package/nitrogen/generated/ios/swift/Role.swift +44 -0
- package/nitrogen/generated/ios/swift/ToolDefinition.swift +39 -0
- package/nitrogen/generated/shared/c++/HybridLiteRTLMSpec.cpp +2 -0
- package/nitrogen/generated/shared/c++/HybridLiteRTLMSpec.hpp +7 -2
- package/nitrogen/generated/shared/c++/LLMConfig.hpp +22 -2
- package/nitrogen/generated/shared/c++/MultimodalPart.hpp +99 -0
- package/nitrogen/generated/shared/c++/PartType.hpp +80 -0
- package/nitrogen/generated/shared/c++/ToolDefinition.hpp +91 -0
- package/package.json +16 -8
- package/react-native-litert-lm.podspec +15 -19
- package/scripts/download-ios-frameworks.sh +14 -48
- package/scripts/postinstall.js +1 -2
- package/src/__mocks__/react-native-nitro-modules.ts +48 -0
- package/src/__tests__/hooks.test.ts +153 -0
- package/src/__tests__/memoryTracker.test.ts +87 -0
- package/src/__tests__/modelFactory.test.ts +68 -0
- package/src/hooks.ts +1 -1
- package/src/index.ts +12 -9
- package/src/modelFactory.ts +82 -80
- package/src/specs/LiteRTLM.nitro.ts +80 -2
- package/cpp/HybridLiteRTLM.cpp +0 -838
- package/cpp/HybridLiteRTLM.hpp +0 -167
- package/cpp/IOSDownloadHelper.h +0 -24
- package/ios/IOSDownloadHelper.mm +0 -129
- package/scripts/build-ios-engine.sh +0 -302
- package/scripts/stubs/cxx_bridge_stubs.cc +0 -224
- package/scripts/stubs/gemma_model_constraint_provider.cc +0 -46
- package/scripts/stubs/llguidance_stubs.c +0 -101
- package/src/templates.ts +0 -105
|
@@ -17,6 +17,7 @@ import com.google.ai.edge.litertlm.Engine
|
|
|
17
17
|
import com.google.ai.edge.litertlm.Conversation
|
|
18
18
|
import com.google.ai.edge.litertlm.EngineConfig
|
|
19
19
|
import com.google.ai.edge.litertlm.ConversationConfig
|
|
20
|
+
import com.google.ai.edge.litertlm.SamplerConfig
|
|
20
21
|
import com.margelo.nitro.dev.litert.litertlm.Backend
|
|
21
22
|
import com.margelo.nitro.dev.litert.litertlm.GenerationStats
|
|
22
23
|
import com.margelo.nitro.dev.litert.litertlm.HybridLiteRTLMSpec
|
|
@@ -25,6 +26,15 @@ import com.margelo.nitro.dev.litert.litertlm.Message
|
|
|
25
26
|
import com.margelo.nitro.dev.litert.litertlm.Role
|
|
26
27
|
import com.margelo.nitro.core.Promise
|
|
27
28
|
import com.google.ai.edge.litertlm.Content
|
|
29
|
+
import com.google.ai.edge.litertlm.Contents
|
|
30
|
+
import com.google.ai.edge.litertlm.ExperimentalApi
|
|
31
|
+
import com.google.ai.edge.litertlm.ExperimentalFlags
|
|
32
|
+
import com.google.ai.edge.litertlm.OpenApiTool
|
|
33
|
+
import com.google.ai.edge.litertlm.ToolProvider
|
|
34
|
+
import java.util.concurrent.CountDownLatch
|
|
35
|
+
import java.util.concurrent.TimeUnit
|
|
36
|
+
import java.util.concurrent.atomic.AtomicBoolean
|
|
37
|
+
import java.util.concurrent.atomic.AtomicReference
|
|
28
38
|
|
|
29
39
|
|
|
30
40
|
// Alias to avoid confusion with our generated Message type
|
|
@@ -42,13 +52,26 @@ internal class StreamingCallbackListener(
|
|
|
42
52
|
private val onToken: (String, Boolean) -> Unit,
|
|
43
53
|
private val responseBuilder: StringBuilder,
|
|
44
54
|
private val history: MutableList<Message>,
|
|
55
|
+
private val userMessage: String,
|
|
56
|
+
private val onStatsReady: (GenerationStats) -> Unit,
|
|
45
57
|
) : com.google.ai.edge.litertlm.MessageCallback {
|
|
46
58
|
|
|
47
|
-
|
|
48
|
-
|
|
59
|
+
private val startTime = System.nanoTime()
|
|
60
|
+
private var firstTokenTime = 0L
|
|
61
|
+
private var tokenCount = 0
|
|
62
|
+
|
|
63
|
+
override fun onMessage(message: com.google.ai.edge.litertlm.Message) {
|
|
64
|
+
val chunk = message.contents.contents
|
|
49
65
|
.filterIsInstance<com.google.ai.edge.litertlm.Content.Text>()
|
|
50
66
|
.joinToString("") { it.text }
|
|
51
67
|
|
|
68
|
+
if (firstTokenTime == 0L && chunk.isNotEmpty()) {
|
|
69
|
+
firstTokenTime = System.nanoTime()
|
|
70
|
+
}
|
|
71
|
+
if (chunk.isNotEmpty()) {
|
|
72
|
+
tokenCount++
|
|
73
|
+
}
|
|
74
|
+
|
|
52
75
|
onToken(chunk, false)
|
|
53
76
|
|
|
54
77
|
if (chunk.isNotEmpty()) {
|
|
@@ -60,12 +83,27 @@ internal class StreamingCallbackListener(
|
|
|
60
83
|
onToken("", true)
|
|
61
84
|
val fullResponse = responseBuilder.toString()
|
|
62
85
|
history.add(Message(Role.MODEL, fullResponse))
|
|
63
|
-
|
|
86
|
+
|
|
87
|
+
// Compute stats using heuristic token counts (~4 chars/token)
|
|
88
|
+
val elapsedMs = (System.nanoTime() - startTime) / 1_000_000.0
|
|
89
|
+
val ttftMs = if (firstTokenTime > 0) (firstTokenTime - startTime) / 1_000_000.0 else 0.0
|
|
90
|
+
val promptTokens = userMessage.length / 4.0
|
|
91
|
+
val completionTokens = fullResponse.length / 4.0
|
|
92
|
+
onStatsReady(GenerationStats(
|
|
93
|
+
promptTokens = promptTokens,
|
|
94
|
+
completionTokens = completionTokens,
|
|
95
|
+
totalTokens = promptTokens + completionTokens,
|
|
96
|
+
timeToFirstToken = ttftMs,
|
|
97
|
+
totalTime = elapsedMs,
|
|
98
|
+
tokensPerSecond = if (elapsedMs > 0) completionTokens / (elapsedMs / 1000.0) else 0.0
|
|
99
|
+
))
|
|
100
|
+
|
|
101
|
+
Log.d("StreamingCallbackListener", "Streaming done. Length: ${fullResponse.length}, TTFT: ${ttftMs.toLong()}ms, Total: ${elapsedMs.toLong()}ms")
|
|
64
102
|
}
|
|
65
103
|
|
|
66
|
-
override fun onError(
|
|
67
|
-
Log.e("StreamingCallbackListener", "Async generation failed",
|
|
68
|
-
onToken("Error: ${
|
|
104
|
+
override fun onError(throwable: Throwable) {
|
|
105
|
+
Log.e("StreamingCallbackListener", "Async generation failed", throwable)
|
|
106
|
+
onToken("Error: ${throwable.message}", true)
|
|
69
107
|
}
|
|
70
108
|
}
|
|
71
109
|
|
|
@@ -80,6 +118,10 @@ class HybridLiteRTLM : HybridLiteRTLMSpec() {
|
|
|
80
118
|
companion object {
|
|
81
119
|
private const val TAG = "HybridLiteRTLM"
|
|
82
120
|
private val initLock = Any()
|
|
121
|
+
|
|
122
|
+
/** Cached result of OpenCL availability probe (null = not yet checked). */
|
|
123
|
+
@Volatile
|
|
124
|
+
private var openCLAvailable: Boolean? = null
|
|
83
125
|
|
|
84
126
|
/**
|
|
85
127
|
* Initialize the native library.
|
|
@@ -129,6 +171,8 @@ class HybridLiteRTLM : HybridLiteRTLMSpec() {
|
|
|
129
171
|
private var topP: Double = 0.95
|
|
130
172
|
private var maxTokens: Int = 1024
|
|
131
173
|
private var systemPrompt: String? = null
|
|
174
|
+
private var tools: Array<ToolDefinition>? = null
|
|
175
|
+
private var enableSpeculativeDecoding: Boolean = false
|
|
132
176
|
|
|
133
177
|
override val memorySize: Long
|
|
134
178
|
get() = 1024L * 1024L * 1024L // ~1GB (models are large)
|
|
@@ -158,9 +202,43 @@ class HybridLiteRTLM : HybridLiteRTLMSpec() {
|
|
|
158
202
|
cfg.topP?.let { topP = it }
|
|
159
203
|
cfg.maxTokens?.let { maxTokens = it.toInt() }
|
|
160
204
|
cfg.systemPrompt?.let { systemPrompt = it }
|
|
205
|
+
tools = cfg.tools
|
|
206
|
+
enableSpeculativeDecoding = cfg.enableSpeculativeDecoding ?: false
|
|
161
207
|
}
|
|
162
208
|
|
|
209
|
+
// Whether to run engine validation after loading
|
|
210
|
+
val shouldValidate = config?.validate?: false
|
|
211
|
+
|
|
163
212
|
try {
|
|
213
|
+
// Early GPU hardware check: probe for OpenCL library before
|
|
214
|
+
// spending time on engine creation. LiteRT-LM's GPU delegate
|
|
215
|
+
// requires OpenCL, which is absent on most Samsung/Qualcomm devices.
|
|
216
|
+
if (backend == Backend.GPU) {
|
|
217
|
+
val hasOpenCL = openCLAvailable ?: run {
|
|
218
|
+
val result = try {
|
|
219
|
+
System.loadLibrary("OpenCL")
|
|
220
|
+
true
|
|
221
|
+
} catch (_: UnsatisfiedLinkError) {
|
|
222
|
+
try {
|
|
223
|
+
// Some devices have it at a non-standard path
|
|
224
|
+
System.load("/system/vendor/lib64/libOpenCL.so")
|
|
225
|
+
true
|
|
226
|
+
} catch (_: UnsatisfiedLinkError) {
|
|
227
|
+
false
|
|
228
|
+
}
|
|
229
|
+
}
|
|
230
|
+
openCLAvailable = result
|
|
231
|
+
result
|
|
232
|
+
}
|
|
233
|
+
if (!hasOpenCL) {
|
|
234
|
+
throw RuntimeException(
|
|
235
|
+
"GPU backend is not supported on this device (OpenCL library not found). " +
|
|
236
|
+
"Please use CPU backend instead."
|
|
237
|
+
)
|
|
238
|
+
}
|
|
239
|
+
Log.i(TAG, "OpenCL library found — GPU backend is available")
|
|
240
|
+
}
|
|
241
|
+
|
|
164
242
|
// Map our Backend enum to LiteRT-LM Backend sealed class
|
|
165
243
|
val lmBackend = when (backend) {
|
|
166
244
|
Backend.GPU -> com.google.ai.edge.litertlm.Backend.GPU()
|
|
@@ -171,12 +249,12 @@ class HybridLiteRTLM : HybridLiteRTLMSpec() {
|
|
|
171
249
|
else -> com.google.ai.edge.litertlm.Backend.CPU()
|
|
172
250
|
}
|
|
173
251
|
|
|
174
|
-
// Detect multimodal support
|
|
252
|
+
// Detect multimodal support. Check config.multimodal flag first, then fall back to filename sniffing.
|
|
175
253
|
// Only Gemma 3n bundles vision/audio executors; Gemma 4 E2B is text-only.
|
|
176
254
|
// Passing vision/audio backends to a text-only model causes
|
|
177
255
|
// vision_litert_compiled_model_executor init failures.
|
|
178
256
|
val modelFileName = modelPath.substringAfterLast("/").lowercase()
|
|
179
|
-
val isMultimodal = modelFileName.contains("3n") || modelFileName.contains("gemma3")
|
|
257
|
+
val isMultimodal = config?.multimodal ?: (modelFileName.contains("3n") || modelFileName.contains("gemma3"))
|
|
180
258
|
|
|
181
259
|
val lmVisionBackend = if (isMultimodal) com.google.ai.edge.litertlm.Backend.GPU() else null
|
|
182
260
|
val lmAudioBackend = if (isMultimodal) com.google.ai.edge.litertlm.Backend.CPU() else null
|
|
@@ -208,6 +286,11 @@ class HybridLiteRTLM : HybridLiteRTLMSpec() {
|
|
|
208
286
|
|
|
209
287
|
if (isClosed) return@synchronized
|
|
210
288
|
|
|
289
|
+
if (enableSpeculativeDecoding) {
|
|
290
|
+
@OptIn(ExperimentalApi::class)
|
|
291
|
+
ExperimentalFlags.enableSpeculativeDecoding = true
|
|
292
|
+
}
|
|
293
|
+
|
|
211
294
|
// Initialize Engine
|
|
212
295
|
engine = Engine(engineConfig).also { it.initialize() }
|
|
213
296
|
Log.i(TAG, "Engine created and initialized successfully")
|
|
@@ -215,9 +298,24 @@ class HybridLiteRTLM : HybridLiteRTLMSpec() {
|
|
|
215
298
|
// Create Conversation
|
|
216
299
|
createNewConversation()
|
|
217
300
|
Log.i(TAG, "Conversation created successfully")
|
|
301
|
+
|
|
302
|
+
// Validate the engine actually works with a quick test inference.
|
|
303
|
+
// GPU/NPU backends can initialize without error but silently fail to
|
|
304
|
+
// produce tokens — enabling this catches those failures at load time.
|
|
305
|
+
// CPU is always reliable so validation is never run on it, even when
|
|
306
|
+
// the `validate` flag is set.
|
|
307
|
+
if (shouldValidate) {
|
|
308
|
+
if (backend == Backend.GPU || backend == Backend.NPU) {
|
|
309
|
+
validateEngine()
|
|
310
|
+
} else {
|
|
311
|
+
Log.i(TAG, "Validation skipped: CPU backend is always reliable")
|
|
312
|
+
}
|
|
313
|
+
}
|
|
218
314
|
|
|
219
315
|
} catch (e: Exception) {
|
|
220
316
|
Log.e(TAG, "Failed to load model: ${e.message}", e)
|
|
317
|
+
// Clean up partial state so isReady() returns false
|
|
318
|
+
cleanupInternal()
|
|
221
319
|
throw RuntimeException("Failed to load model: ${e.message}", e)
|
|
222
320
|
}
|
|
223
321
|
}
|
|
@@ -241,7 +339,7 @@ class HybridLiteRTLM : HybridLiteRTLMSpec() {
|
|
|
241
339
|
Log.i(TAG, "sendMessage (Promise): $message")
|
|
242
340
|
|
|
243
341
|
// Blocking inference (safe here because we are in Promise.parallel worker thread)
|
|
244
|
-
val userMsg = LiteRTMessage.
|
|
342
|
+
val userMsg = LiteRTMessage.user(message)
|
|
245
343
|
val startTime = System.nanoTime()
|
|
246
344
|
val responseMsg = conversation!!.sendMessage(message = userMsg)
|
|
247
345
|
val elapsedMs = (System.nanoTime() - startTime) / 1_000_000.0
|
|
@@ -276,30 +374,48 @@ class HybridLiteRTLM : HybridLiteRTLMSpec() {
|
|
|
276
374
|
// -------------------------------------------------------------------------
|
|
277
375
|
// sendMessageAsync - Streaming inference
|
|
278
376
|
// -------------------------------------------------------------------------
|
|
279
|
-
override fun sendMessageAsync(message: String, onToken: (String, Boolean) -> Unit) {
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
ensureLoaded()
|
|
377
|
+
override fun sendMessageAsync(message: String, onToken: (String, Boolean) -> Unit): Promise<Unit> {
|
|
378
|
+
return Promise.parallel {
|
|
379
|
+
val latch = CountDownLatch(1)
|
|
380
|
+
val errorRef = AtomicReference<Throwable?>(null)
|
|
284
381
|
|
|
285
|
-
|
|
286
|
-
history.add(Message(Role.USER, message))
|
|
287
|
-
Log.d(TAG, "sendMessageAsync: $message")
|
|
382
|
+
ensureLoaded()
|
|
288
383
|
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
onToken = onToken,
|
|
293
|
-
responseBuilder = fullResponseBuilder,
|
|
294
|
-
history = history,
|
|
295
|
-
)
|
|
384
|
+
// Add user message to history
|
|
385
|
+
history.add(Message(Role.USER, message))
|
|
386
|
+
Log.d(TAG, "sendMessageAsync: $message")
|
|
296
387
|
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
388
|
+
val fullResponseBuilder = StringBuilder()
|
|
389
|
+
|
|
390
|
+
val listener = StreamingCallbackListener(
|
|
391
|
+
onToken = { token, done ->
|
|
392
|
+
onToken(token, done)
|
|
393
|
+
if (done) {
|
|
394
|
+
latch.countDown()
|
|
395
|
+
}
|
|
396
|
+
},
|
|
397
|
+
responseBuilder = fullResponseBuilder,
|
|
398
|
+
history = history,
|
|
399
|
+
userMessage = message,
|
|
400
|
+
onStatsReady = { stats -> lastStats = stats },
|
|
401
|
+
)
|
|
402
|
+
|
|
403
|
+
try {
|
|
404
|
+
val userMsg = LiteRTMessage.user(message)
|
|
405
|
+
conversation!!.sendMessageAsync(message = userMsg, callback = listener)
|
|
406
|
+
} catch (e: Exception) {
|
|
407
|
+
Log.e(TAG, "Failed to initiate async generation", e)
|
|
408
|
+
errorRef.set(e)
|
|
409
|
+
onToken("Error: ${e.message}", true)
|
|
410
|
+
latch.countDown()
|
|
411
|
+
}
|
|
412
|
+
|
|
413
|
+
// Wait for completion or error
|
|
414
|
+
latch.await()
|
|
415
|
+
val err = errorRef.get()
|
|
416
|
+
if (err != null) {
|
|
417
|
+
throw RuntimeException("Async inference failed: ${err.message}", err)
|
|
418
|
+
}
|
|
303
419
|
}
|
|
304
420
|
}
|
|
305
421
|
|
|
@@ -359,7 +475,7 @@ class HybridLiteRTLM : HybridLiteRTLMSpec() {
|
|
|
359
475
|
// Use factory method Message.of passing a list of Content
|
|
360
476
|
val textContent = Content.Text(message)
|
|
361
477
|
|
|
362
|
-
val userMsg = LiteRTMessage.of(textContent, Content.ImageFile(processedImagePath))
|
|
478
|
+
val userMsg = LiteRTMessage.user(Contents.of(textContent, Content.ImageFile(processedImagePath)))
|
|
363
479
|
|
|
364
480
|
// Add to history
|
|
365
481
|
history.add(Message(Role.USER, "$message [Image]"))
|
|
@@ -389,6 +505,14 @@ class HybridLiteRTLM : HybridLiteRTLMSpec() {
|
|
|
389
505
|
return Promise.parallel {
|
|
390
506
|
Log.i(TAG, "downloadModel: $url -> $fileName")
|
|
391
507
|
|
|
508
|
+
if (!url.startsWith("https://", ignoreCase = true)) {
|
|
509
|
+
throw IllegalArgumentException("Invalid download URL: HTTPS is required for security.")
|
|
510
|
+
}
|
|
511
|
+
|
|
512
|
+
if (fileName.contains("..") || fileName.contains("/") || fileName.contains("\\")) {
|
|
513
|
+
throw IllegalArgumentException("Invalid filename: path traversal or directory separators are not allowed.")
|
|
514
|
+
}
|
|
515
|
+
|
|
392
516
|
val context = LiteRTLMInitProvider.applicationContext ?: throw RuntimeException("Context not available")
|
|
393
517
|
val modelsDir = java.io.File(context.filesDir, "models")
|
|
394
518
|
if (!modelsDir.exists()) {
|
|
@@ -470,6 +594,11 @@ class HybridLiteRTLM : HybridLiteRTLMSpec() {
|
|
|
470
594
|
override fun deleteModel(fileName: String): Promise<Unit> {
|
|
471
595
|
return Promise.parallel {
|
|
472
596
|
Log.i(TAG, "deleteModel: $fileName")
|
|
597
|
+
|
|
598
|
+
if (fileName.contains("..") || fileName.contains("/") || fileName.contains("\\")) {
|
|
599
|
+
throw IllegalArgumentException("Invalid filename: path traversal or directory separators are not allowed.")
|
|
600
|
+
}
|
|
601
|
+
|
|
473
602
|
val context = LiteRTLMInitProvider.applicationContext ?: throw RuntimeException("Context not available")
|
|
474
603
|
val modelsDir = java.io.File(context.filesDir, "models")
|
|
475
604
|
val modelFile = java.io.File(modelsDir, fileName)
|
|
@@ -501,10 +630,10 @@ class HybridLiteRTLM : HybridLiteRTLMSpec() {
|
|
|
501
630
|
|
|
502
631
|
// Load audio
|
|
503
632
|
|
|
504
|
-
val userMsg = LiteRTMessage.of(
|
|
633
|
+
val userMsg = LiteRTMessage.user(Contents.of(
|
|
505
634
|
Content.Text(message),
|
|
506
635
|
Content.AudioFile(audioPath)
|
|
507
|
-
)
|
|
636
|
+
))
|
|
508
637
|
|
|
509
638
|
history.add(Message(Role.USER, "$message [Audio]"))
|
|
510
639
|
|
|
@@ -601,19 +730,9 @@ class HybridLiteRTLM : HybridLiteRTLMSpec() {
|
|
|
601
730
|
|
|
602
731
|
private fun cleanupInternal() {
|
|
603
732
|
try {
|
|
733
|
+
conversation?.close()
|
|
604
734
|
conversation = null
|
|
605
|
-
//
|
|
606
|
-
// Assuming Engine implements AutoCloseable or has close()
|
|
607
|
-
if (engine is AutoCloseable) {
|
|
608
|
-
(engine as AutoCloseable).close()
|
|
609
|
-
} else {
|
|
610
|
-
// Try reflection or just null it if no close method
|
|
611
|
-
try {
|
|
612
|
-
engine?.javaClass?.getMethod("close")?.invoke(engine)
|
|
613
|
-
} catch (e: Exception) {
|
|
614
|
-
// Method not found, rely on GC
|
|
615
|
-
}
|
|
616
|
-
}
|
|
735
|
+
engine?.close() // Direct call
|
|
617
736
|
engine = null
|
|
618
737
|
} catch (e: Exception) {
|
|
619
738
|
Log.e(TAG, "Error closing resources", e)
|
|
@@ -631,33 +750,167 @@ class HybridLiteRTLM : HybridLiteRTLMSpec() {
|
|
|
631
750
|
// v0.10.2 enforces single-session: close existing conversation first
|
|
632
751
|
conversation?.let { oldConv ->
|
|
633
752
|
try {
|
|
634
|
-
|
|
635
|
-
oldConv.close()
|
|
636
|
-
} else {
|
|
637
|
-
oldConv.javaClass.getMethod("close").invoke(oldConv)
|
|
638
|
-
}
|
|
753
|
+
oldConv.close()
|
|
639
754
|
} catch (e: Exception) {
|
|
640
755
|
Log.w(TAG, "Failed to close old conversation: ${e.message}")
|
|
641
756
|
}
|
|
642
757
|
conversation = null
|
|
643
758
|
}
|
|
644
|
-
|
|
645
|
-
|
|
646
|
-
|
|
647
|
-
|
|
648
|
-
|
|
649
|
-
|
|
650
|
-
|
|
651
|
-
|
|
652
|
-
val systemMsg = LiteRTMessage.of(Content.Text(prompt))
|
|
653
|
-
conversation!!.sendMessage(message = systemMsg)
|
|
654
|
-
Log.i(TAG, "System prompt applied (${prompt.length} chars)")
|
|
655
|
-
} catch (e: Exception) {
|
|
656
|
-
Log.w(TAG, "Failed to apply system prompt: ${e.message}")
|
|
759
|
+
// Map tools
|
|
760
|
+
val lmTools: List<ToolProvider>? = tools?.map { tool ->
|
|
761
|
+
val apiTool = object : OpenApiTool {
|
|
762
|
+
override fun getToolDescriptionJsonString(): String {
|
|
763
|
+
return tool.parametersJson
|
|
764
|
+
}
|
|
765
|
+
override fun execute(paramsJsonString: String): String {
|
|
766
|
+
return "{}"
|
|
657
767
|
}
|
|
658
768
|
}
|
|
769
|
+
(apiTool as Any) as ToolProvider
|
|
659
770
|
}
|
|
771
|
+
|
|
772
|
+
// Create conversation with explicit SamplerConfig (required by Gallery pattern).
|
|
773
|
+
// GPU backend may fail silently without proper sampler params.
|
|
774
|
+
val convConfig = ConversationConfig(
|
|
775
|
+
samplerConfig = SamplerConfig(
|
|
776
|
+
topK = topK,
|
|
777
|
+
topP = topP.toDouble(),
|
|
778
|
+
temperature = temperature.toDouble(),
|
|
779
|
+
),
|
|
780
|
+
systemInstruction = systemPrompt?.let { Contents.of(Content.Text(it)) },
|
|
781
|
+
tools = lmTools ?: emptyList()
|
|
782
|
+
)
|
|
783
|
+
conversation = engine!!.createConversation(convConfig)
|
|
660
784
|
}
|
|
661
785
|
|
|
786
|
+
/**
|
|
787
|
+
* Validate that the engine can actually produce inference output.
|
|
788
|
+
*
|
|
789
|
+
* Some GPU backends initialize without error but silently hang during inference.
|
|
790
|
+
* This sends a minimal test prompt ("Hi") and waits up to 30s for any token.
|
|
791
|
+
* If no token arrives, we throw so the model does NOT appear as loaded.
|
|
792
|
+
*/
|
|
793
|
+
private fun validateEngine() {
|
|
794
|
+
val backendName = when (backend) {
|
|
795
|
+
Backend.GPU -> "GPU"
|
|
796
|
+
Backend.NPU -> "NPU"
|
|
797
|
+
else -> "CPU"
|
|
798
|
+
}
|
|
799
|
+
Log.i(TAG, "Validating $backendName backend with test inference...")
|
|
800
|
+
|
|
801
|
+
val latch = CountDownLatch(1)
|
|
802
|
+
val gotToken = AtomicBoolean(false)
|
|
803
|
+
val errorRef = AtomicReference<String?>(null)
|
|
804
|
+
|
|
805
|
+
// Use the existing conversation for validation (single-session constraint).
|
|
806
|
+
val validationConv = conversation
|
|
807
|
+
?: throw RuntimeException("$backendName backend: no conversation available for validation")
|
|
808
|
+
|
|
809
|
+
try {
|
|
810
|
+
val testMsg = LiteRTMessage.user("Hi")
|
|
811
|
+
validationConv.sendMessageAsync(
|
|
812
|
+
message = testMsg,
|
|
813
|
+
callback = object : com.google.ai.edge.litertlm.MessageCallback {
|
|
814
|
+
override fun onMessage(msg: com.google.ai.edge.litertlm.Message) {
|
|
815
|
+
gotToken.set(true)
|
|
816
|
+
latch.countDown()
|
|
817
|
+
}
|
|
818
|
+
override fun onDone() {
|
|
819
|
+
latch.countDown()
|
|
820
|
+
}
|
|
821
|
+
override fun onError(t: Throwable) {
|
|
822
|
+
errorRef.set(t.message)
|
|
823
|
+
latch.countDown()
|
|
824
|
+
}
|
|
825
|
+
}
|
|
826
|
+
)
|
|
827
|
+
} catch (e: Exception) {
|
|
828
|
+
throw RuntimeException(
|
|
829
|
+
"$backendName backend failed to run inference: ${e.message}. " +
|
|
830
|
+
"This device may not support the $backendName backend. Please try CPU.",
|
|
831
|
+
e
|
|
832
|
+
)
|
|
833
|
+
}
|
|
662
834
|
|
|
835
|
+
// Wait up to 30s for any response
|
|
836
|
+
val completed = latch.await(30, TimeUnit.SECONDS)
|
|
837
|
+
|
|
838
|
+
val error = errorRef.get()
|
|
839
|
+
if (error != null) {
|
|
840
|
+
throw RuntimeException(
|
|
841
|
+
"$backendName backend inference error: $error. " +
|
|
842
|
+
"This device may not support the $backendName backend. Please try CPU."
|
|
843
|
+
)
|
|
844
|
+
}
|
|
845
|
+
if (!completed || !gotToken.get()) {
|
|
846
|
+
throw RuntimeException(
|
|
847
|
+
"$backendName backend produced no response within 30 seconds. " +
|
|
848
|
+
"This device may not support the $backendName backend. Please try CPU."
|
|
849
|
+
)
|
|
850
|
+
}
|
|
851
|
+
|
|
852
|
+
Log.i(TAG, "$backendName backend validated successfully")
|
|
853
|
+
|
|
854
|
+
// Re-create the real conversation (validation consumed one turn)
|
|
855
|
+
createNewConversation()
|
|
856
|
+
}
|
|
857
|
+
|
|
858
|
+
override fun sendMultimodalMessage(parts: Array<MultimodalPart>): Promise<String> {
|
|
859
|
+
return Promise.parallel {
|
|
860
|
+
ensureLoaded()
|
|
861
|
+
val contents = mutableListOf<Content>()
|
|
862
|
+
var userTextRepresentation = ""
|
|
863
|
+
for (part in parts) {
|
|
864
|
+
when (part.type) {
|
|
865
|
+
PartType.TEXT -> part.text?.let {
|
|
866
|
+
contents.add(Content.Text(it))
|
|
867
|
+
userTextRepresentation += "$it "
|
|
868
|
+
}
|
|
869
|
+
PartType.IMAGE -> part.imageBuffer?.let { buffer ->
|
|
870
|
+
val byteBuffer = buffer.getBuffer(false)
|
|
871
|
+
val bytes = ByteArray(byteBuffer.remaining())
|
|
872
|
+
byteBuffer.get(bytes)
|
|
873
|
+
contents.add(Content.ImageBytes(bytes))
|
|
874
|
+
userTextRepresentation += "[Image Buffer] "
|
|
875
|
+
}
|
|
876
|
+
PartType.AUDIO -> part.audioBuffer?.let { buffer ->
|
|
877
|
+
val byteBuffer = buffer.getBuffer(false)
|
|
878
|
+
val bytes = ByteArray(byteBuffer.remaining())
|
|
879
|
+
byteBuffer.get(bytes)
|
|
880
|
+
contents.add(Content.AudioBytes(bytes))
|
|
881
|
+
userTextRepresentation += "[Audio Buffer] "
|
|
882
|
+
}
|
|
883
|
+
}
|
|
884
|
+
}
|
|
885
|
+
userTextRepresentation = userTextRepresentation.trim()
|
|
886
|
+
history.add(Message(Role.USER, userTextRepresentation))
|
|
887
|
+
|
|
888
|
+
val userMsg = LiteRTMessage.user(Contents.of(contents))
|
|
889
|
+
val startTime = System.nanoTime()
|
|
890
|
+
val responseMsg = conversation!!.sendMessage(message = userMsg)
|
|
891
|
+
val elapsedMs = (System.nanoTime() - startTime) / 1_000_000.0
|
|
892
|
+
|
|
893
|
+
val response = responseMsg.contents.contents
|
|
894
|
+
.filterIsInstance<Content.Text>()
|
|
895
|
+
.joinToString("") { it.text }
|
|
896
|
+
|
|
897
|
+
history.add(Message(Role.MODEL, response))
|
|
898
|
+
|
|
899
|
+
val promptTokens = userTextRepresentation.length / 4.0
|
|
900
|
+
val completionTokens = response.length / 4.0
|
|
901
|
+
lastStats = GenerationStats(
|
|
902
|
+
promptTokens = promptTokens,
|
|
903
|
+
completionTokens = completionTokens,
|
|
904
|
+
totalTokens = promptTokens + completionTokens,
|
|
905
|
+
timeToFirstToken = 0.0,
|
|
906
|
+
totalTime = elapsedMs,
|
|
907
|
+
tokensPerSecond = if (elapsedMs > 0) completionTokens / (elapsedMs / 1000.0) else 0.0
|
|
908
|
+
)
|
|
909
|
+
response
|
|
910
|
+
}
|
|
911
|
+
}
|
|
912
|
+
|
|
913
|
+
override fun countTokens(text: String): Double {
|
|
914
|
+
return -1.0
|
|
915
|
+
}
|
|
663
916
|
}
|
|
@@ -1,18 +1,35 @@
|
|
|
1
1
|
package dev.litert.litertlm
|
|
2
2
|
|
|
3
|
+
import android.os.Build
|
|
4
|
+
import android.util.Log
|
|
3
5
|
import com.facebook.react.TurboReactPackage
|
|
4
6
|
import com.facebook.react.bridge.NativeModule
|
|
5
7
|
import com.facebook.react.bridge.ReactApplicationContext
|
|
6
8
|
import com.facebook.react.module.model.ReactModuleInfo
|
|
7
9
|
import com.facebook.react.module.model.ReactModuleInfoProvider
|
|
8
|
-
import com.margelo.nitro.core.HybridObject
|
|
9
10
|
|
|
10
11
|
|
|
11
12
|
import com.margelo.nitro.dev.litert.litertlm.LiteRTLMOnLoad
|
|
12
13
|
|
|
13
14
|
class LiteRTLMPackage : TurboReactPackage() {
|
|
15
|
+
companion object {
|
|
16
|
+
private const val TAG = "LiteRTLMPackage"
|
|
17
|
+
|
|
18
|
+
private fun isSupportedPrimaryAbi(): Boolean {
|
|
19
|
+
val primaryAbi = Build.SUPPORTED_64_BIT_ABIS.firstOrNull() ?: return false
|
|
20
|
+
return primaryAbi == "arm64-v8a"
|
|
21
|
+
}
|
|
22
|
+
}
|
|
14
23
|
init {
|
|
15
|
-
|
|
24
|
+
if (!isSupportedPrimaryAbi()) {
|
|
25
|
+
Log.w(TAG, "Skipping LiteRTLM native init on unsupported primary ABI: ${Build.SUPPORTED_64_BIT_ABIS.firstOrNull()}")
|
|
26
|
+
} else {
|
|
27
|
+
try {
|
|
28
|
+
LiteRTLMOnLoad.initializeNative()
|
|
29
|
+
} catch (e: UnsatisfiedLinkError) {
|
|
30
|
+
Log.e(TAG, "LiteRTLM native init failed; disabling LiteRTLM for this process.", e)
|
|
31
|
+
}
|
|
32
|
+
}
|
|
16
33
|
}
|
|
17
34
|
|
|
18
35
|
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
package com.margelo.nitro.core
|
|
2
|
+
|
|
3
|
+
import androidx.annotation.Keep
|
|
4
|
+
import com.facebook.proguard.annotations.DoNotStrip
|
|
5
|
+
|
|
6
|
+
@Keep
|
|
7
|
+
@DoNotStrip
|
|
8
|
+
class Promise<T> {
|
|
9
|
+
companion object {
|
|
10
|
+
@JvmStatic
|
|
11
|
+
fun <T> parallel(block: () -> T): Promise<T> {
|
|
12
|
+
val promise = Promise<T>()
|
|
13
|
+
try {
|
|
14
|
+
val result = block()
|
|
15
|
+
promise.resolve(result)
|
|
16
|
+
} catch (e: Throwable) {
|
|
17
|
+
promise.reject(e)
|
|
18
|
+
}
|
|
19
|
+
return promise
|
|
20
|
+
}
|
|
21
|
+
}
|
|
22
|
+
|
|
23
|
+
var result: T? = null
|
|
24
|
+
private set
|
|
25
|
+
var error: Throwable? = null
|
|
26
|
+
private set
|
|
27
|
+
var isCompleted = false
|
|
28
|
+
private set
|
|
29
|
+
private val callbacks = mutableListOf<(T?, Throwable?) -> Unit>()
|
|
30
|
+
|
|
31
|
+
fun resolve(value: T) {
|
|
32
|
+
synchronized(this) {
|
|
33
|
+
result = value
|
|
34
|
+
isCompleted = true
|
|
35
|
+
callbacks.forEach { it(value, null) }
|
|
36
|
+
}
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
fun reject(exception: Throwable) {
|
|
40
|
+
synchronized(this) {
|
|
41
|
+
error = exception
|
|
42
|
+
isCompleted = true
|
|
43
|
+
callbacks.forEach { it(null, exception) }
|
|
44
|
+
}
|
|
45
|
+
}
|
|
46
|
+
}
|