react-native-litert-lm 0.3.7 → 0.4.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. package/README.md +153 -135
  2. package/android/build.gradle +12 -0
  3. package/android/src/main/AndroidManifest.xml +8 -0
  4. package/android/src/main/java/com/margelo/nitro/dev/litert/litertlm/HybridLiteRTLM.kt +276 -62
  5. package/android/src/main/java/dev/litert/litertlm/LiteRTLMPackage.kt +19 -2
  6. package/android/src/test/java/com/margelo/nitro/core/Promise.kt +46 -0
  7. package/android/src/test/java/com/margelo/nitro/dev/litert/litertlm/HybridLiteRTLMTest.kt +105 -0
  8. package/ios/HybridLiteRTLM.swift +1344 -0
  9. package/ios/Tests/HybridLiteRTLMTests.swift +113 -0
  10. package/lib/__mocks__/react-native-nitro-modules.d.ts +65 -0
  11. package/lib/__mocks__/react-native-nitro-modules.js +60 -0
  12. package/lib/__tests__/hooks.test.d.ts +1 -0
  13. package/lib/__tests__/hooks.test.js +124 -0
  14. package/lib/__tests__/memoryTracker.test.d.ts +1 -0
  15. package/lib/__tests__/memoryTracker.test.js +74 -0
  16. package/lib/__tests__/modelFactory.test.d.ts +1 -0
  17. package/lib/__tests__/modelFactory.test.js +68 -0
  18. package/lib/hooks.js +27 -3
  19. package/lib/index.d.ts +6 -2
  20. package/lib/index.js +8 -8
  21. package/lib/modelFactory.js +82 -63
  22. package/lib/specs/LiteRTLM.nitro.d.ts +87 -2
  23. package/nitrogen/generated/android/LiteRTLMOnLoad.cpp +2 -2
  24. package/nitrogen/generated/android/c++/JHybridLiteRTLMSpec.cpp +94 -9
  25. package/nitrogen/generated/android/c++/JHybridLiteRTLMSpec.hpp +5 -1
  26. package/nitrogen/generated/android/c++/JLLMConfig.hpp +40 -3
  27. package/nitrogen/generated/android/c++/JMultimodalPart.hpp +74 -0
  28. package/nitrogen/generated/android/c++/JPartType.hpp +61 -0
  29. package/nitrogen/generated/android/c++/JToolDefinition.hpp +65 -0
  30. package/nitrogen/generated/android/kotlin/com/margelo/nitro/dev/litert/litertlm/GenerationStats.kt +23 -0
  31. package/nitrogen/generated/android/kotlin/com/margelo/nitro/dev/litert/litertlm/HybridLiteRTLMSpec.kt +28 -2
  32. package/nitrogen/generated/android/kotlin/com/margelo/nitro/dev/litert/litertlm/LLMConfig.kt +46 -3
  33. package/nitrogen/generated/android/kotlin/com/margelo/nitro/dev/litert/litertlm/MemoryUsage.kt +19 -0
  34. package/nitrogen/generated/android/kotlin/com/margelo/nitro/dev/litert/litertlm/Message.kt +15 -0
  35. package/nitrogen/generated/android/kotlin/com/margelo/nitro/dev/litert/litertlm/MultimodalPart.kt +66 -0
  36. package/nitrogen/generated/android/kotlin/com/margelo/nitro/dev/litert/litertlm/PartType.kt +24 -0
  37. package/nitrogen/generated/android/kotlin/com/margelo/nitro/dev/litert/litertlm/ToolDefinition.kt +61 -0
  38. package/nitrogen/generated/ios/LiteRTLM-Swift-Cxx-Bridge.cpp +57 -1
  39. package/nitrogen/generated/ios/LiteRTLM-Swift-Cxx-Bridge.hpp +414 -3
  40. package/nitrogen/generated/ios/LiteRTLM-Swift-Cxx-Umbrella.hpp +41 -3
  41. package/nitrogen/generated/ios/LiteRTLMAutolinking.mm +4 -6
  42. package/nitrogen/generated/ios/LiteRTLMAutolinking.swift +10 -0
  43. package/nitrogen/generated/ios/c++/HybridLiteRTLMSpecSwift.cpp +11 -0
  44. package/nitrogen/generated/ios/c++/HybridLiteRTLMSpecSwift.hpp +240 -0
  45. package/nitrogen/generated/ios/swift/Backend.swift +44 -0
  46. package/nitrogen/generated/ios/swift/Func_void.swift +46 -0
  47. package/nitrogen/generated/ios/swift/Func_void_double.swift +46 -0
  48. package/nitrogen/generated/ios/swift/Func_void_std__exception_ptr.swift +46 -0
  49. package/nitrogen/generated/ios/swift/Func_void_std__string.swift +46 -0
  50. package/nitrogen/generated/ios/swift/Func_void_std__string_bool.swift +46 -0
  51. package/nitrogen/generated/ios/swift/GenerationStats.swift +54 -0
  52. package/nitrogen/generated/ios/swift/HybridLiteRTLMSpec.swift +71 -0
  53. package/nitrogen/generated/ios/swift/HybridLiteRTLMSpec_cxx.swift +431 -0
  54. package/nitrogen/generated/ios/swift/LLMConfig.swift +203 -0
  55. package/nitrogen/generated/ios/swift/MemoryUsage.swift +44 -0
  56. package/nitrogen/generated/ios/swift/Message.swift +34 -0
  57. package/nitrogen/generated/ios/swift/MultimodalPart.swift +83 -0
  58. package/nitrogen/generated/ios/swift/PartType.swift +44 -0
  59. package/nitrogen/generated/ios/swift/Role.swift +44 -0
  60. package/nitrogen/generated/ios/swift/ToolDefinition.swift +39 -0
  61. package/nitrogen/generated/shared/c++/HybridLiteRTLMSpec.cpp +4 -0
  62. package/nitrogen/generated/shared/c++/HybridLiteRTLMSpec.hpp +9 -2
  63. package/nitrogen/generated/shared/c++/LLMConfig.hpp +22 -2
  64. package/nitrogen/generated/shared/c++/MultimodalPart.hpp +99 -0
  65. package/nitrogen/generated/shared/c++/PartType.hpp +80 -0
  66. package/nitrogen/generated/shared/c++/ToolDefinition.hpp +91 -0
  67. package/package.json +22 -11
  68. package/react-native-litert-lm.podspec +17 -19
  69. package/scripts/download-ios-frameworks.sh +17 -50
  70. package/scripts/framework-source.js +46 -0
  71. package/scripts/postinstall.js +40 -18
  72. package/src/__mocks__/react-native-nitro-modules.ts +58 -0
  73. package/src/__tests__/hooks.test.ts +153 -0
  74. package/src/__tests__/memoryTracker.test.ts +87 -0
  75. package/src/__tests__/modelFactory.test.ts +96 -0
  76. package/src/hooks.ts +29 -7
  77. package/src/index.ts +7 -10
  78. package/src/modelFactory.ts +104 -80
  79. package/src/specs/LiteRTLM.nitro.ts +106 -2
  80. package/cpp/HybridLiteRTLM.cpp +0 -939
  81. package/cpp/HybridLiteRTLM.hpp +0 -169
  82. package/cpp/IOSDownloadHelper.h +0 -24
  83. package/ios/IOSDownloadHelper.mm +0 -129
  84. package/scripts/build-ios-engine.sh +0 -302
  85. package/scripts/stubs/cxx_bridge_stubs.cc +0 -224
  86. package/scripts/stubs/gemma_model_constraint_provider.cc +0 -46
  87. package/scripts/stubs/llguidance_stubs.c +0 -101
  88. package/src/templates.ts +0 -105
