react-native-ai-core 0.1.0 → 0.2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,5 +1,6 @@
1
1
  package com.aicore
2
2
 
3
+ import android.content.Intent
3
4
  import android.os.Build
4
5
  import com.facebook.react.bridge.Arguments
5
6
  import com.facebook.react.bridge.Promise
@@ -11,6 +12,7 @@ import com.google.mediapipe.tasks.genai.llminference.LlmInferenceSession
11
12
  import com.google.mediapipe.tasks.genai.llminference.ProgressListener
12
13
  import com.google.mlkit.genai.common.DownloadStatus
13
14
  import com.google.mlkit.genai.common.FeatureStatus
15
+ import com.google.mlkit.genai.common.GenAiException
14
16
  import com.google.mlkit.genai.prompt.Generation
15
17
  import com.google.mlkit.genai.prompt.GenerativeModel
16
18
  import com.google.mlkit.genai.prompt.TextPart
@@ -19,9 +21,11 @@ import kotlinx.coroutines.CoroutineScope
19
21
  import kotlinx.coroutines.Dispatchers
20
22
  import kotlinx.coroutines.SupervisorJob
21
23
  import kotlinx.coroutines.cancel
24
+ import kotlinx.coroutines.delay
22
25
  import kotlinx.coroutines.flow.catch
23
26
  import kotlinx.coroutines.launch
24
27
  import java.io.File
28
+ import java.util.concurrent.CountDownLatch
25
29
  import java.util.concurrent.ExecutorService
26
30
  import java.util.concurrent.Executors
27
31
 
@@ -34,7 +38,6 @@ class AiCoreModule(reactContext: ReactApplicationContext) :
34
38
  private val executor: ExecutorService = Executors.newSingleThreadExecutor()
35
39
  private val coroutineScope = CoroutineScope(Dispatchers.IO + SupervisorJob())
36
40
 
37
- // Historial de conversación: lista de pares (mensaje usuario, respuesta asistente)
38
41
  private val conversationHistory = mutableListOf<Pair<String, String>>()
39
42
 
40
43
  companion object {
@@ -43,34 +46,38 @@ class AiCoreModule(reactContext: ReactApplicationContext) :
43
46
  const val EVENT_STREAM_COMPLETE = "AICore_streamComplete"
44
47
  const val EVENT_STREAM_ERROR = "AICore_streamError"
45
48
  private const val DEFAULT_TEMPERATURE = 0.7f
46
- private const val DEFAULT_MAX_TOKENS = 2048 // ventana MediaPipe (entrada+salida)
47
- private const val DEFAULT_MAX_OUTPUT_TOKENS = 256 // salida ML Kit/AICore (límite API: 1–256)
49
+ private const val DEFAULT_MAX_TOKENS = 4096 // MediaPipe context window (input+output)
50
+ private const val REQUESTED_MAX_OUTPUT_TOKENS = 256
51
+ private const val FALLBACK_MAX_OUTPUT_TOKENS = 256
48
52
  private const val DEFAULT_TOP_K = 40
49
- // Límite de caracteres del historial para no superar ~3000 tokens de entrada
53
+ private const val PROMPT_CHAR_BUDGET = 4000
50
54
  private const val HISTORY_MAX_CHARS = 9000
55
+ private const val MAX_CONTINUATIONS = 12
56
+ private const val MAX_STREAM_IDLE_RETRIES = 3
57
+ private const val QUOTA_ERROR_CODE = 9 // AICore NPU quota exceeded error code
58
+ private const val CONTINUATION_DELAY_MS = 1200L
59
+ private const val QUOTA_RETRY_DELAY_MS = 1800L
60
+ private const val MAX_QUOTA_RETRIES = 2
61
+ private const val MAX_NON_STREAM_QUOTA_RETRIES = 6
62
+ private const val BACKGROUND_RETRY_DELAY_MS = 2000L
63
+ private const val MAX_BACKGROUND_RETRIES = 1
64
+ private const val CONTINUATION_PROMPT = "Continue exactly from the last generated character. Do not repeat, restart, summarize, explain, or add headings. Output only the direct continuation."
65
+ private const val END_MARKER = "-E-"
66
+ private const val END_MARKER_INSTRUCTION = "[INTERNAL] Append $END_MARKER only once at the true end of the final answer. Never use it before the end."
51
67
  private val STANDARD_MODEL_PATHS = listOf(
52
68
  "/data/local/tmp/gemini-nano.bin",
53
69
  "/sdcard/Download/gemini-nano.bin"
54
70
  )
55
71
  }
