solana-agent 20.1.2__py3-none-any.whl → 31.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. solana_agent/__init__.py +10 -5
  2. solana_agent/adapters/ffmpeg_transcoder.py +375 -0
  3. solana_agent/adapters/mongodb_adapter.py +15 -2
  4. solana_agent/adapters/openai_adapter.py +679 -0
  5. solana_agent/adapters/openai_realtime_ws.py +1813 -0
  6. solana_agent/adapters/pinecone_adapter.py +543 -0
  7. solana_agent/cli.py +128 -0
  8. solana_agent/client/solana_agent.py +180 -20
  9. solana_agent/domains/agent.py +13 -13
  10. solana_agent/domains/routing.py +18 -8
  11. solana_agent/factories/agent_factory.py +239 -38
  12. solana_agent/guardrails/pii.py +107 -0
  13. solana_agent/interfaces/client/client.py +95 -12
  14. solana_agent/interfaces/guardrails/guardrails.py +26 -0
  15. solana_agent/interfaces/plugins/plugins.py +2 -1
  16. solana_agent/interfaces/providers/__init__.py +0 -0
  17. solana_agent/interfaces/providers/audio.py +40 -0
  18. solana_agent/interfaces/providers/data_storage.py +9 -2
  19. solana_agent/interfaces/providers/llm.py +86 -9
  20. solana_agent/interfaces/providers/memory.py +13 -1
  21. solana_agent/interfaces/providers/realtime.py +212 -0
  22. solana_agent/interfaces/providers/vector_storage.py +53 -0
  23. solana_agent/interfaces/services/agent.py +27 -12
  24. solana_agent/interfaces/services/knowledge_base.py +59 -0
  25. solana_agent/interfaces/services/query.py +41 -8
  26. solana_agent/interfaces/services/routing.py +0 -1
  27. solana_agent/plugins/manager.py +37 -16
  28. solana_agent/plugins/registry.py +34 -19
  29. solana_agent/plugins/tools/__init__.py +0 -5
  30. solana_agent/plugins/tools/auto_tool.py +1 -0
  31. solana_agent/repositories/memory.py +332 -111
  32. solana_agent/services/__init__.py +1 -1
  33. solana_agent/services/agent.py +390 -241
  34. solana_agent/services/knowledge_base.py +768 -0
  35. solana_agent/services/query.py +1858 -153
  36. solana_agent/services/realtime.py +626 -0
  37. solana_agent/services/routing.py +104 -51
  38. solana_agent-31.4.0.dist-info/METADATA +1070 -0
  39. solana_agent-31.4.0.dist-info/RECORD +49 -0
  40. {solana_agent-20.1.2.dist-info → solana_agent-31.4.0.dist-info}/WHEEL +1 -1
  41. solana_agent-31.4.0.dist-info/entry_points.txt +3 -0
  42. solana_agent/adapters/llm_adapter.py +0 -160
  43. solana_agent-20.1.2.dist-info/METADATA +0 -464
  44. solana_agent-20.1.2.dist-info/RECORD +0 -35
  45. {solana_agent-20.1.2.dist-info → solana_agent-31.4.0.dist-info/licenses}/LICENSE +0 -0
@@ -5,13 +5,44 @@ This service orchestrates the processing of user queries, coordinating
5
5
  other services to provide comprehensive responses while maintaining
6
6
  clean separation of concerns.