@@ -27,6 +27,10 @@ import com.margelo.nitro.dev.litert.litertlm.Role
27
27
  import com.margelo.nitro.core.Promise
28
28
  import com.google.ai.edge.litertlm.Content
29
29
  import com.google.ai.edge.litertlm.Contents
30
+ import com.google.ai.edge.litertlm.ExperimentalApi
31
+ import com.google.ai.edge.litertlm.ExperimentalFlags
32
+ import com.google.ai.edge.litertlm.OpenApiTool
33
+ import com.google.ai.edge.litertlm.ToolProvider
30
34
  import java.util.concurrent.CountDownLatch
31
35
  import java.util.concurrent.TimeUnit
32
36
  import java.util.concurrent.atomic.AtomicBoolean
@@ -167,6 +171,8 @@ class HybridLiteRTLM : HybridLiteRTLMSpec() {
167
171
  private var topP: Double = 0.95
168
172
  private var maxTokens: Int = 1024
169
173
  private var systemPrompt: String? = null
174
+ private var tools: Array<ToolDefinition>? = null
175
+ private var enableSpeculativeDecoding: Boolean = false
170
176
 
171
177
  override val memorySize: Long
172
178
  get() = 1024L * 1024L * 1024L // ~1GB (models are large)
@@ -196,8 +202,13 @@ class HybridLiteRTLM : HybridLiteRTLMSpec() {
196
202
  cfg.topP?.let { topP = it }
197
203
  cfg.maxTokens?.let { maxTokens = it.toInt() }
198
204
  cfg.systemPrompt?.let { systemPrompt = it }
205
+ tools = cfg.tools
206
+ enableSpeculativeDecoding = cfg.enableSpeculativeDecoding ?: false
199
207
  }
200
208
 
209
+ // Whether to run engine validation after loading
210
+ val shouldValidate = config?.validate?: false
211
+
201
212
  try {
202
213
  // Early GPU hardware check: probe for OpenCL library before
203
214
  // spending time on engine creation. LiteRT-LM's GPU delegate
@@ -238,12 +249,12 @@ class HybridLiteRTLM : HybridLiteRTLMSpec() {
238
249
  else -> com.google.ai.edge.litertlm.Backend.CPU()
239
250
  }
240
251
 
241
- // Detect multimodal support from model filename.
252
+ // Detect multimodal support. Check config.multimodal flag first, then fall back to filename sniffing.
242
253
  // Only Gemma 3n bundles vision/audio executors; Gemma 4 E2B is text-only.
243
254
  // Passing vision/audio backends to a text-only model causes
244
255
  // vision_litert_compiled_model_executor init failures.
245
256
  val modelFileName = modelPath.substringAfterLast("/").lowercase()
246
- val isMultimodal = modelFileName.contains("3n") || modelFileName.contains("gemma3")
257
+ val isMultimodal = config?.multimodal ?: (modelFileName.contains("3n") || modelFileName.contains("gemma3"))
247
258
 
248
259
  val lmVisionBackend = if (isMultimodal) com.google.ai.edge.litertlm.Backend.GPU() else null
249
260
  val lmAudioBackend = if (isMultimodal) com.google.ai.edge.litertlm.Backend.CPU() else null
@@ -275,6 +286,11 @@ class HybridLiteRTLM : HybridLiteRTLMSpec() {
275
286
 
276
287
  if (isClosed) return@synchronized
277
288
 
289
+ if (enableSpeculativeDecoding) {
290
+ @OptIn(ExperimentalApi::class)
291
+ ExperimentalFlags.enableSpeculativeDecoding = true
292
+ }
293
+
278
294
  // Initialize Engine
279
295
  engine = Engine(engineConfig).also { it.initialize() }
280
296
  Log.i(TAG, "Engine created and initialized successfully")
@@ -284,8 +300,17 @@ class HybridLiteRTLM : HybridLiteRTLMSpec() {
284
300
  Log.i(TAG, "Conversation created successfully")
285
301
 
286
302
  // Validate the engine actually works with a quick test inference.
287
- // GPU backend can initialize without error but silently fail to produce tokens.
288
- validateEngine()
303
+ // GPU/NPU backends can initialize without error but silently fail to
304
+ // produce tokens — enabling this catches those failures at load time.
305
+ // CPU is always reliable so validation is never run on it, even when
306
+ // the `validate` flag is set.
307
+ if (shouldValidate) {
308
+ if (backend == Backend.GPU || backend == Backend.NPU) {
309
+ validateEngine()
310
+ } else {
311
+ Log.i(TAG, "Validation skipped: CPU backend is always reliable")
312
+ }
313
+ }
289
314
 
290
315
  } catch (e: Exception) {
291
316
  Log.e(TAG, "Failed to load model: ${e.message}", e)
@@ -349,32 +374,48 @@ class HybridLiteRTLM : HybridLiteRTLMSpec() {
349
374
  // -------------------------------------------------------------------------
350
375
  // sendMessageAsync - Streaming inference
351
376
  // -------------------------------------------------------------------------
352
- override fun sendMessageAsync(message: String, onToken: (String, Boolean) -> Unit) {
353
- // This is already async (void return), so we execute immediately on the calling thread
354
- // (which is the Nitro specialized thread, not Main).
355
- // The SDK's sendMessageAsync is non-blocking anyway.
356
- ensureLoaded()
377
+ override fun sendMessageAsync(message: String, onToken: (String, Boolean) -> Unit): Promise<Unit> {
378
+ return Promise.parallel {
379
+ val latch = CountDownLatch(1)
380
+ val errorRef = AtomicReference<Throwable?>(null)
357
381
 
358
- // Add user message to history
359
- history.add(Message(Role.USER, message))
360
- Log.d(TAG, "sendMessageAsync: $message")
382
+ ensureLoaded()
361
383
 
362
- val fullResponseBuilder = StringBuilder()
363
-
364
- val listener = StreamingCallbackListener(
365
- onToken = onToken,
366
- responseBuilder = fullResponseBuilder,
367
- history = history,
368
- userMessage = message,
369
- onStatsReady = { stats -> lastStats = stats },
370
- )
384
+ // Add user message to history
385
+ history.add(Message(Role.USER, message))
386
+ Log.d(TAG, "sendMessageAsync: $message")
371
387
 
372
- try {
373
- val userMsg = LiteRTMessage.user(message)
374
- conversation!!.sendMessageAsync(message = userMsg, callback = listener)
375
- } catch (e: Exception) {
376
- Log.e(TAG, "Failed to initiate async generation", e)
377
- onToken("Error: ${e.message}", true)
388
+ val fullResponseBuilder = StringBuilder()
389
+
390
+ val listener = StreamingCallbackListener(
391
+ onToken = { token, done ->
392
+ onToken(token, done)
393
+ if (done) {
394
+ latch.countDown()
395
+ }
396
+ },
397
+ responseBuilder = fullResponseBuilder,
398
+ history = history,
399
+ userMessage = message,
400
+ onStatsReady = { stats -> lastStats = stats },
401
+ )
402
+
403
+ try {
404
+ val userMsg = LiteRTMessage.user(message)
405
+ conversation!!.sendMessageAsync(message = userMsg, callback = listener)
406
+ } catch (e: Exception) {
407
+ Log.e(TAG, "Failed to initiate async generation", e)
408
+ errorRef.set(e)
409
+ onToken("Error: ${e.message}", true)
410
+ latch.countDown()
411
+ }
412
+
413
+ // Wait for completion or error
414
+ latch.await()
415
+ val err = errorRef.get()
416
+ if (err != null) {
417
+ throw RuntimeException("Async inference failed: ${err.message}", err)
418
+ }
378
419
  }
379
420
  }
380
421
 
@@ -460,10 +501,87 @@ class HybridLiteRTLM : HybridLiteRTLMSpec() {
460
501
  }
461
502
  }
462
503
 
504
+ override fun sendMessageWithImageAsync(message: String, imagePath: String, onToken: (String, Boolean) -> Unit): Promise<Unit> {
505
+ return Promise.parallel {
506
+ val latch = CountDownLatch(1)
507
+ val errorRef = AtomicReference<Throwable?>(null)
508
+
509
+ ensureLoaded()
510
+
511
+ Log.i(TAG, "sendMessageWithImageAsync: $message, path=$imagePath")
512
+
513
+ // Resize image to prevent OOM on high-resolution photos
514
+ val processedImagePath = resizeImageIfNeeded(imagePath)
515
+
516
+ val fullResponseBuilder = StringBuilder()
517
+
518
+ val listener = StreamingCallbackListener(
519
+ onToken = { token, done ->
520
+ onToken(token, done)
521
+ if (done) {
522
+ latch.countDown()
523
+ }
524
+ },
525
+ responseBuilder = fullResponseBuilder,
526
+ history = history,
527
+ userMessage = message,
528
+ onStatsReady = { stats -> lastStats = stats },
529
+ )
530
+
531
+ try {
532
+ val textContent = Content.Text(message)
533
+ val userMsg = LiteRTMessage.user(Contents.of(textContent, Content.ImageFile(processedImagePath)))
534
+
535
+ history.add(Message(Role.USER, "$message [Image]"))
536
+
537
+ conversation!!.sendMessageAsync(message = userMsg, callback = listener)
538
+ } catch (e: Exception) {
539
+ // Clean up temp resized image to prevent cache dir bloat
540
+ if (processedImagePath != imagePath) {
541
+ try {
542
+ java.io.File(processedImagePath).delete()
543
+ } catch (e: Exception) {
544
+ Log.w(TAG, "Failed to clean up temp image: ${e.message}")
545
+ }
546
+ }
547
+
548
+ Log.e(TAG, "Failed to initiate async multimodal generation", e)
549
+ errorRef.set(e)
550
+ onToken("Error: ${e.message}", true)
551
+ latch.countDown()
552
+ }
553
+
554
+ // Wait for completion or error
555
+ latch.await()
556
+
557
+ // Clean up temp resized image to prevent cache dir bloat
558
+ if (processedImagePath != imagePath) {
559
+ try {
560
+ java.io.File(processedImagePath).delete()
561
+ } catch (e: Exception) {
562
+ Log.w(TAG, "Failed to clean up temp image: ${e.message}")
563
+ }
564
+ }
565
+
566
+ val err = errorRef.get()
567
+ if (err != null) {
568
+ throw RuntimeException("Async multimodal inference failed: ${err.message}", err)
569
+ }
570
+ }
571
+ }
572
+
463
573
  override fun downloadModel(url: String, fileName: String, onProgress: ((Double) -> Unit)?): Promise<String> {
464
574
  return Promise.parallel {
465
575
  Log.i(TAG, "downloadModel: $url -> $fileName")
466
576
 
577
+ if (!url.startsWith("https://", ignoreCase = true)) {
578
+ throw IllegalArgumentException("Invalid download URL: HTTPS is required for security.")
579
+ }
580
+
581
+ if (fileName.contains("..") || fileName.contains("/") || fileName.contains("\\")) {
582
+ throw IllegalArgumentException("Invalid filename: path traversal or directory separators are not allowed.")
583
+ }
584
+
467
585
  val context = LiteRTLMInitProvider.applicationContext ?: throw RuntimeException("Context not available")
468
586
  val modelsDir = java.io.File(context.filesDir, "models")
469
587
  if (!modelsDir.exists()) {
@@ -545,6 +663,11 @@ class HybridLiteRTLM : HybridLiteRTLMSpec() {
545
663
  override fun deleteModel(fileName: String): Promise<Unit> {
546
664
  return Promise.parallel {
547
665
  Log.i(TAG, "deleteModel: $fileName")
666
+
667
+ if (fileName.contains("..") || fileName.contains("/") || fileName.contains("\\")) {
668
+ throw IllegalArgumentException("Invalid filename: path traversal or directory separators are not allowed.")
669
+ }
670
+
548
671
  val context = LiteRTLMInitProvider.applicationContext ?: throw RuntimeException("Context not available")
549
672
  val modelsDir = java.io.File(context.filesDir, "models")
550
673
  val modelFile = java.io.File(modelsDir, fileName)
@@ -569,6 +692,54 @@ class HybridLiteRTLM : HybridLiteRTLMSpec() {
569
692
  }
570
693
  }
571
694
 
695
+ override fun sendMessageWithAudioAsync(message: String, audioPath: String, onToken: (String, Boolean) -> Unit): Promise<Unit> {
696
+ return Promise.parallel {
697
+ val latch = CountDownLatch(1)
698
+ val errorRef = AtomicReference<Throwable?>(null)
699
+
700
+ ensureLoaded()
701
+
702
+ Log.i(TAG, "sendMessageWithAudioAsync: $message, path=$audioPath")
703
+
704
+ val fullResponseBuilder = StringBuilder()
705
+
706
+ val listener = StreamingCallbackListener(
707
+ onToken = { token, done ->
708
+ onToken(token, done)
709
+ if (done) {
710
+ latch.countDown()
711
+ }
712
+ },
713
+ responseBuilder = fullResponseBuilder,
714
+ history = history,
715
+ userMessage = message,
716
+ onStatsReady = { stats -> lastStats = stats },
717
+ )
718
+
719
+ try {
720
+ val userMsg = LiteRTMessage.user(Contents.of(
721
+ Content.Text(message),
722
+ Content.AudioFile(audioPath)
723
+ ))
724
+
725
+ history.add(Message(Role.USER, "$message [Audio]"))
726
+
727
+ conversation!!.sendMessageAsync(message = userMsg, callback = listener)
728
+ } catch (e: Exception) {
729
+ Log.e(TAG, "Failed to initiate async audio generation", e)
730
+ errorRef.set(e)
731
+ onToken("Error: ${e.message}", true)
732
+ latch.countDown()
733
+ }
734
+
735
+ latch.await()
736
+ val err = errorRef.get()
737
+ if (err != null) {
738
+ throw RuntimeException("Async audio inference failed: ${err.message}", err)
739
+ }
740
+ }
741
+ }
742
+
572
743
  override fun sendMessageWithAudio(message: String, audioPath: String): Promise<String> {
573
744
  return Promise.parallel {
574
745
  ensureLoaded()
@@ -676,19 +847,9 @@ class HybridLiteRTLM : HybridLiteRTLMSpec() {
676
847
 
677
848
  private fun cleanupInternal() {
678
849
  try {
850
+ conversation?.close()
679
851
  conversation = null
680
- // Explicitly close engine if it supports it to free native memory immediately
681
- // Assuming Engine implements AutoCloseable or has close()
682
- if (engine is AutoCloseable) {
683
- (engine as AutoCloseable).close()
684
- } else {
685
- // Try reflection or just null it if no close method
686
- try {
687
- engine?.javaClass?.getMethod("close")?.invoke(engine)
688
- } catch (e: Exception) {
689
- // Method not found, rely on GC
690
- }
691
- }
852
+ engine?.close() // Direct call
692
853
  engine = null
693
854
  } catch (e: Exception) {
694
855
  Log.e(TAG, "Error closing resources", e)
@@ -706,41 +867,37 @@ class HybridLiteRTLM : HybridLiteRTLMSpec() {
706
867
  // v0.10.2 enforces single-session: close existing conversation first
707
868
  conversation?.let { oldConv ->
708
869
  try {
709
- if (oldConv is AutoCloseable) {
710
- oldConv.close()
711
- } else {
712
- oldConv.javaClass.getMethod("close").invoke(oldConv)
713
- }
870
+ oldConv.close()
714
871
  } catch (e: Exception) {
715
872
  Log.w(TAG, "Failed to close old conversation: ${e.message}")
716
873
  }
717
874
  conversation = null
718
875
  }
876
+ // Map tools
877
+ val lmTools: List<ToolProvider>? = tools?.map { tool ->
878
+ val apiTool = object : OpenApiTool {
879
+ override fun getToolDescriptionJsonString(): String {
880
+ return tool.parametersJson
881
+ }
882
+ override fun execute(paramsJsonString: String): String {
883
+ return "{}"
884
+ }
885
+ }
886
+ (apiTool as Any) as ToolProvider
887
+ }
888
+
719
889
  // Create conversation with explicit SamplerConfig (required by Gallery pattern).
720
890
  // GPU backend may fail silently without proper sampler params.
721
891
  val convConfig = ConversationConfig(
722
892
  samplerConfig = SamplerConfig(
723
893
  topK = topK,
724
- topP = topP,
725
- temperature = temperature,
726
- )
894
+ topP = topP.toDouble(),
895
+ temperature = temperature.toDouble(),
896
+ ),
897
+ systemInstruction = systemPrompt?.let { Contents.of(Content.Text(it)) },
898
+ tools = lmTools ?: emptyList()
727
899
  )
728
900
  conversation = engine!!.createConversation(convConfig)
729
- // Apply system prompt/instruction if set
730
- systemPrompt?.let { prompt ->
731
- if (prompt.isNotEmpty()) {
732
- try {
733
- // Send system instruction as the first turn to prime the conversation.
734
- // LiteRT-LM's Conversation API handles chat template formatting,
735
- // including Gemma's <start_of_turn>system block.
736
- val systemMsg = LiteRTMessage.system(prompt)
737
- conversation!!.sendMessage(message = systemMsg)
738
- Log.i(TAG, "System prompt applied (${prompt.length} chars)")
739
- } catch (e: Exception) {
740
- Log.w(TAG, "Failed to apply system prompt: ${e.message}")
741
- }
742
- }
743
- }
744
901
  }
745
902
 
746
903
  /**
@@ -815,5 +972,62 @@ class HybridLiteRTLM : HybridLiteRTLMSpec() {
815
972
  createNewConversation()
816
973
  }
817
974
 
975
+ override fun sendMultimodalMessage(parts: Array<MultimodalPart>): Promise<String> {
976
+ return Promise.parallel {
977
+ ensureLoaded()
978
+ val contents = mutableListOf<Content>()
979
+ var userTextRepresentation = ""
980
+ for (part in parts) {
981
+ when (part.type) {
982
+ PartType.TEXT -> part.text?.let {
983
+ contents.add(Content.Text(it))
984
+ userTextRepresentation += "$it "
985
+ }
986
+ PartType.IMAGE -> part.imageBuffer?.let { buffer ->
987
+ val byteBuffer = buffer.getBuffer(false)
988
+ val bytes = ByteArray(byteBuffer.remaining())
989
+ byteBuffer.get(bytes)
990
+ contents.add(Content.ImageBytes(bytes))
991
+ userTextRepresentation += "[Image Buffer] "
992
+ }
993
+ PartType.AUDIO -> part.audioBuffer?.let { buffer ->
994
+ val byteBuffer = buffer.getBuffer(false)
995
+ val bytes = ByteArray(byteBuffer.remaining())
996
+ byteBuffer.get(bytes)
997
+ contents.add(Content.AudioBytes(bytes))
998
+ userTextRepresentation += "[Audio Buffer] "
999
+ }
1000
+ }
1001
+ }
1002
+ userTextRepresentation = userTextRepresentation.trim()
1003
+ history.add(Message(Role.USER, userTextRepresentation))
1004
+
1005
+ val userMsg = LiteRTMessage.user(Contents.of(contents))
1006
+ val startTime = System.nanoTime()
1007
+ val responseMsg = conversation!!.sendMessage(message = userMsg)
1008
+ val elapsedMs = (System.nanoTime() - startTime) / 1_000_000.0
1009
+
1010
+ val response = responseMsg.contents.contents
1011
+ .filterIsInstance<Content.Text>()
1012
+ .joinToString("") { it.text }
1013
+
1014
+ history.add(Message(Role.MODEL, response))
1015
+
1016
+ val promptTokens = userTextRepresentation.length / 4.0
1017
+ val completionTokens = response.length / 4.0
1018
+ lastStats = GenerationStats(
1019
+ promptTokens = promptTokens,
1020
+ completionTokens = completionTokens,
1021
+ totalTokens = promptTokens + completionTokens,
1022
+ timeToFirstToken = 0.0,
1023
+ totalTime = elapsedMs,
1024
+ tokensPerSecond = if (elapsedMs > 0) completionTokens / (elapsedMs / 1000.0) else 0.0
1025
+ )
1026
+ response
1027
+ }
1028
+ }
818
1029
 
1030
+ override fun countTokens(text: String): Double {
1031
+ return -1.0
1032
+ }
819
1033
  }
@@ -1,18 +1,35 @@
1
1
  package dev.litert.litertlm
2
2
 
3
+ import android.os.Build
4
+ import android.util.Log
3
5
  import com.facebook.react.TurboReactPackage
4
6
  import com.facebook.react.bridge.NativeModule
5
7
  import com.facebook.react.bridge.ReactApplicationContext
6
8
  import com.facebook.react.module.model.ReactModuleInfo
7
9
  import com.facebook.react.module.model.ReactModuleInfoProvider
8
- import com.margelo.nitro.core.HybridObject
9
10
 
10
11
 
11
12
  import com.margelo.nitro.dev.litert.litertlm.LiteRTLMOnLoad
12
13
 
13
14
  class LiteRTLMPackage : TurboReactPackage() {
15
+ companion object {
16
+ private const val TAG = "LiteRTLMPackage"
17
+
18
+ private fun isSupportedPrimaryAbi(): Boolean {
19
+ val primaryAbi = Build.SUPPORTED_64_BIT_ABIS.firstOrNull() ?: return false
20
+ return primaryAbi == "arm64-v8a"
21
+ }
22
+ }
14
23
  init {
15
- LiteRTLMOnLoad.initializeNative()
24
+ if (!isSupportedPrimaryAbi()) {
25
+ Log.w(TAG, "Skipping LiteRTLM native init on unsupported primary ABI: ${Build.SUPPORTED_64_BIT_ABIS.firstOrNull()}")
26
+ } else {
27
+ try {
28
+ LiteRTLMOnLoad.initializeNative()
29
+ } catch (e: UnsatisfiedLinkError) {
30
+ Log.e(TAG, "LiteRTLM native init failed; disabling LiteRTLM for this process.", e)
31
+ }
32
+ }
16
33
  }
17
34
 
18
35
 
@@ -0,0 +1,46 @@
1
+ package com.margelo.nitro.core
2
+
3
+ import androidx.annotation.Keep
4
+ import com.facebook.proguard.annotations.DoNotStrip
5
+
6
+ @Keep
7
+ @DoNotStrip
8
+ class Promise<T> {
9
+ companion object {
10
+ @JvmStatic
11
+ fun <T> parallel(block: () -> T): Promise<T> {
12
+ val promise = Promise<T>()
13
+ try {
14
+ val result = block()
15
+ promise.resolve(result)
16
+ } catch (e: Throwable) {
17
+ promise.reject(e)
18
+ }
19
+ return promise
20
+ }
21
+ }
22
+
23
+ var result: T? = null
24
+ private set
25
+ var error: Throwable? = null
26
+ private set
27
+ var isCompleted = false
28
+ private set
29
+ private val callbacks = mutableListOf<(T?, Throwable?) -> Unit>()
30
+
31
+ fun resolve(value: T) {
32
+ synchronized(this) {
33
+ result = value
34
+ isCompleted = true
35
+ callbacks.forEach { it(value, null) }
36
+ }
37
+ }
38
+
39
+ fun reject(exception: Throwable) {
40
+ synchronized(this) {
41
+ error = exception
42
+ isCompleted = true
43
+ callbacks.forEach { it(null, exception) }
44
+ }
45
+ }
46
+ }
@@ -0,0 +1,105 @@
1
+ package com.margelo.nitro.dev.litert.litertlm
2
+
3
+ import org.junit.Assert.*
4
+ import org.junit.Before
5
+ import org.junit.After
6
+ import org.junit.Test
7
+ import org.junit.runner.RunWith
8
+ import org.robolectric.RobolectricTestRunner
9
+ import org.robolectric.RuntimeEnvironment
10
+ import dev.litert.litertlm.LiteRTLMInitProvider
11
+ import java.lang.IllegalArgumentException
12
+
13
+ @RunWith(RobolectricTestRunner::class)
14
+ class HybridLiteRTLMTest {
15
+ private lateinit var bridge: HybridLiteRTLM
16
+
17
+ @Before
18
+ fun setUp() {
19
+ // Initialize the static applicationContext inside LiteRTLMInitProvider via reflection
20
+ try {
21
+ val field = LiteRTLMInitProvider::class.java.getDeclaredField("applicationContext")
22
+ field.isAccessible = true
23
+ field.set(null, RuntimeEnvironment.getApplication())
24
+ } catch (e: Exception) {
25
+ e.printStackTrace()
26
+ }
27
+
28
+ bridge = HybridLiteRTLM()
29
+ }
30
+
31
+ @After
32
+ fun tearDown() {
33
+ bridge.close()
34
+ }
35
+
36
+ @Test
37
+ fun testAndroidPathTraversalPrevention() {
38
+ val traversals = arrayOf("../secret", "/etc/hosts", "nested\\..\\file", "..", "../", "..\\")
39
+ for (traversal in traversals) {
40
+ val promise = bridge.deleteModel(traversal)
41
+ assertNotNull("Promise should not be null", promise)
42
+ assertTrue("Promise should be completed", promise.isCompleted)
43
+ assertNotNull("Promise should have rejected with an error for filename: $traversal", promise.error)
44
+ val error = promise.error!!
45
+ val errMsg = error.message ?: error.cause?.message ?: ""
46
+ assertTrue("Expected message to contain traversal warning, got: $errMsg",
47
+ errMsg.contains("path traversal or directory separators are not allowed"))
48
+ }
49
+ }
50
+
51
+ @Test
52
+ fun testAndroidHTTPSDownloadEnforcement() {
53
+ val promise = bridge.downloadModel("http://insecure.site/model.bin", "model.bin", null)
54
+ assertNotNull("Promise should not be null", promise)
55
+ assertTrue("Promise should be completed", promise.isCompleted)
56
+ assertNotNull("Promise should have rejected with an error", promise.error)
57
+ val error = promise.error!!
58
+ val errMsg = error.message ?: error.cause?.message ?: ""
59
+ assertTrue("Expected message to contain HTTPS warning, got: $errMsg",
60
+ errMsg.contains("HTTPS is required for security"))
61
+ }
62
+
63
+ @Test
64
+ fun testAndroidMemoryTelemetry() {
65
+ val mem = bridge.getMemoryUsage()
66
+ assertNotNull(mem)
67
+ assertTrue(mem.nativeHeapBytes >= 0.0)
68
+ assertTrue(mem.residentBytes >= 0.0)
69
+ assertTrue(mem.availableMemoryBytes >= 0.0)
70
+ }
71
+
72
+ @Test
73
+ fun testSendMessageWithImageAsyncRejectsWithoutModel() {
74
+ val promise = bridge.sendMessageWithImageAsync("hello", "/tmp/image.jpg") { _, _ -> }
75
+ assertNotNull("Promise should not be null", promise)
76
+ assertTrue("Promise should be completed", promise.isCompleted)
77
+ assertNotNull("Promise should have rejected without model", promise.error)
78
+ val errMsg = promise.error!!.message ?: promise.error!!.cause?.message ?: ""
79
+ assertTrue("Expected no-model error, got: $errMsg",
80
+ errMsg.contains("No model loaded"))
81
+ }
82
+
83
+ @Test
84
+ fun testSendMessageWithAudioAsyncRejectsWithoutModel() {
85
+ val promise = bridge.sendMessageWithAudioAsync("hello", "/tmp/audio.wav") { _, _ -> }
86
+ assertNotNull("Promise should not be null", promise)
87
+ assertTrue("Promise should be completed", promise.isCompleted)
88
+ assertNotNull("Promise should have rejected without model", promise.error)
89
+ val errMsg = promise.error!!.message ?: promise.error!!.cause?.message ?: ""
90
+ assertTrue("Expected no-model error, got: $errMsg",
91
+ errMsg.contains("No model loaded"))
92
+ }
93
+
94
+ @Test
95
+ fun testAndroidInitialStats() {
96
+ val stats = bridge.getStats()
97
+ assertNotNull(stats)
98
+ assertEquals(0.0, stats.promptTokens, 0.0)
99
+ assertEquals(0.0, stats.completionTokens, 0.0)
100
+ assertEquals(0.0, stats.totalTokens, 0.0)
101
+ assertEquals(0.0, stats.timeToFirstToken, 0.0)
102
+ assertEquals(0.0, stats.totalTime, 0.0)
103
+ assertEquals(0.0, stats.tokensPerSecond, 0.0)
104
+ }
105
+ }