react-native-ai-core 0.1.0 → 0.3.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,5 +1,6 @@
1
1
  package com.aicore
2
2
 
3
+ import android.content.Intent
3
4
  import android.os.Build
4
5
  import com.facebook.react.bridge.Arguments
5
6
  import com.facebook.react.bridge.Promise
@@ -11,17 +12,21 @@ import com.google.mediapipe.tasks.genai.llminference.LlmInferenceSession
11
12
  import com.google.mediapipe.tasks.genai.llminference.ProgressListener
12
13
  import com.google.mlkit.genai.common.DownloadStatus
13
14
  import com.google.mlkit.genai.common.FeatureStatus
15
+ import com.google.mlkit.genai.common.GenAiException
14
16
  import com.google.mlkit.genai.prompt.Generation
15
17
  import com.google.mlkit.genai.prompt.GenerativeModel
16
18
  import com.google.mlkit.genai.prompt.TextPart
17
19
  import com.google.mlkit.genai.prompt.generateContentRequest
18
20
  import kotlinx.coroutines.CoroutineScope
19
21
  import kotlinx.coroutines.Dispatchers
22
+ import kotlinx.coroutines.Job
20
23
  import kotlinx.coroutines.SupervisorJob
21
24
  import kotlinx.coroutines.cancel
25
+ import kotlinx.coroutines.delay
22
26
  import kotlinx.coroutines.flow.catch
23
27
  import kotlinx.coroutines.launch
24
28
  import java.io.File
29
+ import java.util.concurrent.CountDownLatch
25
30
  import java.util.concurrent.ExecutorService
26
31
  import java.util.concurrent.Executors
27
32
 
@@ -34,8 +39,9 @@ class AiCoreModule(reactContext: ReactApplicationContext) :
34
39
  private val executor: ExecutorService = Executors.newSingleThreadExecutor()
35
40
  private val coroutineScope = CoroutineScope(Dispatchers.IO + SupervisorJob())
36
41
 
37
- // Historial de conversación: lista de pares (mensaje usuario, respuesta asistente)
38
42
  private val conversationHistory = mutableListOf<Pair<String, String>>()
43
+ @Volatile private var cancelRequested = false
44
+ @Volatile private var activeGenerationJob: Job? = null
39
45
 
40
46
  companion object {
41
47
  const val NAME = NativeAiCoreSpec.NAME
@@ -43,34 +49,38 @@ class AiCoreModule(reactContext: ReactApplicationContext) :
43
49
  const val EVENT_STREAM_COMPLETE = "AICore_streamComplete"
44
50
  const val EVENT_STREAM_ERROR = "AICore_streamError"
45
51
  private const val DEFAULT_TEMPERATURE = 0.7f
46
- private const val DEFAULT_MAX_TOKENS = 2048 // ventana MediaPipe (entrada+salida)
47
- private const val DEFAULT_MAX_OUTPUT_TOKENS = 256 // salida ML Kit/AICore (límite API: 1–256)
52
+ private const val DEFAULT_MAX_TOKENS = 4096 // MediaPipe context window (input+output)
53
+ private const val REQUESTED_MAX_OUTPUT_TOKENS = 256
54
+ private const val FALLBACK_MAX_OUTPUT_TOKENS = 256
48
55
  private const val DEFAULT_TOP_K = 40
49
- // Límite de caracteres del historial para no superar ~3000 tokens de entrada
56
+ private const val PROMPT_CHAR_BUDGET = 4000
50
57
  private const val HISTORY_MAX_CHARS = 9000
58
+ private const val MAX_CONTINUATIONS = 12
59
+ private const val MAX_STREAM_IDLE_RETRIES = 3
60
+ private const val QUOTA_ERROR_CODE = 9 // AICore NPU quota exceeded error code
61
+ private const val CONTINUATION_DELAY_MS = 1200L
62
+ private const val QUOTA_RETRY_DELAY_MS = 1800L
63
+ private const val MAX_QUOTA_RETRIES = 2
64
+ private const val MAX_NON_STREAM_QUOTA_RETRIES = 6
65
+ private const val BACKGROUND_RETRY_DELAY_MS = 2000L
66
+ private const val MAX_BACKGROUND_RETRIES = 1
67
+ private const val CONTINUATION_PROMPT = "Continue exactly from the last generated character. Do not repeat, restart, summarize, explain, or add headings. Output only the direct continuation."
68
+ private const val END_MARKER = "-E-"
69
+ private const val END_MARKER_INSTRUCTION = "[INTERNAL] Append $END_MARKER only once at the true end of the final answer. Never use it before the end."
51
70
  private val STANDARD_MODEL_PATHS = listOf(
52
71
  "/data/local/tmp/gemini-nano.bin",
53
72
  "/sdcard/Download/gemini-nano.bin"
54
73
  )
55
74
  }
