react-native-litert-lm 0.3.7 → 0.4.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (86) hide show
  1. package/README.md +153 -135
  2. package/android/build.gradle +12 -0
  3. package/android/src/main/AndroidManifest.xml +5 -0
  4. package/android/src/main/java/com/margelo/nitro/dev/litert/litertlm/HybridLiteRTLM.kt +159 -62
  5. package/android/src/main/java/dev/litert/litertlm/LiteRTLMPackage.kt +19 -2
  6. package/android/src/test/java/com/margelo/nitro/core/Promise.kt +46 -0
  7. package/android/src/test/java/com/margelo/nitro/dev/litert/litertlm/HybridLiteRTLMTest.kt +83 -0
  8. package/ios/HybridLiteRTLM.swift +1058 -0
  9. package/ios/Tests/HybridLiteRTLMTests.swift +67 -0
  10. package/lib/__mocks__/react-native-nitro-modules.d.ts +61 -0
  11. package/lib/__mocks__/react-native-nitro-modules.js +50 -0
  12. package/lib/__tests__/hooks.test.d.ts +1 -0
  13. package/lib/__tests__/hooks.test.js +124 -0
  14. package/lib/__tests__/memoryTracker.test.d.ts +1 -0
  15. package/lib/__tests__/memoryTracker.test.js +74 -0
  16. package/lib/__tests__/modelFactory.test.d.ts +1 -0
  17. package/lib/__tests__/modelFactory.test.js +52 -0
  18. package/lib/hooks.js +1 -1
  19. package/lib/index.d.ts +0 -2
  20. package/lib/index.js +1 -5
  21. package/lib/modelFactory.js +62 -63
  22. package/lib/specs/LiteRTLM.nitro.d.ts +71 -2
  23. package/nitrogen/generated/android/c++/JHybridLiteRTLMSpec.cpp +62 -7
  24. package/nitrogen/generated/android/c++/JHybridLiteRTLMSpec.hpp +3 -1
  25. package/nitrogen/generated/android/c++/JLLMConfig.hpp +40 -3
  26. package/nitrogen/generated/android/c++/JMultimodalPart.hpp +74 -0
  27. package/nitrogen/generated/android/c++/JPartType.hpp +61 -0
  28. package/nitrogen/generated/android/c++/JToolDefinition.hpp +65 -0
  29. package/nitrogen/generated/android/kotlin/com/margelo/nitro/dev/litert/litertlm/GenerationStats.kt +23 -0
  30. package/nitrogen/generated/android/kotlin/com/margelo/nitro/dev/litert/litertlm/HybridLiteRTLMSpec.kt +10 -2
  31. package/nitrogen/generated/android/kotlin/com/margelo/nitro/dev/litert/litertlm/LLMConfig.kt +46 -3
  32. package/nitrogen/generated/android/kotlin/com/margelo/nitro/dev/litert/litertlm/MemoryUsage.kt +19 -0
  33. package/nitrogen/generated/android/kotlin/com/margelo/nitro/dev/litert/litertlm/Message.kt +15 -0
  34. package/nitrogen/generated/android/kotlin/com/margelo/nitro/dev/litert/litertlm/MultimodalPart.kt +66 -0
  35. package/nitrogen/generated/android/kotlin/com/margelo/nitro/dev/litert/litertlm/PartType.kt +24 -0
  36. package/nitrogen/generated/android/kotlin/com/margelo/nitro/dev/litert/litertlm/ToolDefinition.kt +61 -0
  37. package/nitrogen/generated/ios/LiteRTLM-Swift-Cxx-Bridge.cpp +57 -1
  38. package/nitrogen/generated/ios/LiteRTLM-Swift-Cxx-Bridge.hpp +414 -3
  39. package/nitrogen/generated/ios/LiteRTLM-Swift-Cxx-Umbrella.hpp +41 -3
  40. package/nitrogen/generated/ios/LiteRTLMAutolinking.mm +4 -6
  41. package/nitrogen/generated/ios/LiteRTLMAutolinking.swift +10 -0
  42. package/nitrogen/generated/ios/c++/HybridLiteRTLMSpecSwift.cpp +11 -0
  43. package/nitrogen/generated/ios/c++/HybridLiteRTLMSpecSwift.hpp +224 -0
  44. package/nitrogen/generated/ios/swift/Backend.swift +44 -0
  45. package/nitrogen/generated/ios/swift/Func_void.swift +46 -0
  46. package/nitrogen/generated/ios/swift/Func_void_double.swift +46 -0
  47. package/nitrogen/generated/ios/swift/Func_void_std__exception_ptr.swift +46 -0
  48. package/nitrogen/generated/ios/swift/Func_void_std__string.swift +46 -0
  49. package/nitrogen/generated/ios/swift/Func_void_std__string_bool.swift +46 -0
  50. package/nitrogen/generated/ios/swift/GenerationStats.swift +54 -0
  51. package/nitrogen/generated/ios/swift/HybridLiteRTLMSpec.swift +69 -0
  52. package/nitrogen/generated/ios/swift/HybridLiteRTLMSpec_cxx.swift +383 -0
  53. package/nitrogen/generated/ios/swift/LLMConfig.swift +203 -0
  54. package/nitrogen/generated/ios/swift/MemoryUsage.swift +44 -0
  55. package/nitrogen/generated/ios/swift/Message.swift +34 -0
  56. package/nitrogen/generated/ios/swift/MultimodalPart.swift +83 -0
  57. package/nitrogen/generated/ios/swift/PartType.swift +44 -0
  58. package/nitrogen/generated/ios/swift/Role.swift +44 -0
  59. package/nitrogen/generated/ios/swift/ToolDefinition.swift +39 -0
  60. package/nitrogen/generated/shared/c++/HybridLiteRTLMSpec.cpp +2 -0
  61. package/nitrogen/generated/shared/c++/HybridLiteRTLMSpec.hpp +7 -2
  62. package/nitrogen/generated/shared/c++/LLMConfig.hpp +22 -2
  63. package/nitrogen/generated/shared/c++/MultimodalPart.hpp +99 -0
  64. package/nitrogen/generated/shared/c++/PartType.hpp +80 -0
  65. package/nitrogen/generated/shared/c++/ToolDefinition.hpp +91 -0
  66. package/package.json +16 -8
  67. package/react-native-litert-lm.podspec +15 -19
  68. package/scripts/download-ios-frameworks.sh +14 -48
  69. package/scripts/postinstall.js +1 -2
  70. package/src/__mocks__/react-native-nitro-modules.ts +48 -0
  71. package/src/__tests__/hooks.test.ts +153 -0
  72. package/src/__tests__/memoryTracker.test.ts +87 -0
  73. package/src/__tests__/modelFactory.test.ts +68 -0
  74. package/src/hooks.ts +1 -1
  75. package/src/index.ts +0 -7
  76. package/src/modelFactory.ts +82 -80
  77. package/src/specs/LiteRTLM.nitro.ts +80 -2
  78. package/cpp/HybridLiteRTLM.cpp +0 -939
  79. package/cpp/HybridLiteRTLM.hpp +0 -169
  80. package/cpp/IOSDownloadHelper.h +0 -24
  81. package/ios/IOSDownloadHelper.mm +0 -129
  82. package/scripts/build-ios-engine.sh +0 -302
  83. package/scripts/stubs/cxx_bridge_stubs.cc +0 -224
  84. package/scripts/stubs/gemma_model_constraint_provider.cc +0 -46
  85. package/scripts/stubs/llguidance_stubs.c +0 -101
  86. package/src/templates.ts +0 -105
