react-native-litert-lm 0.3.7 → 0.4.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +153 -135
- package/android/build.gradle +12 -0
- package/android/src/main/AndroidManifest.xml +5 -0
- package/android/src/main/java/com/margelo/nitro/dev/litert/litertlm/HybridLiteRTLM.kt +159 -62
- package/android/src/main/java/dev/litert/litertlm/LiteRTLMPackage.kt +19 -2
- package/android/src/test/java/com/margelo/nitro/core/Promise.kt +46 -0
- package/android/src/test/java/com/margelo/nitro/dev/litert/litertlm/HybridLiteRTLMTest.kt +83 -0
- package/ios/HybridLiteRTLM.swift +1058 -0
- package/ios/Tests/HybridLiteRTLMTests.swift +67 -0
- package/lib/__mocks__/react-native-nitro-modules.d.ts +61 -0
- package/lib/__mocks__/react-native-nitro-modules.js +50 -0
- package/lib/__tests__/hooks.test.d.ts +1 -0
- package/lib/__tests__/hooks.test.js +124 -0
- package/lib/__tests__/memoryTracker.test.d.ts +1 -0
- package/lib/__tests__/memoryTracker.test.js +74 -0
- package/lib/__tests__/modelFactory.test.d.ts +1 -0
- package/lib/__tests__/modelFactory.test.js +52 -0
- package/lib/hooks.js +1 -1
- package/lib/index.d.ts +0 -2
- package/lib/index.js +1 -5
- package/lib/modelFactory.js +62 -63
- package/lib/specs/LiteRTLM.nitro.d.ts +71 -2
- package/nitrogen/generated/android/c++/JHybridLiteRTLMSpec.cpp +62 -7
- package/nitrogen/generated/android/c++/JHybridLiteRTLMSpec.hpp +3 -1
- package/nitrogen/generated/android/c++/JLLMConfig.hpp +40 -3
- package/nitrogen/generated/android/c++/JMultimodalPart.hpp +74 -0
- package/nitrogen/generated/android/c++/JPartType.hpp +61 -0
- package/nitrogen/generated/android/c++/JToolDefinition.hpp +65 -0
- package/nitrogen/generated/android/kotlin/com/margelo/nitro/dev/litert/litertlm/GenerationStats.kt +23 -0
- package/nitrogen/generated/android/kotlin/com/margelo/nitro/dev/litert/litertlm/HybridLiteRTLMSpec.kt +10 -2
- package/nitrogen/generated/android/kotlin/com/margelo/nitro/dev/litert/litertlm/LLMConfig.kt +46 -3
- package/nitrogen/generated/android/kotlin/com/margelo/nitro/dev/litert/litertlm/MemoryUsage.kt +19 -0
- package/nitrogen/generated/android/kotlin/com/margelo/nitro/dev/litert/litertlm/Message.kt +15 -0
- package/nitrogen/generated/android/kotlin/com/margelo/nitro/dev/litert/litertlm/MultimodalPart.kt +66 -0
- package/nitrogen/generated/android/kotlin/com/margelo/nitro/dev/litert/litertlm/PartType.kt +24 -0
- package/nitrogen/generated/android/kotlin/com/margelo/nitro/dev/litert/litertlm/ToolDefinition.kt +61 -0
- package/nitrogen/generated/ios/LiteRTLM-Swift-Cxx-Bridge.cpp +57 -1
- package/nitrogen/generated/ios/LiteRTLM-Swift-Cxx-Bridge.hpp +414 -3
- package/nitrogen/generated/ios/LiteRTLM-Swift-Cxx-Umbrella.hpp +41 -3
- package/nitrogen/generated/ios/LiteRTLMAutolinking.mm +4 -6
- package/nitrogen/generated/ios/LiteRTLMAutolinking.swift +10 -0
- package/nitrogen/generated/ios/c++/HybridLiteRTLMSpecSwift.cpp +11 -0
- package/nitrogen/generated/ios/c++/HybridLiteRTLMSpecSwift.hpp +224 -0
- package/nitrogen/generated/ios/swift/Backend.swift +44 -0
- package/nitrogen/generated/ios/swift/Func_void.swift +46 -0
- package/nitrogen/generated/ios/swift/Func_void_double.swift +46 -0
- package/nitrogen/generated/ios/swift/Func_void_std__exception_ptr.swift +46 -0
- package/nitrogen/generated/ios/swift/Func_void_std__string.swift +46 -0
- package/nitrogen/generated/ios/swift/Func_void_std__string_bool.swift +46 -0
- package/nitrogen/generated/ios/swift/GenerationStats.swift +54 -0
- package/nitrogen/generated/ios/swift/HybridLiteRTLMSpec.swift +69 -0
- package/nitrogen/generated/ios/swift/HybridLiteRTLMSpec_cxx.swift +383 -0
- package/nitrogen/generated/ios/swift/LLMConfig.swift +203 -0
- package/nitrogen/generated/ios/swift/MemoryUsage.swift +44 -0
- package/nitrogen/generated/ios/swift/Message.swift +34 -0
- package/nitrogen/generated/ios/swift/MultimodalPart.swift +83 -0
- package/nitrogen/generated/ios/swift/PartType.swift +44 -0
- package/nitrogen/generated/ios/swift/Role.swift +44 -0
- package/nitrogen/generated/ios/swift/ToolDefinition.swift +39 -0
- package/nitrogen/generated/shared/c++/HybridLiteRTLMSpec.cpp +2 -0
- package/nitrogen/generated/shared/c++/HybridLiteRTLMSpec.hpp +7 -2
- package/nitrogen/generated/shared/c++/LLMConfig.hpp +22 -2
- package/nitrogen/generated/shared/c++/MultimodalPart.hpp +99 -0
- package/nitrogen/generated/shared/c++/PartType.hpp +80 -0
- package/nitrogen/generated/shared/c++/ToolDefinition.hpp +91 -0
- package/package.json +16 -8
- package/react-native-litert-lm.podspec +15 -19
- package/scripts/download-ios-frameworks.sh +14 -48
- package/scripts/postinstall.js +1 -2
- package/src/__mocks__/react-native-nitro-modules.ts +48 -0
- package/src/__tests__/hooks.test.ts +153 -0
- package/src/__tests__/memoryTracker.test.ts +87 -0
- package/src/__tests__/modelFactory.test.ts +68 -0
- package/src/hooks.ts +1 -1
- package/src/index.ts +0 -7
- package/src/modelFactory.ts +82 -80
- package/src/specs/LiteRTLM.nitro.ts +80 -2
- package/cpp/HybridLiteRTLM.cpp +0 -939
- package/cpp/HybridLiteRTLM.hpp +0 -169
- package/cpp/IOSDownloadHelper.h +0 -24
- package/ios/IOSDownloadHelper.mm +0 -129
- package/scripts/build-ios-engine.sh +0 -302
- package/scripts/stubs/cxx_bridge_stubs.cc +0 -224
- package/scripts/stubs/gemma_model_constraint_provider.cc +0 -46
- package/scripts/stubs/llguidance_stubs.c +0 -101
- package/src/templates.ts +0 -105
|
@@ -27,6 +27,10 @@ import com.margelo.nitro.dev.litert.litertlm.Role
|
|
|
27
27
|
import com.margelo.nitro.core.Promise
|
|
28
28
|
import com.google.ai.edge.litertlm.Content
|
|
29
29
|
import com.google.ai.edge.litertlm.Contents
|
|
30
|
+
import com.google.ai.edge.litertlm.ExperimentalApi
|
|
31
|
+
import com.google.ai.edge.litertlm.ExperimentalFlags
|
|
32
|
+
import com.google.ai.edge.litertlm.OpenApiTool
|
|
33
|
+
import com.google.ai.edge.litertlm.ToolProvider
|
|
30
34
|
import java.util.concurrent.CountDownLatch
|
|
31
35
|
import java.util.concurrent.TimeUnit
|
|
32
36
|
import java.util.concurrent.atomic.AtomicBoolean
|
|
@@ -167,6 +171,8 @@ class HybridLiteRTLM : HybridLiteRTLMSpec() {
|
|
|
167
171
|
private var topP: Double = 0.95
|
|
168
172
|
private var maxTokens: Int = 1024
|
|
169
173
|
private var systemPrompt: String? = null
|
|
174
|
+
private var tools: Array<ToolDefinition>? = null
|
|
175
|
+
private var enableSpeculativeDecoding: Boolean = false
|
|
170
176
|
|
|
171
177
|
override val memorySize: Long
|
|
172
178
|
get() = 1024L * 1024L * 1024L // ~1GB (models are large)
|
|
@@ -196,8 +202,13 @@ class HybridLiteRTLM : HybridLiteRTLMSpec() {
|
|
|
196
202
|
cfg.topP?.let { topP = it }
|
|
197
203
|
cfg.maxTokens?.let { maxTokens = it.toInt() }
|
|
198
204
|
cfg.systemPrompt?.let { systemPrompt = it }
|
|
205
|
+
tools = cfg.tools
|
|
206
|
+
enableSpeculativeDecoding = cfg.enableSpeculativeDecoding ?: false
|
|
199
207
|
}
|
|
200
208
|
|
|
209
|
+
// Whether to run engine validation after loading
|
|
210
|
+
val shouldValidate = config?.validate?: false
|
|
211
|
+
|
|
201
212
|
try {
|
|
202
213
|
// Early GPU hardware check: probe for OpenCL library before
|
|
203
214
|
// spending time on engine creation. LiteRT-LM's GPU delegate
|
|
@@ -238,12 +249,12 @@ class HybridLiteRTLM : HybridLiteRTLMSpec() {
|
|
|
238
249
|
else -> com.google.ai.edge.litertlm.Backend.CPU()
|
|
239
250
|
}
|
|
240
251
|
|
|
241
|
-
// Detect multimodal support
|
|
252
|
+
// Detect multimodal support. Check config.multimodal flag first, then fall back to filename sniffing.
|
|
242
253
|
// Only Gemma 3n bundles vision/audio executors; Gemma 4 E2B is text-only.
|
|
243
254
|
// Passing vision/audio backends to a text-only model causes
|
|
244
255
|
// vision_litert_compiled_model_executor init failures.
|
|
245
256
|
val modelFileName = modelPath.substringAfterLast("/").lowercase()
|
|
246
|
-
val isMultimodal = modelFileName.contains("3n") || modelFileName.contains("gemma3")
|
|
257
|
+
val isMultimodal = config?.multimodal ?: (modelFileName.contains("3n") || modelFileName.contains("gemma3"))
|
|
247
258
|
|
|
248
259
|
val lmVisionBackend = if (isMultimodal) com.google.ai.edge.litertlm.Backend.GPU() else null
|
|
249
260
|
val lmAudioBackend = if (isMultimodal) com.google.ai.edge.litertlm.Backend.CPU() else null
|
|
@@ -275,6 +286,11 @@ class HybridLiteRTLM : HybridLiteRTLMSpec() {
|
|
|
275
286
|
|
|
276
287
|
if (isClosed) return@synchronized
|
|
277
288
|
|
|
289
|
+
if (enableSpeculativeDecoding) {
|
|
290
|
+
@OptIn(ExperimentalApi::class)
|
|
291
|
+
ExperimentalFlags.enableSpeculativeDecoding = true
|
|
292
|
+
}
|
|
293
|
+
|
|
278
294
|
// Initialize Engine
|
|
279
295
|
engine = Engine(engineConfig).also { it.initialize() }
|
|
280
296
|
Log.i(TAG, "Engine created and initialized successfully")
|
|
@@ -284,8 +300,17 @@ class HybridLiteRTLM : HybridLiteRTLMSpec() {
|
|
|
284
300
|
Log.i(TAG, "Conversation created successfully")
|
|
285
301
|
|
|
286
302
|
// Validate the engine actually works with a quick test inference.
|
|
287
|
-
// GPU
|
|
288
|
-
|
|
303
|
+
// GPU/NPU backends can initialize without error but silently fail to
|
|
304
|
+
// produce tokens — enabling this catches those failures at load time.
|
|
305
|
+
// CPU is always reliable so validation is never run on it, even when
|
|
306
|
+
// the `validate` flag is set.
|
|
307
|
+
if (shouldValidate) {
|
|
308
|
+
if (backend == Backend.GPU || backend == Backend.NPU) {
|
|
309
|
+
validateEngine()
|
|
310
|
+
} else {
|
|
311
|
+
Log.i(TAG, "Validation skipped: CPU backend is always reliable")
|
|
312
|
+
}
|
|
313
|
+
}
|
|
289
314
|
|
|
290
315
|
} catch (e: Exception) {
|
|
291
316
|
Log.e(TAG, "Failed to load model: ${e.message}", e)
|
|
@@ -349,32 +374,48 @@ class HybridLiteRTLM : HybridLiteRTLMSpec() {
|
|
|
349
374
|
// -------------------------------------------------------------------------
|
|
350
375
|
// sendMessageAsync - Streaming inference
|
|
351
376
|
// -------------------------------------------------------------------------
|
|
352
|
-
override fun sendMessageAsync(message: String, onToken: (String, Boolean) -> Unit) {
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
ensureLoaded()
|
|
377
|
+
override fun sendMessageAsync(message: String, onToken: (String, Boolean) -> Unit): Promise<Unit> {
|
|
378
|
+
return Promise.parallel {
|
|
379
|
+
val latch = CountDownLatch(1)
|
|
380
|
+
val errorRef = AtomicReference<Throwable?>(null)
|
|
357
381
|
|
|
358
|
-
|
|
359
|
-
history.add(Message(Role.USER, message))
|
|
360
|
-
Log.d(TAG, "sendMessageAsync: $message")
|
|
382
|
+
ensureLoaded()
|
|
361
383
|
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
onToken = onToken,
|
|
366
|
-
responseBuilder = fullResponseBuilder,
|
|
367
|
-
history = history,
|
|
368
|
-
userMessage = message,
|
|
369
|
-
onStatsReady = { stats -> lastStats = stats },
|
|
370
|
-
)
|
|
384
|
+
// Add user message to history
|
|
385
|
+
history.add(Message(Role.USER, message))
|
|
386
|
+
Log.d(TAG, "sendMessageAsync: $message")
|
|
371
387
|
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
388
|
+
val fullResponseBuilder = StringBuilder()
|
|
389
|
+
|
|
390
|
+
val listener = StreamingCallbackListener(
|
|
391
|
+
onToken = { token, done ->
|
|
392
|
+
onToken(token, done)
|
|
393
|
+
if (done) {
|
|
394
|
+
latch.countDown()
|
|
395
|
+
}
|
|
396
|
+
},
|
|
397
|
+
responseBuilder = fullResponseBuilder,
|
|
398
|
+
history = history,
|
|
399
|
+
userMessage = message,
|
|
400
|
+
onStatsReady = { stats -> lastStats = stats },
|
|
401
|
+
)
|
|
402
|
+
|
|
403
|
+
try {
|
|
404
|
+
val userMsg = LiteRTMessage.user(message)
|
|
405
|
+
conversation!!.sendMessageAsync(message = userMsg, callback = listener)
|
|
406
|
+
} catch (e: Exception) {
|
|
407
|
+
Log.e(TAG, "Failed to initiate async generation", e)
|
|
408
|
+
errorRef.set(e)
|
|
409
|
+
onToken("Error: ${e.message}", true)
|
|
410
|
+
latch.countDown()
|
|
411
|
+
}
|
|
412
|
+
|
|
413
|
+
// Wait for completion or error
|
|
414
|
+
latch.await()
|
|
415
|
+
val err = errorRef.get()
|
|
416
|
+
if (err != null) {
|
|
417
|
+
throw RuntimeException("Async inference failed: ${err.message}", err)
|
|
418
|
+
}
|
|
378
419
|
}
|
|
379
420
|
}
|
|
380
421
|
|
|
@@ -464,6 +505,14 @@ class HybridLiteRTLM : HybridLiteRTLMSpec() {
|
|
|
464
505
|
return Promise.parallel {
|
|
465
506
|
Log.i(TAG, "downloadModel: $url -> $fileName")
|
|
466
507
|
|
|
508
|
+
if (!url.startsWith("https://", ignoreCase = true)) {
|
|
509
|
+
throw IllegalArgumentException("Invalid download URL: HTTPS is required for security.")
|
|
510
|
+
}
|
|
511
|
+
|
|
512
|
+
if (fileName.contains("..") || fileName.contains("/") || fileName.contains("\\")) {
|
|
513
|
+
throw IllegalArgumentException("Invalid filename: path traversal or directory separators are not allowed.")
|
|
514
|
+
}
|
|
515
|
+
|
|
467
516
|
val context = LiteRTLMInitProvider.applicationContext ?: throw RuntimeException("Context not available")
|
|
468
517
|
val modelsDir = java.io.File(context.filesDir, "models")
|
|
469
518
|
if (!modelsDir.exists()) {
|
|
@@ -545,6 +594,11 @@ class HybridLiteRTLM : HybridLiteRTLMSpec() {
|
|
|
545
594
|
override fun deleteModel(fileName: String): Promise<Unit> {
|
|
546
595
|
return Promise.parallel {
|
|
547
596
|
Log.i(TAG, "deleteModel: $fileName")
|
|
597
|
+
|
|
598
|
+
if (fileName.contains("..") || fileName.contains("/") || fileName.contains("\\")) {
|
|
599
|
+
throw IllegalArgumentException("Invalid filename: path traversal or directory separators are not allowed.")
|
|
600
|
+
}
|
|
601
|
+
|
|
548
602
|
val context = LiteRTLMInitProvider.applicationContext ?: throw RuntimeException("Context not available")
|
|
549
603
|
val modelsDir = java.io.File(context.filesDir, "models")
|
|
550
604
|
val modelFile = java.io.File(modelsDir, fileName)
|
|
@@ -676,19 +730,9 @@ class HybridLiteRTLM : HybridLiteRTLMSpec() {
|
|
|
676
730
|
|
|
677
731
|
private fun cleanupInternal() {
|
|
678
732
|
try {
|
|
733
|
+
conversation?.close()
|
|
679
734
|
conversation = null
|
|
680
|
-
//
|
|
681
|
-
// Assuming Engine implements AutoCloseable or has close()
|
|
682
|
-
if (engine is AutoCloseable) {
|
|
683
|
-
(engine as AutoCloseable).close()
|
|
684
|
-
} else {
|
|
685
|
-
// Try reflection or just null it if no close method
|
|
686
|
-
try {
|
|
687
|
-
engine?.javaClass?.getMethod("close")?.invoke(engine)
|
|
688
|
-
} catch (e: Exception) {
|
|
689
|
-
// Method not found, rely on GC
|
|
690
|
-
}
|
|
691
|
-
}
|
|
735
|
+
engine?.close() // Direct call
|
|
692
736
|
engine = null
|
|
693
737
|
} catch (e: Exception) {
|
|
694
738
|
Log.e(TAG, "Error closing resources", e)
|
|
@@ -706,41 +750,37 @@ class HybridLiteRTLM : HybridLiteRTLMSpec() {
|
|
|
706
750
|
// v0.10.2 enforces single-session: close existing conversation first
|
|
707
751
|
conversation?.let { oldConv ->
|
|
708
752
|
try {
|
|
709
|
-
|
|
710
|
-
oldConv.close()
|
|
711
|
-
} else {
|
|
712
|
-
oldConv.javaClass.getMethod("close").invoke(oldConv)
|
|
713
|
-
}
|
|
753
|
+
oldConv.close()
|
|
714
754
|
} catch (e: Exception) {
|
|
715
755
|
Log.w(TAG, "Failed to close old conversation: ${e.message}")
|
|
716
756
|
}
|
|
717
757
|
conversation = null
|
|
718
758
|
}
|
|
759
|
+
// Map tools
|
|
760
|
+
val lmTools: List<ToolProvider>? = tools?.map { tool ->
|
|
761
|
+
val apiTool = object : OpenApiTool {
|
|
762
|
+
override fun getToolDescriptionJsonString(): String {
|
|
763
|
+
return tool.parametersJson
|
|
764
|
+
}
|
|
765
|
+
override fun execute(paramsJsonString: String): String {
|
|
766
|
+
return "{}"
|
|
767
|
+
}
|
|
768
|
+
}
|
|
769
|
+
(apiTool as Any) as ToolProvider
|
|
770
|
+
}
|
|
771
|
+
|
|
719
772
|
// Create conversation with explicit SamplerConfig (required by Gallery pattern).
|
|
720
773
|
// GPU backend may fail silently without proper sampler params.
|
|
721
774
|
val convConfig = ConversationConfig(
|
|
722
775
|
samplerConfig = SamplerConfig(
|
|
723
776
|
topK = topK,
|
|
724
|
-
topP = topP,
|
|
725
|
-
temperature = temperature,
|
|
726
|
-
)
|
|
777
|
+
topP = topP.toDouble(),
|
|
778
|
+
temperature = temperature.toDouble(),
|
|
779
|
+
),
|
|
780
|
+
systemInstruction = systemPrompt?.let { Contents.of(Content.Text(it)) },
|
|
781
|
+
tools = lmTools ?: emptyList()
|
|
727
782
|
)
|
|
728
783
|
conversation = engine!!.createConversation(convConfig)
|
|
729
|
-
// Apply system prompt/instruction if set
|
|
730
|
-
systemPrompt?.let { prompt ->
|
|
731
|
-
if (prompt.isNotEmpty()) {
|
|
732
|
-
try {
|
|
733
|
-
// Send system instruction as the first turn to prime the conversation.
|
|
734
|
-
// LiteRT-LM's Conversation API handles chat template formatting,
|
|
735
|
-
// including Gemma's <start_of_turn>system block.
|
|
736
|
-
val systemMsg = LiteRTMessage.system(prompt)
|
|
737
|
-
conversation!!.sendMessage(message = systemMsg)
|
|
738
|
-
Log.i(TAG, "System prompt applied (${prompt.length} chars)")
|
|
739
|
-
} catch (e: Exception) {
|
|
740
|
-
Log.w(TAG, "Failed to apply system prompt: ${e.message}")
|
|
741
|
-
}
|
|
742
|
-
}
|
|
743
|
-
}
|
|
744
784
|
}
|
|
745
785
|
|
|
746
786
|
/**
|
|
@@ -815,5 +855,62 @@ class HybridLiteRTLM : HybridLiteRTLMSpec() {
|
|
|
815
855
|
createNewConversation()
|
|
816
856
|
}
|
|
817
857
|
|
|
858
|
+
override fun sendMultimodalMessage(parts: Array<MultimodalPart>): Promise<String> {
|
|
859
|
+
return Promise.parallel {
|
|
860
|
+
ensureLoaded()
|
|
861
|
+
val contents = mutableListOf<Content>()
|
|
862
|
+
var userTextRepresentation = ""
|
|
863
|
+
for (part in parts) {
|
|
864
|
+
when (part.type) {
|
|
865
|
+
PartType.TEXT -> part.text?.let {
|
|
866
|
+
contents.add(Content.Text(it))
|
|
867
|
+
userTextRepresentation += "$it "
|
|
868
|
+
}
|
|
869
|
+
PartType.IMAGE -> part.imageBuffer?.let { buffer ->
|
|
870
|
+
val byteBuffer = buffer.getBuffer(false)
|
|
871
|
+
val bytes = ByteArray(byteBuffer.remaining())
|
|
872
|
+
byteBuffer.get(bytes)
|
|
873
|
+
contents.add(Content.ImageBytes(bytes))
|
|
874
|
+
userTextRepresentation += "[Image Buffer] "
|
|
875
|
+
}
|
|
876
|
+
PartType.AUDIO -> part.audioBuffer?.let { buffer ->
|
|
877
|
+
val byteBuffer = buffer.getBuffer(false)
|
|
878
|
+
val bytes = ByteArray(byteBuffer.remaining())
|
|
879
|
+
byteBuffer.get(bytes)
|
|
880
|
+
contents.add(Content.AudioBytes(bytes))
|
|
881
|
+
userTextRepresentation += "[Audio Buffer] "
|
|
882
|
+
}
|
|
883
|
+
}
|
|
884
|
+
}
|
|
885
|
+
userTextRepresentation = userTextRepresentation.trim()
|
|
886
|
+
history.add(Message(Role.USER, userTextRepresentation))
|
|
887
|
+
|
|
888
|
+
val userMsg = LiteRTMessage.user(Contents.of(contents))
|
|
889
|
+
val startTime = System.nanoTime()
|
|
890
|
+
val responseMsg = conversation!!.sendMessage(message = userMsg)
|
|
891
|
+
val elapsedMs = (System.nanoTime() - startTime) / 1_000_000.0
|
|
892
|
+
|
|
893
|
+
val response = responseMsg.contents.contents
|
|
894
|
+
.filterIsInstance<Content.Text>()
|
|
895
|
+
.joinToString("") { it.text }
|
|
896
|
+
|
|
897
|
+
history.add(Message(Role.MODEL, response))
|
|
898
|
+
|
|
899
|
+
val promptTokens = userTextRepresentation.length / 4.0
|
|
900
|
+
val completionTokens = response.length / 4.0
|
|
901
|
+
lastStats = GenerationStats(
|
|
902
|
+
promptTokens = promptTokens,
|
|
903
|
+
completionTokens = completionTokens,
|
|
904
|
+
totalTokens = promptTokens + completionTokens,
|
|
905
|
+
timeToFirstToken = 0.0,
|
|
906
|
+
totalTime = elapsedMs,
|
|
907
|
+
tokensPerSecond = if (elapsedMs > 0) completionTokens / (elapsedMs / 1000.0) else 0.0
|
|
908
|
+
)
|
|
909
|
+
response
|
|
910
|
+
}
|
|
911
|
+
}
|
|
818
912
|
|
|
913
|
+
override fun countTokens(text: String): Double {
|
|
914
|
+
return -1.0
|
|
915
|
+
}
|
|
819
916
|
}
|
|
@@ -1,18 +1,35 @@
|
|
|
1
1
|
package dev.litert.litertlm
|
|
2
2
|
|
|
3
|
+
import android.os.Build
|
|
4
|
+
import android.util.Log
|
|
3
5
|
import com.facebook.react.TurboReactPackage
|
|
4
6
|
import com.facebook.react.bridge.NativeModule
|
|
5
7
|
import com.facebook.react.bridge.ReactApplicationContext
|
|
6
8
|
import com.facebook.react.module.model.ReactModuleInfo
|
|
7
9
|
import com.facebook.react.module.model.ReactModuleInfoProvider
|
|
8
|
-
import com.margelo.nitro.core.HybridObject
|
|
9
10
|
|
|
10
11
|
|
|
11
12
|
import com.margelo.nitro.dev.litert.litertlm.LiteRTLMOnLoad
|
|
12
13
|
|
|
13
14
|
class LiteRTLMPackage : TurboReactPackage() {
|
|
15
|
+
companion object {
|
|
16
|
+
private const val TAG = "LiteRTLMPackage"
|
|
17
|
+
|
|
18
|
+
private fun isSupportedPrimaryAbi(): Boolean {
|
|
19
|
+
val primaryAbi = Build.SUPPORTED_64_BIT_ABIS.firstOrNull() ?: return false
|
|
20
|
+
return primaryAbi == "arm64-v8a"
|
|
21
|
+
}
|
|
22
|
+
}
|
|
14
23
|
init {
|
|
15
|
-
|
|
24
|
+
if (!isSupportedPrimaryAbi()) {
|
|
25
|
+
Log.w(TAG, "Skipping LiteRTLM native init on unsupported primary ABI: ${Build.SUPPORTED_64_BIT_ABIS.firstOrNull()}")
|
|
26
|
+
} else {
|
|
27
|
+
try {
|
|
28
|
+
LiteRTLMOnLoad.initializeNative()
|
|
29
|
+
} catch (e: UnsatisfiedLinkError) {
|
|
30
|
+
Log.e(TAG, "LiteRTLM native init failed; disabling LiteRTLM for this process.", e)
|
|
31
|
+
}
|
|
32
|
+
}
|
|
16
33
|
}
|
|
17
34
|
|
|
18
35
|
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
package com.margelo.nitro.core
|
|
2
|
+
|
|
3
|
+
import androidx.annotation.Keep
|
|
4
|
+
import com.facebook.proguard.annotations.DoNotStrip
|
|
5
|
+
|
|
6
|
+
@Keep
|
|
7
|
+
@DoNotStrip
|
|
8
|
+
class Promise<T> {
|
|
9
|
+
companion object {
|
|
10
|
+
@JvmStatic
|
|
11
|
+
fun <T> parallel(block: () -> T): Promise<T> {
|
|
12
|
+
val promise = Promise<T>()
|
|
13
|
+
try {
|
|
14
|
+
val result = block()
|
|
15
|
+
promise.resolve(result)
|
|
16
|
+
} catch (e: Throwable) {
|
|
17
|
+
promise.reject(e)
|
|
18
|
+
}
|
|
19
|
+
return promise
|
|
20
|
+
}
|
|
21
|
+
}
|
|
22
|
+
|
|
23
|
+
var result: T? = null
|
|
24
|
+
private set
|
|
25
|
+
var error: Throwable? = null
|
|
26
|
+
private set
|
|
27
|
+
var isCompleted = false
|
|
28
|
+
private set
|
|
29
|
+
private val callbacks = mutableListOf<(T?, Throwable?) -> Unit>()
|
|
30
|
+
|
|
31
|
+
fun resolve(value: T) {
|
|
32
|
+
synchronized(this) {
|
|
33
|
+
result = value
|
|
34
|
+
isCompleted = true
|
|
35
|
+
callbacks.forEach { it(value, null) }
|
|
36
|
+
}
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
fun reject(exception: Throwable) {
|
|
40
|
+
synchronized(this) {
|
|
41
|
+
error = exception
|
|
42
|
+
isCompleted = true
|
|
43
|
+
callbacks.forEach { it(null, exception) }
|
|
44
|
+
}
|
|
45
|
+
}
|
|
46
|
+
}
|
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
package com.margelo.nitro.dev.litert.litertlm
|
|
2
|
+
|
|
3
|
+
import org.junit.Assert.*
|
|
4
|
+
import org.junit.Before
|
|
5
|
+
import org.junit.After
|
|
6
|
+
import org.junit.Test
|
|
7
|
+
import org.junit.runner.RunWith
|
|
8
|
+
import org.robolectric.RobolectricTestRunner
|
|
9
|
+
import org.robolectric.RuntimeEnvironment
|
|
10
|
+
import dev.litert.litertlm.LiteRTLMInitProvider
|
|
11
|
+
import java.lang.IllegalArgumentException
|
|
12
|
+
|
|
13
|
+
@RunWith(RobolectricTestRunner::class)
|
|
14
|
+
class HybridLiteRTLMTest {
|
|
15
|
+
private lateinit var bridge: HybridLiteRTLM
|
|
16
|
+
|
|
17
|
+
@Before
|
|
18
|
+
fun setUp() {
|
|
19
|
+
// Initialize the static applicationContext inside LiteRTLMInitProvider via reflection
|
|
20
|
+
try {
|
|
21
|
+
val field = LiteRTLMInitProvider::class.java.getDeclaredField("applicationContext")
|
|
22
|
+
field.isAccessible = true
|
|
23
|
+
field.set(null, RuntimeEnvironment.getApplication())
|
|
24
|
+
} catch (e: Exception) {
|
|
25
|
+
e.printStackTrace()
|
|
26
|
+
}
|
|
27
|
+
|
|
28
|
+
bridge = HybridLiteRTLM()
|
|
29
|
+
}
|
|
30
|
+
|
|
31
|
+
@After
|
|
32
|
+
fun tearDown() {
|
|
33
|
+
bridge.close()
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
@Test
|
|
37
|
+
fun testAndroidPathTraversalPrevention() {
|
|
38
|
+
val traversals = arrayOf("../secret", "/etc/hosts", "nested\\..\\file", "..", "../", "..\\")
|
|
39
|
+
for (traversal in traversals) {
|
|
40
|
+
val promise = bridge.deleteModel(traversal)
|
|
41
|
+
assertNotNull("Promise should not be null", promise)
|
|
42
|
+
assertTrue("Promise should be completed", promise.isCompleted)
|
|
43
|
+
assertNotNull("Promise should have rejected with an error for filename: $traversal", promise.error)
|
|
44
|
+
val error = promise.error!!
|
|
45
|
+
val errMsg = error.message ?: error.cause?.message ?: ""
|
|
46
|
+
assertTrue("Expected message to contain traversal warning, got: $errMsg",
|
|
47
|
+
errMsg.contains("path traversal or directory separators are not allowed"))
|
|
48
|
+
}
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
@Test
|
|
52
|
+
fun testAndroidHTTPSDownloadEnforcement() {
|
|
53
|
+
val promise = bridge.downloadModel("http://insecure.site/model.bin", "model.bin", null)
|
|
54
|
+
assertNotNull("Promise should not be null", promise)
|
|
55
|
+
assertTrue("Promise should be completed", promise.isCompleted)
|
|
56
|
+
assertNotNull("Promise should have rejected with an error", promise.error)
|
|
57
|
+
val error = promise.error!!
|
|
58
|
+
val errMsg = error.message ?: error.cause?.message ?: ""
|
|
59
|
+
assertTrue("Expected message to contain HTTPS warning, got: $errMsg",
|
|
60
|
+
errMsg.contains("HTTPS is required for security"))
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
@Test
|
|
64
|
+
fun testAndroidMemoryTelemetry() {
|
|
65
|
+
val mem = bridge.getMemoryUsage()
|
|
66
|
+
assertNotNull(mem)
|
|
67
|
+
assertTrue(mem.nativeHeapBytes >= 0.0)
|
|
68
|
+
assertTrue(mem.residentBytes >= 0.0)
|
|
69
|
+
assertTrue(mem.availableMemoryBytes >= 0.0)
|
|
70
|
+
}
|
|
71
|
+
|
|
72
|
+
@Test
|
|
73
|
+
fun testAndroidInitialStats() {
|
|
74
|
+
val stats = bridge.getStats()
|
|
75
|
+
assertNotNull(stats)
|
|
76
|
+
assertEquals(0.0, stats.promptTokens, 0.0)
|
|
77
|
+
assertEquals(0.0, stats.completionTokens, 0.0)
|
|
78
|
+
assertEquals(0.0, stats.totalTokens, 0.0)
|
|
79
|
+
assertEquals(0.0, stats.timeToFirstToken, 0.0)
|
|
80
|
+
assertEquals(0.0, stats.totalTime, 0.0)
|
|
81
|
+
assertEquals(0.0, stats.tokensPerSecond, 0.0)
|
|
82
|
+
}
|
|
83
|
+
}
|