react-native-litert-lm 0.3.6 → 0.4.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (87) hide show
  1. package/README.md +207 -158
  2. package/android/build.gradle +12 -0
  3. package/android/src/main/AndroidManifest.xml +5 -0
  4. package/android/src/main/java/com/margelo/nitro/dev/litert/litertlm/HybridLiteRTLM.kt +316 -63
  5. package/android/src/main/java/dev/litert/litertlm/LiteRTLMPackage.kt +19 -2
  6. package/android/src/test/java/com/margelo/nitro/core/Promise.kt +46 -0
  7. package/android/src/test/java/com/margelo/nitro/dev/litert/litertlm/HybridLiteRTLMTest.kt +83 -0
  8. package/cpp/include/README.md +9 -11
  9. package/ios/HybridLiteRTLM.swift +1058 -0
  10. package/ios/Tests/HybridLiteRTLMTests.swift +67 -0
  11. package/lib/__mocks__/react-native-nitro-modules.d.ts +61 -0
  12. package/lib/__mocks__/react-native-nitro-modules.js +50 -0
  13. package/lib/__tests__/hooks.test.d.ts +1 -0
  14. package/lib/__tests__/hooks.test.js +124 -0
  15. package/lib/__tests__/memoryTracker.test.d.ts +1 -0
  16. package/lib/__tests__/memoryTracker.test.js +74 -0
  17. package/lib/__tests__/modelFactory.test.d.ts +1 -0
  18. package/lib/__tests__/modelFactory.test.js +52 -0
  19. package/lib/hooks.js +1 -1
  20. package/lib/index.d.ts +2 -4
  21. package/lib/index.js +12 -7
  22. package/lib/modelFactory.js +62 -63
  23. package/lib/specs/LiteRTLM.nitro.d.ts +71 -2
  24. package/nitrogen/generated/android/c++/JHybridLiteRTLMSpec.cpp +62 -7
  25. package/nitrogen/generated/android/c++/JHybridLiteRTLMSpec.hpp +3 -1
  26. package/nitrogen/generated/android/c++/JLLMConfig.hpp +40 -3
  27. package/nitrogen/generated/android/c++/JMultimodalPart.hpp +74 -0
  28. package/nitrogen/generated/android/c++/JPartType.hpp +61 -0
  29. package/nitrogen/generated/android/c++/JToolDefinition.hpp +65 -0
  30. package/nitrogen/generated/android/kotlin/com/margelo/nitro/dev/litert/litertlm/GenerationStats.kt +23 -0
  31. package/nitrogen/generated/android/kotlin/com/margelo/nitro/dev/litert/litertlm/HybridLiteRTLMSpec.kt +10 -2
  32. package/nitrogen/generated/android/kotlin/com/margelo/nitro/dev/litert/litertlm/LLMConfig.kt +46 -3
  33. package/nitrogen/generated/android/kotlin/com/margelo/nitro/dev/litert/litertlm/MemoryUsage.kt +19 -0
  34. package/nitrogen/generated/android/kotlin/com/margelo/nitro/dev/litert/litertlm/Message.kt +15 -0
  35. package/nitrogen/generated/android/kotlin/com/margelo/nitro/dev/litert/litertlm/MultimodalPart.kt +66 -0
  36. package/nitrogen/generated/android/kotlin/com/margelo/nitro/dev/litert/litertlm/PartType.kt +24 -0
  37. package/nitrogen/generated/android/kotlin/com/margelo/nitro/dev/litert/litertlm/ToolDefinition.kt +61 -0
  38. package/nitrogen/generated/ios/LiteRTLM-Swift-Cxx-Bridge.cpp +57 -1
  39. package/nitrogen/generated/ios/LiteRTLM-Swift-Cxx-Bridge.hpp +414 -3
  40. package/nitrogen/generated/ios/LiteRTLM-Swift-Cxx-Umbrella.hpp +41 -3
  41. package/nitrogen/generated/ios/LiteRTLMAutolinking.mm +4 -6
  42. package/nitrogen/generated/ios/LiteRTLMAutolinking.swift +10 -0
  43. package/nitrogen/generated/ios/c++/HybridLiteRTLMSpecSwift.cpp +11 -0
  44. package/nitrogen/generated/ios/c++/HybridLiteRTLMSpecSwift.hpp +224 -0
  45. package/nitrogen/generated/ios/swift/Backend.swift +44 -0
  46. package/nitrogen/generated/ios/swift/Func_void.swift +46 -0
  47. package/nitrogen/generated/ios/swift/Func_void_double.swift +46 -0
  48. package/nitrogen/generated/ios/swift/Func_void_std__exception_ptr.swift +46 -0
  49. package/nitrogen/generated/ios/swift/Func_void_std__string.swift +46 -0
  50. package/nitrogen/generated/ios/swift/Func_void_std__string_bool.swift +46 -0
  51. package/nitrogen/generated/ios/swift/GenerationStats.swift +54 -0
  52. package/nitrogen/generated/ios/swift/HybridLiteRTLMSpec.swift +69 -0
  53. package/nitrogen/generated/ios/swift/HybridLiteRTLMSpec_cxx.swift +383 -0
  54. package/nitrogen/generated/ios/swift/LLMConfig.swift +203 -0
  55. package/nitrogen/generated/ios/swift/MemoryUsage.swift +44 -0
  56. package/nitrogen/generated/ios/swift/Message.swift +34 -0
  57. package/nitrogen/generated/ios/swift/MultimodalPart.swift +83 -0
  58. package/nitrogen/generated/ios/swift/PartType.swift +44 -0
  59. package/nitrogen/generated/ios/swift/Role.swift +44 -0
  60. package/nitrogen/generated/ios/swift/ToolDefinition.swift +39 -0
  61. package/nitrogen/generated/shared/c++/HybridLiteRTLMSpec.cpp +2 -0
  62. package/nitrogen/generated/shared/c++/HybridLiteRTLMSpec.hpp +7 -2
  63. package/nitrogen/generated/shared/c++/LLMConfig.hpp +22 -2
  64. package/nitrogen/generated/shared/c++/MultimodalPart.hpp +99 -0
  65. package/nitrogen/generated/shared/c++/PartType.hpp +80 -0
  66. package/nitrogen/generated/shared/c++/ToolDefinition.hpp +91 -0
  67. package/package.json +16 -8
  68. package/react-native-litert-lm.podspec +15 -19
  69. package/scripts/download-ios-frameworks.sh +14 -48
  70. package/scripts/postinstall.js +1 -2
  71. package/src/__mocks__/react-native-nitro-modules.ts +48 -0
  72. package/src/__tests__/hooks.test.ts +153 -0
  73. package/src/__tests__/memoryTracker.test.ts +87 -0
  74. package/src/__tests__/modelFactory.test.ts +68 -0
  75. package/src/hooks.ts +1 -1
  76. package/src/index.ts +12 -9
  77. package/src/modelFactory.ts +82 -80
  78. package/src/specs/LiteRTLM.nitro.ts +80 -2
  79. package/cpp/HybridLiteRTLM.cpp +0 -838
  80. package/cpp/HybridLiteRTLM.hpp +0 -167
  81. package/cpp/IOSDownloadHelper.h +0 -24
  82. package/ios/IOSDownloadHelper.mm +0 -129
  83. package/scripts/build-ios-engine.sh +0 -302
  84. package/scripts/stubs/cxx_bridge_stubs.cc +0 -224
  85. package/scripts/stubs/gemma_model_constraint_provider.cc +0 -46
  86. package/scripts/stubs/llguidance_stubs.c +0 -101
  87. package/src/templates.ts +0 -105