@@ -27,6 +27,10 @@ import com.margelo.nitro.dev.litert.litertlm.Role
27
27
  import com.margelo.nitro.core.Promise
28
28
  import com.google.ai.edge.litertlm.Content
29
29
  import com.google.ai.edge.litertlm.Contents
30
+ import com.google.ai.edge.litertlm.ExperimentalApi
31
+ import com.google.ai.edge.litertlm.ExperimentalFlags
32
+ import com.google.ai.edge.litertlm.OpenApiTool
33
+ import com.google.ai.edge.litertlm.ToolProvider
30
34
  import java.util.concurrent.CountDownLatch
31
35
  import java.util.concurrent.TimeUnit
32
36
  import java.util.concurrent.atomic.AtomicBoolean
@@ -167,6 +171,8 @@ class HybridLiteRTLM : HybridLiteRTLMSpec() {
167
171
  private var topP: Double = 0.95
168
172
  private var maxTokens: Int = 1024
169
173
  private var systemPrompt: String? = null
174
+ private var tools: Array<ToolDefinition>? = null
175
+ private var enableSpeculativeDecoding: Boolean = false
170
176
 
171
177
  override val memorySize: Long
172
178
  get() = 1024L * 1024L * 1024L // ~1GB (models are large)
@@ -196,8 +202,13 @@ class HybridLiteRTLM : HybridLiteRTLMSpec() {
196
202
  cfg.topP?.let { topP = it }
197
203
  cfg.maxTokens?.let { maxTokens = it.toInt() }
198
204
  cfg.systemPrompt?.let { systemPrompt = it }
205
+ tools = cfg.tools
206
+ enableSpeculativeDecoding = cfg.enableSpeculativeDecoding ?: false
199
207
  }
200
208
 
209
+ // Whether to run engine validation after loading
210
+ val shouldValidate = config?.validate?: false
211
+
201
212
  try {
202
213
  // Early GPU hardware check: probe for OpenCL library before
203
214
  // spending time on engine creation. LiteRT-LM's GPU delegate
@@ -238,12 +249,12 @@ class HybridLiteRTLM : HybridLiteRTLMSpec() {
238
249
  else -> com.google.ai.edge.litertlm.Backend.CPU()
239
250
  }
240
251
 
241
- // Detect multimodal support from model filename.
252
+ // Detect multimodal support. Check config.multimodal flag first, then fall back to filename sniffing.
242
253
  // Only Gemma 3n bundles vision/audio executors; Gemma 4 E2B is text-only.
243
254
  // Passing vision/audio backends to a text-only model causes
244
255
  // vision_litert_compiled_model_executor init failures.
245
256
  val modelFileName = modelPath.substringAfterLast("/").lowercase()
246
- val isMultimodal = modelFileName.contains("3n") || modelFileName.contains("gemma3")
257
+ val isMultimodal = config?.multimodal ?: (modelFileName.contains("3n") || modelFileName.contains("gemma3"))
247
258
 
248
259
  val lmVisionBackend = if (isMultimodal) com.google.ai.edge.litertlm.Backend.GPU() else null
249
260
  val lmAudioBackend = if (isMultimodal) com.google.ai.edge.litertlm.Backend.CPU() else null
@@ -275,6 +286,11 @@ class HybridLiteRTLM : HybridLiteRTLMSpec() {
275
286
 
276
287
  if (isClosed) return@synchronized
277
288
 
289
+ if (enableSpeculativeDecoding) {
290
+ @OptIn(ExperimentalApi::class)
291
+ ExperimentalFlags.enableSpeculativeDecoding = true
292
+ }
293
+
278
294
  // Initialize Engine
279
295
  engine = Engine(engineConfig).also { it.initialize() }
280
296
  Log.i(TAG, "Engine created and initialized successfully")
@@ -284,8 +300,17 @@ class HybridLiteRTLM : HybridLiteRTLMSpec() {
284
300
  Log.i(TAG, "Conversation created successfully")
285
301
 
286
302
  // Validate the engine actually works with a quick test inference.
287
- // GPU backend can initialize without error but silently fail to produce tokens.
288
- validateEngine()
303
+ // GPU/NPU backends can initialize without error but silently fail to
304
+ // produce tokens — enabling this catches those failures at load time.
305
+ // CPU is always reliable so validation is never run on it, even when
306
+ // the `validate` flag is set.
307
+ if (shouldValidate) {
308
+ if (backend == Backend.GPU || backend == Backend.NPU) {
309
+ validateEngine()
310
+ } else {
311
+ Log.i(TAG, "Validation skipped: CPU backend is always reliable")
312
+ }
313
+ }
289
314
 
290
315
  } catch (e: Exception) {
291
316
  Log.e(TAG, "Failed to load model: ${e.message}", e)
@@ -349,32 +374,48 @@ class HybridLiteRTLM : HybridLiteRTLMSpec() {
349
374
  // -------------------------------------------------------------------------
350
375
  // sendMessageAsync - Streaming inference
351
376
  // -------------------------------------------------------------------------
352
- override fun sendMessageAsync(message: String, onToken: (String, Boolean) -> Unit) {
353
- // This is already async (void return), so we execute immediately on the calling thread
354
- // (which is the Nitro specialized thread, not Main).
355
- // The SDK's sendMessageAsync is non-blocking anyway.
356
- ensureLoaded()
377
+ override fun sendMessageAsync(message: String, onToken: (String, Boolean) -> Unit): Promise<Unit> {
378
+ return Promise.parallel {
379
+ val latch = CountDownLatch(1)
380
+ val errorRef = AtomicReference<Throwable?>(null)
357
381
 
358
- // Add user message to history
359
- history.add(Message(Role.USER, message))
360
- Log.d(TAG, "sendMessageAsync: $message")
382
+ ensureLoaded()
361
383
 
362
- val fullResponseBuilder = StringBuilder()
363
-
364
- val listener = StreamingCallbackListener(
365
- onToken = onToken,
366
- responseBuilder = fullResponseBuilder,
367
- history = history,
368
- userMessage = message,
369
- onStatsReady = { stats -> lastStats = stats },
370
- )
384
+ // Add user message to history
385
+ history.add(Message(Role.USER, message))
386
+ Log.d(TAG, "sendMessageAsync: $message")
371
387
 
372
- try {
373
- val userMsg = LiteRTMessage.user(message)
374
- conversation!!.sendMessageAsync(message = userMsg, callback = listener)
375
- } catch (e: Exception) {
376
- Log.e(TAG, "Failed to initiate async generation", e)
377
- onToken("Error: ${e.message}", true)
388
+ val fullResponseBuilder = StringBuilder()
389
+
390
+ val listener = StreamingCallbackListener(
391
+ onToken = { token, done ->
392
+ onToken(token, done)
393
+ if (done) {
394
+ latch.countDown()
395
+ }
396
+ },
397
+ responseBuilder = fullResponseBuilder,
398
+ history = history,
399
+ userMessage = message,
400
+ onStatsReady = { stats -> lastStats = stats },
401
+ )
402
+
403
+ try {
404
+ val userMsg = LiteRTMessage.user(message)
405
+ conversation!!.sendMessageAsync(message = userMsg, callback = listener)
406
+ } catch (e: Exception) {
407
+ Log.e(TAG, "Failed to initiate async generation", e)
408
+ errorRef.set(e)
409
+ onToken("Error: ${e.message}", true)
410
+ latch.countDown()
411
+ }
412
+
413
+ // Wait for completion or error
414
+ latch.await()
415
+ val err = errorRef.get()
416
+ if (err != null) {
417
+ throw RuntimeException("Async inference failed: ${err.message}", err)
418
+ }
378
419
  }
379
420
  }
380
421
 
@@ -464,6 +505,14 @@ class HybridLiteRTLM : HybridLiteRTLMSpec() {
464
505
  return Promise.parallel {
465
506
  Log.i(TAG, "downloadModel: $url -> $fileName")
466
507
 
508
+ if (!url.startsWith("https://", ignoreCase = true)) {
509
+ throw IllegalArgumentException("Invalid download URL: HTTPS is required for security.")
510
+ }
511
+
512
+ if (fileName.contains("..") || fileName.contains("/") || fileName.contains("\\")) {
513
+ throw IllegalArgumentException("Invalid filename: path traversal or directory separators are not allowed.")
514
+ }
515
+
467
516
  val context = LiteRTLMInitProvider.applicationContext ?: throw RuntimeException("Context not available")
468
517
  val modelsDir = java.io.File(context.filesDir, "models")
469
518
  if (!modelsDir.exists()) {
@@ -545,6 +594,11 @@ class HybridLiteRTLM : HybridLiteRTLMSpec() {
545
594
  override fun deleteModel(fileName: String): Promise<Unit> {
546
595
  return Promise.parallel {
547
596
  Log.i(TAG, "deleteModel: $fileName")
597
+
598
+ if (fileName.contains("..") || fileName.contains("/") || fileName.contains("\\")) {
599
+ throw IllegalArgumentException("Invalid filename: path traversal or directory separators are not allowed.")
600
+ }
601
+
548
602
  val context = LiteRTLMInitProvider.applicationContext ?: throw RuntimeException("Context not available")
549
603
  val modelsDir = java.io.File(context.filesDir, "models")
550
604
  val modelFile = java.io.File(modelsDir, fileName)
@@ -676,19 +730,9 @@ class HybridLiteRTLM : HybridLiteRTLMSpec() {
676
730
 
677
731
  private fun cleanupInternal() {
678
732
  try {
733
+ conversation?.close()
679
734
  conversation = null
680
- // Explicitly close engine if it supports it to free native memory immediately
681
- // Assuming Engine implements AutoCloseable or has close()
682
- if (engine is AutoCloseable) {
683
- (engine as AutoCloseable).close()
684
- } else {
685
- // Try reflection or just null it if no close method
686
- try {
687
- engine?.javaClass?.getMethod("close")?.invoke(engine)
688
- } catch (e: Exception) {
689
- // Method not found, rely on GC
690
- }
691
- }
735
+ engine?.close() // Direct call
692
736
  engine = null
693
737
  } catch (e: Exception) {
694
738
  Log.e(TAG, "Error closing resources", e)
@@ -706,41 +750,37 @@ class HybridLiteRTLM : HybridLiteRTLMSpec() {
706
750
  // v0.10.2 enforces single-session: close existing conversation first
707
751
  conversation?.let { oldConv ->
708
752
  try {
709
- if (oldConv is AutoCloseable) {
710
- oldConv.close()
711
- } else {
712
- oldConv.javaClass.getMethod("close").invoke(oldConv)
713
- }
753
+ oldConv.close()
714
754
  } catch (e: Exception) {
715
755
  Log.w(TAG, "Failed to close old conversation: ${e.message}")
716
756
  }
717
757
  conversation = null
718
758
  }
759
+ // Map tools
760
+ val lmTools: List<ToolProvider>? = tools?.map { tool ->
761
+ val apiTool = object : OpenApiTool {
762
+ override fun getToolDescriptionJsonString(): String {
763
+ return tool.parametersJson
764
+ }
765
+ override fun execute(paramsJsonString: String): String {
766
+ return "{}"
767
+ }
768
+ }
769
+ (apiTool as Any) as ToolProvider
770
+ }
771
+
719
772
  // Create conversation with explicit SamplerConfig (required by Gallery pattern).
720
773
  // GPU backend may fail silently without proper sampler params.
721
774
  val convConfig = ConversationConfig(
722
775
  samplerConfig = SamplerConfig(
723
776
  topK = topK,
724
- topP = topP,
725
- temperature = temperature,
726
- )
777
+ topP = topP.toDouble(),
778
+ temperature = temperature.toDouble(),
779
+ ),
780
+ systemInstruction = systemPrompt?.let { Contents.of(Content.Text(it)) },
781
+ tools = lmTools ?: emptyList()
727
782
  )
728
783
  conversation = engine!!.createConversation(convConfig)
729
- // Apply system prompt/instruction if set
730
- systemPrompt?.let { prompt ->
731
- if (prompt.isNotEmpty()) {
732
- try {
733
- // Send system instruction as the first turn to prime the conversation.
734
- // LiteRT-LM's Conversation API handles chat template formatting,
735
- // including Gemma's <start_of_turn>system block.
736
- val systemMsg = LiteRTMessage.system(prompt)
737
- conversation!!.sendMessage(message = systemMsg)
738
- Log.i(TAG, "System prompt applied (${prompt.length} chars)")
739
- } catch (e: Exception) {
740
- Log.w(TAG, "Failed to apply system prompt: ${e.message}")
741
- }
742
- }
743
- }
744
784
  }
745
785
 
746
786
  /**
@@ -815,5 +855,62 @@ class HybridLiteRTLM : HybridLiteRTLMSpec() {
815
855
  createNewConversation()
816
856
  }
817
857
 
858
+ override fun sendMultimodalMessage(parts: Array<MultimodalPart>): Promise<String> {
859
+ return Promise.parallel {
860
+ ensureLoaded()
861
+ val contents = mutableListOf<Content>()
862
+ var userTextRepresentation = ""
863
+ for (part in parts) {
864
+ when (part.type) {
865
+ PartType.TEXT -> part.text?.let {
866
+ contents.add(Content.Text(it))
867
+ userTextRepresentation += "$it "
868
+ }
869
+ PartType.IMAGE -> part.imageBuffer?.let { buffer ->
870
+ val byteBuffer = buffer.getBuffer(false)
871
+ val bytes = ByteArray(byteBuffer.remaining())
872
+ byteBuffer.get(bytes)
873
+ contents.add(Content.ImageBytes(bytes))
874
+ userTextRepresentation += "[Image Buffer] "
875
+ }
876
+ PartType.AUDIO -> part.audioBuffer?.let { buffer ->
877
+ val byteBuffer = buffer.getBuffer(false)
878
+ val bytes = ByteArray(byteBuffer.remaining())
879
+ byteBuffer.get(bytes)
880
+ contents.add(Content.AudioBytes(bytes))
881
+ userTextRepresentation += "[Audio Buffer] "
882
+ }
883
+ }
884
+ }
885
+ userTextRepresentation = userTextRepresentation.trim()
886
+ history.add(Message(Role.USER, userTextRepresentation))
887
+
888
+ val userMsg = LiteRTMessage.user(Contents.of(contents))
889
+ val startTime = System.nanoTime()
890
+ val responseMsg = conversation!!.sendMessage(message = userMsg)
891
+ val elapsedMs = (System.nanoTime() - startTime) / 1_000_000.0
892
+
893
+ val response = responseMsg.contents.contents
894
+ .filterIsInstance<Content.Text>()
895
+ .joinToString("") { it.text }
896
+
897
+ history.add(Message(Role.MODEL, response))
898
+
899
+ val promptTokens = userTextRepresentation.length / 4.0
900
+ val completionTokens = response.length / 4.0
901
+ lastStats = GenerationStats(
902
+ promptTokens = promptTokens,
903
+ completionTokens = completionTokens,
904
+ totalTokens = promptTokens + completionTokens,
905
+ timeToFirstToken = 0.0,
906
+ totalTime = elapsedMs,
907
+ tokensPerSecond = if (elapsedMs > 0) completionTokens / (elapsedMs / 1000.0) else 0.0
908
+ )
909
+ response
910
+ }
911
+ }
818
912
 
913
+ override fun countTokens(text: String): Double {
914
+ return -1.0
915
+ }
819
916
  }
@@ -1,18 +1,35 @@
1
1
  package dev.litert.litertlm
2
2
 
3
+ import android.os.Build
4
+ import android.util.Log
3
5
  import com.facebook.react.TurboReactPackage
4
6
  import com.facebook.react.bridge.NativeModule
5
7
  import com.facebook.react.bridge.ReactApplicationContext
6
8
  import com.facebook.react.module.model.ReactModuleInfo
7
9
  import com.facebook.react.module.model.ReactModuleInfoProvider
8
- import com.margelo.nitro.core.HybridObject
9
10
 
10
11
 
11
12
  import com.margelo.nitro.dev.litert.litertlm.LiteRTLMOnLoad
12
13
 
13
14
  class LiteRTLMPackage : TurboReactPackage() {
15
+ companion object {
16
+ private const val TAG = "LiteRTLMPackage"
17
+
18
+ private fun isSupportedPrimaryAbi(): Boolean {
19
+ val primaryAbi = Build.SUPPORTED_64_BIT_ABIS.firstOrNull() ?: return false
20
+ return primaryAbi == "arm64-v8a"
21
+ }
22
+ }
14
23
  init {
15
- LiteRTLMOnLoad.initializeNative()
24
+ if (!isSupportedPrimaryAbi()) {
25
+ Log.w(TAG, "Skipping LiteRTLM native init on unsupported primary ABI: ${Build.SUPPORTED_64_BIT_ABIS.firstOrNull()}")
26
+ } else {
27
+ try {
28
+ LiteRTLMOnLoad.initializeNative()
29
+ } catch (e: UnsatisfiedLinkError) {
30
+ Log.e(TAG, "LiteRTLM native init failed; disabling LiteRTLM for this process.", e)
31
+ }
32
+ }
16
33
  }
17
34
 
18
35
 
@@ -0,0 +1,46 @@
1
+ package com.margelo.nitro.core
2
+
3
+ import androidx.annotation.Keep
4
+ import com.facebook.proguard.annotations.DoNotStrip
5
+
6
+ @Keep
7
+ @DoNotStrip
8
+ class Promise<T> {
9
+ companion object {
10
+ @JvmStatic
11
+ fun <T> parallel(block: () -> T): Promise<T> {
12
+ val promise = Promise<T>()
13
+ try {
14
+ val result = block()
15
+ promise.resolve(result)
16
+ } catch (e: Throwable) {
17
+ promise.reject(e)
18
+ }
19
+ return promise
20
+ }
21
+ }
22
+
23
+ var result: T? = null
24
+ private set
25
+ var error: Throwable? = null
26
+ private set
27
+ var isCompleted = false
28
+ private set
29
+ private val callbacks = mutableListOf<(T?, Throwable?) -> Unit>()
30
+
31
+ fun resolve(value: T) {
32
+ synchronized(this) {
33
+ result = value
34
+ isCompleted = true
35
+ callbacks.forEach { it(value, null) }
36
+ }
37
+ }
38
+
39
+ fun reject(exception: Throwable) {
40
+ synchronized(this) {
41
+ error = exception
42
+ isCompleted = true
43
+ callbacks.forEach { it(null, exception) }
44
+ }
45
+ }
46
+ }
@@ -0,0 +1,83 @@
1
+ package com.margelo.nitro.dev.litert.litertlm
2
+
3
+ import org.junit.Assert.*
4
+ import org.junit.Before
5
+ import org.junit.After
6
+ import org.junit.Test
7
+ import org.junit.runner.RunWith
8
+ import org.robolectric.RobolectricTestRunner
9
+ import org.robolectric.RuntimeEnvironment
10
+ import dev.litert.litertlm.LiteRTLMInitProvider
11
+ import java.lang.IllegalArgumentException
12
+
13
+ @RunWith(RobolectricTestRunner::class)
14
+ class HybridLiteRTLMTest {
15
+ private lateinit var bridge: HybridLiteRTLM
16
+
17
+ @Before
18
+ fun setUp() {
19
+ // Initialize the static applicationContext inside LiteRTLMInitProvider via reflection
20
+ try {
21
+ val field = LiteRTLMInitProvider::class.java.getDeclaredField("applicationContext")
22
+ field.isAccessible = true
23
+ field.set(null, RuntimeEnvironment.getApplication())
24
+ } catch (e: Exception) {
25
+ e.printStackTrace()
26
+ }
27
+
28
+ bridge = HybridLiteRTLM()
29
+ }
30
+
31
+ @After
32
+ fun tearDown() {
33
+ bridge.close()
34
+ }
35
+
36
+ @Test
37
+ fun testAndroidPathTraversalPrevention() {
38
+ val traversals = arrayOf("../secret", "/etc/hosts", "nested\\..\\file", "..", "../", "..\\")
39
+ for (traversal in traversals) {
40
+ val promise = bridge.deleteModel(traversal)
41
+ assertNotNull("Promise should not be null", promise)
42
+ assertTrue("Promise should be completed", promise.isCompleted)
43
+ assertNotNull("Promise should have rejected with an error for filename: $traversal", promise.error)
44
+ val error = promise.error!!
45
+ val errMsg = error.message ?: error.cause?.message ?: ""
46
+ assertTrue("Expected message to contain traversal warning, got: $errMsg",
47
+ errMsg.contains("path traversal or directory separators are not allowed"))
48
+ }
49
+ }
50
+
51
+ @Test
52
+ fun testAndroidHTTPSDownloadEnforcement() {
53
+ val promise = bridge.downloadModel("http://insecure.site/model.bin", "model.bin", null)
54
+ assertNotNull("Promise should not be null", promise)
55
+ assertTrue("Promise should be completed", promise.isCompleted)
56
+ assertNotNull("Promise should have rejected with an error", promise.error)
57
+ val error = promise.error!!
58
+ val errMsg = error.message ?: error.cause?.message ?: ""
59
+ assertTrue("Expected message to contain HTTPS warning, got: $errMsg",
60
+ errMsg.contains("HTTPS is required for security"))
61
+ }
62
+
63
+ @Test
64
+ fun testAndroidMemoryTelemetry() {
65
+ val mem = bridge.getMemoryUsage()
66
+ assertNotNull(mem)
67
+ assertTrue(mem.nativeHeapBytes >= 0.0)
68
+ assertTrue(mem.residentBytes >= 0.0)
69
+ assertTrue(mem.availableMemoryBytes >= 0.0)
70
+ }
71
+
72
+ @Test
73
+ fun testAndroidInitialStats() {
74
+ val stats = bridge.getStats()
75
+ assertNotNull(stats)
76
+ assertEquals(0.0, stats.promptTokens, 0.0)
77
+ assertEquals(0.0, stats.completionTokens, 0.0)
78
+ assertEquals(0.0, stats.totalTokens, 0.0)
79
+ assertEquals(0.0, stats.timeToFirstToken, 0.0)
80
+ assertEquals(0.0, stats.totalTime, 0.0)
81
+ assertEquals(0.0, stats.tokensPerSecond, 0.0)
82
+ }
83
+ }