56
72
 
57
- // ── Historial ──────────────────────────────────────────────────────────────
58
-
59
73
  @Synchronized
60
74
  private fun buildContextualPrompt(userPrompt: String): String {
61
- if (conversationHistory.isEmpty()) return userPrompt
62
- val sb = StringBuilder()
63
- for ((u, a) in conversationHistory) {
64
- sb.append("User: ").append(u).append("\nAssistant: ").append(a).append("\n")
65
- }
66
- sb.append("User: ").append(userPrompt).append("\nAssistant:")
67
- return sb.toString()
75
+ return buildPromptWithBudget(userPrompt, null, END_MARKER_INSTRUCTION)
68
76
  }
69
77
 
70
78
  @Synchronized
71
79
  private fun saveToHistory(userPrompt: String, assistantResponse: String) {
72
80
  conversationHistory.add(Pair(userPrompt, assistantResponse))
73
- // Eliminar turnos más antiguos si el historial supera el límite de caracteres
74
81
  var total = conversationHistory.sumOf { it.first.length + it.second.length }
75
82
  while (total > HISTORY_MAX_CHARS && conversationHistory.size > 1) {
76
83
  val removed = conversationHistory.removeAt(0)
@@ -83,12 +90,250 @@ class AiCoreModule(reactContext: ReactApplicationContext) :
83
90
  conversationHistory.clear()
84
91
  }
85
92
 
93
+ private fun trimFromStart(text: String, maxChars: Int): String {
94
+ if (text.length <= maxChars) return text
95
+ return text.takeLast(maxChars)
96
+ }
97
+
98
+ private fun trimFromEnd(text: String, maxChars: Int): String {
99
+ if (text.length <= maxChars) return text
100
+ return text.take(maxChars)
101
+ }
102
+
103
+ @Synchronized
104
+ private fun buildPromptWithBudget(
105
+ userPrompt: String,
106
+ assistantPrefix: String?,
107
+ hiddenUserPrompt: String?
108
+ ): String {
109
+ val hiddenInstruction = hiddenUserPrompt?.let { "\n$it\nAssistant:" } ?: ""
110
+ val assistantBase = "\nAssistant:"
111
+ val normalizedUserPrompt = trimFromEnd(userPrompt, PROMPT_CHAR_BUDGET)
112
+ val historySnapshot = conversationHistory.toMutableList()
113
+
114
+ while (true) {
115
+ val sb = StringBuilder()
116
+ for ((u, a) in historySnapshot) {
117
+ sb.append("User: ").append(u).append("\nAssistant: ").append(a).append("\n")
118
+ }
119
+ sb.append("User: ").append(normalizedUserPrompt).append(assistantBase)
120
+ if (assistantPrefix != null) {
121
+ sb.append(' ').append(assistantPrefix)
122
+ }
123
+ sb.append(hiddenInstruction)
124
+
125
+ val candidate = sb.toString()
126
+ if (candidate.length < PROMPT_CHAR_BUDGET || historySnapshot.isEmpty()) {
127
+ if (candidate.length <= PROMPT_CHAR_BUDGET) return candidate
128
+
129
+ val fixedPrefix = "User: $normalizedUserPrompt$assistantBase"
130
+ val availableForAssistant = (PROMPT_CHAR_BUDGET - fixedPrefix.length - hiddenInstruction.length - 1)
131
+ .coerceAtLeast(0)
132
+ val trimmedAssistantPrefix = assistantPrefix?.let { trimFromStart(it, availableForAssistant) }
133
+
134
+ return buildString {
135
+ append(fixedPrefix)
136
+ if (trimmedAssistantPrefix != null && trimmedAssistantPrefix.isNotEmpty()) {
137
+ append(' ').append(trimmedAssistantPrefix)
138
+ }
139
+ append(hiddenInstruction)
140
+ }
141
+ }
142
+
143
+ historySnapshot.removeAt(0)
144
+ }
145
+ }
146
+
147
+ private fun shouldContinueResponse(text: String): Boolean {
148
+ if (text.isBlank()) return false
149
+ val trimmed = text.trimEnd()
150
+ return !(trimmed.endsWith('.') || trimmed.endsWith('!') || trimmed.endsWith('?'))
151
+ }
152
+
153
+ private fun buildContinuationPrompt(originalUserPrompt: String, partialResponse: String): String {
154
+ return buildPromptWithBudget(
155
+ originalUserPrompt,
156
+ partialResponse,
157
+ "$CONTINUATION_PROMPT\n$END_MARKER_INSTRUCTION"
158
+ )
159
+ }
160
+
161
+ private fun containsEndMarker(text: String): Boolean {
162
+ return text.contains(END_MARKER)
163
+ }
164
+
165
+ private fun stripEndMarker(text: String): String {
166
+ return text.replace(END_MARKER, "")
167
+ }
168
+
169
+ private fun sanitizeVisibleText(text: String): String {
170
+ var cleaned = text
171
+ cleaned = cleaned.replace(Regex("(?i)\\[\\s*internal\\s*\\][^\\n\\r]*"), "")
172
+ cleaned = cleaned.replace(Regex("(?i)append\\s+\\Q$END_MARKER\\E[^\\n\\r]*"), "")
173
+ cleaned = cleaned.replace(Regex("(?i)never use it before the end\\.?"), "")
174
+ return cleaned
175
+ }
176
+
177
+ private fun adjustChunkBoundary(existing: String, incoming: String): String {
178
+ if (incoming.isEmpty() || existing.isEmpty()) return incoming
179
+
180
+ var chunk = incoming
181
+ val last = existing.last()
182
+ val first = chunk.first()
183
+
184
+ if (last.isLetterOrDigit() && first.isLetterOrDigit()) {
185
+ chunk = " $chunk"
186
+ }
187
+
188
+ if (existing.last() == ' ' && chunk.firstOrNull() == ' ') {
189
+ chunk = chunk.trimStart(' ')
190
+ }
191
+
192
+ return chunk
193
+ }
194
+
195
+ private fun emitStreamToken(token: String, done: Boolean) {
196
+ sendEvent(EVENT_STREAM_TOKEN, Arguments.createMap().apply {
197
+ putString("token", token)
198
+ putBoolean("done", done)
199
+ })
200
+ }
201
+
202
+ private fun isOutOfRangeError(error: Throwable): Boolean {
203
+ return error.message?.contains("out of range", ignoreCase = true) == true
204
+ }
205
+
206
+ private fun isQuotaError(error: Throwable): Boolean {
207
+ return error is GenAiException && error.errorCode == QUOTA_ERROR_CODE
208
+ }
209
+
210
+ private fun isBackgroundError(error: Throwable): Boolean {
211
+ val msg = error.message?.lowercase() ?: ""
212
+ return msg.contains("background usage is blocked") ||
213
+ msg.contains("use the api when your app is in the foreground")
214
+ }
215
+
216
+ private suspend fun generateMlKitChunk(model: GenerativeModel, prompt: String): String {
217
+ var quotaRetries = 0
218
+ var backgroundRetries = 0
219
+ while (true) {
220
+ try {
221
+ val request = generateContentRequest(TextPart(prompt)) {
222
+ maxOutputTokens = REQUESTED_MAX_OUTPUT_TOKENS
223
+ }
224
+ return model.generateContent(request).candidates.firstOrNull()?.text ?: ""
225
+ } catch (error: Exception) {
226
+ if (isQuotaError(error) && quotaRetries < MAX_QUOTA_RETRIES) {
227
+ quotaRetries++
228
+ delay(QUOTA_RETRY_DELAY_MS)
229
+ continue
230
+ }
231
+ if (isBackgroundError(error) && backgroundRetries < MAX_BACKGROUND_RETRIES) {
232
+ backgroundRetries++
233
+ delay(BACKGROUND_RETRY_DELAY_MS)
234
+ continue
235
+ }
236
+ if (!isOutOfRangeError(error)) throw error
237
+ while (true) {
238
+ try {
239
+ val fallbackRequest = generateContentRequest(TextPart(prompt)) {
240
+ maxOutputTokens = FALLBACK_MAX_OUTPUT_TOKENS
241
+ }
242
+ return model.generateContent(fallbackRequest).candidates.firstOrNull()?.text ?: ""
243
+ } catch (fallbackError: Exception) {
244
+ if (isQuotaError(fallbackError) && quotaRetries < MAX_QUOTA_RETRIES) {
245
+ quotaRetries++
246
+ delay(QUOTA_RETRY_DELAY_MS)
247
+ continue
248
+ }
249
+ throw fallbackError
250
+ }
251
+ }
252
+ }
253
+ }
254
+ }
255
+
256
+ private suspend fun streamMlKitChunk(
257
+ model: GenerativeModel,
258
+ prompt: String,
259
+ onToken: (String) -> Unit
260
+ ): Boolean {
261
+ suspend fun collectWithLimit(limit: Int): Boolean {
262
+ var quotaRetries = 0
263
+ var backgroundRetries = 0
264
+ while (true) {
265
+ var quotaHit = false
266
+ var backgroundHit = false
267
+ val request = generateContentRequest(TextPart(prompt)) {
268
+ maxOutputTokens = limit
269
+ }
270
+ try {
271
+ model.generateContentStream(request)
272
+ .catch { error ->
273
+ if (isQuotaError(error)) {
274
+ quotaHit = true
275
+ } else if (isBackgroundError(error)) {
276
+ backgroundHit = true
277
+ } else {
278
+ throw error
279
+ }
280
+ }
281
+ .collect { chunk ->
282
+ val token = chunk.candidates.firstOrNull()?.text ?: ""
283
+ onToken(token)
284
+ }
285
+ } catch (error: Exception) {
286
+ if (isQuotaError(error)) {
287
+ quotaHit = true
288
+ } else if (isBackgroundError(error)) {
289
+ backgroundHit = true
290
+ } else {
291
+ throw error
292
+ }
293
+ }
294
+ if (!quotaHit && !backgroundHit) return false
295
+ if (quotaHit) {
296
+ if (quotaRetries >= MAX_QUOTA_RETRIES) return true
297
+ quotaRetries++
298
+ delay(QUOTA_RETRY_DELAY_MS)
299
+ } else {
300
+ if (backgroundRetries >= MAX_BACKGROUND_RETRIES) return false
301
+ backgroundRetries++
302
+ delay(BACKGROUND_RETRY_DELAY_MS)
303
+ }
304
+ }
305
+ }
306
+
307
+ return try {
308
+ collectWithLimit(REQUESTED_MAX_OUTPUT_TOKENS)
309
+ } catch (error: Exception) {
310
+ if (!isOutOfRangeError(error)) throw error
311
+ collectWithLimit(FALLBACK_MAX_OUTPUT_TOKENS)
312
+ }
313
+ }
314
+
86
315
  private fun sendEvent(name: String, params: WritableMap?) {
87
316
  reactApplicationContext
88
317
  .getJSModule(DeviceEventManagerModule.RCTDeviceEventEmitter::class.java)
89
318
  .emit(name, params)
90
319
  }
91
320
 
321
+ private fun startInferenceService() {
322
+ val intent = Intent(reactApplicationContext, InferenceService::class.java)
323
+ if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.O) {
324
+ reactApplicationContext.startForegroundService(intent)
325
+ } else {
326
+ reactApplicationContext.startService(intent)
327
+ }
328
+ }
329
+
330
+ private fun stopInferenceService() {
331
+ val intent = Intent(reactApplicationContext, InferenceService::class.java).apply {
332
+ action = InferenceService.ACTION_STOP
333
+ }
334
+ reactApplicationContext.startService(intent)
335
+ }
336
+
92
337
  private fun createErrorMap(code: String, message: String): WritableMap =