56
75
 
57
- // ── Historial ──────────────────────────────────────────────────────────────
58
-
59
76
  @Synchronized
60
77
  private fun buildContextualPrompt(userPrompt: String): String {
61
- if (conversationHistory.isEmpty()) return userPrompt
62
- val sb = StringBuilder()
63
- for ((u, a) in conversationHistory) {
64
- sb.append("User: ").append(u).append("\nAssistant: ").append(a).append("\n")
65
- }
66
- sb.append("User: ").append(userPrompt).append("\nAssistant:")
67
- return sb.toString()
78
+ return buildPromptWithBudget(userPrompt, null, END_MARKER_INSTRUCTION)
68
79
  }
69
80
 
70
81
  @Synchronized
71
82
  private fun saveToHistory(userPrompt: String, assistantResponse: String) {
72
83
  conversationHistory.add(Pair(userPrompt, assistantResponse))
73
- // Eliminar turnos más antiguos si el historial supera el límite de caracteres
74
84
  var total = conversationHistory.sumOf { it.first.length + it.second.length }
75
85
  while (total > HISTORY_MAX_CHARS && conversationHistory.size > 1) {
76
86
  val removed = conversationHistory.removeAt(0)
@@ -83,12 +93,250 @@ class AiCoreModule(reactContext: ReactApplicationContext) :
83
93
  conversationHistory.clear()
84
94
  }
85
95
 
96
+ private fun trimFromStart(text: String, maxChars: Int): String {
97
+ if (text.length <= maxChars) return text
98
+ return text.takeLast(maxChars)
99
+ }
100
+
101
+ private fun trimFromEnd(text: String, maxChars: Int): String {
102
+ if (text.length <= maxChars) return text
103
+ return text.take(maxChars)
104
+ }
105
+
106
+ @Synchronized
107
+ private fun buildPromptWithBudget(
108
+ userPrompt: String,
109
+ assistantPrefix: String?,
110
+ hiddenUserPrompt: String?
111
+ ): String {
112
+ val hiddenInstruction = hiddenUserPrompt?.let { "\n$it\nAssistant:" } ?: ""
113
+ val assistantBase = "\nAssistant:"
114
+ val normalizedUserPrompt = trimFromEnd(userPrompt, PROMPT_CHAR_BUDGET)
115
+ val historySnapshot = conversationHistory.toMutableList()
116
+
117
+ while (true) {
118
+ val sb = StringBuilder()
119
+ for ((u, a) in historySnapshot) {
120
+ sb.append("User: ").append(u).append("\nAssistant: ").append(a).append("\n")
121
+ }
122
+ sb.append("User: ").append(normalizedUserPrompt).append(assistantBase)
123
+ if (assistantPrefix != null) {
124
+ sb.append(' ').append(assistantPrefix)
125
+ }
126
+ sb.append(hiddenInstruction)
127
+
128
+ val candidate = sb.toString()
129
+ if (candidate.length < PROMPT_CHAR_BUDGET || historySnapshot.isEmpty()) {
130
+ if (candidate.length <= PROMPT_CHAR_BUDGET) return candidate
131
+
132
+ val fixedPrefix = "User: $normalizedUserPrompt$assistantBase"
133
+ val availableForAssistant = (PROMPT_CHAR_BUDGET - fixedPrefix.length - hiddenInstruction.length - 1)
134
+ .coerceAtLeast(0)
135
+ val trimmedAssistantPrefix = assistantPrefix?.let { trimFromStart(it, availableForAssistant) }
136
+
137
+ return buildString {
138
+ append(fixedPrefix)
139
+ if (trimmedAssistantPrefix != null && trimmedAssistantPrefix.isNotEmpty()) {
140
+ append(' ').append(trimmedAssistantPrefix)
141
+ }
142
+ append(hiddenInstruction)
143
+ }
144
+ }
145
+
146
+ historySnapshot.removeAt(0)
147
+ }
148
+ }
149
+
150
+ private fun shouldContinueResponse(text: String): Boolean {
151
+ if (text.isBlank()) return false
152
+ val trimmed = text.trimEnd()
153
+ return !(trimmed.endsWith('.') || trimmed.endsWith('!') || trimmed.endsWith('?'))
154
+ }
155
+
156
+ private fun buildContinuationPrompt(originalUserPrompt: String, partialResponse: String): String {
157
+ return buildPromptWithBudget(
158
+ originalUserPrompt,
159
+ partialResponse,
160
+ "$CONTINUATION_PROMPT\n$END_MARKER_INSTRUCTION"
161
+ )
162
+ }
163
+
164
+ private fun containsEndMarker(text: String): Boolean {
165
+ return text.contains(END_MARKER)
166
+ }
167
+
168
+ private fun stripEndMarker(text: String): String {
169
+ return text.replace(END_MARKER, "")
170
+ }
171
+
172
+ private fun sanitizeVisibleText(text: String): String {
173
+ var cleaned = text
174
+ cleaned = cleaned.replace(Regex("(?i)\\[\\s*internal\\s*\\][^\\n\\r]*"), "")
175
+ cleaned = cleaned.replace(Regex("(?i)append\\s+\\Q$END_MARKER\\E[^\\n\\r]*"), "")
176
+ cleaned = cleaned.replace(Regex("(?i)never use it before the end\\.?"), "")
177
+ return cleaned
178
+ }
179
+
180
+ private fun adjustChunkBoundary(existing: String, incoming: String): String {
181
+ if (incoming.isEmpty() || existing.isEmpty()) return incoming
182
+
183
+ var chunk = incoming
184
+ val last = existing.last()
185
+ val first = chunk.first()
186
+
187
+ if (last.isLetterOrDigit() && first.isLetterOrDigit()) {
188
+ chunk = " $chunk"
189
+ }
190
+
191
+ if (existing.last() == ' ' && chunk.firstOrNull() == ' ') {
192
+ chunk = chunk.trimStart(' ')
193
+ }
194
+
195
+ return chunk
196
+ }
197
+
198
+ private fun emitStreamToken(token: String, done: Boolean) {
199
+ sendEvent(EVENT_STREAM_TOKEN, Arguments.createMap().apply {
200
+ putString("token", token)
201
+ putBoolean("done", done)
202
+ })
203
+ }
204
+
205
+ private fun isOutOfRangeError(error: Throwable): Boolean {
206
+ return error.message?.contains("out of range", ignoreCase = true) == true
207
+ }
208
+
209
+ private fun isQuotaError(error: Throwable): Boolean {
210
+ return error is GenAiException && error.errorCode == QUOTA_ERROR_CODE
211
+ }
212
+
213
+ private fun isBackgroundError(error: Throwable): Boolean {
214
+ val msg = error.message?.lowercase() ?: ""
215
+ return msg.contains("background usage is blocked") ||
216
+ msg.contains("use the api when your app is in the foreground")
217
+ }
218
+
219
+ private suspend fun generateMlKitChunk(model: GenerativeModel, prompt: String): String {
220
+ var quotaRetries = 0
221
+ var backgroundRetries = 0
222
+ while (true) {
223
+ try {
224
+ val request = generateContentRequest(TextPart(prompt)) {
225
+ maxOutputTokens = REQUESTED_MAX_OUTPUT_TOKENS
226
+ }
227
+ return model.generateContent(request).candidates.firstOrNull()?.text ?: ""
228
+ } catch (error: Exception) {
229
+ if (isQuotaError(error) && quotaRetries < MAX_QUOTA_RETRIES) {
230
+ quotaRetries++
231
+ delay(QUOTA_RETRY_DELAY_MS)
232
+ continue
233
+ }
234
+ if (isBackgroundError(error) && backgroundRetries < MAX_BACKGROUND_RETRIES) {
235
+ backgroundRetries++
236
+ delay(BACKGROUND_RETRY_DELAY_MS)
237
+ continue
238
+ }
239
+ if (!isOutOfRangeError(error)) throw error
240
+ while (true) {
241
+ try {
242
+ val fallbackRequest = generateContentRequest(TextPart(prompt)) {
243
+ maxOutputTokens = FALLBACK_MAX_OUTPUT_TOKENS
244
+ }
245
+ return model.generateContent(fallbackRequest).candidates.firstOrNull()?.text ?: ""
246
+ } catch (fallbackError: Exception) {
247
+ if (isQuotaError(fallbackError) && quotaRetries < MAX_QUOTA_RETRIES) {
248
+ quotaRetries++
249
+ delay(QUOTA_RETRY_DELAY_MS)
250
+ continue
251
+ }
252
+ throw fallbackError
253
+ }
254
+ }
255
+ }
256
+ }
257
+ }
258
+
259
+ private suspend fun streamMlKitChunk(
260
+ model: GenerativeModel,
261
+ prompt: String,
262
+ onToken: (String) -> Unit
263
+ ): Boolean {
264
+ suspend fun collectWithLimit(limit: Int): Boolean {
265
+ var quotaRetries = 0
266
+ var backgroundRetries = 0
267
+ while (true) {
268
+ var quotaHit = false
269
+ var backgroundHit = false
270
+ val request = generateContentRequest(TextPart(prompt)) {
271
+ maxOutputTokens = limit
272
+ }
273
+ try {
274
+ model.generateContentStream(request)
275
+ .catch { error ->
276
+ if (isQuotaError(error)) {
277
+ quotaHit = true
278
+ } else if (isBackgroundError(error)) {
279
+ backgroundHit = true
280
+ } else {
281
+ throw error
282
+ }
283
+ }
284
+ .collect { chunk ->
285
+ val token = chunk.candidates.firstOrNull()?.text ?: ""
286
+ onToken(token)
287
+ }
288
+ } catch (error: Exception) {
289
+ if (isQuotaError(error)) {
290
+ quotaHit = true
291
+ } else if (isBackgroundError(error)) {
292
+ backgroundHit = true
293
+ } else {
294
+ throw error
295
+ }
296
+ }
297
+ if (!quotaHit && !backgroundHit) return false
298
+ if (quotaHit) {
299
+ if (quotaRetries >= MAX_QUOTA_RETRIES) return true
300
+ quotaRetries++
301
+ delay(QUOTA_RETRY_DELAY_MS)
302
+ } else {
303
+ if (backgroundRetries >= MAX_BACKGROUND_RETRIES) return false
304
+ backgroundRetries++
305
+ delay(BACKGROUND_RETRY_DELAY_MS)
306
+ }
307
+ }
308
+ }
309
+
310
+ return try {
311
+ collectWithLimit(REQUESTED_MAX_OUTPUT_TOKENS)
312
+ } catch (error: Exception) {
313
+ if (!isOutOfRangeError(error)) throw error
314
+ collectWithLimit(FALLBACK_MAX_OUTPUT_TOKENS)
315
+ }
316
+ }
317
+
86
318
  private fun sendEvent(name: String, params: WritableMap?) {
87
319
  reactApplicationContext
88
320
  .getJSModule(DeviceEventManagerModule.RCTDeviceEventEmitter::class.java)
89
321
  .emit(name, params)
90
322
  }
91
323
 
324
+ private fun startInferenceService() {
325
+ val intent = Intent(reactApplicationContext, InferenceService::class.java)
326
+ if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.O) {
327
+ reactApplicationContext.startForegroundService(intent)
328
+ } else {
329
+ reactApplicationContext.startService(intent)
330
+ }
331
+ }
332
+
333
+ private fun stopInferenceService() {
334
+ val intent = Intent(reactApplicationContext, InferenceService::class.java).apply {
335
+ action = InferenceService.ACTION_STOP
336
+ }
337
+ reactApplicationContext.startService(intent)
338
+ }
339
+
92
340
  private fun createErrorMap(code: String, message: String): WritableMap =