7
7
  """
8
- from typing import Any, AsyncGenerator, Dict, Literal, Optional, Union
9
8
 
9
+ import logging
10
+ import asyncio
11
+ import re
12
+ import time
13
+ from typing import (
14
+ Any,
15
+ AsyncGenerator,
16
+ Dict,
17
+ List,
18
+ Literal,
19
+ Optional,
20
+ Type,
21
+ Union,
22
+ Tuple,
23
+ )
24
+
25
+ from pydantic import BaseModel
26
+
27
+ # Interface imports
10
28
  from solana_agent.interfaces.services.query import QueryService as QueryServiceInterface
11
- from solana_agent.interfaces.services.routing import RoutingService as RoutingServiceInterface
29
+ from solana_agent.interfaces.services.routing import (
30
+ RoutingService as RoutingServiceInterface,
31
+ )
32
+ from solana_agent.interfaces.providers.memory import (
33
+ MemoryProvider as MemoryProviderInterface,
34
+ )
35
+ from solana_agent.interfaces.services.knowledge_base import (
36
+ KnowledgeBaseService as KnowledgeBaseInterface,
37
+ )
38
+ from solana_agent.interfaces.guardrails.guardrails import InputGuardrail
39
+
40
+ from solana_agent.interfaces.providers.realtime import RealtimeSessionOptions
41
+
12
42
  from solana_agent.services.agent import AgentService
13
43
  from solana_agent.services.routing import RoutingService
14
- from solana_agent.interfaces.providers.memory import MemoryProvider
44
+
45
+ logger = logging.getLogger(__name__)
15
46
 
16
47
 
17
48
  class QueryService(QueryServiceInterface):
@@ -21,107 +52,1662 @@ class QueryService(QueryServiceInterface):
21
52
  self,
22
53
  agent_service: AgentService,
23
54
  routing_service: RoutingService,
24
- memory_provider: Optional[MemoryProvider] = None,
55
+ memory_provider: Optional[MemoryProviderInterface] = None,
56
+ knowledge_base: Optional[KnowledgeBaseInterface] = None,
57
+ kb_results_count: int = 3,
58
+ input_guardrails: List[InputGuardrail] = None,
25
59
  ):
26
- """Initialize the query service.
27
-
28
- Args:
29
- agent_service: Service for AI agent management
30
- routing_service: Service for routing queries to appropriate agents
31
- memory_provider: Optional provider for memory storage and retrieval
32
- """
60
+ """Initialize the query service."""
33
61
  self.agent_service = agent_service
34
62
  self.routing_service = routing_service
35
63
  self.memory_provider = memory_provider
64
+ self.knowledge_base = knowledge_base
65
+ self.kb_results_count = kb_results_count
66
+ self.input_guardrails = input_guardrails or []
67
+ # Per-user sticky sessions (in-memory)
68
+ # { user_id: { 'agent': str, 'started_at': float, 'last_updated': float, 'required_complete': bool } }
69
+ self._sticky_sessions: Dict[str, Dict[str, Any]] = {}
70
+ # Optional realtime service attached by factory (populated in factory)
71
+ self.realtime = None # type: ignore[attr-defined]
72
+ # Persistent realtime WS pool per user for reuse across turns/devices
73
+ # { user_id: [RealtimeService, ...] }
74
+ self._rt_services: Dict[str, List[Any]] = {}
75
+ # Global lock for creating/finding per-user sessions
76
+ self._rt_lock = asyncio.Lock()
77
+
78
+ async def _try_acquire_lock(self, lock: asyncio.Lock) -> bool:
79
+ try:
80
+ await asyncio.wait_for(lock.acquire(), timeout=0)
81
+ return True
82
+ except asyncio.TimeoutError:
83
+ return False
84
+ except Exception:
85
+ return False
86
+
87
+ async def _alloc_realtime_session(
88
+ self,
89
+ user_id: str,
90
+ *,
91
+ api_key: str,
92
+ rt_voice: str,
93
+ final_instructions: str,
94
+ initial_tools: Optional[List[Dict[str, Any]]],
95
+ encode_in: bool,
96
+ encode_out: bool,
97
+ audio_input_format: str,
98
+ audio_output_format: str,
99
+ rt_output_modalities: Optional[List[Literal["audio", "text"]]] = None,
100
+ ) -> Any:
101
+ """Get a free (or new) realtime session for this user. Marks it busy via an internal lock.
102
+
103
+ Returns the RealtimeService with an acquired _in_use_lock that MUST be released by caller.
104
+ """
105
+ from solana_agent.interfaces.providers.realtime import (
106
+ RealtimeSessionOptions,
107
+ )
108
+ from solana_agent.adapters.openai_realtime_ws import (
109
+ OpenAIRealtimeWebSocketSession,
110
+ )
111
+ from solana_agent.adapters.ffmpeg_transcoder import FFmpegTranscoder
112
+
113
+ def _mime_from(fmt: str) -> str:
114
+ f = (fmt or "").lower()
115
+ return {
116
+ "aac": "audio/aac",
117
+ "mp3": "audio/mpeg",
118
+ "mp4": "audio/mp4",
119
+ "m4a": "audio/mp4",
120
+ "mpeg": "audio/mpeg",
121
+ "mpga": "audio/mpeg",
122
+ "wav": "audio/wav",
123
+ "flac": "audio/flac",
124
+ "opus": "audio/opus",
125
+ "ogg": "audio/ogg",
126
+ "webm": "audio/webm",
127
+ "pcm": "audio/pcm",
128
+ }.get(f, "audio/pcm")
129
+
130
+ async with self._rt_lock:
131
+ pool = self._rt_services.get(user_id) or []
132
+ # Try to reuse an idle session strictly owned by this user
133
+ for rt in pool:
134
+ # Extra safety: never reuse a session from another user
135
+ owner = getattr(rt, "_owner_user_id", None)
136
+ if owner is not None and owner != user_id:
137
+ continue
138
+ lock = getattr(rt, "_in_use_lock", None)
139
+ if lock is None:
140
+ lock = asyncio.Lock()
141
+ setattr(rt, "_in_use_lock", lock)
142
+ if not lock.locked():
143
+ if await self._try_acquire_lock(lock):
144
+ return rt
145
+ # None free: create a new session
146
+ opts = RealtimeSessionOptions(
147
+ model="gpt-realtime",
148
+ voice=rt_voice,
149
+ vad_enabled=False,
150
+ input_rate_hz=24000,
151
+ output_rate_hz=24000,
152
+ input_mime="audio/pcm",
153
+ output_mime="audio/pcm",
154
+ output_modalities=rt_output_modalities,
155
+ tools=initial_tools or None,
156
+ tool_choice="auto",
157
+ )
158
+ try:
159
+ opts.instructions = final_instructions
160
+ opts.voice = rt_voice
161
+ except Exception:
162
+ pass
163
+ conv_session = OpenAIRealtimeWebSocketSession(api_key=api_key, options=opts)
164
+ transcoder = FFmpegTranscoder() if (encode_in or encode_out) else None
165
+ from solana_agent.services.realtime import RealtimeService
166
+
167
+ rt = RealtimeService(
168
+ session=conv_session,
169
+ options=opts,
170
+ transcoder=transcoder,
171
+ accept_compressed_input=encode_in,
172
+ client_input_mime=_mime_from(audio_input_format),
173
+ encode_output=encode_out,
174
+ client_output_mime=_mime_from(audio_output_format),
175
+ )
176
+ # Tag ownership to prevent any cross-user reuse
177
+ setattr(rt, "_owner_user_id", user_id)
178
+ setattr(rt, "_in_use_lock", asyncio.Lock())
179
+ # Mark busy
180
+ await getattr(rt, "_in_use_lock").acquire()
181
+ pool.append(rt)
182
+ self._rt_services[user_id] = pool
183
+ return rt
184
+
185
+ def _get_sticky_agent(self, user_id: str) -> Optional[str]:
186
+ sess = self._sticky_sessions.get(user_id)
187
+ return sess.get("agent") if isinstance(sess, dict) else None
188
+
189
+ def _set_sticky_agent(
190
+ self, user_id: str, agent_name: str, required_complete: bool = False
191
+ ) -> None:
192
+ self._sticky_sessions[user_id] = {
193
+ "agent": agent_name,
194
+ "started_at": time.time(),
195
+ "last_updated": time.time(),
196
+ "required_complete": required_complete,
197
+ }
198
+
199
+ def _update_sticky_required_complete(
200
+ self, user_id: str, required_complete: bool
201
+ ) -> None:
202
+ if user_id in self._sticky_sessions:
203
+ self._sticky_sessions[user_id]["required_complete"] = required_complete
204
+ self._sticky_sessions[user_id]["last_updated"] = time.time()
205
+
206
+ def _clear_sticky_agent(self, user_id: str) -> None:
207
+ if user_id in self._sticky_sessions:
208
+ try:
209
+ del self._sticky_sessions[user_id]
210
+ except Exception:
211
+ pass
212
+
213
+ async def _build_combined_context(
214
+ self,
215
+ user_id: str,
216
+ user_text: str,
217
+ agent_name: str,
218
+ capture_name: Optional[str] = None,
219
+ capture_schema: Optional[Dict[str, Any]] = None,
220
+ prev_assistant: str = "",
221
+ ) -> Tuple[str, bool]:
222
+ """Build combined context string and return required_complete flag."""
223
+ # Memory context
224
+ memory_context = ""
225
+ if self.memory_provider:
226
+ try:
227
+ memory_context = await self.memory_provider.retrieve(user_id)
228
+ except Exception:
229
+ memory_context = ""
230
+
231
+ # KB context
232
+ kb_context = ""
233
+ if self.knowledge_base:
234
+ try:
235
+ kb_results = await self.knowledge_base.query(
236
+ query_text=user_text,
237
+ top_k=self.kb_results_count,
238
+ include_content=True,
239
+ include_metadata=False,
240
+ )
241
+ if kb_results:
242
+ kb_lines = [
243
+ "**KNOWLEDGE BASE (CRITICAL: MAKE THIS INFORMATION THE TOP PRIORITY):**"
244
+ ]
245
+ for i, r in enumerate(kb_results, 1):
246
+ kb_lines.append(f"[{i}] {r.get('content', '').strip()}\n")
247
+ kb_context = "\n".join(kb_lines)
248
+ except Exception:
249
+ kb_context = ""
250
+
251
+ # Capture context
252
+ capture_context = ""
253
+ required_complete = False
254
+
255
+ active_capture_name = capture_name
256
+ active_capture_schema = capture_schema
257
+ if not active_capture_name or not active_capture_schema:
258
+ try:
259
+ cap_cfg = self.agent_service.get_agent_capture(agent_name)
260
+ if cap_cfg:
261
+ active_capture_name = active_capture_name or cap_cfg.get("name")
262
+ active_capture_schema = active_capture_schema or cap_cfg.get(
263
+ "schema"
264
+ )
265
+ except Exception:
266
+ pass
267
+
268
+ if active_capture_name and isinstance(active_capture_schema, dict):
269
+ latest_by_name: Dict[str, Dict[str, Any]] = {}
270
+ if self.memory_provider:
271
+ try:
272
+ docs = self.memory_provider.find(
273
+ collection="captures",
274
+ query={"user_id": user_id},
275
+ sort=[("timestamp", -1)],
276
+ limit=100,
277
+ )
278
+ for d in docs or []:
279
+ name = (d or {}).get("capture_name")
280
+ if not name or name in latest_by_name:
281
+ continue
282
+ latest_by_name[name] = {
283
+ "data": (d or {}).get("data", {}) or {},
284
+ "mode": (d or {}).get("mode", "once"),
285
+ "agent": (d or {}).get("agent_name"),
286
+ }
287
+ except Exception:
288
+ pass
289
+
290
+ active_data = (latest_by_name.get(active_capture_name, {}) or {}).get(
291
+ "data", {}
292
+ ) or {}
293
+
294
+ def _non_empty(v: Any) -> bool:
295
+ if v is None:
296
+ return False
297
+ if isinstance(v, str):
298
+ s = v.strip().lower()
299
+ return s not in {"", "null", "none", "n/a", "na", "undefined", "."}
300
+ if isinstance(v, (list, dict, tuple, set)):
301
+ return len(v) > 0
302
+ return True
303
+
304
+ props = (active_capture_schema or {}).get("properties", {})
305
+ required_fields = list(
306
+ (active_capture_schema or {}).get("required", []) or []
307
+ )
308
+ all_fields = list(props.keys())
309
+ optional_fields = [f for f in all_fields if f not in set(required_fields)]
310
+
311
+ def _missing_from(data: Dict[str, Any], fields: List[str]) -> List[str]:
312
+ return [f for f in fields if not _non_empty(data.get(f))]
313
+
314
+ missing_required = _missing_from(active_data, required_fields)
315
+ missing_optional = _missing_from(active_data, optional_fields)
316
+
317
+ required_complete = len(missing_required) == 0 and len(required_fields) > 0
318
+
319
+ lines: List[str] = []
320
+ lines.append(
321
+ "CAPTURED FORM STATE (Authoritative; do not re-ask filled values):"
322
+ )
323
+ lines.append(f"- form_name: {active_capture_name}")
324
+
325
+ if active_data:
326
+ pairs = [f"{k}: {v}" for k, v in active_data.items() if _non_empty(v)]
327
+ lines.append(
328
+ f"- filled_fields: {', '.join(pairs) if pairs else '(none)'}"
329
+ )
330
+ else:
331
+ lines.append("- filled_fields: (none)")
332
+
333
+ lines.append(
334
+ f"- missing_required_fields: {', '.join(missing_required) if missing_required else '(none)'}"
335
+ )
336
+ lines.append(
337
+ f"- missing_optional_fields: {', '.join(missing_optional) if missing_optional else '(none)'}"
338
+ )
339
+ lines.append("")
340
+
341
+ if latest_by_name:
342
+ lines.append("OTHER CAPTURED USER DATA (for reference):")
343
+ for cname, info in latest_by_name.items():
344
+ if cname == active_capture_name:
345
+ continue
346
+ data = info.get("data", {}) or {}
347
+ if data:
348
+ pairs = [f"{k}: {v}" for k, v in data.items() if _non_empty(v)]
349
+ lines.append(
350
+ f"- {cname}: {', '.join(pairs) if pairs else '(none)'}"
351
+ )
352
+ else:
353
+ lines.append(f"- {cname}: (none)")
354
+
355
+ if lines:
356
+ capture_context = "\n".join(lines) + "\n\n"
357
+
358
+ # Merge contexts
359
+ combined_context = ""
360
+ if capture_context:
361
+ combined_context += capture_context
362
+ if memory_context:
363
+ combined_context += f"CONVERSATION HISTORY (Use for continuity; not authoritative for facts):\n{memory_context}\n\n"
364
+ if kb_context:
365
+ combined_context += kb_context + "\n"
366
+
367
+ guide = (
368
+ "PRIORITIZATION GUIDE:\n"
369
+ "- Prefer Captured User Data for user-specific fields.\n"
370
+ "- Prefer KB/tools for facts.\n"
371
+ "- History is for tone and continuity.\n\n"
372
+ "FORM FLOW RULES:\n"
373
+ "- Ask exactly one field per turn.\n"
374
+ "- If any required fields are missing, ask the next missing required field.\n"
375
+ "- If all required fields are filled but optional fields are missing, ask the next missing optional field.\n"
376
+ "- Do NOT re-ask or verify values present in Captured User Data (auto-saved, authoritative).\n"
377
+ "- Do NOT provide summaries until no required or optional fields are missing.\n\n"
378
+ )
379
+
380
+ if combined_context:
381
+ combined_context += guide
382
+ else:
383
+ # Diagnostics for why the context is empty
384
+ try:
385
+ logger.debug(
386
+ "_build_combined_context: empty sources — memory_provider=%s, knowledge_base=%s, active_capture=%s",
387
+ bool(self.memory_provider),
388
+ bool(self.knowledge_base),
389
+ bool(
390
+ active_capture_name and isinstance(active_capture_schema, dict)
391
+ ),
392
+ )
393
+ except Exception:
394
+ pass
395
+ # Provide minimal guide so realtime instructions are not blank
396
+ combined_context = guide
397
+
398
+ return combined_context, required_complete
399
+
400
+ # LLM-backed switch intent detection (gpt-4.1-mini)
401
+ class _SwitchIntentModel(BaseModel):
402
+ switch: bool = False
403
+ target_agent: Optional[str] = None
404
+ start_new: bool = False
405
+
406
+ async def _detect_switch_intent(
407
+ self, text: str, available_agents: List[str]
408
+ ) -> Tuple[bool, Optional[str], bool]:
409
+ """Detect if the user is asking to switch agents or start a new conversation.
410
+
411
+ Returns: (switch_requested, target_agent_name_or_none, start_new_conversation)
412
+ Implemented as an LLM call to gpt-4.1-mini with structured output.
413
+ """
414
+ if not text:
415
+ return (False, None, False)
416
+
417
+ # Instruction and user prompt for the classifier
418
+ instruction = (
419
+ "You are a strict intent classifier for agent routing. "
420
+ "Decide if the user's message requests switching to another agent or starting a new conversation. "
421
+ "Only return JSON with keys: switch (bool), target_agent (string|null), start_new (bool). "
422
+ "If a target agent is mentioned, it MUST be one of the provided agent names (case-insensitive). "
423
+ "If none clearly applies, set switch=false and start_new=false and target_agent=null."
424
+ )
425
+ user_prompt = (
426
+ f"Available agents (choose only from these if a target is specified): {available_agents}\n\n"
427
+ f"User message:\n{text}\n\n"
428
+ 'Return JSON only, like: {"switch": true|false, "target_agent": "<one_of_available_or_null>", "start_new": true|false}'
429
+ )
430
+
431
+ # Primary: use llm_provider.parse_structured_output
432
+ try:
433
+ if hasattr(self.agent_service.llm_provider, "parse_structured_output"):
434
+ try:
435
+ result = (
436
+ await self.agent_service.llm_provider.parse_structured_output(
437
+ prompt=user_prompt,
438
+ system_prompt=instruction,
439
+ model_class=QueryService._SwitchIntentModel,
440
+ model="gpt-4.1-mini",
441
+ )
442
+ )
443
+ except TypeError:
444
+ # Provider may not accept 'model' kwarg
445
+ result = (
446
+ await self.agent_service.llm_provider.parse_structured_output(
447
+ prompt=user_prompt,
448
+ system_prompt=instruction,
449
+ model_class=QueryService._SwitchIntentModel,
450
+ )
451
+ )
452
+ switch = bool(getattr(result, "switch", False))
453
+ target = getattr(result, "target_agent", None)
454
+ start_new = bool(getattr(result, "start_new", False))
455
+ # Normalize target to available agent name
456
+ if target:
457
+ target_lower = target.lower()
458
+ norm = None
459
+ for a in available_agents:
460
+ if a.lower() == target_lower or target_lower in a.lower():
461
+ norm = a
462
+ break
463
+ target = norm
464
+ if not switch:
465
+ target = None
466
+ return (switch, target, start_new)
467
+ except Exception as e:
468
+ logger.debug(f"LLM switch intent parse_structured_output failed: {e}")
469
+
470
+ # Fallback: generate_response with output_model
471
+ try:
472
+ async for r in self.agent_service.generate_response(
473
+ agent_name="default",
474
+ user_id="router",
475
+ query="",
476
+ images=None,
477
+ memory_context="",
478
+ output_format="text",
479
+ prompt=f"{instruction}\n\n{user_prompt}",
480
+ output_model=QueryService._SwitchIntentModel,
481
+ ):
482
+ result = r
483
+ switch = False
484
+ target = None
485
+ start_new = False
486
+ try:
487
+ switch = bool(result.switch) # type: ignore[attr-defined]
488
+ target = result.target_agent # type: ignore[attr-defined]
489
+ start_new = bool(result.start_new) # type: ignore[attr-defined]
490
+ except Exception:
491
+ try:
492
+ d = result.model_dump()
493
+ switch = bool(d.get("switch", False))
494
+ target = d.get("target_agent")
495
+ start_new = bool(d.get("start_new", False))
496
+ except Exception:
497
+ pass
498
+ if target:
499
+ target_lower = str(target).lower()
500
+ norm = None
501
+ for a in available_agents:
502
+ if a.lower() == target_lower or target_lower in a.lower():
503
+ norm = a
504
+ break
505
+ target = norm
506
+ if not switch:
507
+ target = None
508
+ return (switch, target, start_new)
509
+ except Exception as e:
510
+ logger.debug(f"LLM switch intent generate_response failed: {e}")
511
+
512
+ # Last resort: no switch
513
+ return (False, None, False)
36
514
 
37
515
  async def process(
38
516
  self,
39
517
  user_id: str,
40
518
  query: Union[str, bytes],
519
+ images: Optional[List[Union[str, bytes]]] = None,
41
520
  output_format: Literal["text", "audio"] = "text",
42
- audio_voice: Literal["alloy", "ash", "ballad", "coral", "echo",
43
- "fable", "onyx", "nova", "sage", "shimmer"] = "nova",
44
- audio_instructions: Optional[str] = None,
45
- audio_output_format: Literal['mp3', 'opus',
46
- 'aac', 'flac', 'wav', 'pcm'] = "aac",
521
+ realtime: bool = False,
522
+ # Realtime minimal controls (voice/format come from audio_* args)
523
+ vad: Optional[bool] = None,
524
+ rt_encode_input: bool = False,
525
+ rt_encode_output: bool = False,
526
+ rt_output_modalities: Optional[List[Literal["audio", "text"]]] = None,
527
+ rt_voice: Literal[
528
+ "alloy",
529
+ "ash",
530
+ "ballad",
531
+ "cedar",
532
+ "coral",
533
+ "echo",
534
+ "marin",
535
+ "sage",
536
+ "shimmer",
537
+ "verse",
538
+ ] = "marin",
539
+ # Realtime transcription configuration (new)
540
+ rt_transcription_model: Optional[str] = None,
541
+ rt_transcription_language: Optional[str] = None,
542
+ rt_transcription_prompt: Optional[str] = None,
543
+ rt_transcription_noise_reduction: Optional[bool] = None,
544
+ rt_transcription_include_logprobs: bool = False,
545
+ audio_voice: Literal[
546
+ "alloy",
547
+ "ash",
548
+ "ballad",
549
+ "coral",
550
+ "echo",
551
+ "fable",
552
+ "onyx",
553
+ "nova",
554
+ "sage",
555
+ "shimmer",
556
+ ] = "nova",
557
+ audio_output_format: Literal[
558
+ "mp3", "opus", "aac", "flac", "wav", "pcm"
559
+ ] = "aac",
47
560
  audio_input_format: Literal[
48
561
  "flac", "mp3", "mp4", "mpeg", "mpga", "m4a", "ogg", "wav", "webm"
49
562
  ] = "mp4",
50
563
  prompt: Optional[str] = None,
51
564
  router: Optional[RoutingServiceInterface] = None,
52
- ) -> AsyncGenerator[Union[str, bytes], None]: # pragma: no cover
53
- """Process the user request with appropriate agent.
54
-
55
- Args:
56
- user_id: User ID
57
- query: Text query or audio bytes
58
- output_format: Response format ("text" or "audio")
59
- audio_voice: Voice for TTS (text-to-speech)
60
- audio_instructions: Optional instructions for TTS
61
- audio_output_format: Audio output format
62
- audio_input_format: Audio input format
63
- prompt: Optional prompt for the agent
64
- router: Optional routing service for processing
65
-
66
- Yields:
67
- Response chunks (text strings or audio bytes)
68
- """
565
+ output_model: Optional[Type[BaseModel]] = None,
566
+ capture_schema: Optional[Dict[str, Any]] = None,
567
+ capture_name: Optional[str] = None,
568
+ ) -> AsyncGenerator[Union[str, bytes, BaseModel], None]: # pragma: no cover
569
+ """Process the user request and generate a response."""
69
570
  try:
70
- # Handle audio input if provided
571
+ # Realtime request: HTTP STT for user + single WS for assistant audio
572
+ if realtime:
573
+ # 1) Determine if input is audio bytes. We now ALWAYS skip HTTP STT in realtime mode.
574
+ # The realtime websocket session (optionally with built-in transcription) is authoritative.
575
+ is_audio_bytes = isinstance(query, (bytes, bytearray))
576
+ user_text = "" if is_audio_bytes else str(query)
577
+ # Provide a sensible default realtime transcription model when audio supplied
578
+ if is_audio_bytes and not rt_transcription_model:
579
+ rt_transcription_model = "gpt-4o-mini-transcribe"
580
+
581
+ # 2) Single agent selection (no multi-agent routing in realtime path)
582
+ agent_name = self._get_sticky_agent(user_id)
583
+ if not agent_name:
584
+ try:
585
+ agents = self.agent_service.get_all_ai_agents() or {}
586
+ agent_name = next(iter(agents.keys())) if agents else "default"
587
+ except Exception:
588
+ agent_name = "default"
589
+ prev_assistant = ""
590
+ if self.memory_provider:
591
+ try:
592
+ prev_docs = self.memory_provider.find(
593
+ collection="conversations",
594
+ query={"user_id": user_id},
595
+ sort=[("timestamp", -1)],
596
+ limit=1,
597
+ )
598
+ if prev_docs:
599
+ prev_assistant = (prev_docs[0] or {}).get(
600
+ "assistant_message", ""
601
+ ) or ""
602
+ except Exception:
603
+ pass
604
+
605
+ # 3) Build context + tools
606
+ combined_ctx = ""
607
+ required_complete = False
608
+ try:
609
+ (
610
+ combined_ctx,
611
+ required_complete,
612
+ ) = await self._build_combined_context(
613
+ user_id=user_id,
614
+ user_text=(user_text if not is_audio_bytes else ""),
615
+ agent_name=agent_name,
616
+ capture_name=capture_name,
617
+ capture_schema=capture_schema,
618
+ prev_assistant=prev_assistant,
619
+ )
620
+ try:
621
+ self._update_sticky_required_complete(
622
+ user_id, required_complete
623
+ )
624
+ except Exception:
625
+ pass
626
+ except Exception:
627
+ combined_ctx = ""
628
+ try:
629
+ # GA Realtime expects flattened tool definitions (no nested "function" object)
630
+ initial_tools = [
631
+ {
632
+ "type": "function",
633
+ "name": t["name"],
634
+ "description": t.get("description", ""),
635
+ "parameters": t.get("parameters", {}),
636
+ "strict": True,
637
+ }
638
+ for t in self.agent_service.get_agent_tools(agent_name)
639
+ ]
640
+ except Exception:
641
+ initial_tools = []
642
+
643
+ # Build realtime instructions: include full agent system prompt, context, and optional prompt (no user_text)
644
+ system_prompt = ""
645
+ try:
646
+ system_prompt = self.agent_service.get_agent_system_prompt(
647
+ agent_name
648
+ )
649
+ except Exception:
650
+ system_prompt = ""
651
+
652
+ parts: List[str] = []
653
+ if system_prompt:
654
+ parts.append(system_prompt)
655
+ if combined_ctx:
656
+ parts.append(combined_ctx)
657
+ if prompt:
658
+ parts.append(str(prompt))
659
+ final_instructions = "\n\n".join([p for p in parts if p])
660
+
661
+ # 4) Open a single WS session for assistant audio
662
+ # Realtime imports handled inside allocator helper
663
+
664
+ api_key = None
665
+ try:
666
+ api_key = self.agent_service.llm_provider.get_api_key() # type: ignore[attr-defined]
667
+ except Exception:
668
+ pass
669
+ if not api_key:
670
+ raise ValueError("OpenAI API key is required for realtime")
671
+
672
+ # Per-user persistent WS (single session)
673
+ def _mime_from(fmt: str) -> str:
674
+ f = (fmt or "").lower()
675
+ return {
676
+ "aac": "audio/aac",
677
+ "mp3": "audio/mpeg",
678
+ "mp4": "audio/mp4",
679
+ "m4a": "audio/mp4",
680
+ "mpeg": "audio/mpeg",
681
+ "mpga": "audio/mpeg",
682
+ "wav": "audio/wav",
683
+ "flac": "audio/flac",
684
+ "opus": "audio/opus",
685
+ "ogg": "audio/ogg",
686
+ "webm": "audio/webm",
687
+ "pcm": "audio/pcm",
688
+ }.get(f, "audio/pcm")
689
+
690
+ # Choose output encoding automatically when non-PCM output is requested
691
+ encode_out = bool(
692
+ rt_encode_output or (audio_output_format.lower() != "pcm")
693
+ )
694
+ # If caller explicitly requests text-only realtime, disable output encoding entirely
695
+ if (
696
+ rt_output_modalities is not None
697
+ and "audio" not in rt_output_modalities
698
+ ):
699
+ if encode_out:
700
+ logger.debug(
701
+ "Realtime(QueryService): forcing encode_out False for text-only modalities=%s",
702
+ rt_output_modalities,
703
+ )
704
+ encode_out = False
705
+ # Choose input transcoding when compressed input is provided (or explicitly requested)
706
+ is_audio_bytes = isinstance(query, (bytes, bytearray))
707
+ encode_in = bool(
708
+ rt_encode_input
709
+ or (is_audio_bytes and audio_input_format.lower() != "pcm")
710
+ )
711
+
712
+ # Allocate or reuse a realtime session for this specific request/user.
713
+ # (Transcription options may be applied below; if they change after allocate we will reconfigure.)
714
+ rt = await self._alloc_realtime_session(
715
+ user_id,
716
+ api_key=api_key,
717
+ rt_voice=rt_voice,
718
+ final_instructions=final_instructions,
719
+ initial_tools=initial_tools,
720
+ encode_in=encode_in,
721
+ encode_out=encode_out,
722
+ audio_input_format=audio_input_format,
723
+ audio_output_format=audio_output_format,
724
+ rt_output_modalities=rt_output_modalities,
725
+ )
726
+ # Ensure lock is released no matter what
727
+ try:
728
+ # --- Apply realtime transcription config BEFORE connecting (new) ---
729
+ if rt_transcription_model and hasattr(rt, "_options"):
730
+ try:
731
+ setattr(
732
+ rt._options,
733
+ "transcription_model",
734
+ rt_transcription_model,
735
+ )
736
+ if rt_transcription_language is not None:
737
+ setattr(
738
+ rt._options,
739
+ "transcription_language",
740
+ rt_transcription_language,
741
+ )
742
+ if rt_transcription_prompt is not None:
743
+ setattr(
744
+ rt._options,
745
+ "transcription_prompt",
746
+ rt_transcription_prompt,
747
+ )
748
+ if rt_transcription_noise_reduction is not None:
749
+ setattr(
750
+ rt._options,
751
+ "transcription_noise_reduction",
752
+ rt_transcription_noise_reduction,
753
+ )
754
+ if rt_transcription_include_logprobs:
755
+ setattr(
756
+ rt._options, "transcription_include_logprobs", True
757
+ )
758
+ except Exception:
759
+ logger.debug(
760
+ "Failed pre-connect transcription option assignment",
761
+ exc_info=True,
762
+ )
763
+
764
+ # Tool executor
765
+ async def _exec(
766
+ tool_name: str, args: Dict[str, Any]
767
+ ) -> Dict[str, Any]:
768
+ try:
769
+ return await self.agent_service.execute_tool(
770
+ agent_name, tool_name, args or {}
771
+ )
772
+ except Exception as e:
773
+ return {"status": "error", "message": str(e)}
774
+
775
+ # If possible, set on underlying session
776
+ try:
777
+ if hasattr(rt, "_session"):
778
+ getattr(rt, "_session").set_tool_executor(_exec) # type: ignore[attr-defined]
779
+ except Exception:
780
+ pass
781
+
782
+ # Connect/configure
783
+ if not getattr(rt, "_connected", False):
784
+ await rt.start()
785
+ await rt.configure(
786
+ voice=rt_voice,
787
+ vad_enabled=bool(vad) if vad is not None else False,
788
+ instructions=final_instructions,
789
+ tools=initial_tools or None,
790
+ tool_choice="auto",
791
+ )
792
+
793
+ # Ensure clean input buffers for this turn
794
+ try:
795
+ await rt.clear_input()
796
+ except Exception:
797
+ pass
798
+ # Also reset any leftover output audio so new turn doesn't replay old chunks
799
+ try:
800
+ if hasattr(rt, "reset_output_stream"):
801
+ rt.reset_output_stream()
802
+ except Exception:
803
+ pass
804
+
805
+ # Begin streaming turn (defer user transcript persistence until final to avoid duplicates)
806
+ turn_id = await self.realtime_begin_turn(user_id)
807
+ # We'll buffer the full user transcript (text input or realtime audio transcription) and persist exactly once.
808
+ # Initialize empty; we'll build it strictly from realtime transcript segments to avoid
809
+ # accidental duplication with pre-supplied user_text or prior buffers.
810
+ final_user_tr: str = ""
811
+ user_persisted = False
812
+
813
+ # Feed audio into WS if audio bytes provided and audio modality requested; else treat as text
814
+ wants_audio = (
815
+ (
816
+ getattr(rt, "_options", None)
817
+ and getattr(rt, "_options").output_modalities
818
+ )
819
+ and "audio" in getattr(rt, "_options").output_modalities # type: ignore[attr-defined]
820
+ ) or (
821
+ rt_output_modalities is None
822
+ or (rt_output_modalities and "audio" in rt_output_modalities)
823
+ )
824
+ # Determine if realtime transcription should be enabled (always skip HTTP STT regardless)
825
+ # realtime_transcription_enabled now implicit (options set before connect)
826
+
827
+ if is_audio_bytes and not wants_audio:
828
+ # Feed audio solely for transcription (no audio output requested)
829
+ bq = bytes(query)
830
+ logger.info(
831
+ "Realtime: appending input audio for transcription only, len=%d, fmt=%s",
832
+ len(bq),
833
+ audio_input_format,
834
+ )
835
+ await rt.append_audio(bq)
836
+ vad_enabled_value = bool(vad) if vad is not None else False
837
+ if not vad_enabled_value:
838
+ await rt.commit_input()
839
+ # Request only text response
840
+ await rt.create_response({"modalities": ["text"]})
841
+ else:
842
+ logger.debug(
843
+ "Realtime: VAD enabled (text-only output) — skipping manual response.create"
844
+ )
845
+ if is_audio_bytes and wants_audio:
846
+ bq = bytes(query)
847
+ logger.info(
848
+ "Realtime: appending input audio to WS via FFmpeg, len=%d, fmt=%s",
849
+ len(bq),
850
+ audio_input_format,
851
+ )
852
+ await rt.append_audio(bq)
853
+ vad_enabled_value = bool(vad) if vad is not None else False
854
+ if not vad_enabled_value:
855
+ await rt.commit_input()
856
+ # Manually trigger response when VAD is disabled
857
+ await rt.create_response({})
858
+ else:
859
+ # With server VAD enabled, the model will auto-create a response at end of speech
860
+ logger.debug(
861
+ "Realtime: VAD enabled — skipping manual response.create"
862
+ )
863
+ else: # Text-only path OR caller excluded audio modality
864
+ # For text input, create conversation item first, then response
865
+ await rt.create_conversation_item(
866
+ {
867
+ "type": "message",
868
+ "role": "user",
869
+ "content": [
870
+ {"type": "input_text", "text": user_text or ""}
871
+ ],
872
+ }
873
+ )
874
+ # Determine effective modalities (fall back to provided override or text only)
875
+ if rt_output_modalities is not None:
876
+ modalities = rt_output_modalities or ["text"]
877
+ else:
878
+ mo = getattr(
879
+ rt, "_options", RealtimeSessionOptions()
880
+ ).output_modalities
881
+ modalities = mo if mo else ["audio"]
882
+ if "audio" not in modalities:
883
+ # Ensure we do not accidentally request audio generation
884
+ modalities = [m for m in modalities if m == "text"] or [
885
+ "text"
886
+ ]
887
+ await rt.create_response(
888
+ {
889
+ "modalities": modalities,
890
+ }
891
+ )
892
+
893
+ # Collect audio and transcripts
894
+ user_tr = "" # Accumulates realtime input transcript segments (audio path)
895
+ asst_tr = ""
896
+
897
+ input_segments: List[str] = []
898
+
899
+ async def _drain_in_tr():
900
+ """Accumulate realtime input transcript segments, de-duplicating cumulative repeats.
901
+
902
+ Some realtime providers emit growing cumulative transcripts (e.g. "Hel", "Hello") or
903
+ may occasionally resend the full final transcript. Previous logic naively concatenated
904
+ every segment which could yield duplicated text ("HelloHello") if cumulative or repeated
905
+ finals were received. This routine keeps a canonical buffer (user_tr) and only appends
906
+ the non-overlapping suffix of each new segment.
907
+ """
908
+ nonlocal user_tr
909
+ async for t in rt.iter_input_transcript():
910
+ if not t:
911
+ continue
912
+ # Track raw segment for optional debugging
913
+ input_segments.append(t)
914
+ if not user_tr:
915
+ user_tr = t
916
+ continue
917
+ if t == user_tr:
918
+ # Exact duplicate of current buffer; skip
919
+ continue
920
+ if t.startswith(user_tr):
921
+ # Cumulative growth; append only the new suffix
922
+ user_tr += t[len(user_tr) :]
923
+ continue
924
+ # General case: find largest overlap between end of user_tr and start of t
925
+ # to avoid duplicated middle content (e.g., user_tr="My name is", t="name is John")
926
+ overlap = 0
927
+ max_check = min(len(user_tr), len(t))
928
+ for k in range(max_check, 0, -1):
929
+ if user_tr.endswith(t[:k]):
930
+ overlap = k
931
+ break
932
+ user_tr += t[overlap:]
933
+
934
+ # Check if we need both audio and text modalities
935
+ modalities = getattr(
936
+ rt, "_options", RealtimeSessionOptions()
937
+ ).output_modalities or ["audio"]
938
+ use_combined_stream = "audio" in modalities and "text" in modalities
939
+
940
+ if use_combined_stream and wants_audio:
941
+ # Use combined stream for both modalities
942
+ async def _drain_out_tr():
943
+ nonlocal asst_tr
944
+ async for t in rt.iter_output_transcript():
945
+ if t:
946
+ asst_tr += t
947
+
948
+ in_task = asyncio.create_task(_drain_in_tr())
949
+ out_task = asyncio.create_task(_drain_out_tr())
950
+ try:
951
+ # Check if the service has iter_output_combined method
952
+ if hasattr(rt, "iter_output_combined"):
953
+ async for chunk in rt.iter_output_combined():
954
+ # Adapt output based on caller's requested output_format
955
+ if output_format == "text":
956
+ # Only yield text modalities as plain strings
957
+ if getattr(chunk, "modality", None) == "text":
958
+ yield chunk.data # type: ignore[attr-defined]
959
+ continue
960
+ # Audio streaming path
961
+ if getattr(chunk, "modality", None) == "audio":
962
+ # Yield raw bytes if data present
963
+ yield getattr(chunk, "data", b"")
964
+ elif (
965
+ getattr(chunk, "modality", None) == "text"
966
+ and output_format == "audio"
967
+ ):
968
+ # Optionally ignore or log text while audio requested
969
+ continue
970
+ else:
971
+ # Fallback: ignore unknown modalities for now
972
+ continue
973
+ else:
974
+ # Fallback: yield audio chunks as RealtimeChunk objects
975
+ async for audio_chunk in rt.iter_output_audio_encoded():
976
+ if output_format == "text":
977
+ # Ignore audio when text requested
978
+ continue
979
+ # output_format audio: provide raw bytes
980
+ if hasattr(audio_chunk, "modality"):
981
+ if (
982
+ getattr(audio_chunk, "modality", None)
983
+ == "audio"
984
+ ):
985
+ yield getattr(audio_chunk, "data", b"")
986
+ else:
987
+ yield audio_chunk
988
+ finally:
989
+ # Allow transcript drain tasks to finish to capture user/asst text before persistence
990
+ try:
991
+ await asyncio.wait_for(in_task, timeout=0.05)
992
+ except Exception:
993
+ in_task.cancel()
994
+ try:
995
+ await asyncio.wait_for(out_task, timeout=0.05)
996
+ except Exception:
997
+ out_task.cancel()
998
+ # HTTP STT path removed: realtime audio input transcript (if any) is authoritative
999
+ # Persist transcripts after combined streaming completes
1000
+ if turn_id:
1001
+ try:
1002
+ effective_user_tr = user_tr or ("".join(input_segments))
1003
+ try:
1004
+ setattr(
1005
+ self,
1006
+ "_last_realtime_user_transcript",
1007
+ effective_user_tr,
1008
+ )
1009
+ except Exception:
1010
+ pass
1011
+ if effective_user_tr:
1012
+ final_user_tr = effective_user_tr
1013
+ elif (
1014
+ isinstance(query, str)
1015
+ and query
1016
+ and not input_segments
1017
+ and not user_tr
1018
+ ):
1019
+ final_user_tr = query
1020
+ if asst_tr:
1021
+ await self.realtime_update_assistant(
1022
+ user_id, turn_id, asst_tr
1023
+ )
1024
+ except Exception:
1025
+ pass
1026
+ if final_user_tr and not user_persisted:
1027
+ try:
1028
+ await self.realtime_update_user(
1029
+ user_id, turn_id, final_user_tr
1030
+ )
1031
+ user_persisted = True
1032
+ except Exception:
1033
+ pass
1034
+ try:
1035
+ await self.realtime_finalize_turn(user_id, turn_id)
1036
+ except Exception:
1037
+ pass
1038
+ if final_user_tr and not user_persisted:
1039
+ try:
1040
+ await self.realtime_update_user(
1041
+ user_id, turn_id, final_user_tr
1042
+ )
1043
+ user_persisted = True
1044
+ except Exception:
1045
+ pass
1046
+ elif wants_audio:
1047
+ # Use separate streams (legacy behavior)
1048
+ async def _drain_out_tr():
1049
+ nonlocal asst_tr
1050
+ async for t in rt.iter_output_transcript():
1051
+ if t:
1052
+ asst_tr += t
1053
+
1054
+ in_task = asyncio.create_task(_drain_in_tr())
1055
+ out_task = asyncio.create_task(_drain_out_tr())
1056
+ try:
1057
+ async for audio_chunk in rt.iter_output_audio_encoded():
1058
+ if output_format == "text":
1059
+ # Skip audio when caller wants text only
1060
+ continue
1061
+ # output_format audio: yield raw bytes
1062
+ if hasattr(audio_chunk, "modality"):
1063
+ if (
1064
+ getattr(audio_chunk, "modality", None)
1065
+ == "audio"
1066
+ ):
1067
+ yield getattr(audio_chunk, "data", b"")
1068
+ else:
1069
+ yield audio_chunk
1070
+ finally:
1071
+ try:
1072
+ await asyncio.wait_for(in_task, timeout=0.05)
1073
+ except Exception:
1074
+ in_task.cancel()
1075
+ try:
1076
+ await asyncio.wait_for(out_task, timeout=0.05)
1077
+ except Exception:
1078
+ out_task.cancel()
1079
+ # HTTP STT path removed
1080
+ # Persist transcripts after audio-only streaming
1081
+ if turn_id:
1082
+ try:
1083
+ effective_user_tr = user_tr or ("".join(input_segments))
1084
+ try:
1085
+ setattr(
1086
+ self,
1087
+ "_last_realtime_user_transcript",
1088
+ effective_user_tr,
1089
+ )
1090
+ except Exception:
1091
+ pass
1092
+ # Buffer final transcript for single persistence
1093
+ if effective_user_tr:
1094
+ final_user_tr = effective_user_tr
1095
+ elif (
1096
+ isinstance(query, str)
1097
+ and query
1098
+ and not input_segments
1099
+ and not user_tr
1100
+ ):
1101
+ final_user_tr = query
1102
+ if asst_tr:
1103
+ await self.realtime_update_assistant(
1104
+ user_id, turn_id, asst_tr
1105
+ )
1106
+ except Exception:
1107
+ pass
1108
+ if final_user_tr and not user_persisted:
1109
+ try:
1110
+ await self.realtime_update_user(
1111
+ user_id, turn_id, final_user_tr
1112
+ )
1113
+ user_persisted = True
1114
+ except Exception:
1115
+ pass
1116
+ try:
1117
+ await self.realtime_finalize_turn(user_id, turn_id)
1118
+ except Exception:
1119
+ pass
1120
+ # If no WS input transcript was captured, fall back to HTTP STT result
1121
+ else:
1122
+ # Text-only: just stream assistant transcript if available (no audio iteration)
1123
+ # If original input was audio bytes but caller only wants text output (no audio modality),
1124
+ # we still need to drain the input transcript stream to build user_tr.
1125
+ in_task_audio_only = None
1126
+ if is_audio_bytes:
1127
+ in_task_audio_only = asyncio.create_task(_drain_in_tr())
1128
+
1129
+ async def _drain_out_tr_text():
1130
+ nonlocal asst_tr
1131
+ async for t in rt.iter_output_transcript():
1132
+ if t:
1133
+ asst_tr += t
1134
+ yield t # Yield incremental text chunks directly
1135
+
1136
+ async for t in _drain_out_tr_text():
1137
+ # Provide plain text to caller
1138
+ yield t
1139
+ # Wait for input transcript (if any) before persistence
1140
+ if "in_task_audio_only" in locals() and in_task_audio_only:
1141
+ try:
1142
+ await asyncio.wait_for(in_task_audio_only, timeout=0.1)
1143
+ except Exception:
1144
+ in_task_audio_only.cancel()
1145
+ # No HTTP STT fallback
1146
+ if turn_id:
1147
+ try:
1148
+ effective_user_tr = user_tr or ("".join(input_segments))
1149
+ try:
1150
+ setattr(
1151
+ self,
1152
+ "_last_realtime_user_transcript",
1153
+ effective_user_tr,
1154
+ )
1155
+ except Exception:
1156
+ pass
1157
+ # For text-only modality but audio-origin (cumulative segments captured), persist user transcript
1158
+ if effective_user_tr:
1159
+ final_user_tr = effective_user_tr
1160
+ elif (
1161
+ isinstance(query, str)
1162
+ and query
1163
+ and not input_segments
1164
+ and not user_tr
1165
+ ):
1166
+ final_user_tr = query
1167
+ if asst_tr:
1168
+ await self.realtime_update_assistant(
1169
+ user_id, turn_id, asst_tr
1170
+ )
1171
+ except Exception:
1172
+ pass
1173
+ if final_user_tr and not user_persisted:
1174
+ try:
1175
+ await self.realtime_update_user(
1176
+ user_id, turn_id, final_user_tr
1177
+ )
1178
+ user_persisted = True
1179
+ except Exception:
1180
+ pass
1181
+ try:
1182
+ await self.realtime_finalize_turn(user_id, turn_id)
1183
+ except Exception:
1184
+ pass
1185
+ # Input transcript task already awaited above
1186
+ # Clear input buffer for next turn reuse
1187
+ try:
1188
+ await rt.clear_input()
1189
+ except Exception:
1190
+ pass
1191
+ finally:
1192
+ # Always release the session for reuse by other concurrent requests/devices
1193
+ try:
1194
+ lock = getattr(rt, "_in_use_lock", None)
1195
+ if lock and lock.locked():
1196
+ lock.release()
1197
+ except Exception:
1198
+ pass
1199
+ return
1200
+
1201
+ # 1) Acquire user_text (transcribe audio or direct text) for non-realtime path
71
1202
  user_text = ""
72
1203
  if not isinstance(query, str):
73
- async for transcript in self.agent_service.llm_provider.transcribe_audio(query, audio_input_format):
74
- user_text += transcript
1204
+ try:
1205
+ logger.info(
1206
+ f"Received audio input, transcribing format: {audio_input_format}"
1207
+ )
1208
+ async for tpart in self.agent_service.llm_provider.transcribe_audio( # type: ignore[attr-defined]
1209
+ query, audio_input_format
1210
+ ):
1211
+ user_text += tpart
1212
+ except Exception:
1213
+ user_text = ""
75
1214
  else:
76
1215
  user_text = query
77
1216
 
78
- # Handle simple greetings
79
- if user_text.strip().lower() in ["test", "hello", "hi", "hey", "ping"]:
80
- response = "Hello! How can I help you today?"
81
- if output_format == "audio":
82
- async for chunk in self.agent_service.llm_provider.tts(
83
- text=response,
84
- voice=audio_voice,
85
- response_format=audio_output_format
86
- ):
87
- yield chunk
88
- else:
89
- yield response
90
-
91
- # Store simple interaction in memory
92
- if self.memory_provider:
93
- await self._store_conversation(user_id, user_text, response)
94
- return
1217
+ # 2) Input guardrails
1218
+ for guardrail in self.input_guardrails:
1219
+ try:
1220
+ user_text = await guardrail.process(user_text)
1221
+ except Exception as e:
1222
+ logger.debug(f"Guardrail error: {e}")
95
1223
 
96
- # Get memory context if available
1224
+ # 3) Memory context (conversation history)
97
1225
  memory_context = ""
98
1226
  if self.memory_provider:
99
- memory_context = await self.memory_provider.retrieve(user_id)
1227
+ try:
1228
+ memory_context = await self.memory_provider.retrieve(user_id)
1229
+ except Exception:
1230
+ memory_context = ""
1231
+
1232
+ # 4) Knowledge base context
1233
+ kb_context = ""
1234
+ if self.knowledge_base:
1235
+ try:
1236
+ kb_results = await self.knowledge_base.query(
1237
+ query_text=user_text,
1238
+ top_k=self.kb_results_count,
1239
+ include_content=True,
1240
+ include_metadata=False,
1241
+ )
1242
+ if kb_results:
1243
+ kb_lines = [
1244
+ "**KNOWLEDGE BASE (CRITICAL: MAKE THIS INFORMATION THE TOP PRIORITY):**"
1245
+ ]
1246
+ for i, r in enumerate(kb_results, 1):
1247
+ kb_lines.append(f"[{i}] {r.get('content', '').strip()}\n")
1248
+ kb_context = "\n".join(kb_lines)
1249
+ except Exception:
1250
+ kb_context = ""
100
1251
 
101
- # Route query to appropriate agent
102
- if router:
103
- agent_name = await router.route_query(user_text)
1252
+ # 5) Determine agent (sticky session aware; allow explicit switch/new conversation)
1253
+ agent_name = "default"
1254
+ prev_assistant = ""
1255
+ routing_input = user_text
1256
+ if self.memory_provider:
1257
+ try:
1258
+ prev_docs = self.memory_provider.find(
1259
+ collection="conversations",
1260
+ query={"user_id": user_id},
1261
+ sort=[("timestamp", -1)],
1262
+ limit=1,
1263
+ )
1264
+ if prev_docs:
1265
+ prev_user_msg = (prev_docs[0] or {}).get(
1266
+ "user_message", ""
1267
+ ) or ""
1268
+ prev_assistant = (prev_docs[0] or {}).get(
1269
+ "assistant_message", ""
1270
+ ) or ""
1271
+ if prev_user_msg:
1272
+ routing_input = f"previous_user_message: {prev_user_msg}\ncurrent_user_message: {user_text}"
1273
+ except Exception:
1274
+ pass
1275
+
1276
+ # Get available agents first so the LLM can select a valid target
1277
+ agents = self.agent_service.get_all_ai_agents() or {}
1278
+ available_agent_names = list(agents.keys())
1279
+
1280
+ # Fast path: if only one agent, skip all routing logic entirely
1281
+ if len(available_agent_names) == 1:
1282
+ agent_name = available_agent_names[0]
1283
+ self._set_sticky_agent(user_id, agent_name, required_complete=False)
104
1284
  else:
105
- agent_name = await self.routing_service.route_query(user_text)
1285
+ # LLM detects switch intent (only needed with multiple agents)
1286
+ (
1287
+ switch_requested,
1288
+ requested_agent_raw,
1289
+ start_new,
1290
+ ) = await self._detect_switch_intent(user_text, available_agent_names)
1291
+
1292
+ # Normalize requested agent to an exact available key
1293
+ requested_agent = None
1294
+ if requested_agent_raw:
1295
+ raw_lower = requested_agent_raw.lower()
1296
+ for a in available_agent_names:
1297
+ if a.lower() == raw_lower or raw_lower in a.lower():
1298
+ requested_agent = a
1299
+ break
1300
+
1301
+ sticky_agent = self._get_sticky_agent(user_id)
1302
+
1303
+ if sticky_agent and not switch_requested:
1304
+ agent_name = sticky_agent
1305
+ else:
1306
+ try:
1307
+ if start_new:
1308
+ # Start fresh
1309
+ self._clear_sticky_agent(user_id)
1310
+ if requested_agent:
1311
+ agent_name = requested_agent
1312
+ else:
1313
+ # Route if no explicit target
1314
+ if router:
1315
+ agent_name = await router.route_query(routing_input)
1316
+ else:
1317
+ agent_name = await self.routing_service.route_query(
1318
+ routing_input
1319
+ )
1320
+ except Exception:
1321
+ agent_name = next(iter(agents.keys())) if agents else "default"
1322
+ self._set_sticky_agent(user_id, agent_name, required_complete=False)
1323
+
1324
+ # 7) Captured data context + incremental save using previous assistant message
1325
+ capture_context = ""
1326
+ # Two completion flags:
1327
+ required_complete = False
1328
+ form_complete = False # required + optional
1329
+
1330
+ # Helpers
1331
+ def _non_empty(v: Any) -> bool:
1332
+ if v is None:
1333
+ return False
1334
+ if isinstance(v, str):
1335
+ s = v.strip().lower()
1336
+ return s not in {"", "null", "none", "n/a", "na", "undefined", "."}
1337
+ if isinstance(v, (list, dict, tuple, set)):
1338
+ return len(v) > 0
1339
+ return True
1340
+
1341
+ def _parse_numbers_list(s: str) -> List[str]:
1342
+ nums = re.findall(r"\b(\d+)\b", s)
1343
+ seen, out = set(), []
1344
+ for n in nums:
1345
+ if n not in seen:
1346
+ seen.add(n)
1347
+ out.append(n)
1348
+ return out
1349
+
1350
+ def _extract_numbered_options(text: str) -> Dict[str, str]:
1351
+ """Parse previous assistant message for lines like:
1352
+ '1) Foo', '1. Foo', '- 1) Foo', '* 1. Foo' -> {'1': 'Foo'}"""
1353
+ options: Dict[str, str] = {}
1354
+ if not text:
1355
+ return options
1356
+ for raw in text.splitlines():
1357
+ line = raw.strip()
1358
+ if not line:
1359
+ continue
1360
+ m = re.match(r"^(?:[-*]\s*)?(\d+)[\.)]?\s+(.*)$", line)
1361
+ if m:
1362
+ idx, label = m.group(1), m.group(2).strip().rstrip()
1363
+ if len(label) >= 1:
1364
+ options[idx] = label
1365
+ return options
1366
+
1367
+ # LLM-backed field detection (gpt-4.1-mini) with graceful fallbacks
1368
+ class _FieldDetect(BaseModel):
1369
+ field: Optional[str] = None
1370
+
1371
+ async def _detect_field_from_prev_question(
1372
+ prev_text: str, schema: Optional[Dict[str, Any]]
1373
+ ) -> Optional[str]:
1374
+ if not prev_text or not isinstance(schema, dict):
1375
+ return None
1376
+ props = list((schema.get("properties") or {}).keys())
1377
+ if not props:
1378
+ return None
1379
+
1380
+ question = prev_text.strip()
1381
+ instruction = (
1382
+ "You are a strict classifier. Given the assistant's last question and a list of "
1383
+ "permitted schema field keys, choose exactly one key that the question is asking the user to answer. "
1384
+ "If none apply, return null."
1385
+ )
1386
+ user_prompt = (
1387
+ f"Schema field keys (choose exactly one of these): {props}\n"
1388
+ f"Assistant question:\n{question}\n\n"
1389
+ 'Return strictly JSON like: {"field": "<one_of_the_keys_or_null>"}'
1390
+ )
1391
+
1392
+ # Try llm_provider.parse_structured_output with mini
1393
+ try:
1394
+ if hasattr(
1395
+ self.agent_service.llm_provider, "parse_structured_output"
1396
+ ):
1397
+ try:
1398
+ result = await self.agent_service.llm_provider.parse_structured_output(
1399
+ prompt=user_prompt,
1400
+ system_prompt=instruction,
1401
+ model_class=_FieldDetect,
1402
+ model="gpt-4.1-mini",
1403
+ )
1404
+ except TypeError:
1405
+ # Provider may not accept 'model' kwarg
1406
+ result = await self.agent_service.llm_provider.parse_structured_output(
1407
+ prompt=user_prompt,
1408
+ system_prompt=instruction,
1409
+ model_class=_FieldDetect,
1410
+ )
1411
+ sel = None
1412
+ try:
1413
+ sel = getattr(result, "field", None)
1414
+ except Exception:
1415
+ sel = None
1416
+ if sel is None:
1417
+ try:
1418
+ d = result.model_dump()
1419
+ sel = d.get("field")
1420
+ except Exception:
1421
+ sel = None
1422
+ if sel in props:
1423
+ return sel
1424
+ except Exception as e:
1425
+ logger.debug(
1426
+ f"LLM parse_structured_output field detection failed: {e}"
1427
+ )
1428
+
1429
+ # Fallback: use generate_response with output_model=_FieldDetect
1430
+ try:
1431
+ async for r in self.agent_service.generate_response(
1432
+ agent_name=agent_name,
1433
+ user_id=user_id,
1434
+ query=user_text,
1435
+ images=images,
1436
+ memory_context="",
1437
+ output_format="text",
1438
+ prompt=f"{instruction}\n\n{user_prompt}",
1439
+ output_model=_FieldDetect,
1440
+ ):
1441
+ fd = r
1442
+ sel = None
1443
+ try:
1444
+ sel = fd.field # type: ignore[attr-defined]
1445
+ except Exception:
1446
+ try:
1447
+ d = fd.model_dump()
1448
+ sel = d.get("field")
1449
+ except Exception:
1450
+ sel = None
1451
+ if sel in props:
1452
+ return sel
1453
+ break
1454
+ except Exception as e:
1455
+ logger.debug(f"LLM generate_response field detection failed: {e}")
1456
+
1457
+ # Final heuristic fallback (keeps system working if LLM unavailable)
1458
+ t = question.lower()
1459
+ for key in props:
1460
+ if key in t:
1461
+ return key
1462
+ return None
1463
+
1464
+ # Resolve active capture from args or agent config
1465
+ active_capture_name = capture_name
1466
+ active_capture_schema = capture_schema
1467
+ if not active_capture_name or not active_capture_schema:
1468
+ try:
1469
+ cap_cfg = self.agent_service.get_agent_capture(agent_name)
1470
+ if cap_cfg:
1471
+ active_capture_name = active_capture_name or cap_cfg.get("name")
1472
+ active_capture_schema = active_capture_schema or cap_cfg.get(
1473
+ "schema"
1474
+ )
1475
+ except Exception:
1476
+ pass
1477
+
1478
+ latest_by_name: Dict[str, Dict[str, Any]] = {}
1479
+ if self.memory_provider:
1480
+ try:
1481
+ docs = self.memory_provider.find(
1482
+ collection="captures",
1483
+ query={"user_id": user_id},
1484
+ sort=[("timestamp", -1)],
1485
+ limit=100,
1486
+ )
1487
+ for d in docs or []:
1488
+ name = (d or {}).get("capture_name")
1489
+ if not name or name in latest_by_name:
1490
+ continue
1491
+ latest_by_name[name] = {
1492
+ "data": (d or {}).get("data", {}) or {},
1493
+ "mode": (d or {}).get("mode", "once"),
1494
+ "agent": (d or {}).get("agent_name"),
1495
+ }
1496
+ except Exception:
1497
+ pass
1498
+
1499
+ # Incremental save: use prev_assistant's numbered list to map numeric reply -> labels
1500
+ incremental: Dict[str, Any] = {}
1501
+ try:
1502
+ if (
1503
+ self.memory_provider
1504
+ and active_capture_name
1505
+ and isinstance(active_capture_schema, dict)
1506
+ ):
1507
+ props = (active_capture_schema or {}).get("properties", {})
1508
+ required_fields = list(
1509
+ (active_capture_schema or {}).get("required", []) or []
1510
+ )
1511
+ all_fields = list(props.keys())
1512
+ optional_fields = [
1513
+ f for f in all_fields if f not in set(required_fields)
1514
+ ]
1515
+
1516
+ active_data_existing = (
1517
+ latest_by_name.get(active_capture_name, {}) or {}
1518
+ ).get("data", {}) or {}
1519
+
1520
+ def _missing(fields: List[str]) -> List[str]:
1521
+ return [
1522
+ f
1523
+ for f in fields
1524
+ if not _non_empty(active_data_existing.get(f))
1525
+ ]
1526
+
1527
+ missing_required = _missing(required_fields)
1528
+ missing_optional = _missing(optional_fields)
1529
+
1530
+ target_field: Optional[
1531
+ str
1532
+ ] = await _detect_field_from_prev_question(
1533
+ prev_assistant, active_capture_schema
1534
+ )
1535
+ if not target_field:
1536
+ # If exactly one required missing, target it; else if none required missing and exactly one optional missing, target it.
1537
+ if len(missing_required) == 1:
1538
+ target_field = missing_required[0]
1539
+ elif len(missing_required) == 0 and len(missing_optional) == 1:
1540
+ target_field = missing_optional[0]
1541
+
1542
+ if target_field and target_field in props:
1543
+ f_schema = props.get(target_field, {}) or {}
1544
+ f_type = f_schema.get("type")
1545
+ number_to_label = _extract_numbered_options(prev_assistant)
1546
+
1547
+ if number_to_label:
1548
+ nums = _parse_numbers_list(user_text)
1549
+ labels = [
1550
+ number_to_label[n] for n in nums if n in number_to_label
1551
+ ]
1552
+ if labels:
1553
+ if f_type == "array":
1554
+ incremental[target_field] = labels
1555
+ else:
1556
+ incremental[target_field] = labels[0]
1557
+
1558
+ if target_field not in incremental:
1559
+ if f_type == "number":
1560
+ m = re.search(r"\b([0-9]+(?:\.[0-9]+)?)\b", user_text)
1561
+ if m:
1562
+ try:
1563
+ incremental[target_field] = float(m.group(1))
1564
+ except Exception:
1565
+ pass
1566
+ elif f_type == "array":
1567
+ parts = [
1568
+ p.strip()
1569
+ for p in re.split(r"[,\n;]+", user_text)
1570
+ if p.strip()
1571
+ ]
1572
+ if parts:
1573
+ incremental[target_field] = parts
1574
+ else:
1575
+ if user_text.strip():
1576
+ incremental[target_field] = user_text.strip()
1577
+
1578
+ if incremental:
1579
+ cleaned = {
1580
+ k: v for k, v in incremental.items() if _non_empty(v)
1581
+ }
1582
+ if cleaned:
1583
+ try:
1584
+ await self.memory_provider.save_capture(
1585
+ user_id=user_id,
1586
+ capture_name=active_capture_name,
1587
+ agent_name=agent_name,
1588
+ data=cleaned,
1589
+ schema=active_capture_schema,
1590
+ )
1591
+ except Exception as se:
1592
+ logger.error(f"Error saving incremental capture: {se}")
1593
+
1594
+ except Exception as e:
1595
+ logger.debug(f"Incremental extraction skipped: {e}")
1596
+
1597
+ # Build capture context, merging in incremental immediately (avoid read lag)
1598
+ def _get_active_data(name: Optional[str]) -> Dict[str, Any]:
1599
+ if not name:
1600
+ return {}
1601
+ base = (latest_by_name.get(name, {}) or {}).get("data", {}) or {}
1602
+ if incremental:
1603
+ base = {**base, **incremental}
1604
+ return base
1605
+
1606
+ lines: List[str] = []
1607
+ if active_capture_name and isinstance(active_capture_schema, dict):
1608
+ props = (active_capture_schema or {}).get("properties", {})
1609
+ required_fields = list(
1610
+ (active_capture_schema or {}).get("required", []) or []
1611
+ )
1612
+ all_fields = list(props.keys())
1613
+ optional_fields = [
1614
+ f for f in all_fields if f not in set(required_fields)
1615
+ ]
1616
+
1617
+ active_data = _get_active_data(active_capture_name)
1618
+
1619
+ def _missing_from(data: Dict[str, Any], fields: List[str]) -> List[str]:
1620
+ return [f for f in fields if not _non_empty(data.get(f))]
1621
+
1622
+ missing_required = _missing_from(active_data, required_fields)
1623
+ missing_optional = _missing_from(active_data, optional_fields)
1624
+
1625
+ required_complete = (
1626
+ len(missing_required) == 0 and len(required_fields) > 0
1627
+ )
1628
+ form_complete = required_complete and len(missing_optional) == 0
1629
+
1630
+ lines.append(
1631
+ "CAPTURED FORM STATE (Authoritative; do not re-ask filled values):"
1632
+ )
1633
+ lines.append(f"- form_name: {active_capture_name}")
1634
+
1635
+ if active_data:
1636
+ pairs = [
1637
+ f"{k}: {v}" for k, v in active_data.items() if _non_empty(v)
1638
+ ]
1639
+ lines.append(
1640
+ f"- filled_fields: {', '.join(pairs) if pairs else '(none)'}"
1641
+ )
1642
+ else:
1643
+ lines.append("- filled_fields: (none)")
1644
+
1645
+ lines.append(
1646
+ f"- missing_required_fields: {', '.join(missing_required) if missing_required else '(none)'}"
1647
+ )
1648
+ lines.append(
1649
+ f"- missing_optional_fields: {', '.join(missing_optional) if missing_optional else '(none)'}"
1650
+ )
1651
+ lines.append("")
106
1652
 
107
- print(f"Routed to agent: {agent_name}")
1653
+ if latest_by_name:
1654
+ lines.append("OTHER CAPTURED USER DATA (for reference):")
1655
+ for cname, info in latest_by_name.items():
1656
+ if cname == active_capture_name:
1657
+ continue
1658
+ data = info.get("data", {}) or {}
1659
+ if data:
1660
+ pairs = [f"{k}: {v}" for k, v in data.items() if _non_empty(v)]
1661
+ lines.append(
1662
+ f"- {cname}: {', '.join(pairs) if pairs else '(none)'}"
1663
+ )
1664
+ else:
1665
+ lines.append(f"- {cname}: (none)")
1666
+
1667
+ if lines:
1668
+ capture_context = "\n".join(lines) + "\n\n"
1669
+ # Update sticky session completion flag
1670
+ try:
1671
+ self._update_sticky_required_complete(user_id, required_complete)
1672
+ except Exception:
1673
+ pass
1674
+
1675
+ # Merge contexts + flow rules
1676
+ combined_context = ""
1677
+ if capture_context:
1678
+ combined_context += capture_context
1679
+ if memory_context:
1680
+ combined_context += f"CONVERSATION HISTORY (Use for continuity; not authoritative for facts):\n{memory_context}\n\n"
1681
+ if kb_context:
1682
+ combined_context += kb_context + "\n"
1683
+ if combined_context:
1684
+ combined_context += (
1685
+ "PRIORITIZATION GUIDE:\n"
1686
+ "- Prefer Captured User Data for user-specific fields.\n"
1687
+ "- Prefer KB/tools for facts.\n"
1688
+ "- History is for tone and continuity.\n\n"
1689
+ "FORM FLOW RULES:\n"
1690
+ "- Ask exactly one field per turn.\n"
1691
+ "- If any required fields are missing, ask the next missing required field.\n"
1692
+ "- If all required fields are filled but optional fields are missing, ask the next missing optional field.\n"
1693
+ "- Do NOT re-ask or verify values present in Captured User Data (auto-saved, authoritative).\n"
1694
+ "- Do NOT provide summaries until no required or optional fields are missing.\n\n"
1695
+ )
108
1696
 
109
- # Generate response
1697
+ # 8) Generate response
110
1698
  if output_format == "audio":
111
1699
  async for audio_chunk in self.agent_service.generate_response(
112
1700
  agent_name=agent_name,
113
1701
  user_id=user_id,
114
- query=query,
115
- memory_context=memory_context,
1702
+ query=user_text,
1703
+ images=images,
1704
+ memory_context=combined_context,
116
1705
  output_format="audio",
117
1706
  audio_voice=audio_voice,
118
- audio_input_format=audio_input_format,
119
1707
  audio_output_format=audio_output_format,
120
- audio_instructions=audio_instructions,
121
1708
  prompt=prompt,
122
1709
  ):
123
1710
  yield audio_chunk
124
-
125
1711
  if self.memory_provider:
126
1712
  await self._store_conversation(
127
1713
  user_id=user_id,
@@ -130,78 +1716,130 @@ class QueryService(QueryServiceInterface):
130
1716
  )
131
1717
  else:
132
1718
  full_text_response = ""
1719
+ capture_data: Optional[BaseModel] = None
1720
+
1721
+ # Resolve agent capture if not provided
1722
+ if not capture_schema or not capture_name:
1723
+ try:
1724
+ cap = self.agent_service.get_agent_capture(agent_name)
1725
+ if cap:
1726
+ capture_name = cap.get("name")
1727
+ capture_schema = cap.get("schema")
1728
+ except Exception:
1729
+ pass
1730
+
1731
+ # Only run final structured output when no required or optional fields are missing
1732
+ if capture_schema and capture_name and form_complete:
1733
+ try:
1734
+ DynamicModel = self._build_model_from_json_schema(
1735
+ capture_name, capture_schema
1736
+ )
1737
+ async for result in self.agent_service.generate_response(
1738
+ agent_name=agent_name,
1739
+ user_id=user_id,
1740
+ query=user_text,
1741
+ images=images,
1742
+ memory_context=combined_context,
1743
+ output_format="text",
1744
+ prompt=(
1745
+ (
1746
+ prompt
1747
+ + "\n\nUsing the captured user data above, return only the JSON for the requested schema. Do not invent values."
1748
+ )
1749
+ if prompt
1750
+ else "Using the captured user data above, return only the JSON for the requested schema. Do not invent values."
1751
+ ),
1752
+ output_model=DynamicModel,
1753
+ ):
1754
+ capture_data = result # type: ignore
1755
+ break
1756
+ except Exception as e:
1757
+ logger.error(f"Error during capture structured output: {e}")
1758
+
133
1759
  async for chunk in self.agent_service.generate_response(
134
1760
  agent_name=agent_name,
135
1761
  user_id=user_id,
136
1762
  query=user_text,
137
- memory_context=memory_context,
1763
+ images=images,
1764
+ memory_context=combined_context,
138
1765
  output_format="text",
139
1766
  prompt=prompt,
1767
+ output_model=output_model,
140
1768
  ):
141
1769
  yield chunk
142
- full_text_response += chunk
1770
+ if output_model is None:
1771
+ full_text_response += chunk
143
1772
 
144
1773
  if self.memory_provider and full_text_response:
145
1774
  await self._store_conversation(
146
1775
  user_id=user_id,
147
1776
  user_message=user_text,
148
- assistant_message=full_text_response
1777
+ assistant_message=full_text_response,
149
1778
  )
150
1779
 
1780
+ # Save final capture data if the model returned it
1781
+ if (
1782
+ self.memory_provider
1783
+ and capture_schema
1784
+ and capture_name
1785
+ and capture_data is not None
1786
+ ):
1787
+ try:
1788
+ data_dict = (
1789
+ capture_data.model_dump()
1790
+ if hasattr(capture_data, "model_dump")
1791
+ else capture_data.dict()
1792
+ )
1793
+ await self.memory_provider.save_capture(
1794
+ user_id=user_id,
1795
+ capture_name=capture_name,
1796
+ agent_name=agent_name,
1797
+ data=data_dict,
1798
+ schema=capture_schema,
1799
+ )
1800
+ except Exception as e:
1801
+ logger.error(f"Error saving capture: {e}")
1802
+
151
1803
  except Exception as e:
152
- error_msg = f"I apologize for the technical difficulty. {str(e)}"
1804
+ import traceback
1805
+
1806
+ error_msg = (
1807
+ "I apologize for the technical difficulty. Please try again later."
1808
+ )
1809
+ logger.error(f"Error in query processing: {e}\n{traceback.format_exc()}")
1810
+
153
1811
  if output_format == "audio":
154
- async for chunk in self.agent_service.llm_provider.tts(
155
- text=error_msg,
156
- voice=audio_voice,
157
- response_format=audio_output_format
158
- ):
159
- yield chunk
1812
+ try:
1813
+ async for chunk in self.agent_service.llm_provider.tts(
1814
+ text=error_msg,
1815
+ voice=audio_voice,
1816
+ response_format=audio_output_format,
1817
+ ):
1818
+ yield chunk
1819
+ except Exception as tts_e:
1820
+ logger.error(f"Error during TTS for error message: {tts_e}")
1821
+ yield error_msg + f" (TTS Error: {tts_e})"
160
1822
  else:
161
1823
  yield error_msg
162
1824
 
163
- print(f"Error in query processing: {str(e)}")
164
- import traceback
165
- print(traceback.format_exc())
166
-
167
1825
  async def delete_user_history(self, user_id: str) -> None:
168
- """Delete all conversation history for a user.
169
-
170
- Args:
171
- user_id: User ID
172
- """
1826
+ """Delete all conversation history for a user."""
173
1827
  if self.memory_provider:
174
1828
  try:
175
1829
  await self.memory_provider.delete(user_id)
176
1830
  except Exception as e:
177
- print(f"Error deleting user history: {str(e)}")
1831
+ logger.error(f"Error deleting user history for {user_id}: {e}")
1832
+ else:
1833
+ logger.debug("No memory provider; skip delete_user_history")
178
1834
 
179
1835
  async def get_user_history(
180
1836
  self,
181
1837
  user_id: str,
182
1838
  page_num: int = 1,
183
1839
  page_size: int = 20,
184
- sort_order: str = "desc" # "asc" for oldest-first, "desc" for newest-first
1840
+ sort_order: str = "desc",
185
1841
  ) -> Dict[str, Any]:
186
- """Get paginated message history for a user.
187
-
188
- Args:
189
- user_id: User ID
190
- page_num: Page number (starting from 1)
191
- page_size: Number of messages per page
192
- sort_order: Sort order ("asc" or "desc")
193
-
194
- Returns:
195
- Dictionary with paginated results and metadata:
196
- {
197
- "data": List of conversation entries,
198
- "total": Total number of entries,
199
- "page": Current page number,
200
- "page_size": Number of items per page,
201
- "total_pages": Total number of pages,
202
- "error": Error message if any
203
- }
204
- """
1842
+ """Get paginated message history for a user."""
205
1843
  if not self.memory_provider:
206
1844
  return {
207
1845
  "data": [],
@@ -209,85 +1847,152 @@ class QueryService(QueryServiceInterface):
209
1847
  "page": page_num,
210
1848
  "page_size": page_size,
211
1849
  "total_pages": 0,
212
- "error": "Memory provider not available"
1850
+ "error": "Memory provider not available",
213
1851
  }
214
-
215
1852
  try:
216
- # Calculate skip and limit for pagination
217
1853
  skip = (page_num - 1) * page_size
218
-
219
- # Get total count of documents
220
1854
  total = self.memory_provider.count_documents(
221
- collection="conversations",
222
- query={"user_id": user_id}
1855
+ collection="conversations", query={"user_id": user_id}
223
1856
  )
1857
+ total_pages = (total + page_size - 1) // page_size if total > 0 else 0
224
1858
 
225
- # Calculate total pages
226
- total_pages = (total + page_size - 1) // page_size
227
-
228
- # Get paginated results
229
1859
  conversations = self.memory_provider.find(
230
1860
  collection="conversations",
231
1861
  query={"user_id": user_id},
232
1862
  sort=[("timestamp", 1 if sort_order == "asc" else -1)],
233
1863
  skip=skip,
234
- limit=page_size
1864
+ limit=page_size,
235
1865
  )
236
1866
 
237
- # Format the results
238
- formatted_conversations = []
1867
+ formatted: List[Dict[str, Any]] = []
239
1868
  for conv in conversations:
240
- # Convert datetime to Unix timestamp (seconds since epoch)
241
- timestamp = int(conv.get("timestamp").timestamp()
242
- ) if conv.get("timestamp") else None
243
-
244
- formatted_conversations.append({
245
- "id": str(conv.get("_id")),
246
- "user_message": conv.get("user_message"),
247
- "assistant_message": conv.get("assistant_message"),
248
- "timestamp": timestamp,
249
- })
1869
+ ts = conv.get("timestamp")
1870
+ ts_epoch = int(ts.timestamp()) if ts else None
1871
+ formatted.append(
1872
+ {
1873
+ "id": str(conv.get("_id")),
1874
+ "user_message": conv.get("user_message"),
1875
+ "assistant_message": conv.get("assistant_message"),
1876
+ "timestamp": ts_epoch,
1877
+ }
1878
+ )
250
1879
 
251
1880
  return {
252
- "data": formatted_conversations,
1881
+ "data": formatted,
253
1882
  "total": total,
254
1883
  "page": page_num,
255
1884
  "page_size": page_size,
256
1885
  "total_pages": total_pages,
257
- "error": None
1886
+ "error": None,
258
1887
  }
259
-
260
1888
  except Exception as e:
261
- print(f"Error retrieving user history: {str(e)}")
262
1889
  import traceback
263
- print(traceback.format_exc())
1890
+
1891
+ logger.error(
1892
+ f"Error retrieving user history for {user_id}: {e}\n{traceback.format_exc()}"
1893
+ )
264
1894
  return {
265
1895
  "data": [],
266
1896
  "total": 0,
267
1897
  "page": page_num,
268
1898
  "page_size": page_size,
269
1899
  "total_pages": 0,
270
- "error": f"Error retrieving history: {str(e)}"
1900
+ "error": f"Error retrieving history: {str(e)}",
271
1901
  }
272
1902
 
273
1903
  async def _store_conversation(
274
1904
  self, user_id: str, user_message: str, assistant_message: str
275
1905
  ) -> None:
276
- """Store conversation history in memory provider.
1906
+ """Store conversation history in memory provider."""
1907
+ if not self.memory_provider:
1908
+ return
1909
+ try:
1910
+ await self.memory_provider.store(
1911
+ user_id,
1912
+ [
1913
+ {"role": "user", "content": user_message},
1914
+ {"role": "assistant", "content": assistant_message},
1915
+ ],
1916
+ )
1917
+ except Exception as e:
1918
+ logger.error(f"Store conversation error for {user_id}: {e}")
277
1919
 