93
338
  Arguments.createMap().apply {
94
339
  putString("code", code)
@@ -96,7 +341,7 @@ class AiCoreModule(reactContext: ReactApplicationContext) :
96
341
  }
97
342
 
98
343
  private fun createMediaPipeSession(): LlmInferenceSession {
99
- val inference = llmInference ?: throw IllegalStateException("LLM no inicializado.")
344
+ val inference = llmInference ?: throw IllegalStateException("LLM not initialized.")
100
345
  val opts = LlmInferenceSession.LlmInferenceSessionOptions.builder()
101
346
  .setTemperature(DEFAULT_TEMPERATURE)
102
347
  .setTopK(DEFAULT_TOP_K)
@@ -104,6 +349,20 @@ class AiCoreModule(reactContext: ReactApplicationContext) :
104
349
  return LlmInferenceSession.createFromOptions(inference, opts)
105
350
  }
106
351
 
352
+ private fun buildPrompt(userPrompt: String, useConversationHistory: Boolean): String {
353
+ return if (useConversationHistory) buildContextualPrompt(userPrompt) else userPrompt
354
+ }
355
+
356
+ private fun maybeSaveToHistory(
357
+ userPrompt: String,
358
+ assistantResponse: String,
359
+ useConversationHistory: Boolean
360
+ ) {
361
+ if (useConversationHistory) {
362
+ saveToHistory(userPrompt, assistantResponse)
363
+ }
364
+ }
365
+
107
366
  override fun initialize(modelPath: String, promise: Promise) {
108
367
  mlkitModel = null
109
368
  llmInference?.close()
@@ -131,9 +390,9 @@ class AiCoreModule(reactContext: ReactApplicationContext) :
131
390
  }