93
341
  Arguments.createMap().apply {
94
342
  putString("code", code)
@@ -96,7 +344,7 @@ class AiCoreModule(reactContext: ReactApplicationContext) :
96
344
  }
97
345
 
98
346
  private fun createMediaPipeSession(): LlmInferenceSession {
99
- val inference = llmInference ?: throw IllegalStateException("LLM no inicializado.")
347
+ val inference = llmInference ?: throw IllegalStateException("LLM not initialized.")
100
348
  val opts = LlmInferenceSession.LlmInferenceSessionOptions.builder()
101
349
  .setTemperature(DEFAULT_TEMPERATURE)
102
350
  .setTopK(DEFAULT_TOP_K)
@@ -104,6 +352,20 @@ class AiCoreModule(reactContext: ReactApplicationContext) :
104
352
  return LlmInferenceSession.createFromOptions(inference, opts)
105
353
  }
106
354
 
355
+ private fun buildPrompt(userPrompt: String, useConversationHistory: Boolean): String {
356
+ return if (useConversationHistory) buildContextualPrompt(userPrompt) else userPrompt
357
+ }
358
+
359
+ private fun maybeSaveToHistory(
360
+ userPrompt: String,
361
+ assistantResponse: String,
362
+ useConversationHistory: Boolean
363
+ ) {
364
+ if (useConversationHistory) {
365
+ saveToHistory(userPrompt, assistantResponse)
366
+ }
367
+ }
368
+
107
369
  override fun initialize(modelPath: String, promise: Promise) {
108
370
  mlkitModel = null
109
371
  llmInference?.close()
@@ -131,9 +393,9 @@ class AiCoreModule(reactContext: ReactApplicationContext) :
131
393
  }
