react-native-litert-lm 0.3.7 → 0.4.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +153 -135
- package/android/build.gradle +12 -0
- package/android/src/main/AndroidManifest.xml +8 -0
- package/android/src/main/java/com/margelo/nitro/dev/litert/litertlm/HybridLiteRTLM.kt +276 -62
- package/android/src/main/java/dev/litert/litertlm/LiteRTLMPackage.kt +19 -2
- package/android/src/test/java/com/margelo/nitro/core/Promise.kt +46 -0
- package/android/src/test/java/com/margelo/nitro/dev/litert/litertlm/HybridLiteRTLMTest.kt +105 -0
- package/ios/HybridLiteRTLM.swift +1344 -0
- package/ios/Tests/HybridLiteRTLMTests.swift +113 -0
- package/lib/__mocks__/react-native-nitro-modules.d.ts +65 -0
- package/lib/__mocks__/react-native-nitro-modules.js +60 -0
- package/lib/__tests__/hooks.test.d.ts +1 -0
- package/lib/__tests__/hooks.test.js +124 -0
- package/lib/__tests__/memoryTracker.test.d.ts +1 -0
- package/lib/__tests__/memoryTracker.test.js +74 -0
- package/lib/__tests__/modelFactory.test.d.ts +1 -0
- package/lib/__tests__/modelFactory.test.js +68 -0
- package/lib/hooks.js +27 -3
- package/lib/index.d.ts +6 -2
- package/lib/index.js +8 -8
- package/lib/modelFactory.js +82 -63
- package/lib/specs/LiteRTLM.nitro.d.ts +87 -2
- package/nitrogen/generated/android/LiteRTLMOnLoad.cpp +2 -2
- package/nitrogen/generated/android/c++/JHybridLiteRTLMSpec.cpp +94 -9
- package/nitrogen/generated/android/c++/JHybridLiteRTLMSpec.hpp +5 -1
- package/nitrogen/generated/android/c++/JLLMConfig.hpp +40 -3
- package/nitrogen/generated/android/c++/JMultimodalPart.hpp +74 -0
- package/nitrogen/generated/android/c++/JPartType.hpp +61 -0
- package/nitrogen/generated/android/c++/JToolDefinition.hpp +65 -0
- package/nitrogen/generated/android/kotlin/com/margelo/nitro/dev/litert/litertlm/GenerationStats.kt +23 -0
- package/nitrogen/generated/android/kotlin/com/margelo/nitro/dev/litert/litertlm/HybridLiteRTLMSpec.kt +28 -2
- package/nitrogen/generated/android/kotlin/com/margelo/nitro/dev/litert/litertlm/LLMConfig.kt +46 -3
- package/nitrogen/generated/android/kotlin/com/margelo/nitro/dev/litert/litertlm/MemoryUsage.kt +19 -0
- package/nitrogen/generated/android/kotlin/com/margelo/nitro/dev/litert/litertlm/Message.kt +15 -0
- package/nitrogen/generated/android/kotlin/com/margelo/nitro/dev/litert/litertlm/MultimodalPart.kt +66 -0
- package/nitrogen/generated/android/kotlin/com/margelo/nitro/dev/litert/litertlm/PartType.kt +24 -0
- package/nitrogen/generated/android/kotlin/com/margelo/nitro/dev/litert/litertlm/ToolDefinition.kt +61 -0
- package/nitrogen/generated/ios/LiteRTLM-Swift-Cxx-Bridge.cpp +57 -1
- package/nitrogen/generated/ios/LiteRTLM-Swift-Cxx-Bridge.hpp +414 -3
- package/nitrogen/generated/ios/LiteRTLM-Swift-Cxx-Umbrella.hpp +41 -3
- package/nitrogen/generated/ios/LiteRTLMAutolinking.mm +4 -6
- package/nitrogen/generated/ios/LiteRTLMAutolinking.swift +10 -0
- package/nitrogen/generated/ios/c++/HybridLiteRTLMSpecSwift.cpp +11 -0
- package/nitrogen/generated/ios/c++/HybridLiteRTLMSpecSwift.hpp +240 -0
- package/nitrogen/generated/ios/swift/Backend.swift +44 -0
- package/nitrogen/generated/ios/swift/Func_void.swift +46 -0
- package/nitrogen/generated/ios/swift/Func_void_double.swift +46 -0
- package/nitrogen/generated/ios/swift/Func_void_std__exception_ptr.swift +46 -0
- package/nitrogen/generated/ios/swift/Func_void_std__string.swift +46 -0
- package/nitrogen/generated/ios/swift/Func_void_std__string_bool.swift +46 -0
- package/nitrogen/generated/ios/swift/GenerationStats.swift +54 -0
- package/nitrogen/generated/ios/swift/HybridLiteRTLMSpec.swift +71 -0
- package/nitrogen/generated/ios/swift/HybridLiteRTLMSpec_cxx.swift +431 -0
- package/nitrogen/generated/ios/swift/LLMConfig.swift +203 -0
- package/nitrogen/generated/ios/swift/MemoryUsage.swift +44 -0
- package/nitrogen/generated/ios/swift/Message.swift +34 -0
- package/nitrogen/generated/ios/swift/MultimodalPart.swift +83 -0
- package/nitrogen/generated/ios/swift/PartType.swift +44 -0
- package/nitrogen/generated/ios/swift/Role.swift +44 -0
- package/nitrogen/generated/ios/swift/ToolDefinition.swift +39 -0
- package/nitrogen/generated/shared/c++/HybridLiteRTLMSpec.cpp +4 -0
- package/nitrogen/generated/shared/c++/HybridLiteRTLMSpec.hpp +9 -2
- package/nitrogen/generated/shared/c++/LLMConfig.hpp +22 -2
- package/nitrogen/generated/shared/c++/MultimodalPart.hpp +99 -0
- package/nitrogen/generated/shared/c++/PartType.hpp +80 -0
- package/nitrogen/generated/shared/c++/ToolDefinition.hpp +91 -0
- package/package.json +22 -11
- package/react-native-litert-lm.podspec +17 -19
- package/scripts/download-ios-frameworks.sh +17 -50
- package/scripts/framework-source.js +46 -0
- package/scripts/postinstall.js +40 -18
- package/src/__mocks__/react-native-nitro-modules.ts +58 -0
- package/src/__tests__/hooks.test.ts +153 -0
- package/src/__tests__/memoryTracker.test.ts +87 -0
- package/src/__tests__/modelFactory.test.ts +96 -0
- package/src/hooks.ts +29 -7
- package/src/index.ts +7 -10
- package/src/modelFactory.ts +104 -80
- package/src/specs/LiteRTLM.nitro.ts +106 -2
- package/cpp/HybridLiteRTLM.cpp +0 -939
- package/cpp/HybridLiteRTLM.hpp +0 -169
- package/cpp/IOSDownloadHelper.h +0 -24
- package/ios/IOSDownloadHelper.mm +0 -129
- package/scripts/build-ios-engine.sh +0 -302
- package/scripts/stubs/cxx_bridge_stubs.cc +0 -224
- package/scripts/stubs/gemma_model_constraint_provider.cc +0 -46
- package/scripts/stubs/llguidance_stubs.c +0 -101
- package/src/templates.ts +0 -105
|
@@ -27,6 +27,10 @@ import com.margelo.nitro.dev.litert.litertlm.Role
|
|
|
27
27
|
import com.margelo.nitro.core.Promise
|
|
28
28
|
import com.google.ai.edge.litertlm.Content
|
|
29
29
|
import com.google.ai.edge.litertlm.Contents
|
|
30
|
+
import com.google.ai.edge.litertlm.ExperimentalApi
|
|
31
|
+
import com.google.ai.edge.litertlm.ExperimentalFlags
|
|
32
|
+
import com.google.ai.edge.litertlm.OpenApiTool
|
|
33
|
+
import com.google.ai.edge.litertlm.ToolProvider
|
|
30
34
|
import java.util.concurrent.CountDownLatch
|
|
31
35
|
import java.util.concurrent.TimeUnit
|
|
32
36
|
import java.util.concurrent.atomic.AtomicBoolean
|
|
@@ -167,6 +171,8 @@ class HybridLiteRTLM : HybridLiteRTLMSpec() {
|
|
|
167
171
|
private var topP: Double = 0.95
|
|
168
172
|
private var maxTokens: Int = 1024
|
|
169
173
|
private var systemPrompt: String? = null
|
|
174
|
+
private var tools: Array<ToolDefinition>? = null
|
|
175
|
+
private var enableSpeculativeDecoding: Boolean = false
|
|
170
176
|
|
|
171
177
|
override val memorySize: Long
|
|
172
178
|
get() = 1024L * 1024L * 1024L // ~1GB (models are large)
|
|
@@ -196,8 +202,13 @@ class HybridLiteRTLM : HybridLiteRTLMSpec() {
|
|
|
196
202
|
cfg.topP?.let { topP = it }
|
|
197
203
|
cfg.maxTokens?.let { maxTokens = it.toInt() }
|
|
198
204
|
cfg.systemPrompt?.let { systemPrompt = it }
|
|
205
|
+
tools = cfg.tools
|
|
206
|
+
enableSpeculativeDecoding = cfg.enableSpeculativeDecoding ?: false
|
|
199
207
|
}
|
|
200
208
|
|
|
209
|
+
// Whether to run engine validation after loading
|
|
210
|
+
val shouldValidate = config?.validate?: false
|
|
211
|
+
|
|
201
212
|
try {
|
|
202
213
|
// Early GPU hardware check: probe for OpenCL library before
|
|
203
214
|
// spending time on engine creation. LiteRT-LM's GPU delegate
|
|
@@ -238,12 +249,12 @@ class HybridLiteRTLM : HybridLiteRTLMSpec() {
|
|
|
238
249
|
else -> com.google.ai.edge.litertlm.Backend.CPU()
|
|
239
250
|
}
|
|
240
251
|
|
|
241
|
-
// Detect multimodal support
|
|
252
|
+
// Detect multimodal support. Check config.multimodal flag first, then fall back to filename sniffing.
|
|
242
253
|
// Only Gemma 3n bundles vision/audio executors; Gemma 4 E2B is text-only.
|
|
243
254
|
// Passing vision/audio backends to a text-only model causes
|
|
244
255
|
// vision_litert_compiled_model_executor init failures.
|
|
245
256
|
val modelFileName = modelPath.substringAfterLast("/").lowercase()
|
|
246
|
-
val isMultimodal = modelFileName.contains("3n") || modelFileName.contains("gemma3")
|
|
257
|
+
val isMultimodal = config?.multimodal ?: (modelFileName.contains("3n") || modelFileName.contains("gemma3"))
|
|
247
258
|
|
|
248
259
|
val lmVisionBackend = if (isMultimodal) com.google.ai.edge.litertlm.Backend.GPU() else null
|
|
249
260
|
val lmAudioBackend = if (isMultimodal) com.google.ai.edge.litertlm.Backend.CPU() else null
|
|
@@ -275,6 +286,11 @@ class HybridLiteRTLM : HybridLiteRTLMSpec() {
|
|
|
275
286
|
|
|
276
287
|
if (isClosed) return@synchronized
|
|
277
288
|
|
|
289
|
+
if (enableSpeculativeDecoding) {
|
|
290
|
+
@OptIn(ExperimentalApi::class)
|
|
291
|
+
ExperimentalFlags.enableSpeculativeDecoding = true
|
|
292
|
+
}
|
|
293
|
+
|
|
278
294
|
// Initialize Engine
|
|
279
295
|
engine = Engine(engineConfig).also { it.initialize() }
|
|
280
296
|
Log.i(TAG, "Engine created and initialized successfully")
|
|
@@ -284,8 +300,17 @@ class HybridLiteRTLM : HybridLiteRTLMSpec() {
|
|
|
284
300
|
Log.i(TAG, "Conversation created successfully")
|
|
285
301
|
|
|
286
302
|
// Validate the engine actually works with a quick test inference.
|
|
287
|
-
// GPU
|
|
288
|
-
|
|
303
|
+
// GPU/NPU backends can initialize without error but silently fail to
|
|
304
|
+
// produce tokens — enabling this catches those failures at load time.
|
|
305
|
+
// CPU is always reliable so validation is never run on it, even when
|
|
306
|
+
// the `validate` flag is set.
|
|
307
|
+
if (shouldValidate) {
|
|
308
|
+
if (backend == Backend.GPU || backend == Backend.NPU) {
|
|
309
|
+
validateEngine()
|
|
310
|
+
} else {
|
|
311
|
+
Log.i(TAG, "Validation skipped: CPU backend is always reliable")
|
|
312
|
+
}
|
|
313
|
+
}
|
|
289
314
|
|
|
290
315
|
} catch (e: Exception) {
|
|
291
316
|
Log.e(TAG, "Failed to load model: ${e.message}", e)
|
|
@@ -349,32 +374,48 @@ class HybridLiteRTLM : HybridLiteRTLMSpec() {
|
|
|
349
374
|
// -------------------------------------------------------------------------
|
|
350
375
|
// sendMessageAsync - Streaming inference
|
|
351
376
|
// -------------------------------------------------------------------------
|
|
352
|
-
override fun sendMessageAsync(message: String, onToken: (String, Boolean) -> Unit) {
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
ensureLoaded()
|
|
377
|
+
override fun sendMessageAsync(message: String, onToken: (String, Boolean) -> Unit): Promise<Unit> {
|
|
378
|
+
return Promise.parallel {
|
|
379
|
+
val latch = CountDownLatch(1)
|
|
380
|
+
val errorRef = AtomicReference<Throwable?>(null)
|
|
357
381
|
|
|
358
|
-
|
|
359
|
-
history.add(Message(Role.USER, message))
|
|
360
|
-
Log.d(TAG, "sendMessageAsync: $message")
|
|
382
|
+
ensureLoaded()
|
|
361
383
|
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
onToken = onToken,
|
|
366
|
-
responseBuilder = fullResponseBuilder,
|
|
367
|
-
history = history,
|
|
368
|
-
userMessage = message,
|
|
369
|
-
onStatsReady = { stats -> lastStats = stats },
|
|
370
|
-
)
|
|
384
|
+
// Add user message to history
|
|
385
|
+
history.add(Message(Role.USER, message))
|
|
386
|
+
Log.d(TAG, "sendMessageAsync: $message")
|
|
371
387
|
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
388
|
+
val fullResponseBuilder = StringBuilder()
|
|
389
|
+
|
|
390
|
+
val listener = StreamingCallbackListener(
|
|
391
|
+
onToken = { token, done ->
|
|
392
|
+
onToken(token, done)
|
|
393
|
+
if (done) {
|
|
394
|
+
latch.countDown()
|
|
395
|
+
}
|
|
396
|
+
},
|
|
397
|
+
responseBuilder = fullResponseBuilder,
|
|
398
|
+
history = history,
|
|
399
|
+
userMessage = message,
|
|
400
|
+
onStatsReady = { stats -> lastStats = stats },
|
|
401
|
+
)
|
|
402
|
+
|
|
403
|
+
try {
|
|
404
|
+
val userMsg = LiteRTMessage.user(message)
|
|
405
|
+
conversation!!.sendMessageAsync(message = userMsg, callback = listener)
|
|
406
|
+
} catch (e: Exception) {
|
|
407
|
+
Log.e(TAG, "Failed to initiate async generation", e)
|
|
408
|
+
errorRef.set(e)
|
|
409
|
+
onToken("Error: ${e.message}", true)
|
|
410
|
+
latch.countDown()
|
|
411
|
+
}
|
|
412
|
+
|
|
413
|
+
// Wait for completion or error
|
|
414
|
+
latch.await()
|
|
415
|
+
val err = errorRef.get()
|
|
416
|
+
if (err != null) {
|
|
417
|
+
throw RuntimeException("Async inference failed: ${err.message}", err)
|
|
418
|
+
}
|
|
378
419
|
}
|
|
379
420
|
}
|
|
380
421
|
|
|
@@ -460,10 +501,87 @@ class HybridLiteRTLM : HybridLiteRTLMSpec() {
|
|
|
460
501
|
}
|
|
461
502
|
}
|
|
462
503
|
|
|
504
|
+
override fun sendMessageWithImageAsync(message: String, imagePath: String, onToken: (String, Boolean) -> Unit): Promise<Unit> {
|
|
505
|
+
return Promise.parallel {
|
|
506
|
+
val latch = CountDownLatch(1)
|
|
507
|
+
val errorRef = AtomicReference<Throwable?>(null)
|
|
508
|
+
|
|
509
|
+
ensureLoaded()
|
|
510
|
+
|
|
511
|
+
Log.i(TAG, "sendMessageWithImageAsync: $message, path=$imagePath")
|
|
512
|
+
|
|
513
|
+
// Resize image to prevent OOM on high-resolution photos
|
|
514
|
+
val processedImagePath = resizeImageIfNeeded(imagePath)
|
|
515
|
+
|
|
516
|
+
val fullResponseBuilder = StringBuilder()
|
|
517
|
+
|
|
518
|
+
val listener = StreamingCallbackListener(
|
|
519
|
+
onToken = { token, done ->
|
|
520
|
+
onToken(token, done)
|
|
521
|
+
if (done) {
|
|
522
|
+
latch.countDown()
|
|
523
|
+
}
|
|
524
|
+
},
|
|
525
|
+
responseBuilder = fullResponseBuilder,
|
|
526
|
+
history = history,
|
|
527
|
+
userMessage = message,
|
|
528
|
+
onStatsReady = { stats -> lastStats = stats },
|
|
529
|
+
)
|
|
530
|
+
|
|
531
|
+
try {
|
|
532
|
+
val textContent = Content.Text(message)
|
|
533
|
+
val userMsg = LiteRTMessage.user(Contents.of(textContent, Content.ImageFile(processedImagePath)))
|
|
534
|
+
|
|
535
|
+
history.add(Message(Role.USER, "$message [Image]"))
|
|
536
|
+
|
|
537
|
+
conversation!!.sendMessageAsync(message = userMsg, callback = listener)
|
|
538
|
+
} catch (e: Exception) {
|
|
539
|
+
// Clean up temp resized image to prevent cache dir bloat
|
|
540
|
+
if (processedImagePath != imagePath) {
|
|
541
|
+
try {
|
|
542
|
+
java.io.File(processedImagePath).delete()
|
|
543
|
+
} catch (e: Exception) {
|
|
544
|
+
Log.w(TAG, "Failed to clean up temp image: ${e.message}")
|
|
545
|
+
}
|
|
546
|
+
}
|
|
547
|
+
|
|
548
|
+
Log.e(TAG, "Failed to initiate async multimodal generation", e)
|
|
549
|
+
errorRef.set(e)
|
|
550
|
+
onToken("Error: ${e.message}", true)
|
|
551
|
+
latch.countDown()
|
|
552
|
+
}
|
|
553
|
+
|
|
554
|
+
// Wait for completion or error
|
|
555
|
+
latch.await()
|
|
556
|
+
|
|
557
|
+
// Clean up temp resized image to prevent cache dir bloat
|
|
558
|
+
if (processedImagePath != imagePath) {
|
|
559
|
+
try {
|
|
560
|
+
java.io.File(processedImagePath).delete()
|
|
561
|
+
} catch (e: Exception) {
|
|
562
|
+
Log.w(TAG, "Failed to clean up temp image: ${e.message}")
|
|
563
|
+
}
|
|
564
|
+
}
|
|
565
|
+
|
|
566
|
+
val err = errorRef.get()
|
|
567
|
+
if (err != null) {
|
|
568
|
+
throw RuntimeException("Async multimodal inference failed: ${err.message}", err)
|
|
569
|
+
}
|
|
570
|
+
}
|
|
571
|
+
}
|
|
572
|
+
|
|
463
573
|
override fun downloadModel(url: String, fileName: String, onProgress: ((Double) -> Unit)?): Promise<String> {
|
|
464
574
|
return Promise.parallel {
|
|
465
575
|
Log.i(TAG, "downloadModel: $url -> $fileName")
|
|
466
576
|
|
|
577
|
+
if (!url.startsWith("https://", ignoreCase = true)) {
|
|
578
|
+
throw IllegalArgumentException("Invalid download URL: HTTPS is required for security.")
|
|
579
|
+
}
|
|
580
|
+
|
|
581
|
+
if (fileName.contains("..") || fileName.contains("/") || fileName.contains("\\")) {
|
|
582
|
+
throw IllegalArgumentException("Invalid filename: path traversal or directory separators are not allowed.")
|
|
583
|
+
}
|
|
584
|
+
|
|
467
585
|
val context = LiteRTLMInitProvider.applicationContext ?: throw RuntimeException("Context not available")
|
|
468
586
|
val modelsDir = java.io.File(context.filesDir, "models")
|
|
469
587
|
if (!modelsDir.exists()) {
|
|
@@ -545,6 +663,11 @@ class HybridLiteRTLM : HybridLiteRTLMSpec() {
|
|
|
545
663
|
override fun deleteModel(fileName: String): Promise<Unit> {
|
|
546
664
|
return Promise.parallel {
|
|
547
665
|
Log.i(TAG, "deleteModel: $fileName")
|
|
666
|
+
|
|
667
|
+
if (fileName.contains("..") || fileName.contains("/") || fileName.contains("\\")) {
|
|
668
|
+
throw IllegalArgumentException("Invalid filename: path traversal or directory separators are not allowed.")
|
|
669
|
+
}
|
|
670
|
+
|
|
548
671
|
val context = LiteRTLMInitProvider.applicationContext ?: throw RuntimeException("Context not available")
|
|
549
672
|
val modelsDir = java.io.File(context.filesDir, "models")
|
|
550
673
|
val modelFile = java.io.File(modelsDir, fileName)
|
|
@@ -569,6 +692,54 @@ class HybridLiteRTLM : HybridLiteRTLMSpec() {
|
|
|
569
692
|
}
|
|
570
693
|
}
|
|
571
694
|
|
|
695
|
+
override fun sendMessageWithAudioAsync(message: String, audioPath: String, onToken: (String, Boolean) -> Unit): Promise<Unit> {
|
|
696
|
+
return Promise.parallel {
|
|
697
|
+
val latch = CountDownLatch(1)
|
|
698
|
+
val errorRef = AtomicReference<Throwable?>(null)
|
|
699
|
+
|
|
700
|
+
ensureLoaded()
|
|
701
|
+
|
|
702
|
+
Log.i(TAG, "sendMessageWithAudioAsync: $message, path=$audioPath")
|
|
703
|
+
|
|
704
|
+
val fullResponseBuilder = StringBuilder()
|
|
705
|
+
|
|
706
|
+
val listener = StreamingCallbackListener(
|
|
707
|
+
onToken = { token, done ->
|
|
708
|
+
onToken(token, done)
|
|
709
|
+
if (done) {
|
|
710
|
+
latch.countDown()
|
|
711
|
+
}
|
|
712
|
+
},
|
|
713
|
+
responseBuilder = fullResponseBuilder,
|
|
714
|
+
history = history,
|
|
715
|
+
userMessage = message,
|
|
716
|
+
onStatsReady = { stats -> lastStats = stats },
|
|
717
|
+
)
|
|
718
|
+
|
|
719
|
+
try {
|
|
720
|
+
val userMsg = LiteRTMessage.user(Contents.of(
|
|
721
|
+
Content.Text(message),
|
|
722
|
+
Content.AudioFile(audioPath)
|
|
723
|
+
))
|
|
724
|
+
|
|
725
|
+
history.add(Message(Role.USER, "$message [Audio]"))
|
|
726
|
+
|
|
727
|
+
conversation!!.sendMessageAsync(message = userMsg, callback = listener)
|
|
728
|
+
} catch (e: Exception) {
|
|
729
|
+
Log.e(TAG, "Failed to initiate async audio generation", e)
|
|
730
|
+
errorRef.set(e)
|
|
731
|
+
onToken("Error: ${e.message}", true)
|
|
732
|
+
latch.countDown()
|
|
733
|
+
}
|
|
734
|
+
|
|
735
|
+
latch.await()
|
|
736
|
+
val err = errorRef.get()
|
|
737
|
+
if (err != null) {
|
|
738
|
+
throw RuntimeException("Async audio inference failed: ${err.message}", err)
|
|
739
|
+
}
|
|
740
|
+
}
|
|
741
|
+
}
|
|
742
|
+
|
|
572
743
|
override fun sendMessageWithAudio(message: String, audioPath: String): Promise<String> {
|
|
573
744
|
return Promise.parallel {
|
|
574
745
|
ensureLoaded()
|
|
@@ -676,19 +847,9 @@ class HybridLiteRTLM : HybridLiteRTLMSpec() {
|
|
|
676
847
|
|
|
677
848
|
private fun cleanupInternal() {
|
|
678
849
|
try {
|
|
850
|
+
conversation?.close()
|
|
679
851
|
conversation = null
|
|
680
|
-
//
|
|
681
|
-
// Assuming Engine implements AutoCloseable or has close()
|
|
682
|
-
if (engine is AutoCloseable) {
|
|
683
|
-
(engine as AutoCloseable).close()
|
|
684
|
-
} else {
|
|
685
|
-
// Try reflection or just null it if no close method
|
|
686
|
-
try {
|
|
687
|
-
engine?.javaClass?.getMethod("close")?.invoke(engine)
|
|
688
|
-
} catch (e: Exception) {
|
|
689
|
-
// Method not found, rely on GC
|
|
690
|
-
}
|
|
691
|
-
}
|
|
852
|
+
engine?.close() // Direct call
|
|
692
853
|
engine = null
|
|
693
854
|
} catch (e: Exception) {
|
|
694
855
|
Log.e(TAG, "Error closing resources", e)
|
|
@@ -706,41 +867,37 @@ class HybridLiteRTLM : HybridLiteRTLMSpec() {
|
|
|
706
867
|
// v0.10.2 enforces single-session: close existing conversation first
|
|
707
868
|
conversation?.let { oldConv ->
|
|
708
869
|
try {
|
|
709
|
-
|
|
710
|
-
oldConv.close()
|
|
711
|
-
} else {
|
|
712
|
-
oldConv.javaClass.getMethod("close").invoke(oldConv)
|
|
713
|
-
}
|
|
870
|
+
oldConv.close()
|
|
714
871
|
} catch (e: Exception) {
|
|
715
872
|
Log.w(TAG, "Failed to close old conversation: ${e.message}")
|
|
716
873
|
}
|
|
717
874
|
conversation = null
|
|
718
875
|
}
|
|
876
|
+
// Map tools
|
|
877
|
+
val lmTools: List<ToolProvider>? = tools?.map { tool ->
|
|
878
|
+
val apiTool = object : OpenApiTool {
|
|
879
|
+
override fun getToolDescriptionJsonString(): String {
|
|
880
|
+
return tool.parametersJson
|
|
881
|
+
}
|
|
882
|
+
override fun execute(paramsJsonString: String): String {
|
|
883
|
+
return "{}"
|
|
884
|
+
}
|
|
885
|
+
}
|
|
886
|
+
(apiTool as Any) as ToolProvider
|
|
887
|
+
}
|
|
888
|
+
|
|
719
889
|
// Create conversation with explicit SamplerConfig (required by Gallery pattern).
|
|
720
890
|
// GPU backend may fail silently without proper sampler params.
|
|
721
891
|
val convConfig = ConversationConfig(
|
|
722
892
|
samplerConfig = SamplerConfig(
|
|
723
893
|
topK = topK,
|
|
724
|
-
topP = topP,
|
|
725
|
-
temperature = temperature,
|
|
726
|
-
)
|
|
894
|
+
topP = topP.toDouble(),
|
|
895
|
+
temperature = temperature.toDouble(),
|
|
896
|
+
),
|
|
897
|
+
systemInstruction = systemPrompt?.let { Contents.of(Content.Text(it)) },
|
|
898
|
+
tools = lmTools ?: emptyList()
|
|
727
899
|
)
|
|
728
900
|
conversation = engine!!.createConversation(convConfig)
|
|
729
|
-
// Apply system prompt/instruction if set
|
|
730
|
-
systemPrompt?.let { prompt ->
|
|
731
|
-
if (prompt.isNotEmpty()) {
|
|
732
|
-
try {
|
|
733
|
-
// Send system instruction as the first turn to prime the conversation.
|
|
734
|
-
// LiteRT-LM's Conversation API handles chat template formatting,
|
|
735
|
-
// including Gemma's <start_of_turn>system block.
|
|
736
|
-
val systemMsg = LiteRTMessage.system(prompt)
|
|
737
|
-
conversation!!.sendMessage(message = systemMsg)
|
|
738
|
-
Log.i(TAG, "System prompt applied (${prompt.length} chars)")
|
|
739
|
-
} catch (e: Exception) {
|
|
740
|
-
Log.w(TAG, "Failed to apply system prompt: ${e.message}")
|
|
741
|
-
}
|
|
742
|
-
}
|
|
743
|
-
}
|
|
744
901
|
}
|
|
745
902
|
|
|
746
903
|
/**
|
|
@@ -815,5 +972,62 @@ class HybridLiteRTLM : HybridLiteRTLMSpec() {
|
|
|
815
972
|
createNewConversation()
|
|
816
973
|
}
|
|
817
974
|
|
|
975
|
+
override fun sendMultimodalMessage(parts: Array<MultimodalPart>): Promise<String> {
|
|
976
|
+
return Promise.parallel {
|
|
977
|
+
ensureLoaded()
|
|
978
|
+
val contents = mutableListOf<Content>()
|
|
979
|
+
var userTextRepresentation = ""
|
|
980
|
+
for (part in parts) {
|
|
981
|
+
when (part.type) {
|
|
982
|
+
PartType.TEXT -> part.text?.let {
|
|
983
|
+
contents.add(Content.Text(it))
|
|
984
|
+
userTextRepresentation += "$it "
|
|
985
|
+
}
|
|
986
|
+
PartType.IMAGE -> part.imageBuffer?.let { buffer ->
|
|
987
|
+
val byteBuffer = buffer.getBuffer(false)
|
|
988
|
+
val bytes = ByteArray(byteBuffer.remaining())
|
|
989
|
+
byteBuffer.get(bytes)
|
|
990
|
+
contents.add(Content.ImageBytes(bytes))
|
|
991
|
+
userTextRepresentation += "[Image Buffer] "
|
|
992
|
+
}
|
|
993
|
+
PartType.AUDIO -> part.audioBuffer?.let { buffer ->
|
|
994
|
+
val byteBuffer = buffer.getBuffer(false)
|
|
995
|
+
val bytes = ByteArray(byteBuffer.remaining())
|
|
996
|
+
byteBuffer.get(bytes)
|
|
997
|
+
contents.add(Content.AudioBytes(bytes))
|
|
998
|
+
userTextRepresentation += "[Audio Buffer] "
|
|
999
|
+
}
|
|
1000
|
+
}
|
|
1001
|
+
}
|
|
1002
|
+
userTextRepresentation = userTextRepresentation.trim()
|
|
1003
|
+
history.add(Message(Role.USER, userTextRepresentation))
|
|
1004
|
+
|
|
1005
|
+
val userMsg = LiteRTMessage.user(Contents.of(contents))
|
|
1006
|
+
val startTime = System.nanoTime()
|
|
1007
|
+
val responseMsg = conversation!!.sendMessage(message = userMsg)
|
|
1008
|
+
val elapsedMs = (System.nanoTime() - startTime) / 1_000_000.0
|
|
1009
|
+
|
|
1010
|
+
val response = responseMsg.contents.contents
|
|
1011
|
+
.filterIsInstance<Content.Text>()
|
|
1012
|
+
.joinToString("") { it.text }
|
|
1013
|
+
|
|
1014
|
+
history.add(Message(Role.MODEL, response))
|
|
1015
|
+
|
|
1016
|
+
val promptTokens = userTextRepresentation.length / 4.0
|
|
1017
|
+
val completionTokens = response.length / 4.0
|
|
1018
|
+
lastStats = GenerationStats(
|
|
1019
|
+
promptTokens = promptTokens,
|
|
1020
|
+
completionTokens = completionTokens,
|
|
1021
|
+
totalTokens = promptTokens + completionTokens,
|
|
1022
|
+
timeToFirstToken = 0.0,
|
|
1023
|
+
totalTime = elapsedMs,
|
|
1024
|
+
tokensPerSecond = if (elapsedMs > 0) completionTokens / (elapsedMs / 1000.0) else 0.0
|
|
1025
|
+
)
|
|
1026
|
+
response
|
|
1027
|
+
}
|
|
1028
|
+
}
|
|
818
1029
|
|
|
1030
|
+
override fun countTokens(text: String): Double {
|
|
1031
|
+
return -1.0
|
|
1032
|
+
}
|
|
819
1033
|
}
|
|
@@ -1,18 +1,35 @@
|
|
|
1
1
|
package dev.litert.litertlm
|
|
2
2
|
|
|
3
|
+
import android.os.Build
|
|
4
|
+
import android.util.Log
|
|
3
5
|
import com.facebook.react.TurboReactPackage
|
|
4
6
|
import com.facebook.react.bridge.NativeModule
|
|
5
7
|
import com.facebook.react.bridge.ReactApplicationContext
|
|
6
8
|
import com.facebook.react.module.model.ReactModuleInfo
|
|
7
9
|
import com.facebook.react.module.model.ReactModuleInfoProvider
|
|
8
|
-
import com.margelo.nitro.core.HybridObject
|
|
9
10
|
|
|
10
11
|
|
|
11
12
|
import com.margelo.nitro.dev.litert.litertlm.LiteRTLMOnLoad
|
|
12
13
|
|
|
13
14
|
class LiteRTLMPackage : TurboReactPackage() {
|
|
15
|
+
companion object {
|
|
16
|
+
private const val TAG = "LiteRTLMPackage"
|
|
17
|
+
|
|
18
|
+
private fun isSupportedPrimaryAbi(): Boolean {
|
|
19
|
+
val primaryAbi = Build.SUPPORTED_64_BIT_ABIS.firstOrNull() ?: return false
|
|
20
|
+
return primaryAbi == "arm64-v8a"
|
|
21
|
+
}
|
|
22
|
+
}
|
|
14
23
|
init {
|
|
15
|
-
|
|
24
|
+
if (!isSupportedPrimaryAbi()) {
|
|
25
|
+
Log.w(TAG, "Skipping LiteRTLM native init on unsupported primary ABI: ${Build.SUPPORTED_64_BIT_ABIS.firstOrNull()}")
|
|
26
|
+
} else {
|
|
27
|
+
try {
|
|
28
|
+
LiteRTLMOnLoad.initializeNative()
|
|
29
|
+
} catch (e: UnsatisfiedLinkError) {
|
|
30
|
+
Log.e(TAG, "LiteRTLM native init failed; disabling LiteRTLM for this process.", e)
|
|
31
|
+
}
|
|
32
|
+
}
|
|
16
33
|
}
|
|
17
34
|
|
|
18
35
|
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
package com.margelo.nitro.core
|
|
2
|
+
|
|
3
|
+
import androidx.annotation.Keep
|
|
4
|
+
import com.facebook.proguard.annotations.DoNotStrip
|
|
5
|
+
|
|
6
|
+
@Keep
|
|
7
|
+
@DoNotStrip
|
|
8
|
+
class Promise<T> {
|
|
9
|
+
companion object {
|
|
10
|
+
@JvmStatic
|
|
11
|
+
fun <T> parallel(block: () -> T): Promise<T> {
|
|
12
|
+
val promise = Promise<T>()
|
|
13
|
+
try {
|
|
14
|
+
val result = block()
|
|
15
|
+
promise.resolve(result)
|
|
16
|
+
} catch (e: Throwable) {
|
|
17
|
+
promise.reject(e)
|
|
18
|
+
}
|
|
19
|
+
return promise
|
|
20
|
+
}
|
|
21
|
+
}
|
|
22
|
+
|
|
23
|
+
var result: T? = null
|
|
24
|
+
private set
|
|
25
|
+
var error: Throwable? = null
|
|
26
|
+
private set
|
|
27
|
+
var isCompleted = false
|
|
28
|
+
private set
|
|
29
|
+
private val callbacks = mutableListOf<(T?, Throwable?) -> Unit>()
|
|
30
|
+
|
|
31
|
+
fun resolve(value: T) {
|
|
32
|
+
synchronized(this) {
|
|
33
|
+
result = value
|
|
34
|
+
isCompleted = true
|
|
35
|
+
callbacks.forEach { it(value, null) }
|
|
36
|
+
}
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
fun reject(exception: Throwable) {
|
|
40
|
+
synchronized(this) {
|
|
41
|
+
error = exception
|
|
42
|
+
isCompleted = true
|
|
43
|
+
callbacks.forEach { it(null, exception) }
|
|
44
|
+
}
|
|
45
|
+
}
|
|
46
|
+
}
|
|
@@ -0,0 +1,105 @@
|
|
|
1
|
+
package com.margelo.nitro.dev.litert.litertlm
|
|
2
|
+
|
|
3
|
+
import org.junit.Assert.*
|
|
4
|
+
import org.junit.Before
|
|
5
|
+
import org.junit.After
|
|
6
|
+
import org.junit.Test
|
|
7
|
+
import org.junit.runner.RunWith
|
|
8
|
+
import org.robolectric.RobolectricTestRunner
|
|
9
|
+
import org.robolectric.RuntimeEnvironment
|
|
10
|
+
import dev.litert.litertlm.LiteRTLMInitProvider
|
|
11
|
+
import java.lang.IllegalArgumentException
|
|
12
|
+
|
|
13
|
+
@RunWith(RobolectricTestRunner::class)
|
|
14
|
+
class HybridLiteRTLMTest {
|
|
15
|
+
private lateinit var bridge: HybridLiteRTLM
|
|
16
|
+
|
|
17
|
+
@Before
|
|
18
|
+
fun setUp() {
|
|
19
|
+
// Initialize the static applicationContext inside LiteRTLMInitProvider via reflection
|
|
20
|
+
try {
|
|
21
|
+
val field = LiteRTLMInitProvider::class.java.getDeclaredField("applicationContext")
|
|
22
|
+
field.isAccessible = true
|
|
23
|
+
field.set(null, RuntimeEnvironment.getApplication())
|
|
24
|
+
} catch (e: Exception) {
|
|
25
|
+
e.printStackTrace()
|
|
26
|
+
}
|
|
27
|
+
|
|
28
|
+
bridge = HybridLiteRTLM()
|
|
29
|
+
}
|
|
30
|
+
|
|
31
|
+
@After
|
|
32
|
+
fun tearDown() {
|
|
33
|
+
bridge.close()
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
@Test
|
|
37
|
+
fun testAndroidPathTraversalPrevention() {
|
|
38
|
+
val traversals = arrayOf("../secret", "/etc/hosts", "nested\\..\\file", "..", "../", "..\\")
|
|
39
|
+
for (traversal in traversals) {
|
|
40
|
+
val promise = bridge.deleteModel(traversal)
|
|
41
|
+
assertNotNull("Promise should not be null", promise)
|
|
42
|
+
assertTrue("Promise should be completed", promise.isCompleted)
|
|
43
|
+
assertNotNull("Promise should have rejected with an error for filename: $traversal", promise.error)
|
|
44
|
+
val error = promise.error!!
|
|
45
|
+
val errMsg = error.message ?: error.cause?.message ?: ""
|
|
46
|
+
assertTrue("Expected message to contain traversal warning, got: $errMsg",
|
|
47
|
+
errMsg.contains("path traversal or directory separators are not allowed"))
|
|
48
|
+
}
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
@Test
|
|
52
|
+
fun testAndroidHTTPSDownloadEnforcement() {
|
|
53
|
+
val promise = bridge.downloadModel("http://insecure.site/model.bin", "model.bin", null)
|
|
54
|
+
assertNotNull("Promise should not be null", promise)
|
|
55
|
+
assertTrue("Promise should be completed", promise.isCompleted)
|
|
56
|
+
assertNotNull("Promise should have rejected with an error", promise.error)
|
|
57
|
+
val error = promise.error!!
|
|
58
|
+
val errMsg = error.message ?: error.cause?.message ?: ""
|
|
59
|
+
assertTrue("Expected message to contain HTTPS warning, got: $errMsg",
|
|
60
|
+
errMsg.contains("HTTPS is required for security"))
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
@Test
|
|
64
|
+
fun testAndroidMemoryTelemetry() {
|
|
65
|
+
val mem = bridge.getMemoryUsage()
|
|
66
|
+
assertNotNull(mem)
|
|
67
|
+
assertTrue(mem.nativeHeapBytes >= 0.0)
|
|
68
|
+
assertTrue(mem.residentBytes >= 0.0)
|
|
69
|
+
assertTrue(mem.availableMemoryBytes >= 0.0)
|
|
70
|
+
}
|
|
71
|
+
|
|
72
|
+
@Test
|
|
73
|
+
fun testSendMessageWithImageAsyncRejectsWithoutModel() {
|
|
74
|
+
val promise = bridge.sendMessageWithImageAsync("hello", "/tmp/image.jpg") { _, _ -> }
|
|
75
|
+
assertNotNull("Promise should not be null", promise)
|
|
76
|
+
assertTrue("Promise should be completed", promise.isCompleted)
|
|
77
|
+
assertNotNull("Promise should have rejected without model", promise.error)
|
|
78
|
+
val errMsg = promise.error!!.message ?: promise.error!!.cause?.message ?: ""
|
|
79
|
+
assertTrue("Expected no-model error, got: $errMsg",
|
|
80
|
+
errMsg.contains("No model loaded"))
|
|
81
|
+
}
|
|
82
|
+
|
|
83
|
+
@Test
|
|
84
|
+
fun testSendMessageWithAudioAsyncRejectsWithoutModel() {
|
|
85
|
+
val promise = bridge.sendMessageWithAudioAsync("hello", "/tmp/audio.wav") { _, _ -> }
|
|
86
|
+
assertNotNull("Promise should not be null", promise)
|
|
87
|
+
assertTrue("Promise should be completed", promise.isCompleted)
|
|
88
|
+
assertNotNull("Promise should have rejected without model", promise.error)
|
|
89
|
+
val errMsg = promise.error!!.message ?: promise.error!!.cause?.message ?: ""
|
|
90
|
+
assertTrue("Expected no-model error, got: $errMsg",
|
|
91
|
+
errMsg.contains("No model loaded"))
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
@Test
|
|
95
|
+
fun testAndroidInitialStats() {
|
|
96
|
+
val stats = bridge.getStats()
|
|
97
|
+
assertNotNull(stats)
|
|
98
|
+
assertEquals(0.0, stats.promptTokens, 0.0)
|
|
99
|
+
assertEquals(0.0, stats.completionTokens, 0.0)
|
|
100
|
+
assertEquals(0.0, stats.totalTokens, 0.0)
|
|
101
|
+
assertEquals(0.0, stats.timeToFirstToken, 0.0)
|
|
102
|
+
assertEquals(0.0, stats.totalTime, 0.0)
|
|
103
|
+
assertEquals(0.0, stats.tokensPerSecond, 0.0)
|
|
104
|
+
}
|
|
105
|
+
}
|