@@ -17,6 +17,7 @@ import com.google.ai.edge.litertlm.Engine
17
17
  import com.google.ai.edge.litertlm.Conversation
18
18
  import com.google.ai.edge.litertlm.EngineConfig
19
19
  import com.google.ai.edge.litertlm.ConversationConfig
20
+ import com.google.ai.edge.litertlm.SamplerConfig
20
21
  import com.margelo.nitro.dev.litert.litertlm.Backend
21
22
  import com.margelo.nitro.dev.litert.litertlm.GenerationStats
22
23
  import com.margelo.nitro.dev.litert.litertlm.HybridLiteRTLMSpec
@@ -25,6 +26,15 @@ import com.margelo.nitro.dev.litert.litertlm.Message
25
26
  import com.margelo.nitro.dev.litert.litertlm.Role
26
27
  import com.margelo.nitro.core.Promise
27
28
  import com.google.ai.edge.litertlm.Content
29
+ import com.google.ai.edge.litertlm.Contents
30
+ import com.google.ai.edge.litertlm.ExperimentalApi
31
+ import com.google.ai.edge.litertlm.ExperimentalFlags
32
+ import com.google.ai.edge.litertlm.OpenApiTool
33
+ import com.google.ai.edge.litertlm.ToolProvider
34
+ import java.util.concurrent.CountDownLatch
35
+ import java.util.concurrent.TimeUnit
36
+ import java.util.concurrent.atomic.AtomicBoolean
37
+ import java.util.concurrent.atomic.AtomicReference
28
38
 