132
394
  }
133
395
  }
134
- FeatureStatus.DOWNLOADING -> promise.reject("ALREADY_DOWNLOADING", "Ya se está descargando.")
135
- FeatureStatus.UNAVAILABLE -> promise.reject("AICORE_UNAVAILABLE", "Gemini Nano no disponible en este dispositivo.")
136
- else -> promise.reject("AICORE_UNKNOWN", "Estado desconocido.")
396
+ FeatureStatus.DOWNLOADING -> promise.reject("ALREADY_DOWNLOADING", "Model download already in progress.")
397
+ FeatureStatus.UNAVAILABLE -> promise.reject("AICORE_UNAVAILABLE", "Gemini Nano is not available on this device.")
398
+ else -> promise.reject("AICORE_UNKNOWN", "Unknown AICore status.")
137
399
  }
138
400
  } catch (e: Exception) {
139
401
  promise.reject("AICORE_ERROR", e.message, e)
@@ -143,7 +405,7 @@ class AiCoreModule(reactContext: ReactApplicationContext) :
143
405
  executor.execute {
144
406
  try {
145
407
  if (!File(modelPath).exists()) {
146
- promise.reject("MODEL_NOT_FOUND", "Modelo no encontrado en: $modelPath")
408
+ promise.reject("MODEL_NOT_FOUND", "Model file not found at: $modelPath")
147
409
  return@execute
148
410
  }
149
411
  val options = LlmInference.LlmInferenceOptions.builder()
@@ -165,105 +427,322 @@ class AiCoreModule(reactContext: ReactApplicationContext) :
165
427
  }
166
428
 
167
429
  override fun generateResponse(prompt: String, promise: Promise) {
430
+ generateResponseInternal(prompt, true, promise)
431
+ }
432
+
433
+ override fun generateResponseStateless(prompt: String, promise: Promise) {
434
+ generateResponseInternal(prompt, false, promise)
435
+ }
436
+
437
+ private fun generateResponseInternal(
438
+ prompt: String,
439
+ useConversationHistory: Boolean,
440
+ promise: Promise
441
+ ) {
168
442
  val mlkit = mlkitModel
169
443
  val mediapipe = llmInference
170
- val contextualPrompt = buildContextualPrompt(prompt)
444
+ startInferenceService()
445
+ cancelRequested = false
171
446
  when {
172
447
  mlkit != null -> coroutineScope.launch {
173
448
  try {
174
- val request = generateContentRequest(TextPart(contextualPrompt)) {
175
- maxOutputTokens = DEFAULT_MAX_OUTPUT_TOKENS
449
+ val rawTotal = StringBuilder()
450
+ val visibleTotal = StringBuilder()
451
+ var currentPrompt = buildPrompt(prompt, useConversationHistory)
452
+ var continuations = 0
453
+ var continuationJoinPending = false
454
+ var quotaRetries = 0
455
+ while (true) {
456
+ if (cancelRequested) {
457
+ stopInferenceService()
458
+ promise.reject("CANCELLED", "Generation cancelled.")
459
+ return@launch
460
+ }
461
+ val part = try {
462
+ generateMlKitChunk(mlkit, currentPrompt)
463
+ } catch (e: GenAiException) {
464
+ if (e.errorCode == QUOTA_ERROR_CODE) {
465
+ if (quotaRetries < MAX_NON_STREAM_QUOTA_RETRIES) {
466
+ quotaRetries++
467
+ delay(QUOTA_RETRY_DELAY_MS)
468
+ continue
469
+ }
470
+ if (rawTotal.isNotEmpty()) break
471
+ }
472
+ throw e
473
+ }
474
+ quotaRetries = 0
475
+ rawTotal.append(part)
476
+ if (containsEndMarker(rawTotal.toString())) break
477
+
478
+ val cleanPart = sanitizeVisibleText(stripEndMarker(part))
479
+ val partForUi = if (continuationJoinPending) {
480
+ continuationJoinPending = false
481
+ adjustChunkBoundary(visibleTotal.toString(), cleanPart)
482
+ } else {
483
+ cleanPart
484
+ }
485
+ visibleTotal.append(partForUi)
486
+ val visible = visibleTotal.toString()
487
+ if (
488
+ useConversationHistory &&
489
+ shouldContinueResponse(visible) &&
490
+ continuations < MAX_CONTINUATIONS
491
+ ) {
492
+ currentPrompt = buildContinuationPrompt(prompt, visible)
493
+ continuations++
494
+ continuationJoinPending = true
495
+ delay(CONTINUATION_DELAY_MS)
496
+ } else break
176
497
  }
177
- val response = mlkit.generateContent(request).candidates.firstOrNull()?.text ?: ""
178
- saveToHistory(prompt, response)
179
- promise.resolve(response)
498
+ val full = if (visibleTotal.isNotEmpty()) {
499
+ sanitizeVisibleText(visibleTotal.toString())
500
+ } else {
501
+ sanitizeVisibleText(stripEndMarker(rawTotal.toString()))
502
+ }
503
+ maybeSaveToHistory(prompt, full, useConversationHistory)
504
+ stopInferenceService()
505
+ promise.resolve(full)
180
506
  } catch (e: Exception) {
181
- promise.reject("GENERATION_ERROR", e.message, e)
507
+ if (cancelRequested) {
508
+ stopInferenceService()
509
+ promise.reject("CANCELLED", "Generation cancelled.")
510
+ } else {
511
+ stopInferenceService()
512
+ promise.reject("GENERATION_ERROR", e.message, e)
513
+ }
182
514
  }
183
- }
515
+ }.also { activeGenerationJob = it }
184
516
  mediapipe != null -> executor.execute {
185
517
  var session: LlmInferenceSession? = null
186
518
  try {
187
- session = createMediaPipeSession()
188
- session.addQueryChunk(contextualPrompt)
189
- val response = session.generateResponse()
190
- saveToHistory(prompt, response)
191
- promise.resolve(response)
519
+ val rawTotal = StringBuilder()
520
+ val visibleTotal = StringBuilder()
521
+ var currentPrompt = buildPrompt(prompt, useConversationHistory)
522
+ var continuations = 0
523
+ var continuationJoinPending = false
524
+ while (true) {
525
+ if (cancelRequested) {
526
+ stopInferenceService()
527
+ promise.reject("CANCELLED", "Generation cancelled.")
528
+ return@execute
529
+ }
530
+ session = createMediaPipeSession()
531
+ session.addQueryChunk(currentPrompt)
532
+ val part = session.generateResponse()
533
+ session.close()
534
+ session = null
535
+ rawTotal.append(part)
536
+ if (containsEndMarker(rawTotal.toString())) break
537
+
538
+ val cleanPart = sanitizeVisibleText(stripEndMarker(part))
539
+ val partForUi = if (continuationJoinPending) {
540
+ continuationJoinPending = false
541
+ adjustChunkBoundary(visibleTotal.toString(), cleanPart)
542
+ } else {
543
+ cleanPart
544
+ }
545
+ visibleTotal.append(partForUi)
546
+ val visible = visibleTotal.toString()
547
+ if (
548
+ useConversationHistory &&
549
+ shouldContinueResponse(visible) &&
550
+ continuations < MAX_CONTINUATIONS
551
+ ) {
552
+ currentPrompt = buildContinuationPrompt(prompt, visible)
553
+ continuations++
554
+ continuationJoinPending = true
555
+ } else break
556
+ }
557
+ val full = if (visibleTotal.isNotEmpty()) {
558
+ sanitizeVisibleText(visibleTotal.toString())
559
+ } else {
560
+ sanitizeVisibleText(stripEndMarker(rawTotal.toString()))
561
+ }
562
+ maybeSaveToHistory(prompt, full, useConversationHistory)
563
+ stopInferenceService()
564
+ promise.resolve(full)
192
565
  } catch (e: Exception) {
566
+ stopInferenceService()
193
567
  promise.reject("GENERATION_ERROR", e.message, e)
194
568
  } finally {
195
569
  session?.close()
196
570
  }
197
571
  }
198
- else -> promise.reject("NOT_INITIALIZED", "LLM no inicializado.")
572
+ else -> promise.reject("NOT_INITIALIZED", "LLM not initialized.")
199
573
  }
200
574
  }
201
575
 
202
576
  override fun generateResponseStream(prompt: String) {
203
577
  val mlkit = mlkitModel
204
578
  val mediapipe = llmInference
205
- val contextualPrompt = buildContextualPrompt(prompt)
579
+ startInferenceService()
580
+ cancelRequested = false
206
581
  when {
207
582
  mlkit != null -> coroutineScope.launch {
583
+ val total = StringBuilder()
584
+ val rawTotal = StringBuilder()
585
+ var currentPrompt = buildContextualPrompt(prompt)
586
+ var continuations = 0
587
+ var continuationJoinPending = false
588
+ var firstDeltaInPass = false
589
+ var idleRetries = 0
208
590
  var streamError = false
209
- val fullResponse = StringBuilder()
591
+ var markerReached = false
210
592
  try {
211
- val request = generateContentRequest(TextPart(contextualPrompt)) {
212
- maxOutputTokens = DEFAULT_MAX_OUTPUT_TOKENS
213
- }
214
- mlkit.generateContentStream(request)
215
- .catch { e ->
593
+ while (true) {
594
+ if (cancelRequested) break
595
+ val beforeLength = total.length
596
+ firstDeltaInPass = true
597
+ var quotaHit = false
598
+ try {
599
+ quotaHit = streamMlKitChunk(mlkit, currentPrompt) { token ->
600
+ rawTotal.append(token)
601
+ if (containsEndMarker(rawTotal.toString())) {
602
+ markerReached = true
603
+ }
604
+
605
+ val visibleNow = sanitizeVisibleText(stripEndMarker(rawTotal.toString()))
606
+ if (visibleNow.length > total.length) {
607
+ val delta = visibleNow.substring(total.length)
608
+ val adjustedDelta = if (continuationJoinPending && firstDeltaInPass) {
609
+ continuationJoinPending = false
610
+ firstDeltaInPass = false
611
+ adjustChunkBoundary(total.toString(), delta)
612
+ } else {
613
+ firstDeltaInPass = false
614
+ delta
615
+ }
616
+ if (adjustedDelta.isNotEmpty()) {
617
+ total.append(adjustedDelta)
618
+ emitStreamToken(adjustedDelta, false)
619
+ }
620
+ }
621
+ }
622
+ } catch (e: GenAiException) {
623
+ if (e.errorCode == QUOTA_ERROR_CODE && total.isNotEmpty()) {
624
+ quotaHit = true
625
+ } else {
626
+ throw e
627
+ }
628
+ } catch (e: Exception) {
629
+ if (cancelRequested) break
216
630
  streamError = true
631
+ stopInferenceService()
217
632
  sendEvent(EVENT_STREAM_ERROR, createErrorMap("STREAM_ERROR", e.message ?: "Error"))
633
+ break
218
634
  }
219
- .collect { chunk ->
220
- val token = chunk.candidates.firstOrNull()?.text ?: ""
221
- fullResponse.append(token)
222
- sendEvent(EVENT_STREAM_TOKEN, Arguments.createMap().apply {
223
- putString("token", token)
224
- putBoolean("done", false)
225
- })
635
+ if (streamError) return@launch
636
+ if (markerReached) break
637
+
638
+ val appendedNewText = total.length > beforeLength
639
+ if (!appendedNewText && shouldContinueResponse(total.toString()) && idleRetries < MAX_STREAM_IDLE_RETRIES) {
640
+ idleRetries++
641
+ currentPrompt = buildContinuationPrompt(prompt, total.toString())
642
+ continuationJoinPending = true
643
+ delay(if (quotaHit) QUOTA_RETRY_DELAY_MS else CONTINUATION_DELAY_MS)
644
+ continue
226
645
  }
227
- if (!streamError) {
228
- saveToHistory(prompt, fullResponse.toString())
229
- sendEvent(EVENT_STREAM_TOKEN, Arguments.createMap().apply {
230
- putString("token", "")
231
- putBoolean("done", true)
232
- })
233
- sendEvent(EVENT_STREAM_COMPLETE, Arguments.createMap())
646
+
647
+ idleRetries = 0
648
+ if (shouldContinueResponse(total.toString()) && continuations < MAX_CONTINUATIONS) {
649
+ currentPrompt = buildContinuationPrompt(prompt, total.toString())
650
+ continuations++
651
+ continuationJoinPending = true
652
+ delay(if (quotaHit) QUOTA_RETRY_DELAY_MS else CONTINUATION_DELAY_MS)
653
+ } else break
234
654
  }
655
+ if (!cancelRequested) saveToHistory(prompt, sanitizeVisibleText(total.toString()))
656
+ emitStreamToken("", true)
657
+ stopInferenceService()
658
+ sendEvent(EVENT_STREAM_COMPLETE, Arguments.createMap())
235
659
  } catch (e: Exception) {
236
- if (!streamError) {
660
+ if (cancelRequested) {
661
+ emitStreamToken("", true)
662
+ stopInferenceService()
663
+ sendEvent(EVENT_STREAM_COMPLETE, Arguments.createMap())
664
+ } else if (!streamError) {
665
+ stopInferenceService()
237
666
  sendEvent(EVENT_STREAM_ERROR, createErrorMap("STREAM_ERROR", e.message ?: "Error"))
238
667
  }
239
668
  }
240
- }
669
+ }.also { activeGenerationJob = it }
241
670
  mediapipe != null -> executor.execute {
242
- val fullResponse = StringBuilder()
671
+ val total = StringBuilder()
672
+ val rawTotal = StringBuilder()
673
+ var currentPrompt = buildContextualPrompt(prompt)
674
+ var continuations = 0
675
+ var continuationJoinPending = false
676
+ var firstDeltaInPass = false
243
677
  var session: LlmInferenceSession? = null
678
+ var markerReached = false
244
679
  try {
245
- session = createMediaPipeSession()
246
- val capturedSession = session
247
- session.addQueryChunk(contextualPrompt)
248
- session.generateResponseAsync(ProgressListener<String> { partial, done ->
249
- val token = partial ?: ""
250
- fullResponse.append(token)
251
- sendEvent(EVENT_STREAM_TOKEN, Arguments.createMap().apply {
252
- putString("token", token)
253
- putBoolean("done", done)
680
+ while (true) {
681
+ if (cancelRequested) break
682
+ val latch = CountDownLatch(1)
683
+ firstDeltaInPass = true
684
+ session = createMediaPipeSession()
685
+ val capturedSession = session
686
+ session.addQueryChunk(currentPrompt)
687
+ session.generateResponseAsync(ProgressListener<String> { partial, done ->
688
+ if (cancelRequested) {
689
+ capturedSession.close()
690
+ latch.countDown()
691
+ return@ProgressListener
692
+ }
693
+ val token = partial ?: ""
694
+ rawTotal.append(token)
695
+ if (containsEndMarker(rawTotal.toString())) {
696
+ markerReached = true
697
+ }
698
+
699
+ val visibleNow = sanitizeVisibleText(stripEndMarker(rawTotal.toString()))
700
+ if (visibleNow.length > total.length) {
701
+ val delta = visibleNow.substring(total.length)
702
+ val adjustedDelta = if (continuationJoinPending && firstDeltaInPass) {
703
+ continuationJoinPending = false
704
+ firstDeltaInPass = false
705
+ adjustChunkBoundary(total.toString(), delta)
706
+ } else {
707
+ firstDeltaInPass = false
708
+ delta
709
+ }
710
+ if (adjustedDelta.isNotEmpty()) {
711
+ total.append(adjustedDelta)
712
+ emitStreamToken(adjustedDelta, false)
713
+ }
714
+ }
715
+ if (done) {
716
+ capturedSession.close()
717
+ latch.countDown()
718
+ }
254
719
  })
255
- if (done) {
256
- saveToHistory(prompt, fullResponse.toString())
257
- sendEvent(EVENT_STREAM_COMPLETE, Arguments.createMap())
258
- capturedSession.close()
259
- }
260
- })
720
+ latch.await()
721
+ session = null
722
+ if (markerReached) break
723
+ if (shouldContinueResponse(total.toString()) && continuations < MAX_CONTINUATIONS) {
724
+ currentPrompt = buildContinuationPrompt(prompt, total.toString())
725
+ continuations++
726
+ continuationJoinPending = true
727
+ } else break
728
+ }
729
+ if (!cancelRequested) saveToHistory(prompt, sanitizeVisibleText(total.toString()))
730
+ emitStreamToken("", true)
731
+ stopInferenceService()
732
+ sendEvent(EVENT_STREAM_COMPLETE, Arguments.createMap())
261
733
  } catch (e: Exception) {
262
734
  session?.close()
263
- sendEvent(EVENT_STREAM_ERROR, createErrorMap("STREAM_ERROR", e.message ?: "Error"))
735
+ if (cancelRequested) {
736
+ emitStreamToken("", true)
737
+ stopInferenceService()
738
+ sendEvent(EVENT_STREAM_COMPLETE, Arguments.createMap())
739
+ } else {
740
+ stopInferenceService()
741
+ sendEvent(EVENT_STREAM_ERROR, createErrorMap("STREAM_ERROR", e.message ?: "Error"))
742
+ }
264
743
  }
265
744
  }
266
- else -> sendEvent(EVENT_STREAM_ERROR, createErrorMap("NOT_INITIALIZED", "LLM no inicializado."))
745
+ else -> sendEvent(EVENT_STREAM_ERROR, createErrorMap("NOT_INITIALIZED", "LLM not initialized."))
267
746
  }
268
747
  }
269
748
 
@@ -311,12 +790,19 @@ class AiCoreModule(reactContext: ReactApplicationContext) :
311
790
  promise.resolve(null)
312
791
  }
313
792
 
793
+ override fun cancelGeneration(promise: Promise) {
794
+ cancelRequested = true
795
+ activeGenerationJob?.cancel()
796
+ promise.resolve(null)
797
+ }
798
+
314
799
  override fun addListener(eventName: String) {}
315
800
  override fun removeListeners(count: Double) {}
316
801
 
317
802
  override fun invalidate() {
318
803
  super.invalidate()
319
804
  try {
805
+ stopInferenceService()
320
806
  llmInference?.close()
321
807
  llmInference = null
322
808
  mlkitModel = null