react-native-litert-lm 0.3.6 → 0.3.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/README.md CHANGED
@@ -6,7 +6,7 @@ High-performance on-device LLM inference for React Native, powered by [LiteRT-LM
6
6
 
7
7
  - 🚀 **Native Performance** — Kotlin (Android) / C++ (iOS) via Nitro Modules JSI bindings
8
8
  - 🧠 **Gemma 4 Ready** — First-class support for Gemma 4 E2B/E4B multimodal models (text + vision + audio)
9
- - ⚡ **GPU Acceleration** — GPU delegate (Android), Metal/MPS (iOS)
9
+ - ⚡ **GPU Acceleration** — Metal (iOS), OpenCL GPU delegate (Android, Pixel devices)
10
10
  - 🔄 **Streaming Support** — Token-by-token generation callbacks
11
11
  - 📱 **Cross-Platform** — Android API 26+ / iOS 15.0+
12
12
  - 🖼️ **Multimodal** — Image and audio input support
@@ -15,6 +15,12 @@ High-performance on-device LLM inference for React Native, powered by [LiteRT-LM
15
15
  - 🧮 **Zero-Copy Buffers** — Memory snapshots stored in native ArrayBuffers via Nitro Modules
16
16
  - 📥 **Automatic Model Download** — Downloads models from URL with progress tracking and local caching
17
17
 
18
+ ## Demo
19
+
20
+ > Gemma 4 E2B running on-device on a Samsung Galaxy S22 (Snapdragon 8 Gen 1, 4 GB RAM) — CPU backend, streaming inference.
21
+
22
+ <video src="https://github.com/user-attachments/assets/1da527ce-0432-4f8b-8899-474f81b2feea" width="300" controls></video>
23
+
18
24
  ## Installation
19
25
 
20
26
  ```bash
@@ -94,35 +100,38 @@ The `example/` directory contains a fully functional test app with a dark-themed
94
100
 
95
101
  ## Model Management
96
102
 
97
- LiteRT-LM models (like Gemma 4) are large files (2–4 GB) and cannot be bundled into your app binary. They are downloaded at runtime.
103
+ LiteRT-LM models (like Gemma 4) are large files (1–4 GB) and cannot be bundled into your app binary. They are downloaded at runtime.
98
104
 
99
105
  ### Automatic Downloading
100
106
 
101
- The library handles downloading automatically when you pass a URL to `loadModel` or `useModel`. Downloads include:
107
+ Pass an HTTPS URL to `useModel()` or `loadModel()` the library handles the rest:
102
108
 
103
109
  - **Progress tracking** — real-time download percentage via callbacks
104
110
  - **Local caching** — downloaded models are cached and reused across app launches
105
- - **Android**: app-local temp directory
111
+ - **Android**: `files/models/` (app-private)
106
112
  - **iOS**: `Library/Caches/litert_models/` (survives app relaunch; reclaimable by iOS under storage pressure)
107
113
  - **HTTPS enforcement** — only secure URLs are accepted
108
114
 
109
- ### Manual Downloading (Optional)
115
+ ### Manual Downloading
110
116
 
111
- If you prefer to manage downloads yourself (e.g., using `expo-file-system`), download the `.litertlm` file to a local path and pass that path to the library:
117
+ If you need custom control over downloads (e.g., authentication headers for private model hosting, resumable downloads, or custom caching), use your preferred HTTP client and pass the local file path:
112
118
 
113
119
  ```typescript
114
- import * as FileSystem from "expo-file-system";
115
- import { GEMMA_4_E2B_IT } from "react-native-litert-lm";
120
+ import { fetch } from "expo/fetch";
121
+ import { File, Paths } from "expo-file-system";
122
+ import { useModel } from "react-native-litert-lm";
116
123
 
117
- const localPath = `${FileSystem.documentDirectory}gemma-4-E2B-it.litertlm`;
124
+ const MODEL_URL = "https://example.com/private-model.litertlm";
118
125
 
119
- async function downloadModel() {
120
- const info = await FileSystem.getInfoAsync(localPath);
121
- if (info.exists) return localPath;
126
+ // Download with custom headers using expo/fetch
127
+ const response = await fetch(MODEL_URL, {
128
+ headers: { Authorization: `Bearer ${token}` },
129
+ });
130
+ const modelFile = new File(Paths.cache, "my-model.litertlm");
131
+ modelFile.write(await response.bytes());
122
132
 
123
- await FileSystem.downloadAsync(GEMMA_4_E2B_IT, localPath);
124
- return localPath;
125
- }
133
+ // Pass the local path — no download occurs
134
+ const { model, isReady } = useModel(modelFile.uri, { backend: "cpu" });
126
135
  ```
127
136
 
128
137
  ## Usage
@@ -307,19 +316,19 @@ const buffer = tracker.getNativeBuffer();
307
316
 
308
317
  ## Supported Models
309
318
 
310
- Download `.litertlm` models automatically using the exported URL constants, or manually from [HuggingFace](https://huggingface.co/litert-community):
319
+ All exported model URLs are **public no authentication required**. Pass them directly to `useModel()` or `loadModel()` for automatic downloading with progress tracking and local caching.
311
320
 
312
- | Constant | Model | Size | Min RAM | Auth Required |
313
- | :--------------------- | :------------------------------ | :------ | :------ | :------------- |
314
- | `GEMMA_4_E2B_IT` | Gemma 4 E2B (Multimodal, IT) | 2.58 GB | 4 GB+ | No |
315
- | `GEMMA_4_E4B_IT` | Gemma 4 E4B (Higher Quality) | 3.65 GB | 6 GB+ | No |
316
- | `GEMMA_3N_E2B_IT_INT4` | Gemma 3n E2B (Int4, Multimodal) | ~1.3 GB | 4 GB+ | ✅ HuggingFace |
321
+ | Constant | Model | Size | Min RAM | Source |
322
+ | :--------------------- | :------------------------------ | :------ | :------ | :---------- |
323
+ | `GEMMA_4_E2B_IT` | Gemma 4 E2B (Multimodal, IT) | 2.58 GB | 4 GB+ | HuggingFace |
324
+ | `GEMMA_4_E4B_IT` | Gemma 4 E4B (Higher Quality) | 3.65 GB | 6 GB+ | HuggingFace |
325
+ | `GEMMA_3N_E2B_IT_INT4` | Gemma 3n E2B (Int4, Multimodal) | ~1.3 GB | 4 GB+ | litert.dev |
317
326
 
318
- > **Recommended:** Use `GEMMA_4_E2B_IT` for most use cases. It's multimodal (text + vision + audio) and downloads directly from HuggingFace without requiring an account.
327
+ > **Recommended:** Use `GEMMA_4_E2B_IT` for most use cases multimodal (text + vision + audio) and the best quality-to-size ratio.
319
328
  >
320
- > **iOS Note:** Models larger than ~2 GB (like Gemma 4) require the `com.apple.developer.kernel.extended-virtual-addressing` entitlement. See [iOS Entitlements](#ios-entitlements) below.
329
+ > **iOS Note:** Models larger than ~2 GB require the `com.apple.developer.kernel.extended-virtual-addressing` entitlement. See [iOS Entitlements](#ios-entitlements) below. Gemma 3n E2B (~1.3 GB) works without it.
321
330
 
322
- **Other compatible models** (download manually from HuggingFace):
331
+ **Other compatible models** (download `.litertlm` files manually from [HuggingFace](https://huggingface.co/litert-community)):
323
332
 
324
333
  | Model | Size | Min RAM | Notes |
325
334
  | ------------- | ------- | ------- | --------------------- |
@@ -352,13 +361,15 @@ Loads a model from a local path or HTTPS URL.
352
361
 
353
362
  #### Backend Options
354
363
 
355
- | Backend | Engine | Speed | Notes |
356
- | ------- | ------------------- | ------- | ---------------------------------------------- |
357
- | `'cpu'` | CPU inference | Slowest | Always available, lower RAM requirement |
358
- | `'gpu'` | GPU / Metal | Fast | Recommended default |
359
- | `'npu'` | NPU / Neural Engine | Fastest | Requires supported hardware; falls back to GPU |
364
+ | Backend | Engine | Speed | Notes |
365
+ | ------- | ------------------------------ | ------- | ---------------------------------------------------------------------------------- |
366
+ | `'cpu'` | CPU inference | Slowest | Always available on all devices |
367
+ | `'gpu'` | Metal (iOS) / OpenCL (Android) | Fast | iOS: always available. Android: requires OpenCL (Pixel only, not Samsung/Qualcomm) |
368
+ | `'npu'` | NPU / Neural Engine | Fastest | Requires supported hardware; experimental |
360
369
 
361
- > **iOS**: `'cpu'` is the recommended default backend. `'gpu'` (Metal/MPS) is also supported. The engine automatically tries multiple backend combinations if the primary one fails.
370
+ > **iOS**: Both `'cpu'` and `'gpu'` (Metal) are supported. The engine automatically tries fallback backend combinations if the primary one fails.
371
+ >
372
+ > **Android GPU**: The GPU backend requires OpenCL, which is **not available on most Samsung and Qualcomm devices**. Use `checkBackendSupport('gpu')` to check before loading. The engine will throw a clear error if GPU is unsupported.
362
373
 
363
374
  ### `sendMessage(message): Promise<string>`
364
375
 
@@ -383,14 +394,16 @@ Returns performance metrics from the last inference call.
383
394
  ```typescript
384
395
  interface GenerationStats {
385
396
  tokensPerSecond: number;
386
- totalTime: number; // seconds
387
- timeToFirstToken: number; // seconds
397
+ totalTime: number; // milliseconds
398
+ timeToFirstToken: number; // milliseconds
388
399
  promptTokens: number;
389
400
  completionTokens: number;
390
- prefillSpeed: number; // tokens/sec
401
+ totalTokens: number;
391
402
  }
392
403
  ```
393
404
 
405
+ > **Note**: Stats are available for both sync (`sendMessage`) and streaming (`sendMessageAsync`) on both platforms. iOS uses real benchmark data from the C API; Android uses heuristic token counts (~4 chars/token) with precise timing.
406
+
394
407
  ### `getMemoryUsage(): MemoryUsage`
395
408
 
396
409
  Returns real OS-level memory usage.
@@ -432,10 +445,21 @@ import {
432
445
  applyLlamaTemplate,
433
446
  } from "react-native-litert-lm";
434
447
 
435
- // Check if a backend is supported
436
- const warning = checkBackendSupport("npu"); // string | undefined
448
+ // Check if GPU is supported on this device
449
+ const gpuWarning = checkBackendSupport("gpu");
450
+ if (gpuWarning) {
451
+ console.warn(gpuWarning);
452
+ // "GPU backend requires OpenCL support, which is unavailable on most Samsung and Qualcomm devices."
453
+ }
454
+
455
+ // Check NPU support
456
+ const npuWarning = checkBackendSupport("npu"); // string | undefined
457
+
458
+ // Check multimodal support
437
459
  const mmError = checkMultimodalSupport(); // string | undefined
438
- const backend = getRecommendedBackend(); // 'gpu' | 'cpu'
460
+
461
+ // Get recommended backend
462
+ const backend = getRecommendedBackend(); // 'cpu'
439
463
 
440
464
  // Manual prompt formatting (advanced)
441
465
  const prompt = applyGemmaTemplate(
@@ -456,10 +480,10 @@ const prompt = applyGemmaTemplate(
456
480
 
457
481
  ## Platform Support
458
482
 
459
- | Platform | Status | Architecture | Backends |
460
- | -------- | -------- | ------------ | ---------------- |
461
- | Android | ✅ Ready | arm64-v8a | CPU, GPU, NPU |
462
- | iOS | ✅ Ready | arm64 | CPU, GPU (Metal) |
483
+ | Platform | Status | Architecture | Backends |
484
+ | -------- | -------- | ------------ | ------------------------------------------------- |
485
+ | Android | ✅ Ready | arm64-v8a | CPU (all devices), GPU (OpenCL devices only), NPU |
486
+ | iOS | ✅ Ready | arm64 | CPU, GPU (Metal — always available) |
463
487
 
464
488
  ### iOS Feature Matrix
465
489
 
@@ -552,13 +576,20 @@ Additionally, `PromptTemplate` is patched at build time to use a simplified C++
552
576
  ├──────────────────────┬──────────────────────────┤
553
577
  │ Android (Kotlin) │ iOS (C++) │
554
578
  │ HybridLiteRTLM.kt │ HybridLiteRTLM.cpp │
555
- │ litertlm-android │ LiteRTLM C API
579
+ │ litertlm-android │ LiteRT-LM C API
556
580
  │ AAR (GPU delegate) │ XCFramework (Metal) │
557
581
  └──────────────────────┴──────────────────────────┘
558
582
  ```
559
583
 
560
- - **Android**: Kotlin (`HybridLiteRTLM.kt`) interfacing with the `litertlm-android` AAR.
561
- - **iOS**: C++ (`HybridLiteRTLM.cpp`) interfacing with the LiteRT-LM C API via a prebuilt `LiteRTLM.xcframework`. All engine operations (load, inference, streaming) run on dedicated `pthread` threads with 8 MB stack to accommodate XNNPack's stack requirements. Platform-specific code (model downloading, file management) is in Objective-C++ (`ios/IOSDownloadHelper.mm`).
584
+ - **Android**: Kotlin (`HybridLiteRTLM.kt`) interfacing with the `litertlm-android` AAR via the **Kotlin SDK**. The SDK handles control token stripping and turn management automatically. Engine validation probes for OpenCL availability before GPU initialization. `ConversationConfig` with `SamplerConfig` is passed for all conversations (matching the Gallery app pattern).
585
+ - **iOS**: C++ (`HybridLiteRTLM.cpp`) interfacing with the LiteRT-LM **C API** via a prebuilt `LiteRTLM.xcframework`. Unlike the Kotlin SDK, the C API emits raw tokens including control sequences (`<end_of_turn>`, `<start_of_turn>`) and echoed user messages. The C++ layer implements a robust sanitization pipeline:
586
+ - **Accumulation-and-diff** — buffers the full response and emits only new deltas
587
+ - **`stripControlTokens()`** — removes all control sequences from the accumulated buffer
588
+ - **`safeEmitLength()`** — look-ahead buffering that withholds partial control tokens (e.g., `<end_of_tur`) from emission until the full token is received or the stream terminates
589
+ - **Echo mitigation** — strips echoed user messages from the raw stream
590
+ - **Final flush** — mandatory clean-and-flush step on stream termination
591
+
592
+ Platform-specific code (model downloading, file management) is in Objective-C++ (`ios/IOSDownloadHelper.mm`).
562
593
 
563
594
  > **For contributors**: Changes to `cpp/HybridLiteRTLM.cpp` do not affect Android. Feature changes must be applied to both the Kotlin and C++ implementations.
564
595
 
@@ -17,6 +17,7 @@ import com.google.ai.edge.litertlm.Engine
17
17
  import com.google.ai.edge.litertlm.Conversation
18
18
  import com.google.ai.edge.litertlm.EngineConfig
19
19
  import com.google.ai.edge.litertlm.ConversationConfig
20
+ import com.google.ai.edge.litertlm.SamplerConfig
20
21
  import com.margelo.nitro.dev.litert.litertlm.Backend
21
22
  import com.margelo.nitro.dev.litert.litertlm.GenerationStats
22
23
  import com.margelo.nitro.dev.litert.litertlm.HybridLiteRTLMSpec
@@ -25,6 +26,11 @@ import com.margelo.nitro.dev.litert.litertlm.Message
25
26
  import com.margelo.nitro.dev.litert.litertlm.Role
26
27
  import com.margelo.nitro.core.Promise
27
28
  import com.google.ai.edge.litertlm.Content
29
+ import com.google.ai.edge.litertlm.Contents
30
+ import java.util.concurrent.CountDownLatch
31
+ import java.util.concurrent.TimeUnit
32
+ import java.util.concurrent.atomic.AtomicBoolean
33
+ import java.util.concurrent.atomic.AtomicReference
28
34
 
29
35
 
30
36
  // Alias to avoid confusion with our generated Message type
@@ -42,13 +48,26 @@ internal class StreamingCallbackListener(
42
48
  private val onToken: (String, Boolean) -> Unit,
43
49
  private val responseBuilder: StringBuilder,
44
50
  private val history: MutableList<Message>,
51
+ private val userMessage: String,
52
+ private val onStatsReady: (GenerationStats) -> Unit,
45
53
  ) : com.google.ai.edge.litertlm.MessageCallback {
46
54
 
47
- override fun onMessage(responseMsg: com.google.ai.edge.litertlm.Message) {
48
- val chunk = responseMsg.contents.contents
55
+ private val startTime = System.nanoTime()
56
+ private var firstTokenTime = 0L
57
+ private var tokenCount = 0
58
+
59
+ override fun onMessage(message: com.google.ai.edge.litertlm.Message) {
60
+ val chunk = message.contents.contents
49
61
  .filterIsInstance<com.google.ai.edge.litertlm.Content.Text>()
50
62
  .joinToString("") { it.text }
51
63
 
64
+ if (firstTokenTime == 0L && chunk.isNotEmpty()) {
65
+ firstTokenTime = System.nanoTime()
66
+ }
67
+ if (chunk.isNotEmpty()) {
68
+ tokenCount++
69
+ }
70
+
52
71
  onToken(chunk, false)
53
72
 
54
73
  if (chunk.isNotEmpty()) {
@@ -60,12 +79,27 @@ internal class StreamingCallbackListener(
60
79
  onToken("", true)
61
80
  val fullResponse = responseBuilder.toString()
62
81
  history.add(Message(Role.MODEL, fullResponse))
63
- Log.d("StreamingCallbackListener", "Streaming done. Length: ${fullResponse.length}")
82
+
83
+ // Compute stats using heuristic token counts (~4 chars/token)
84
+ val elapsedMs = (System.nanoTime() - startTime) / 1_000_000.0
85
+ val ttftMs = if (firstTokenTime > 0) (firstTokenTime - startTime) / 1_000_000.0 else 0.0
86
+ val promptTokens = userMessage.length / 4.0
87
+ val completionTokens = fullResponse.length / 4.0
88
+ onStatsReady(GenerationStats(
89
+ promptTokens = promptTokens,
90
+ completionTokens = completionTokens,
91
+ totalTokens = promptTokens + completionTokens,
92
+ timeToFirstToken = ttftMs,
93
+ totalTime = elapsedMs,
94
+ tokensPerSecond = if (elapsedMs > 0) completionTokens / (elapsedMs / 1000.0) else 0.0
95
+ ))
96
+
97
+ Log.d("StreamingCallbackListener", "Streaming done. Length: ${fullResponse.length}, TTFT: ${ttftMs.toLong()}ms, Total: ${elapsedMs.toLong()}ms")
64
98
  }
65
99
 
66
- override fun onError(t: Throwable) {
67
- Log.e("StreamingCallbackListener", "Async generation failed", t)
68
- onToken("Error: ${t.message}", true)
100
+ override fun onError(throwable: Throwable) {
101
+ Log.e("StreamingCallbackListener", "Async generation failed", throwable)
102
+ onToken("Error: ${throwable.message}", true)
69
103
  }
70
104
  }
71
105
 
@@ -80,6 +114,10 @@ class HybridLiteRTLM : HybridLiteRTLMSpec() {
80
114
  companion object {
81
115
  private const val TAG = "HybridLiteRTLM"
82
116
  private val initLock = Any()
117
+
118
+ /** Cached result of OpenCL availability probe (null = not yet checked). */
119
+ @Volatile
120
+ private var openCLAvailable: Boolean? = null
83
121
 
84
122
  /**
85
123
  * Initialize the native library.
@@ -161,6 +199,35 @@ class HybridLiteRTLM : HybridLiteRTLMSpec() {
161
199
  }
162
200
 
163
201
  try {
202
+ // Early GPU hardware check: probe for OpenCL library before
203
+ // spending time on engine creation. LiteRT-LM's GPU delegate
204
+ // requires OpenCL, which is absent on most Samsung/Qualcomm devices.
205
+ if (backend == Backend.GPU) {
206
+ val hasOpenCL = openCLAvailable ?: run {
207
+ val result = try {
208
+ System.loadLibrary("OpenCL")
209
+ true
210
+ } catch (_: UnsatisfiedLinkError) {
211
+ try {
212
+ // Some devices have it at a non-standard path
213
+ System.load("/system/vendor/lib64/libOpenCL.so")
214
+ true
215
+ } catch (_: UnsatisfiedLinkError) {
216
+ false
217
+ }
218
+ }
219
+ openCLAvailable = result
220
+ result
221
+ }
222
+ if (!hasOpenCL) {
223
+ throw RuntimeException(
224
+ "GPU backend is not supported on this device (OpenCL library not found). " +
225
+ "Please use CPU backend instead."
226
+ )
227
+ }
228
+ Log.i(TAG, "OpenCL library found — GPU backend is available")
229
+ }
230
+
164
231
  // Map our Backend enum to LiteRT-LM Backend sealed class
165
232
  val lmBackend = when (backend) {
166
233
  Backend.GPU -> com.google.ai.edge.litertlm.Backend.GPU()
@@ -215,9 +282,15 @@ class HybridLiteRTLM : HybridLiteRTLMSpec() {
215
282
  // Create Conversation
216
283
  createNewConversation()
217
284
  Log.i(TAG, "Conversation created successfully")
285
+
286
+ // Validate the engine actually works with a quick test inference.
287
+ // GPU backend can initialize without error but silently fail to produce tokens.
288
+ validateEngine()
218
289
 
219
290
  } catch (e: Exception) {
220
291
  Log.e(TAG, "Failed to load model: ${e.message}", e)
292
+ // Clean up partial state so isReady() returns false
293
+ cleanupInternal()
221
294
  throw RuntimeException("Failed to load model: ${e.message}", e)
222
295
  }
223
296
  }
@@ -241,7 +314,7 @@ class HybridLiteRTLM : HybridLiteRTLMSpec() {
241
314
  Log.i(TAG, "sendMessage (Promise): $message")
242
315
 
243
316
  // Blocking inference (safe here because we are in Promise.parallel worker thread)
244
- val userMsg = LiteRTMessage.of(text = message)
317
+ val userMsg = LiteRTMessage.user(message)
245
318
  val startTime = System.nanoTime()
246
319
  val responseMsg = conversation!!.sendMessage(message = userMsg)
247
320
  val elapsedMs = (System.nanoTime() - startTime) / 1_000_000.0
@@ -292,10 +365,12 @@ class HybridLiteRTLM : HybridLiteRTLMSpec() {
292
365
  onToken = onToken,
293
366
  responseBuilder = fullResponseBuilder,
294
367
  history = history,
368
+ userMessage = message,
369
+ onStatsReady = { stats -> lastStats = stats },
295
370
  )
296
371
 
297
372
  try {
298
- val userMsg = LiteRTMessage.of(text = message)
373
+ val userMsg = LiteRTMessage.user(message)
299
374
  conversation!!.sendMessageAsync(message = userMsg, callback = listener)
300
375
  } catch (e: Exception) {
301
376
  Log.e(TAG, "Failed to initiate async generation", e)
@@ -359,7 +434,7 @@ class HybridLiteRTLM : HybridLiteRTLMSpec() {
359
434
  // Use factory method Message.of passing a list of Content
360
435
  val textContent = Content.Text(message)
361
436
 
362
- val userMsg = LiteRTMessage.of(textContent, Content.ImageFile(processedImagePath))
437
+ val userMsg = LiteRTMessage.user(Contents.of(textContent, Content.ImageFile(processedImagePath)))
363
438
 
364
439
  // Add to history
365
440
  history.add(Message(Role.USER, "$message [Image]"))
@@ -501,10 +576,10 @@ class HybridLiteRTLM : HybridLiteRTLMSpec() {
501
576
 
502
577
  // Load audio
503
578
 
504
- val userMsg = LiteRTMessage.of(
579
+ val userMsg = LiteRTMessage.user(Contents.of(
505
580
  Content.Text(message),
506
581
  Content.AudioFile(audioPath)
507
- )
582
+ ))
508
583
 
509
584
  history.add(Message(Role.USER, "$message [Audio]"))
510
585
 
@@ -641,7 +716,16 @@ class HybridLiteRTLM : HybridLiteRTLMSpec() {
641
716
  }
642
717
  conversation = null
643
718
  }
644
- conversation = engine!!.createConversation()
719
+ // Create conversation with explicit SamplerConfig (required by Gallery pattern).
720
+ // GPU backend may fail silently without proper sampler params.
721
+ val convConfig = ConversationConfig(
722
+ samplerConfig = SamplerConfig(
723
+ topK = topK,
724
+ topP = topP,
725
+ temperature = temperature,
726
+ )
727
+ )
728
+ conversation = engine!!.createConversation(convConfig)
645
729
  // Apply system prompt/instruction if set
646
730
  systemPrompt?.let { prompt ->
647
731
  if (prompt.isNotEmpty()) {
@@ -649,7 +733,7 @@ class HybridLiteRTLM : HybridLiteRTLMSpec() {
649
733
  // Send system instruction as the first turn to prime the conversation.
650
734
  // LiteRT-LM's Conversation API handles chat template formatting,
651
735
  // including Gemma's <start_of_turn>system block.
652
- val systemMsg = LiteRTMessage.of(Content.Text(prompt))
736
+ val systemMsg = LiteRTMessage.system(prompt)
653
737
  conversation!!.sendMessage(message = systemMsg)
654
738
  Log.i(TAG, "System prompt applied (${prompt.length} chars)")
655
739
  } catch (e: Exception) {
@@ -659,5 +743,77 @@ class HybridLiteRTLM : HybridLiteRTLMSpec() {
659
743
  }
660
744
  }
661
745
 
746
+ /**
747
+ * Validate that the engine can actually produce inference output.
748
+ *
749
+ * Some GPU backends initialize without error but silently hang during inference.
750
+ * This sends a minimal test prompt ("Hi") and waits up to 30s for any token.
751
+ * If no token arrives, we throw so the model does NOT appear as loaded.
752
+ */
753
+ private fun validateEngine() {
754
+ val backendName = when (backend) {
755
+ Backend.GPU -> "GPU"
756
+ Backend.NPU -> "NPU"
757
+ else -> "CPU"
758
+ }
759
+ Log.i(TAG, "Validating $backendName backend with test inference...")
760
+
761
+ val latch = CountDownLatch(1)
762
+ val gotToken = AtomicBoolean(false)
763
+ val errorRef = AtomicReference<String?>(null)
764
+
765
+ // Use the existing conversation for validation (single-session constraint).
766
+ val validationConv = conversation
767
+ ?: throw RuntimeException("$backendName backend: no conversation available for validation")
768
+
769
+ try {
770
+ val testMsg = LiteRTMessage.user("Hi")
771
+ validationConv.sendMessageAsync(
772
+ message = testMsg,
773
+ callback = object : com.google.ai.edge.litertlm.MessageCallback {
774
+ override fun onMessage(msg: com.google.ai.edge.litertlm.Message) {
775
+ gotToken.set(true)
776
+ latch.countDown()
777
+ }
778
+ override fun onDone() {
779
+ latch.countDown()
780
+ }
781
+ override fun onError(t: Throwable) {
782
+ errorRef.set(t.message)
783
+ latch.countDown()
784
+ }
785
+ }
786
+ )
787
+ } catch (e: Exception) {
788
+ throw RuntimeException(
789
+ "$backendName backend failed to run inference: ${e.message}. " +
790
+ "This device may not support the $backendName backend. Please try CPU.",
791
+ e
792
+ )
793
+ }
794
+
795
+ // Wait up to 30s for any response
796
+ val completed = latch.await(30, TimeUnit.SECONDS)
797
+
798
+ val error = errorRef.get()
799
+ if (error != null) {
800
+ throw RuntimeException(
801
+ "$backendName backend inference error: $error. " +
802
+ "This device may not support the $backendName backend. Please try CPU."
803
+ )
804
+ }
805
+ if (!completed || !gotToken.get()) {
806
+ throw RuntimeException(
807
+ "$backendName backend produced no response within 30 seconds. " +
808
+ "This device may not support the $backendName backend. Please try CPU."
809
+ )
810
+ }
811
+
812
+ Log.i(TAG, "$backendName backend validated successfully")
813
+
814
+ // Re-create the real conversation (validation consumed one turn)
815
+ createNewConversation()
816
+ }
817
+
662
818
 
663
819
  }
@@ -110,31 +110,67 @@ std::string HybridLiteRTLM::buildAudioMessageJson(const std::string& text, const
110
110
  }
111
111
 
112
112
  /**
113
- * Strip Gemma / LiteRT-LM control tokens from model output.
114
- * The iOS C API returns raw model text including stop/turn markers
115
- * that the Android Kotlin SDK strips automatically.
113
+ * Gemma / LiteRT-LM control tokens that the iOS C API includes in raw output.
114
+ * The Android Kotlin SDK strips these automatically.
115
+ */
116
+ static const char* kControlTokens[] = {
117
+ "<end_of_turn>",
118
+ "<start_of_turn>model",
119
+ "<start_of_turn>user",
120
+ "<start_of_turn>",
121
+ "<eos>",
122
+ };
123
+
124
+ /**
125
+ * Strip control tokens from model output, preserving whitespace.
126
+ * Streaming tokens like " the", " is" have meaningful leading spaces
127
+ * that must not be trimmed.
116
128
  */
117
129
  static std::string stripControlTokens(const std::string& text) {
118
- static const char* tokens[] = {
119
- "<end_of_turn>",
120
- "<start_of_turn>model",
121
- "<start_of_turn>user",
122
- "<start_of_turn>",
123
- "<eos>",
124
- };
125
130
  std::string result = text;
126
- for (auto* tok : tokens) {
131
+ for (auto* tok : kControlTokens) {
127
132
  std::string t(tok);
128
133
  size_t pos;
129
134
  while ((pos = result.find(t)) != std::string::npos) {
130
135
  result.erase(pos, t.length());
131
136
  }
132
137
  }
133
- // Trim leading/trailing whitespace
134
- size_t start = result.find_first_not_of(" \t\n\r");
138
+ return result;
139
+ }
140
+
141
+ /**
142
+ * Determine how many characters from the start of `text` are safe to emit.
143
+ * If the tail of `text` could be the beginning of a control token (split
144
+ * across chunk boundaries), those characters are withheld until the next
145
+ * chunk confirms whether it's a real token or normal content.
146
+ */
147
+ static size_t safeEmitLength(const std::string& text) {
148
+ // Find the last '<' — it could be the start of a partial control token
149
+ size_t lastAngle = text.rfind('<');
150
+ if (lastAngle == std::string::npos) {
151
+ return text.length(); // No '<' found, safe to emit all
152
+ }
153
+
154
+ std::string suffix = text.substr(lastAngle);
155
+ // Check if this suffix is a prefix of any control token
156
+ for (auto* tok : kControlTokens) {
157
+ std::string t(tok);
158
+ if (suffix.length() < t.length() && t.compare(0, suffix.length(), suffix) == 0) {
159
+ // This suffix could be the start of a control token — hold it back
160
+ return lastAngle;
161
+ }
162
+ }
163
+
164
+ // The '<' doesn't match any control token prefix, safe to emit all
165
+ return text.length();
166
+ }
167
+
168
+ /** Trim leading/trailing whitespace from a complete response. */
169
+ static std::string trimWhitespace(const std::string& text) {
170
+ size_t start = text.find_first_not_of(" \t\n\r");
135
171
  if (start == std::string::npos) return "";
136
- size_t end = result.find_last_not_of(" \t\n\r");
137
- return result.substr(start, end - start + 1);
172
+ size_t end = text.find_last_not_of(" \t\n\r");
173
+ return text.substr(start, end - start + 1);
138
174
  }
139
175
 
140
176
  std::string HybridLiteRTLM::extractTextFromResponse(const std::string& jsonResponse) {
@@ -427,7 +463,7 @@ std::string HybridLiteRTLM::sendMessageInternal(const std::string& message) {
427
463
 
428
464
  const char* responseStr = litert_lm_json_response_get_string(response);
429
465
  if (responseStr) {
430
- result = extractTextFromResponse(std::string(responseStr));
466
+ result = trimWhitespace(extractTextFromResponse(std::string(responseStr)));
431
467
  }
432
468
  litert_lm_json_response_delete(response);
433
469
 
@@ -485,6 +521,26 @@ void HybridLiteRTLM::streamCallbackFn(void* callback_data, const char* chunk,
485
521
  ctx->lastStats->tokensPerSecond = (ctx->tokenCount / durationMs) * 1000.0;
486
522
  }
487
523
 
524
+ // Final flush: do one last clean of the full accumulated response
525
+ // to emit any text that was withheld by safeEmitLength.
526
+ std::string cleaned = stripControlTokens(ctx->rawResponse);
527
+ size_t start = cleaned.find_first_not_of(" \t\n\r");
528
+ if (start != std::string::npos) {
529
+ cleaned = cleaned.substr(start);
530
+ // Strip echoed user message
531
+ if (!ctx->userMessage.empty() && cleaned.find(ctx->userMessage) == 0) {
532
+ cleaned = cleaned.substr(ctx->userMessage.length());
533
+ size_t nextStart = cleaned.find_first_not_of(" \t\n\r");
534
+ cleaned = (nextStart != std::string::npos) ? cleaned.substr(nextStart) : "";
535
+ }
536
+ // Emit any remaining text not yet sent
537
+ if (cleaned.length() > ctx->lastEmittedLength) {
538
+ std::string remaining = cleaned.substr(ctx->lastEmittedLength);
539
+ ctx->onToken(remaining, false);
540
+ }
541
+ ctx->fullResponse = cleaned;
542
+ }
543
+
488
544
  // Update history (thread-safe)
489
545
  {
490
546
  std::lock_guard<std::mutex> lock(*ctx->historyMutex);
@@ -499,12 +555,55 @@ void HybridLiteRTLM::streamCallbackFn(void* callback_data, const char* chunk,
499
555
 
500
556
  if (chunk) {
501
557
  std::string token(chunk);
502
- // Filter out Gemma control tokens from streamed chunks
503
- std::string cleaned = stripControlTokens(token);
504
- ctx->fullResponse += cleaned;
505
- ctx->tokenCount++;
506
- if (!cleaned.empty()) {
507
- ctx->onToken(cleaned, false);
558
+
559
+ // The C API may return JSON-wrapped responses (e.g.
560
+ // {"role":"model","content":[{"type":"text","text":"Hi"}]})
561
+ // instead of raw text tokens. Detect and extract text content.
562
+ std::string raw;
563
+ if (token.size() > 2 && token[0] == '{' && token.find("\"role\"") != std::string::npos) {
564
+ raw = HybridLiteRTLM::extractTextFromResponse(token);
565
+ } else {
566
+ raw = token;
567
+ }
568
+
569
+ // Accumulate raw text, then strip control tokens from the FULL buffer.
570
+ // This correctly handles tokens split across chunk boundaries (e.g.
571
+ // chunk1="<end_of_tu" chunk2="rn>Hello").
572
+ ctx->rawResponse += raw;
573
+ std::string cleaned = stripControlTokens(ctx->rawResponse);
574
+
575
+ // Trim leading whitespace from the overall response
576
+ size_t start = cleaned.find_first_not_of(" \t\n\r");
577
+ if (start == std::string::npos) {
578
+ // Still only whitespace/control tokens — nothing to emit yet
579
+ return;
580
+ }
581
+ cleaned = cleaned.substr(start);
582
+
583
+ // The C API may echo back the user's message before the model response.
584
+ // Strip the echoed user message prefix if present.
585
+ if (!ctx->userMessage.empty()) {
586
+ size_t userPos = cleaned.find(ctx->userMessage);
587
+ if (userPos == 0) {
588
+ cleaned = cleaned.substr(ctx->userMessage.length());
589
+ // Trim any whitespace after the stripped user message
590
+ size_t nextStart = cleaned.find_first_not_of(" \t\n\r");
591
+ if (nextStart == std::string::npos) {
592
+ return; // Only user message so far, nothing to emit
593
+ }
594
+ cleaned = cleaned.substr(nextStart);
595
+ }
596
+ }
597
+
598
+ // Only emit text that is "safe" — withhold any trailing characters
599
+ // that could be the start of a control token split across chunks.
600
+ size_t safe = safeEmitLength(cleaned);
601
+ if (safe > ctx->lastEmittedLength) {
602
+ std::string newText = cleaned.substr(ctx->lastEmittedLength, safe - ctx->lastEmittedLength);
603
+ ctx->fullResponse = cleaned.substr(0, safe);
604
+ ctx->lastEmittedLength = safe;
605
+ ctx->tokenCount++;
606
+ ctx->onToken(newText, false);
508
607
  }
509
608
  }
510
609
  }
@@ -520,7 +619,9 @@ void HybridLiteRTLM::sendMessageAsync(
520
619
  // Capture shared state safely — use unique_ptr to prevent leaks
521
620
  auto ctxOwner = std::make_unique<StreamContext>();
522
621
  ctxOwner->onToken = std::move(onTokenCopy);
622
+ ctxOwner->rawResponse = "";
523
623
  ctxOwner->fullResponse = "";
624
+ ctxOwner->lastEmittedLength = 0;
524
625
  ctxOwner->history = &history_;
525
626
  ctxOwner->historyMutex = &mutex_;
526
627
  ctxOwner->userMessage = messageCopy;
@@ -602,7 +703,7 @@ std::string HybridLiteRTLM::sendMessageWithImageInternal(
602
703
 
603
704
  const char* responseStr = litert_lm_json_response_get_string(response);
604
705
  if (responseStr) {
605
- result = extractTextFromResponse(std::string(responseStr));
706
+ result = trimWhitespace(extractTextFromResponse(std::string(responseStr)));
606
707
  }
607
708
  litert_lm_json_response_delete(response);
608
709
  #else
@@ -662,7 +763,7 @@ std::string HybridLiteRTLM::sendMessageWithAudioInternal(
662
763
 
663
764
  const char* responseStr = litert_lm_json_response_get_string(response);
664
765
  if (responseStr) {
665
- result = extractTextFromResponse(std::string(responseStr));
766
+ result = trimWhitespace(extractTextFromResponse(std::string(responseStr)));
666
767
  }
667
768
  litert_lm_json_response_delete(response);
668
769
  #else
@@ -149,7 +149,9 @@ private:
149
149
  // Streaming callback context (must be a plain struct for C function pointer)
150
150
  struct StreamContext {
151
151
  std::function<void(const std::string&, bool)> onToken;
152
- std::string fullResponse;
152
+ std::string rawResponse; // Raw accumulated chunks (before stripping)
153
+ std::string fullResponse; // Clean accumulated text (after stripping)
154
+ size_t lastEmittedLength; // Length of fullResponse already emitted to JS
153
155
  std::vector<Message>* history;
154
156
  std::mutex* historyMutex;
155
157
  std::string userMessage;
@@ -1,34 +1,32 @@
1
1
  # LiteRT-LM Headers Fallback
2
2
 
3
- This directory is a fallback location for LiteRT-LM C++ headers when Prefab doesn't expose them correctly.
3
+ This directory contains the LiteRT-LM C API header (`litert_lm_engine.h`) used by the iOS C++ implementation.
4
4
 
5
5
  ## If Headers Are Missing
6
6
 
7
- If you get compilation errors like `litert/lm/engine.h: No such file or directory`, you need to manually copy LiteRT-LM headers here:
7
+ If you get compilation errors like `litert_lm_engine.h: No such file or directory`, you need to manually copy the LiteRT-LM C API header here:
8
8
 
9
9
  1. Clone LiteRT-LM repository:
10
10
 
11
11
  ```bash
12
12
  git clone https://github.com/google-ai-edge/LiteRT-LM.git /tmp/LiteRT-LM
13
+ cd /tmp/LiteRT-LM && git checkout v0.10.2
13
14
  ```
14
15
 
15
- 2. Copy the headers:
16
+ 2. Copy the header:
16
17
  ```bash
17
- cp -r /tmp/LiteRT-LM/runtime/include/litert ./
18
+ cp /tmp/LiteRT-LM/c/litert_lm_engine.h ./
18
19
  ```
19
20
 
20
21
  The expected directory structure after copying:
21
22
 
22
23
  ```
23
24
  cpp/include/
24
- └── litert/
25
- └── lm/
26
- ├── engine.h
27
- ├── conversation.h
28
- ├── types.h
29
- └── ...
25
+ ├── litert_lm_engine.h # LiteRT-LM C API header
26
+ ├── stb_image.h # Image loading for multimodal
27
+ └── README.md
30
28
  ```
31
29
 
32
30
  ## Note
33
31
 
34
- The ideal scenario is that the Maven package `litertlm-android:0.9.0-alpha01` exposes headers via Prefab, making this directory unnecessary. This is just a fallback.
32
+ On **Android**, headers are provided by the `litertlm-android` AAR via Prefab this directory is only needed for the **iOS** build which uses the raw C API via the prebuilt XCFramework.
package/lib/index.d.ts CHANGED
@@ -110,8 +110,8 @@ export declare function checkBackendSupport(backend: Backend): string | undefine
110
110
  */
111
111
  export declare function checkMultimodalSupport(): string | undefined;
112
112
  /**
113
- * Download URL for the Gemma 3n E2B IT INT4 model.
114
- * Note: Requires a HuggingFace account (gated model).
113
+ * Download URL for the Gemma 3n E2B IT INT4 model (~1.3 GB).
114
+ * Public hosted on litert.dev, no authentication required.
115
115
  */
116
116
  export declare const GEMMA_3N_E2B_IT_INT4 = "https://litert.dev/gemma-3n-E2B-it-int4.litertlm";
117
117
  /**
package/lib/index.js CHANGED
@@ -116,6 +116,15 @@ function getRecommendedBackend() {
116
116
  * ```
117
117
  */
118
118
  function checkBackendSupport(backend) {
119
+ if (backend === "gpu") {
120
+ if (react_native_1.Platform.OS === "android") {
121
+ // LiteRT-LM GPU delegate requires OpenCL, which is unavailable
122
+ // on most Samsung/Qualcomm devices. Only Pixel devices reliably expose it.
123
+ return "GPU backend requires OpenCL support, which is unavailable on most Samsung and Qualcomm devices.";
124
+ }
125
+ // iOS always supports GPU via Metal
126
+ return undefined;
127
+ }
119
128
  if (backend === "npu") {
120
129
  if (react_native_1.Platform.OS === "android") {
121
130
  return "NPU backend requires compatible hardware (Qualcomm Hexagon, MediaTek APU, etc.). Will fall back to GPU if unavailable.";
@@ -150,8 +159,8 @@ function checkMultimodalSupport() {
150
159
  return undefined;
151
160
  }
152
161
  /**
153
- * Download URL for the Gemma 3n E2B IT INT4 model.
154
- * Note: Requires a HuggingFace account (gated model).
162
+ * Download URL for the Gemma 3n E2B IT INT4 model (~1.3 GB).
163
+ * Public hosted on litert.dev, no authentication required.
155
164
  */
156
165
  exports.GEMMA_3N_E2B_IT_INT4 = "https://litert.dev/gemma-3n-E2B-it-int4.litertlm";
157
166
  /**
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "react-native-litert-lm",
3
- "version": "0.3.6",
3
+ "version": "0.3.7",
4
4
  "litertLm": {
5
5
  "version": "0.10.2",
6
6
  "androidMavenVersion": "0.10.2",
package/src/index.ts CHANGED
@@ -132,6 +132,16 @@ export function getRecommendedBackend(): Backend {
132
132
  * ```
133
133
  */
134
134
  export function checkBackendSupport(backend: Backend): string | undefined {
135
+ if (backend === "gpu") {
136
+ if (Platform.OS === "android") {
137
+ // LiteRT-LM GPU delegate requires OpenCL, which is unavailable
138
+ // on most Samsung/Qualcomm devices. Only Pixel devices reliably expose it.
139
+ return "GPU backend requires OpenCL support, which is unavailable on most Samsung and Qualcomm devices.";
140
+ }
141
+ // iOS always supports GPU via Metal
142
+ return undefined;
143
+ }
144
+
135
145
  if (backend === "npu") {
136
146
  if (Platform.OS === "android") {
137
147
  return "NPU backend requires compatible hardware (Qualcomm Hexagon, MediaTek APU, etc.). Will fall back to GPU if unavailable.";
@@ -169,8 +179,8 @@ export function checkMultimodalSupport(): string | undefined {
169
179
  }
170
180
 
171
181
  /**
172
- * Download URL for the Gemma 3n E2B IT INT4 model.
173
- * Note: Requires a HuggingFace account (gated model).
182
+ * Download URL for the Gemma 3n E2B IT INT4 model (~1.3 GB).
183
+ * Public hosted on litert.dev, no authentication required.
174
184
  */
175
185
  export const GEMMA_3N_E2B_IT_INT4 =
176
186
  "https://litert.dev/gemma-3n-E2B-it-int4.litertlm";