29
39
 
30
40
  // Alias to avoid confusion with our generated Message type
@@ -42,13 +52,26 @@ internal class StreamingCallbackListener(
42
52
  private val onToken: (String, Boolean) -> Unit,
43
53
  private val responseBuilder: StringBuilder,
44
54
  private val history: MutableList<Message>,
55
+ private val userMessage: String,
56
+ private val onStatsReady: (GenerationStats) -> Unit,
45
57
  ) : com.google.ai.edge.litertlm.MessageCallback {
46
58
 
47
- override fun onMessage(responseMsg: com.google.ai.edge.litertlm.Message) {
48
- val chunk = responseMsg.contents.contents
59
+ private val startTime = System.nanoTime()
60
+ private var firstTokenTime = 0L
61
+ private var tokenCount = 0
62
+
63
+ override fun onMessage(message: com.google.ai.edge.litertlm.Message) {
64
+ val chunk = message.contents.contents
49
65
  .filterIsInstance<com.google.ai.edge.litertlm.Content.Text>()
50
66
  .joinToString("") { it.text }
51
67
 
68
+ if (firstTokenTime == 0L && chunk.isNotEmpty()) {
69
+ firstTokenTime = System.nanoTime()
70
+ }
71
+ if (chunk.isNotEmpty()) {
72
+ tokenCount++
73
+ }
74
+
52
75
  onToken(chunk, false)
53
76
 
54
77
  if (chunk.isNotEmpty()) {
@@ -60,12 +83,27 @@ internal class StreamingCallbackListener(
60
83
  onToken("", true)
61
84
  val fullResponse = responseBuilder.toString()
62
85
  history.add(Message(Role.MODEL, fullResponse))
63
- Log.d("StreamingCallbackListener", "Streaming done. Length: ${fullResponse.length}")
86
+
87
+ // Compute stats using heuristic token counts (~4 chars/token)
88
+ val elapsedMs = (System.nanoTime() - startTime) / 1_000_000.0
89
+ val ttftMs = if (firstTokenTime > 0) (firstTokenTime - startTime) / 1_000_000.0 else 0.0
90
+ val promptTokens = userMessage.length / 4.0
91
+ val completionTokens = fullResponse.length / 4.0
92
+ onStatsReady(GenerationStats(
93
+ promptTokens = promptTokens,
94
+ completionTokens = completionTokens,
95
+ totalTokens = promptTokens + completionTokens,
96
+ timeToFirstToken = ttftMs,
97
+ totalTime = elapsedMs,
98
+ tokensPerSecond = if (elapsedMs > 0) completionTokens / (elapsedMs / 1000.0) else 0.0
99
+ ))
100
+
101
+ Log.d("StreamingCallbackListener", "Streaming done. Length: ${fullResponse.length}, TTFT: ${ttftMs.toLong()}ms, Total: ${elapsedMs.toLong()}ms")
64
102
  }
65
103
 
66
- override fun onError(t: Throwable) {
67
- Log.e("StreamingCallbackListener", "Async generation failed", t)
68
- onToken("Error: ${t.message}", true)
104
+ override fun onError(throwable: Throwable) {
105
+ Log.e("StreamingCallbackListener", "Async generation failed", throwable)
106
+ onToken("Error: ${throwable.message}", true)
69
107
  }
70
108
  }
71
109
 
@@ -80,6 +118,10 @@ class HybridLiteRTLM : HybridLiteRTLMSpec() {
80
118
  companion object {
81
119
  private const val TAG = "HybridLiteRTLM"
82
120
  private val initLock = Any()
121
+
122
+ /** Cached result of OpenCL availability probe (null = not yet checked). */
123
+ @Volatile
124
+ private var openCLAvailable: Boolean? = null
83
125
 
84
126
  /**
85
127
  * Initialize the native library.
@@ -129,6 +171,8 @@ class HybridLiteRTLM : HybridLiteRTLMSpec() {
129
171
  private var topP: Double = 0.95
130
172
  private var maxTokens: Int = 1024
131
173
  private var systemPrompt: String? = null
174
+ private var tools: Array<ToolDefinition>? = null
175
+ private var enableSpeculativeDecoding: Boolean = false
132
176
 
133
177
  override val memorySize: Long
134
178
  get() = 1024L * 1024L * 1024L // ~1GB (models are large)
@@ -158,9 +202,43 @@ class HybridLiteRTLM : HybridLiteRTLMSpec() {
158
202
  cfg.topP?.let { topP = it }
159
203
  cfg.maxTokens?.let { maxTokens = it.toInt() }
160
204
  cfg.systemPrompt?.let { systemPrompt = it }
205
+ tools = cfg.tools
206
+ enableSpeculativeDecoding = cfg.enableSpeculativeDecoding ?: false
161
207
  }
162
208
 
209
+ // Whether to run engine validation after loading
210
+ val shouldValidate = config?.validate?: false
211
+
163
212
  try {
213
+ // Early GPU hardware check: probe for OpenCL library before
214
+ // spending time on engine creation. LiteRT-LM's GPU delegate
215
+ // requires OpenCL, which is absent on most Samsung/Qualcomm devices.
216
+ if (backend == Backend.GPU) {
217
+ val hasOpenCL = openCLAvailable ?: run {
218
+ val result = try {
219
+ System.loadLibrary("OpenCL")
220
+ true
221
+ } catch (_: UnsatisfiedLinkError) {
222
+ try {
223
+ // Some devices have it at a non-standard path
224
+ System.load("/system/vendor/lib64/libOpenCL.so")
225
+ true
226
+ } catch (_: UnsatisfiedLinkError) {
227
+ false
228
+ }
229
+ }
230
+ openCLAvailable = result
231
+ result
232
+ }
233
+ if (!hasOpenCL) {
234
+ throw RuntimeException(
235
+ "GPU backend is not supported on this device (OpenCL library not found). " +
236
+ "Please use CPU backend instead."
237
+ )
238
+ }
239
+ Log.i(TAG, "OpenCL library found — GPU backend is available")
240
+ }
241
+
164
242
  // Map our Backend enum to LiteRT-LM Backend sealed class
165
243
  val lmBackend = when (backend) {
166
244
  Backend.GPU -> com.google.ai.edge.litertlm.Backend.GPU()
@@ -171,12 +249,12 @@ class HybridLiteRTLM : HybridLiteRTLMSpec() {
171
249
  else -> com.google.ai.edge.litertlm.Backend.CPU()
172
250
  }
173
251
 
174
- // Detect multimodal support from model filename.
252
+ // Detect multimodal support. Check config.multimodal flag first, then fall back to filename sniffing.
175
253
  // Only Gemma 3n bundles vision/audio executors; Gemma 4 E2B is text-only.
176
254
  // Passing vision/audio backends to a text-only model causes
177
255
  // vision_litert_compiled_model_executor init failures.
178
256
  val modelFileName = modelPath.substringAfterLast("/").lowercase()
179
- val isMultimodal = modelFileName.contains("3n") || modelFileName.contains("gemma3")
257
+ val isMultimodal = config?.multimodal ?: (modelFileName.contains("3n") || modelFileName.contains("gemma3"))
180
258
 
181
259
  val lmVisionBackend = if (isMultimodal) com.google.ai.edge.litertlm.Backend.GPU() else null
182
260
  val lmAudioBackend = if (isMultimodal) com.google.ai.edge.litertlm.Backend.CPU() else null
@@ -208,6 +286,11 @@ class HybridLiteRTLM : HybridLiteRTLMSpec() {
208
286
 
209
287
  if (isClosed) return@synchronized
210
288
 
289
+ if (enableSpeculativeDecoding) {
290
+ @OptIn(ExperimentalApi::class)
291
+ ExperimentalFlags.enableSpeculativeDecoding = true
292
+ }
293
+
211
294
  // Initialize Engine
212
295
  engine = Engine(engineConfig).also { it.initialize() }
213
296
  Log.i(TAG, "Engine created and initialized successfully")
@@ -215,9 +298,24 @@ class HybridLiteRTLM : HybridLiteRTLMSpec() {
215
298
  // Create Conversation
216
299
  createNewConversation()
217
300
  Log.i(TAG, "Conversation created successfully")
301
+
302
+ // Validate the engine actually works with a quick test inference.
303
+ // GPU/NPU backends can initialize without error but silently fail to
304
+ // produce tokens — enabling this catches those failures at load time.
305
+ // CPU is always reliable so validation is never run on it, even when
306
+ // the `validate` flag is set.
307
+ if (shouldValidate) {
308
+ if (backend == Backend.GPU || backend == Backend.NPU) {
309
+ validateEngine()
310
+ } else {
311
+ Log.i(TAG, "Validation skipped: CPU backend is always reliable")
312
+ }
313
+ }
218
314
 
219
315
  } catch (e: Exception) {
220
316
  Log.e(TAG, "Failed to load model: ${e.message}", e)
317
+ // Clean up partial state so isReady() returns false
318
+ cleanupInternal()
221
319
  throw RuntimeException("Failed to load model: ${e.message}", e)
222
320
  }
223
321
  }
@@ -241,7 +339,7 @@ class HybridLiteRTLM : HybridLiteRTLMSpec() {
241
339
  Log.i(TAG, "sendMessage (Promise): $message")
242
340
 
243
341
  // Blocking inference (safe here because we are in Promise.parallel worker thread)
244
- val userMsg = LiteRTMessage.of(text = message)
342
+ val userMsg = LiteRTMessage.user(message)
245
343
  val startTime = System.nanoTime()
246
344
  val responseMsg = conversation!!.sendMessage(message = userMsg)
247
345
  val elapsedMs = (System.nanoTime() - startTime) / 1_000_000.0
@@ -276,30 +374,48 @@ class HybridLiteRTLM : HybridLiteRTLMSpec() {
276
374
  // -------------------------------------------------------------------------
277
375
  // sendMessageAsync - Streaming inference
278
376
  // -------------------------------------------------------------------------
279
- override fun sendMessageAsync(message: String, onToken: (String, Boolean) -> Unit) {
280
- // This is already async (void return), so we execute immediately on the calling thread
281
- // (which is the Nitro specialized thread, not Main).
282
- // The SDK's sendMessageAsync is non-blocking anyway.
283
- ensureLoaded()
377
+ override fun sendMessageAsync(message: String, onToken: (String, Boolean) -> Unit): Promise<Unit> {
378
+ return Promise.parallel {
379
+ val latch = CountDownLatch(1)
380
+ val errorRef = AtomicReference<Throwable?>(null)
284
381
 
285
- // Add user message to history
286
- history.add(Message(Role.USER, message))
287
- Log.d(TAG, "sendMessageAsync: $message")
382
+ ensureLoaded()
288
383
 
289
- val fullResponseBuilder = StringBuilder()
290
-
291
- val listener = StreamingCallbackListener(
292
- onToken = onToken,
293
- responseBuilder = fullResponseBuilder,
294
- history = history,
295
- )
384
+ // Add user message to history
385
+ history.add(Message(Role.USER, message))
386
+ Log.d(TAG, "sendMessageAsync: $message")
296
387
 
297
- try {
298
- val userMsg = LiteRTMessage.of(text = message)
299
- conversation!!.sendMessageAsync(message = userMsg, callback = listener)
300
- } catch (e: Exception) {
301
- Log.e(TAG, "Failed to initiate async generation", e)
302
- onToken("Error: ${e.message}", true)
388
+ val fullResponseBuilder = StringBuilder()
389
+
390
+ val listener = StreamingCallbackListener(
391
+ onToken = { token, done ->
392
+ onToken(token, done)
393
+ if (done) {
394
+ latch.countDown()
395
+ }
396
+ },
397
+ responseBuilder = fullResponseBuilder,
398
+ history = history,
399
+ userMessage = message,
400
+ onStatsReady = { stats -> lastStats = stats },
401
+ )
402
+
403
+ try {
404
+ val userMsg = LiteRTMessage.user(message)
405
+ conversation!!.sendMessageAsync(message = userMsg, callback = listener)
406
+ } catch (e: Exception) {
407
+ Log.e(TAG, "Failed to initiate async generation", e)
408
+ errorRef.set(e)
409
+ onToken("Error: ${e.message}", true)
410
+ latch.countDown()
411
+ }
412
+
413
+ // Wait for completion or error
414
+ latch.await()
415
+ val err = errorRef.get()
416
+ if (err != null) {
417
+ throw RuntimeException("Async inference failed: ${err.message}", err)
418
+ }
303
419
  }
304
420
  }
305
421
 
@@ -359,7 +475,7 @@ class HybridLiteRTLM : HybridLiteRTLMSpec() {
359
475
  // Use factory method Message.of passing a list of Content
360
476
  val textContent = Content.Text(message)
361
477
 
362
- val userMsg = LiteRTMessage.of(textContent, Content.ImageFile(processedImagePath))
478
+ val userMsg = LiteRTMessage.user(Contents.of(textContent, Content.ImageFile(processedImagePath)))
363
479
 
364
480
  // Add to history
365
481
  history.add(Message(Role.USER, "$message [Image]"))
@@ -389,6 +505,14 @@ class HybridLiteRTLM : HybridLiteRTLMSpec() {
389
505
  return Promise.parallel {
390
506
  Log.i(TAG, "downloadModel: $url -> $fileName")
391
507
 
508
+ if (!url.startsWith("https://", ignoreCase = true)) {
509
+ throw IllegalArgumentException("Invalid download URL: HTTPS is required for security.")
510
+ }
511
+
512
+ if (fileName.contains("..") || fileName.contains("/") || fileName.contains("\\")) {
513
+ throw IllegalArgumentException("Invalid filename: path traversal or directory separators are not allowed.")
514
+ }
515
+
392
516
  val context = LiteRTLMInitProvider.applicationContext ?: throw RuntimeException("Context not available")
393
517
  val modelsDir = java.io.File(context.filesDir, "models")
394
518
  if (!modelsDir.exists()) {
@@ -470,6 +594,11 @@ class HybridLiteRTLM : HybridLiteRTLMSpec() {
470
594
  override fun deleteModel(fileName: String): Promise<Unit> {
471
595
  return Promise.parallel {
472
596
  Log.i(TAG, "deleteModel: $fileName")
597
+
598
+ if (fileName.contains("..") || fileName.contains("/") || fileName.contains("\\")) {
599
+ throw IllegalArgumentException("Invalid filename: path traversal or directory separators are not allowed.")
600
+ }
601
+
473
602
  val context = LiteRTLMInitProvider.applicationContext ?: throw RuntimeException("Context not available")
474
603
  val modelsDir = java.io.File(context.filesDir, "models")
475
604
  val modelFile = java.io.File(modelsDir, fileName)
@@ -501,10 +630,10 @@ class HybridLiteRTLM : HybridLiteRTLMSpec() {
501
630
 
502
631
  // Load audio
503
632
 
504
- val userMsg = LiteRTMessage.of(
633
+ val userMsg = LiteRTMessage.user(Contents.of(
505
634
  Content.Text(message),
506
635
  Content.AudioFile(audioPath)
507
- )
636
+ ))
508
637
 
509
638
  history.add(Message(Role.USER, "$message [Audio]"))
510
639
 
@@ -601,19 +730,9 @@ class HybridLiteRTLM : HybridLiteRTLMSpec() {
601
730
 
602
731
  private fun cleanupInternal() {
603
732
  try {
733
+ conversation?.close()
604
734
  conversation = null
605
- // Explicitly close engine if it supports it to free native memory immediately
606
- // Assuming Engine implements AutoCloseable or has close()
607
- if (engine is AutoCloseable) {
608
- (engine as AutoCloseable).close()
609
- } else {
610
- // Try reflection or just null it if no close method
611
- try {
612
- engine?.javaClass?.getMethod("close")?.invoke(engine)
613
- } catch (e: Exception) {
614
- // Method not found, rely on GC
615
- }
616
- }
735
+ engine?.close() // Direct call
617
736
  engine = null
618
737
  } catch (e: Exception) {
619
738
  Log.e(TAG, "Error closing resources", e)
@@ -631,33 +750,167 @@ class HybridLiteRTLM : HybridLiteRTLMSpec() {
631
750
  // v0.10.2 enforces single-session: close existing conversation first
632
751
  conversation?.let { oldConv ->
633
752
  try {
634
- if (oldConv is AutoCloseable) {
635
- oldConv.close()
636
- } else {
637
- oldConv.javaClass.getMethod("close").invoke(oldConv)
638
- }
753
+ oldConv.close()
639
754
  } catch (e: Exception) {
640
755
  Log.w(TAG, "Failed to close old conversation: ${e.message}")
641
756
  }
642
757
  conversation = null
643
758
  }
644
- conversation = engine!!.createConversation()
645
- // Apply system prompt/instruction if set
646
- systemPrompt?.let { prompt ->
647
- if (prompt.isNotEmpty()) {
648
- try {
649
- // Send system instruction as the first turn to prime the conversation.
650
- // LiteRT-LM's Conversation API handles chat template formatting,
651
- // including Gemma's <start_of_turn>system block.
652
- val systemMsg = LiteRTMessage.of(Content.Text(prompt))
653
- conversation!!.sendMessage(message = systemMsg)
654
- Log.i(TAG, "System prompt applied (${prompt.length} chars)")
655
- } catch (e: Exception) {
656
- Log.w(TAG, "Failed to apply system prompt: ${e.message}")
759
+ // Map tools
760
+ val lmTools: List<ToolProvider>? = tools?.map { tool ->
761
+ val apiTool = object : OpenApiTool {
762
+ override fun getToolDescriptionJsonString(): String {
763
+ return tool.parametersJson
764
+ }
765
+ override fun execute(paramsJsonString: String): String {
766
+ return "{}"
657
767
  }
658
768
  }
769
+ (apiTool as Any) as ToolProvider
659
770
  }
771
+
772
+ // Create conversation with explicit SamplerConfig (required by Gallery pattern).
773
+ // GPU backend may fail silently without proper sampler params.
774
+ val convConfig = ConversationConfig(
775
+ samplerConfig = SamplerConfig(
776
+ topK = topK,
777
+ topP = topP.toDouble(),
778
+ temperature = temperature.toDouble(),
779
+ ),
780
+ systemInstruction = systemPrompt?.let { Contents.of(Content.Text(it)) },
781
+ tools = lmTools ?: emptyList()
782
+ )
783
+ conversation = engine!!.createConversation(convConfig)
660
784
  }
661
785
 
786
+ /**
787
+ * Validate that the engine can actually produce inference output.
788
+ *
789
+ * Some GPU backends initialize without error but silently hang during inference.
790
+ * This sends a minimal test prompt ("Hi") and waits up to 30s for any token.
791
+ * If no token arrives, we throw so the model does NOT appear as loaded.
792
+ */
793
+ private fun validateEngine() {
794
+ val backendName = when (backend) {
795
+ Backend.GPU -> "GPU"
796
+ Backend.NPU -> "NPU"
797
+ else -> "CPU"
798
+ }
799
+ Log.i(TAG, "Validating $backendName backend with test inference...")
800
+
801
+ val latch = CountDownLatch(1)
802
+ val gotToken = AtomicBoolean(false)
803
+ val errorRef = AtomicReference<String?>(null)
804
+
805
+ // Use the existing conversation for validation (single-session constraint).
806
+ val validationConv = conversation
807
+ ?: throw RuntimeException("$backendName backend: no conversation available for validation")
808
+
809
+ try {
810
+ val testMsg = LiteRTMessage.user("Hi")
811
+ validationConv.sendMessageAsync(
812
+ message = testMsg,
813
+ callback = object : com.google.ai.edge.litertlm.MessageCallback {
814
+ override fun onMessage(msg: com.google.ai.edge.litertlm.Message) {
815
+ gotToken.set(true)
816
+ latch.countDown()
817
+ }
818
+ override fun onDone() {
819
+ latch.countDown()
820
+ }
821
+ override fun onError(t: Throwable) {
822
+ errorRef.set(t.message)
823
+ latch.countDown()
824
+ }
825
+ }
826
+ )
827
+ } catch (e: Exception) {
828
+ throw RuntimeException(
829
+ "$backendName backend failed to run inference: ${e.message}. " +
830
+ "This device may not support the $backendName backend. Please try CPU.",
831
+ e
832
+ )
833
+ }
662
834
 
835
+ // Wait up to 30s for any response
836
+ val completed = latch.await(30, TimeUnit.SECONDS)
837
+
838
+ val error = errorRef.get()
839
+ if (error != null) {
840
+ throw RuntimeException(
841
+ "$backendName backend inference error: $error. " +
842
+ "This device may not support the $backendName backend. Please try CPU."
843
+ )
844
+ }
845
+ if (!completed || !gotToken.get()) {
846
+ throw RuntimeException(
847
+ "$backendName backend produced no response within 30 seconds. " +
848
+ "This device may not support the $backendName backend. Please try CPU."
849
+ )
850
+ }
851
+
852
+ Log.i(TAG, "$backendName backend validated successfully")
853
+
854
+ // Re-create the real conversation (validation consumed one turn)
855
+ createNewConversation()
856
+ }
857
+
858
+ override fun sendMultimodalMessage(parts: Array<MultimodalPart>): Promise<String> {
859
+ return Promise.parallel {
860
+ ensureLoaded()
861
+ val contents = mutableListOf<Content>()
862
+ var userTextRepresentation = ""
863
+ for (part in parts) {
864
+ when (part.type) {
865
+ PartType.TEXT -> part.text?.let {
866
+ contents.add(Content.Text(it))
867
+ userTextRepresentation += "$it "
868
+ }
869
+ PartType.IMAGE -> part.imageBuffer?.let { buffer ->
870
+ val byteBuffer = buffer.getBuffer(false)
871
+ val bytes = ByteArray(byteBuffer.remaining())
872
+ byteBuffer.get(bytes)
873
+ contents.add(Content.ImageBytes(bytes))
874
+ userTextRepresentation += "[Image Buffer] "
875
+ }
876
+ PartType.AUDIO -> part.audioBuffer?.let { buffer ->
877
+ val byteBuffer = buffer.getBuffer(false)
878
+ val bytes = ByteArray(byteBuffer.remaining())
879
+ byteBuffer.get(bytes)
880
+ contents.add(Content.AudioBytes(bytes))
881
+ userTextRepresentation += "[Audio Buffer] "
882
+ }
883
+ }
884
+ }
885
+ userTextRepresentation = userTextRepresentation.trim()
886
+ history.add(Message(Role.USER, userTextRepresentation))
887
+
888
+ val userMsg = LiteRTMessage.user(Contents.of(contents))
889
+ val startTime = System.nanoTime()
890
+ val responseMsg = conversation!!.sendMessage(message = userMsg)
891
+ val elapsedMs = (System.nanoTime() - startTime) / 1_000_000.0
892
+
893
+ val response = responseMsg.contents.contents
894
+ .filterIsInstance<Content.Text>()
895
+ .joinToString("") { it.text }
896
+
897
+ history.add(Message(Role.MODEL, response))
898
+
899
+ val promptTokens = userTextRepresentation.length / 4.0
900
+ val completionTokens = response.length / 4.0
901
+ lastStats = GenerationStats(
902
+ promptTokens = promptTokens,
903
+ completionTokens = completionTokens,
904
+ totalTokens = promptTokens + completionTokens,
905
+ timeToFirstToken = 0.0,
906
+ totalTime = elapsedMs,
907
+ tokensPerSecond = if (elapsedMs > 0) completionTokens / (elapsedMs / 1000.0) else 0.0
908
+ )
909
+ response
910
+ }
911
+ }
912
+
913
+ override fun countTokens(text: String): Double {
914
+ return -1.0
915
+ }
663
916
  }
@@ -1,18 +1,35 @@
1
1
  package dev.litert.litertlm
2
2
 
3
+ import android.os.Build
4
+ import android.util.Log
3
5
  import com.facebook.react.TurboReactPackage
4
6
  import com.facebook.react.bridge.NativeModule
5
7
  import com.facebook.react.bridge.ReactApplicationContext
6
8
  import com.facebook.react.module.model.ReactModuleInfo
7
9
  import com.facebook.react.module.model.ReactModuleInfoProvider
8
- import com.margelo.nitro.core.HybridObject
9
10
 
10
11
 
11
12
  import com.margelo.nitro.dev.litert.litertlm.LiteRTLMOnLoad
12
13
 
13
14
  class LiteRTLMPackage : TurboReactPackage() {
15
+ companion object {
16
+ private const val TAG = "LiteRTLMPackage"
17
+
18
+ private fun isSupportedPrimaryAbi(): Boolean {
19
+ val primaryAbi = Build.SUPPORTED_64_BIT_ABIS.firstOrNull() ?: return false
20
+ return primaryAbi == "arm64-v8a"
21
+ }
22
+ }
14
23
  init {
15
- LiteRTLMOnLoad.initializeNative()
24
+ if (!isSupportedPrimaryAbi()) {
25
+ Log.w(TAG, "Skipping LiteRTLM native init on unsupported primary ABI: ${Build.SUPPORTED_64_BIT_ABIS.firstOrNull()}")
26
+ } else {
27
+ try {
28
+ LiteRTLMOnLoad.initializeNative()
29
+ } catch (e: UnsatisfiedLinkError) {
30
+ Log.e(TAG, "LiteRTLM native init failed; disabling LiteRTLM for this process.", e)
31
+ }
32
+ }
16
33
  }
17
34
 
18
35
 
@@ -0,0 +1,46 @@
1
+ package com.margelo.nitro.core
2
+
3
+ import androidx.annotation.Keep
4
+ import com.facebook.proguard.annotations.DoNotStrip
5
+
6
+ @Keep
7
+ @DoNotStrip
8
+ class Promise<T> {
9
+ companion object {
10
+ @JvmStatic
11
+ fun <T> parallel(block: () -> T): Promise<T> {
12
+ val promise = Promise<T>()
13
+ try {
14
+ val result = block()
15
+ promise.resolve(result)
16
+ } catch (e: Throwable) {
17
+ promise.reject(e)
18
+ }
19
+ return promise
20
+ }
21
+ }
22
+
23
+ var result: T? = null
24
+ private set
25
+ var error: Throwable? = null
26
+ private set
27
+ var isCompleted = false
28
+ private set
29
+ private val callbacks = mutableListOf<(T?, Throwable?) -> Unit>()
30
+
31
+ fun resolve(value: T) {
32
+ synchronized(this) {
33
+ result = value
34
+ isCompleted = true
35
+ callbacks.forEach { it(value, null) }
36
+ }
37
+ }
38
+
39
+ fun reject(exception: Throwable) {
40
+ synchronized(this) {
41
+ error = exception
42
+ isCompleted = true
43
+ callbacks.forEach { it(null, exception) }
44
+ }
45
+ }
46
+ }