278
- Args:
279
- user_id: User ID
280
- user_message: User message
281
- assistant_message: Assistant message
282
- """
283
- if self.memory_provider:
284
- try:
285
- await self.memory_provider.store(
286
- user_id,
287
- [
288
- {"role": "user", "content": user_message},
289
- {"role": "assistant", "content": assistant_message},
290
- ],
291
- )
292
- except Exception as e:
293
- print(f"Error storing conversation: {e}")
1920
+ # --- Realtime persistence helpers (used by client/server using realtime service) ---
1921
+ async def realtime_begin_turn(
1922
+ self, user_id: str
1923
+ ) -> Optional[str]: # pragma: no cover
1924
+ if not self.memory_provider:
1925
+ return None
1926
+ if not hasattr(self.memory_provider, "begin_stream_turn"):
1927
+ return None
1928
+ return await self.memory_provider.begin_stream_turn(user_id) # type: ignore[attr-defined]
1929
+
1930
+ async def realtime_update_user(
1931
+ self, user_id: str, turn_id: str, delta: str
1932
+ ) -> None: # pragma: no cover
1933
+ if not self.memory_provider:
1934
+ return
1935
+ if not hasattr(self.memory_provider, "update_stream_user"):
1936
+ return
1937
+ await self.memory_provider.update_stream_user(user_id, turn_id, delta) # type: ignore[attr-defined]
1938
+
1939
+ async def realtime_update_assistant(
1940
+ self, user_id: str, turn_id: str, delta: str
1941
+ ) -> None: # pragma: no cover
1942
+ if not self.memory_provider:
1943
+ return
1944
+ if not hasattr(self.memory_provider, "update_stream_assistant"):
1945
+ return
1946
+ await self.memory_provider.update_stream_assistant(user_id, turn_id, delta) # type: ignore[attr-defined]
1947
+
1948
+ async def realtime_finalize_turn(
1949
+ self, user_id: str, turn_id: str
1950
+ ) -> None: # pragma: no cover
1951
+ if not self.memory_provider:
1952
+ return
1953
+ if not hasattr(self.memory_provider, "finalize_stream_turn"):
1954
+ return
1955
+ await self.memory_provider.finalize_stream_turn(user_id, turn_id) # type: ignore[attr-defined]
1956
+
1957
+ def _build_model_from_json_schema(
1958
+ self, name: str, schema: Dict[str, Any]
1959
+ ) -> Type[BaseModel]:
1960
+ """Create a Pydantic model dynamically from a JSON Schema subset."""
1961
+ from pydantic import create_model
1962
+
1963
+ def py_type(js: Dict[str, Any]):
1964
+ t = js.get("type")
1965
+ if isinstance(t, list):
1966
+ non_null = [x for x in t if x != "null"]
1967
+ if not non_null:
1968
+ return Optional[Any]
1969
+ base = py_type({"type": non_null[0]})
1970
+ return Optional[base]
1971
+ if t == "string":
1972
+ return str
1973
+ if t == "integer":
1974
+ return int
1975
+ if t == "number":
1976
+ return float
1977
+ if t == "boolean":
1978
+ return bool
1979
+ if t == "array":
1980
+ items = js.get("items", {"type": "string"})
1981
+ return List[py_type(items)]
1982
+ if t == "object":
1983
+ return Dict[str, Any]
1984
+ return Any
1985
+
1986
+ properties: Dict[str, Any] = schema.get("properties", {})
1987
+ required = set(schema.get("required", []))
1988
+ fields: Dict[str, Any] = {}
1989
+ for field_name, field_schema in properties.items():
1990
+ typ = py_type(field_schema)
1991
+ default = field_schema.get("default")
1992
+ if field_name in required and default is None:
1993
+ fields[field_name] = (typ, ...)
1994
+ else:
1995
+ fields[field_name] = (typ, default)
1996
+
1997
+ Model = create_model(name, **fields) # type: ignore
1998
+ return Model