132
391
  }
133
392
  }
134
- FeatureStatus.DOWNLOADING -> promise.reject("ALREADY_DOWNLOADING", "Ya se está descargando.")
135
- FeatureStatus.UNAVAILABLE -> promise.reject("AICORE_UNAVAILABLE", "Gemini Nano no disponible en este dispositivo.")
136
- else -> promise.reject("AICORE_UNKNOWN", "Estado desconocido.")
393
+ FeatureStatus.DOWNLOADING -> promise.reject("ALREADY_DOWNLOADING", "Model download already in progress.")
394
+ FeatureStatus.UNAVAILABLE -> promise.reject("AICORE_UNAVAILABLE", "Gemini Nano is not available on this device.")
395
+ else -> promise.reject("AICORE_UNKNOWN", "Unknown AICore status.")
137
396
  }
138
397
  } catch (e: Exception) {
139
398
  promise.reject("AICORE_ERROR", e.message, e)
@@ -143,7 +402,7 @@ class AiCoreModule(reactContext: ReactApplicationContext) :
143
402
  executor.execute {
144
403
  try {
145
404
  if (!File(modelPath).exists()) {
146
- promise.reject("MODEL_NOT_FOUND", "Modelo no encontrado en: $modelPath")
405
+ promise.reject("MODEL_NOT_FOUND", "Model file not found at: $modelPath")
147
406
  return@execute
148
407
  }
149
408
  val options = LlmInference.LlmInferenceOptions.builder()
@@ -165,105 +424,287 @@ class AiCoreModule(reactContext: ReactApplicationContext) :
165
424
  }
166
425
 
167
426
  override fun generateResponse(prompt: String, promise: Promise) {
427
+ generateResponseInternal(prompt, true, promise)
428
+ }
429
+
430
+ override fun generateResponseStateless(prompt: String, promise: Promise) {
431
+ generateResponseInternal(prompt, false, promise)
432
+ }
433
+
434
+ private fun generateResponseInternal(
435
+ prompt: String,
436
+ useConversationHistory: Boolean,
437
+ promise: Promise
438
+ ) {
168
439
  val mlkit = mlkitModel
169
440
  val mediapipe = llmInference
170
- val contextualPrompt = buildContextualPrompt(prompt)
441
+ startInferenceService()
171
442
  when {
172
443
  mlkit != null -> coroutineScope.launch {
173
444
  try {
174
- val request = generateContentRequest(TextPart(contextualPrompt)) {
175
- maxOutputTokens = DEFAULT_MAX_OUTPUT_TOKENS
445
+ val rawTotal = StringBuilder()
446
+ val visibleTotal = StringBuilder()
447
+ var currentPrompt = buildPrompt(prompt, useConversationHistory)
448
+ var continuations = 0
449
+ var continuationJoinPending = false
450
+ var quotaRetries = 0
451
+ while (true) {
452
+ val part = try {
453
+ generateMlKitChunk(mlkit, currentPrompt)
454
+ } catch (e: GenAiException) {
455
+ if (e.errorCode == QUOTA_ERROR_CODE) {
456
+ if (quotaRetries < MAX_NON_STREAM_QUOTA_RETRIES) {
457
+ quotaRetries++
458
+ delay(QUOTA_RETRY_DELAY_MS)
459
+ continue
460
+ }
461
+ if (rawTotal.isNotEmpty()) break
462
+ }
463
+ throw e
464
+ }
465
+ quotaRetries = 0
466
+ rawTotal.append(part)
467
+ if (containsEndMarker(rawTotal.toString())) break
468
+
469
+ val cleanPart = sanitizeVisibleText(stripEndMarker(part))
470
+ val partForUi = if (continuationJoinPending) {
471
+ continuationJoinPending = false
472
+ adjustChunkBoundary(visibleTotal.toString(), cleanPart)
473
+ } else {
474
+ cleanPart
475
+ }
476
+ visibleTotal.append(partForUi)
477
+ val visible = visibleTotal.toString()
478
+ if (
479
+ useConversationHistory &&
480
+ shouldContinueResponse(visible) &&
481
+ continuations < MAX_CONTINUATIONS
482
+ ) {
483
+ currentPrompt = buildContinuationPrompt(prompt, visible)
484
+ continuations++
485
+ continuationJoinPending = true
486
+ delay(CONTINUATION_DELAY_MS)
487
+ } else break
488
+ }
489
+ val full = if (visibleTotal.isNotEmpty()) {
490
+ sanitizeVisibleText(visibleTotal.toString())
491
+ } else {
492
+ sanitizeVisibleText(stripEndMarker(rawTotal.toString()))
176
493
  }
177
- val response = mlkit.generateContent(request).candidates.firstOrNull()?.text ?: ""
178
- saveToHistory(prompt, response)
179
- promise.resolve(response)
494
+ maybeSaveToHistory(prompt, full, useConversationHistory)
495
+ stopInferenceService()
496
+ promise.resolve(full)
180
497
  } catch (e: Exception) {
498
+ stopInferenceService()
181
499
  promise.reject("GENERATION_ERROR", e.message, e)
182
500
  }
183
501
  }
184
502
  mediapipe != null -> executor.execute {
185
503
  var session: LlmInferenceSession? = null
186
504
  try {
187
- session = createMediaPipeSession()
188
- session.addQueryChunk(contextualPrompt)
189
- val response = session.generateResponse()
190
- saveToHistory(prompt, response)
191
- promise.resolve(response)
505
+ val rawTotal = StringBuilder()
506
+ val visibleTotal = StringBuilder()
507
+ var currentPrompt = buildPrompt(prompt, useConversationHistory)
508
+ var continuations = 0
509
+ var continuationJoinPending = false
510
+ while (true) {
511
+ session = createMediaPipeSession()
512
+ session.addQueryChunk(currentPrompt)
513
+ val part = session.generateResponse()
514
+ session.close()
515
+ session = null
516
+ rawTotal.append(part)
517
+ if (containsEndMarker(rawTotal.toString())) break
518
+
519
+ val cleanPart = sanitizeVisibleText(stripEndMarker(part))
520
+ val partForUi = if (continuationJoinPending) {
521
+ continuationJoinPending = false
522
+ adjustChunkBoundary(visibleTotal.toString(), cleanPart)
523
+ } else {
524
+ cleanPart
525
+ }
526
+ visibleTotal.append(partForUi)
527
+ val visible = visibleTotal.toString()
528
+ if (
529
+ useConversationHistory &&
530
+ shouldContinueResponse(visible) &&
531
+ continuations < MAX_CONTINUATIONS
532
+ ) {
533
+ currentPrompt = buildContinuationPrompt(prompt, visible)
534
+ continuations++
535
+ continuationJoinPending = true
536
+ } else break
537
+ }
538
+ val full = if (visibleTotal.isNotEmpty()) {
539
+ sanitizeVisibleText(visibleTotal.toString())
540
+ } else {
541
+ sanitizeVisibleText(stripEndMarker(rawTotal.toString()))
542
+ }
543
+ maybeSaveToHistory(prompt, full, useConversationHistory)
544
+ stopInferenceService()
545
+ promise.resolve(full)
192
546
  } catch (e: Exception) {
547
+ stopInferenceService()
193
548
  promise.reject("GENERATION_ERROR", e.message, e)
194
549
  } finally {
195
550
  session?.close()
196
551
  }
197
552
  }
198
- else -> promise.reject("NOT_INITIALIZED", "LLM no inicializado.")
553
+ else -> promise.reject("NOT_INITIALIZED", "LLM not initialized.")
199
554
  }
200
555
  }
201
556
 
202
557
  override fun generateResponseStream(prompt: String) {
203
558
  val mlkit = mlkitModel
204
559
  val mediapipe = llmInference
205
- val contextualPrompt = buildContextualPrompt(prompt)
560
+ startInferenceService()
206
561
  when {
207
562
  mlkit != null -> coroutineScope.launch {
563
+ val total = StringBuilder()
564
+ val rawTotal = StringBuilder()
565
+ var currentPrompt = buildContextualPrompt(prompt)
566
+ var continuations = 0
567
+ var continuationJoinPending = false
568
+ var firstDeltaInPass = false
569
+ var idleRetries = 0
208
570
  var streamError = false
209
- val fullResponse = StringBuilder()
571
+ var markerReached = false
210
572
  try {
211
- val request = generateContentRequest(TextPart(contextualPrompt)) {
212
- maxOutputTokens = DEFAULT_MAX_OUTPUT_TOKENS
213
- }
214
- mlkit.generateContentStream(request)
215
- .catch { e ->
573
+ while (true) {
574
+ val beforeLength = total.length
575
+ firstDeltaInPass = true
576
+ var quotaHit = false
577
+ try {
578
+ quotaHit = streamMlKitChunk(mlkit, currentPrompt) { token ->
579
+ rawTotal.append(token)
580
+ if (containsEndMarker(rawTotal.toString())) {
581
+ markerReached = true
582
+ }
583
+
584
+ val visibleNow = sanitizeVisibleText(stripEndMarker(rawTotal.toString()))
585
+ if (visibleNow.length > total.length) {
586
+ val delta = visibleNow.substring(total.length)
587
+ val adjustedDelta = if (continuationJoinPending && firstDeltaInPass) {
588
+ continuationJoinPending = false
589
+ firstDeltaInPass = false
590
+ adjustChunkBoundary(total.toString(), delta)
591
+ } else {
592
+ firstDeltaInPass = false
593
+ delta
594
+ }
595
+ if (adjustedDelta.isNotEmpty()) {
596
+ total.append(adjustedDelta)
597
+ emitStreamToken(adjustedDelta, false)
598
+ }
599
+ }
600
+ }
601
+ } catch (e: GenAiException) {
602
+ if (e.errorCode == QUOTA_ERROR_CODE && total.isNotEmpty()) {
603
+ quotaHit = true
604
+ } else {
605
+ throw e
606
+ }
607
+ } catch (e: Exception) {
216
608
  streamError = true
609
+ stopInferenceService()
217
610
  sendEvent(EVENT_STREAM_ERROR, createErrorMap("STREAM_ERROR", e.message ?: "Error"))
611
+ break
218
612
  }
219
- .collect { chunk ->
220
- val token = chunk.candidates.firstOrNull()?.text ?: ""
221
- fullResponse.append(token)
222
- sendEvent(EVENT_STREAM_TOKEN, Arguments.createMap().apply {
223
- putString("token", token)
224
- putBoolean("done", false)
225
- })
613
+ if (streamError) return@launch
614
+ if (markerReached) break
615
+
616
+ val appendedNewText = total.length > beforeLength
617
+ if (!appendedNewText && shouldContinueResponse(total.toString()) && idleRetries < MAX_STREAM_IDLE_RETRIES) {
618
+ idleRetries++
619
+ currentPrompt = buildContinuationPrompt(prompt, total.toString())
620
+ continuationJoinPending = true
621
+ delay(if (quotaHit) QUOTA_RETRY_DELAY_MS else CONTINUATION_DELAY_MS)
622
+ continue
226
623
  }
227
- if (!streamError) {
228
- saveToHistory(prompt, fullResponse.toString())
229
- sendEvent(EVENT_STREAM_TOKEN, Arguments.createMap().apply {
230
- putString("token", "")
231
- putBoolean("done", true)
232
- })
233
- sendEvent(EVENT_STREAM_COMPLETE, Arguments.createMap())
624
+
625
+ idleRetries = 0
626
+ if (shouldContinueResponse(total.toString()) && continuations < MAX_CONTINUATIONS) {
627
+ currentPrompt = buildContinuationPrompt(prompt, total.toString())
628
+ continuations++
629
+ continuationJoinPending = true
630
+ delay(if (quotaHit) QUOTA_RETRY_DELAY_MS else CONTINUATION_DELAY_MS)
631
+ } else break
234
632
  }
633
+ saveToHistory(prompt, sanitizeVisibleText(total.toString()))
634
+ emitStreamToken("", true)
635
+ stopInferenceService()
636
+ sendEvent(EVENT_STREAM_COMPLETE, Arguments.createMap())
235
637
  } catch (e: Exception) {
236
638
  if (!streamError) {
639
+ stopInferenceService()
237
640
  sendEvent(EVENT_STREAM_ERROR, createErrorMap("STREAM_ERROR", e.message ?: "Error"))
238
641
  }
239
642
  }
240
643
  }
241
644
  mediapipe != null -> executor.execute {
242
- val fullResponse = StringBuilder()
645
+ val total = StringBuilder()
646
+ val rawTotal = StringBuilder()
647
+ var currentPrompt = buildContextualPrompt(prompt)
648
+ var continuations = 0
649
+ var continuationJoinPending = false
650
+ var firstDeltaInPass = false
243
651
  var session: LlmInferenceSession? = null
652
+ var markerReached = false
244
653
  try {
245
- session = createMediaPipeSession()
246
- val capturedSession = session
247
- session.addQueryChunk(contextualPrompt)
248
- session.generateResponseAsync(ProgressListener<String> { partial, done ->
249
- val token = partial ?: ""
250
- fullResponse.append(token)
251
- sendEvent(EVENT_STREAM_TOKEN, Arguments.createMap().apply {
252
- putString("token", token)
253
- putBoolean("done", done)
654
+ while (true) {
655
+ val latch = CountDownLatch(1)
656
+ firstDeltaInPass = true
657
+ session = createMediaPipeSession()
658
+ val capturedSession = session
659
+ session.addQueryChunk(currentPrompt)
660
+ session.generateResponseAsync(ProgressListener<String> { partial, done ->
661
+ val token = partial ?: ""
662
+ rawTotal.append(token)
663
+ if (containsEndMarker(rawTotal.toString())) {
664
+ markerReached = true
665
+ }
666
+
667
+ val visibleNow = sanitizeVisibleText(stripEndMarker(rawTotal.toString()))
668
+ if (visibleNow.length > total.length) {
669
+ val delta = visibleNow.substring(total.length)
670
+ val adjustedDelta = if (continuationJoinPending && firstDeltaInPass) {
671
+ continuationJoinPending = false
672
+ firstDeltaInPass = false
673
+ adjustChunkBoundary(total.toString(), delta)
674
+ } else {
675
+ firstDeltaInPass = false
676
+ delta
677
+ }
678
+ if (adjustedDelta.isNotEmpty()) {
679
+ total.append(adjustedDelta)
680
+ emitStreamToken(adjustedDelta, false)
681
+ }
682
+ }
683
+ if (done) {
684
+ capturedSession.close()
685
+ latch.countDown()
686
+ }
254
687
  })
255
- if (done) {
256
- saveToHistory(prompt, fullResponse.toString())
257
- sendEvent(EVENT_STREAM_COMPLETE, Arguments.createMap())
258
- capturedSession.close()
259
- }
260
- })
688
+ latch.await()
689
+ session = null
690
+ if (markerReached) break
691
+ if (shouldContinueResponse(total.toString()) && continuations < MAX_CONTINUATIONS) {
692
+ currentPrompt = buildContinuationPrompt(prompt, total.toString())
693
+ continuations++
694
+ continuationJoinPending = true
695
+ } else break
696
+ }
697
+ saveToHistory(prompt, sanitizeVisibleText(total.toString()))
698
+ emitStreamToken("", true)
699
+ stopInferenceService()
700
+ sendEvent(EVENT_STREAM_COMPLETE, Arguments.createMap())
261
701
  } catch (e: Exception) {
262
702
  session?.close()
703
+ stopInferenceService()
263
704
  sendEvent(EVENT_STREAM_ERROR, createErrorMap("STREAM_ERROR", e.message ?: "Error"))
264
705
  }
265
706
  }
266
- else -> sendEvent(EVENT_STREAM_ERROR, createErrorMap("NOT_INITIALIZED", "LLM no inicializado."))
707
+ else -> sendEvent(EVENT_STREAM_ERROR, createErrorMap("NOT_INITIALIZED", "LLM not initialized."))
267
708
  }
268
709
  }
269
710
 
@@ -317,6 +758,7 @@ class AiCoreModule(reactContext: ReactApplicationContext) :
317
758
  override fun invalidate() {
318
759
  super.invalidate()
319
760
  try {
761
+ stopInferenceService()
320
762
  llmInference?.close()
321
763
  llmInference = null
322
764
  mlkitModel = null