solana-agent 20.1.2__py3-none-any.whl → 31.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- solana_agent/__init__.py +10 -5
- solana_agent/adapters/ffmpeg_transcoder.py +375 -0
- solana_agent/adapters/mongodb_adapter.py +15 -2
- solana_agent/adapters/openai_adapter.py +679 -0
- solana_agent/adapters/openai_realtime_ws.py +1813 -0
- solana_agent/adapters/pinecone_adapter.py +543 -0
- solana_agent/cli.py +128 -0
- solana_agent/client/solana_agent.py +180 -20
- solana_agent/domains/agent.py +13 -13
- solana_agent/domains/routing.py +18 -8
- solana_agent/factories/agent_factory.py +239 -38
- solana_agent/guardrails/pii.py +107 -0
- solana_agent/interfaces/client/client.py +95 -12
- solana_agent/interfaces/guardrails/guardrails.py +26 -0
- solana_agent/interfaces/plugins/plugins.py +2 -1
- solana_agent/interfaces/providers/__init__.py +0 -0
- solana_agent/interfaces/providers/audio.py +40 -0
- solana_agent/interfaces/providers/data_storage.py +9 -2
- solana_agent/interfaces/providers/llm.py +86 -9
- solana_agent/interfaces/providers/memory.py +13 -1
- solana_agent/interfaces/providers/realtime.py +212 -0
- solana_agent/interfaces/providers/vector_storage.py +53 -0
- solana_agent/interfaces/services/agent.py +27 -12
- solana_agent/interfaces/services/knowledge_base.py +59 -0
- solana_agent/interfaces/services/query.py +41 -8
- solana_agent/interfaces/services/routing.py +0 -1
- solana_agent/plugins/manager.py +37 -16
- solana_agent/plugins/registry.py +34 -19
- solana_agent/plugins/tools/__init__.py +0 -5
- solana_agent/plugins/tools/auto_tool.py +1 -0
- solana_agent/repositories/memory.py +332 -111
- solana_agent/services/__init__.py +1 -1
- solana_agent/services/agent.py +390 -241
- solana_agent/services/knowledge_base.py +768 -0
- solana_agent/services/query.py +1858 -153
- solana_agent/services/realtime.py +626 -0
- solana_agent/services/routing.py +104 -51
- solana_agent-31.4.0.dist-info/METADATA +1070 -0
- solana_agent-31.4.0.dist-info/RECORD +49 -0
- {solana_agent-20.1.2.dist-info → solana_agent-31.4.0.dist-info}/WHEEL +1 -1
- solana_agent-31.4.0.dist-info/entry_points.txt +3 -0
- solana_agent/adapters/llm_adapter.py +0 -160
- solana_agent-20.1.2.dist-info/METADATA +0 -464
- solana_agent-20.1.2.dist-info/RECORD +0 -35
- {solana_agent-20.1.2.dist-info → solana_agent-31.4.0.dist-info/licenses}/LICENSE +0 -0
solana_agent/services/query.py
CHANGED
|
@@ -5,13 +5,44 @@ This service orchestrates the processing of user queries, coordinating
|
|
|
5
5
|
other services to provide comprehensive responses while maintaining
|
|
6
6
|
clean separation of concerns.
|
|
7
7
|
"""
|
|
8
|
-
from typing import Any, AsyncGenerator, Dict, Literal, Optional, Union
|
|
9
8
|
|
|
9
|
+
import logging
|
|
10
|
+
import asyncio
|
|
11
|
+
import re
|
|
12
|
+
import time
|
|
13
|
+
from typing import (
|
|
14
|
+
Any,
|
|
15
|
+
AsyncGenerator,
|
|
16
|
+
Dict,
|
|
17
|
+
List,
|
|
18
|
+
Literal,
|
|
19
|
+
Optional,
|
|
20
|
+
Type,
|
|
21
|
+
Union,
|
|
22
|
+
Tuple,
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
from pydantic import BaseModel
|
|
26
|
+
|
|
27
|
+
# Interface imports
|
|
10
28
|
from solana_agent.interfaces.services.query import QueryService as QueryServiceInterface
|
|
11
|
-
from solana_agent.interfaces.services.routing import
|
|
29
|
+
from solana_agent.interfaces.services.routing import (
|
|
30
|
+
RoutingService as RoutingServiceInterface,
|
|
31
|
+
)
|
|
32
|
+
from solana_agent.interfaces.providers.memory import (
|
|
33
|
+
MemoryProvider as MemoryProviderInterface,
|
|
34
|
+
)
|
|
35
|
+
from solana_agent.interfaces.services.knowledge_base import (
|
|
36
|
+
KnowledgeBaseService as KnowledgeBaseInterface,
|
|
37
|
+
)
|
|
38
|
+
from solana_agent.interfaces.guardrails.guardrails import InputGuardrail
|
|
39
|
+
|
|
40
|
+
from solana_agent.interfaces.providers.realtime import RealtimeSessionOptions
|
|
41
|
+
|
|
12
42
|
from solana_agent.services.agent import AgentService
|
|
13
43
|
from solana_agent.services.routing import RoutingService
|
|
14
|
-
|
|
44
|
+
|
|
45
|
+
logger = logging.getLogger(__name__)
|
|
15
46
|
|
|
16
47
|
|
|
17
48
|
class QueryService(QueryServiceInterface):
|
|
@@ -21,107 +52,1662 @@ class QueryService(QueryServiceInterface):
|
|
|
21
52
|
self,
|
|
22
53
|
agent_service: AgentService,
|
|
23
54
|
routing_service: RoutingService,
|
|
24
|
-
memory_provider: Optional[
|
|
55
|
+
memory_provider: Optional[MemoryProviderInterface] = None,
|
|
56
|
+
knowledge_base: Optional[KnowledgeBaseInterface] = None,
|
|
57
|
+
kb_results_count: int = 3,
|
|
58
|
+
input_guardrails: List[InputGuardrail] = None,
|
|
25
59
|
):
|
|
26
|
-
"""Initialize the query service.
|
|
27
|
-
|
|
28
|
-
Args:
|
|
29
|
-
agent_service: Service for AI agent management
|
|
30
|
-
routing_service: Service for routing queries to appropriate agents
|
|
31
|
-
memory_provider: Optional provider for memory storage and retrieval
|
|
32
|
-
"""
|
|
60
|
+
"""Initialize the query service."""
|
|
33
61
|
self.agent_service = agent_service
|
|
34
62
|
self.routing_service = routing_service
|
|
35
63
|
self.memory_provider = memory_provider
|
|
64
|
+
self.knowledge_base = knowledge_base
|
|
65
|
+
self.kb_results_count = kb_results_count
|
|
66
|
+
self.input_guardrails = input_guardrails or []
|
|
67
|
+
# Per-user sticky sessions (in-memory)
|
|
68
|
+
# { user_id: { 'agent': str, 'started_at': float, 'last_updated': float, 'required_complete': bool } }
|
|
69
|
+
self._sticky_sessions: Dict[str, Dict[str, Any]] = {}
|
|
70
|
+
# Optional realtime service attached by factory (populated in factory)
|
|
71
|
+
self.realtime = None # type: ignore[attr-defined]
|
|
72
|
+
# Persistent realtime WS pool per user for reuse across turns/devices
|
|
73
|
+
# { user_id: [RealtimeService, ...] }
|
|
74
|
+
self._rt_services: Dict[str, List[Any]] = {}
|
|
75
|
+
# Global lock for creating/finding per-user sessions
|
|
76
|
+
self._rt_lock = asyncio.Lock()
|
|
77
|
+
|
|
78
|
+
async def _try_acquire_lock(self, lock: asyncio.Lock) -> bool:
|
|
79
|
+
try:
|
|
80
|
+
await asyncio.wait_for(lock.acquire(), timeout=0)
|
|
81
|
+
return True
|
|
82
|
+
except asyncio.TimeoutError:
|
|
83
|
+
return False
|
|
84
|
+
except Exception:
|
|
85
|
+
return False
|
|
86
|
+
|
|
87
|
+
async def _alloc_realtime_session(
|
|
88
|
+
self,
|
|
89
|
+
user_id: str,
|
|
90
|
+
*,
|
|
91
|
+
api_key: str,
|
|
92
|
+
rt_voice: str,
|
|
93
|
+
final_instructions: str,
|
|
94
|
+
initial_tools: Optional[List[Dict[str, Any]]],
|
|
95
|
+
encode_in: bool,
|
|
96
|
+
encode_out: bool,
|
|
97
|
+
audio_input_format: str,
|
|
98
|
+
audio_output_format: str,
|
|
99
|
+
rt_output_modalities: Optional[List[Literal["audio", "text"]]] = None,
|
|
100
|
+
) -> Any:
|
|
101
|
+
"""Get a free (or new) realtime session for this user. Marks it busy via an internal lock.
|
|
102
|
+
|
|
103
|
+
Returns the RealtimeService with an acquired _in_use_lock that MUST be released by caller.
|
|
104
|
+
"""
|
|
105
|
+
from solana_agent.interfaces.providers.realtime import (
|
|
106
|
+
RealtimeSessionOptions,
|
|
107
|
+
)
|
|
108
|
+
from solana_agent.adapters.openai_realtime_ws import (
|
|
109
|
+
OpenAIRealtimeWebSocketSession,
|
|
110
|
+
)
|
|
111
|
+
from solana_agent.adapters.ffmpeg_transcoder import FFmpegTranscoder
|
|
112
|
+
|
|
113
|
+
def _mime_from(fmt: str) -> str:
|
|
114
|
+
f = (fmt or "").lower()
|
|
115
|
+
return {
|
|
116
|
+
"aac": "audio/aac",
|
|
117
|
+
"mp3": "audio/mpeg",
|
|
118
|
+
"mp4": "audio/mp4",
|
|
119
|
+
"m4a": "audio/mp4",
|
|
120
|
+
"mpeg": "audio/mpeg",
|
|
121
|
+
"mpga": "audio/mpeg",
|
|
122
|
+
"wav": "audio/wav",
|
|
123
|
+
"flac": "audio/flac",
|
|
124
|
+
"opus": "audio/opus",
|
|
125
|
+
"ogg": "audio/ogg",
|
|
126
|
+
"webm": "audio/webm",
|
|
127
|
+
"pcm": "audio/pcm",
|
|
128
|
+
}.get(f, "audio/pcm")
|
|
129
|
+
|
|
130
|
+
async with self._rt_lock:
|
|
131
|
+
pool = self._rt_services.get(user_id) or []
|
|
132
|
+
# Try to reuse an idle session strictly owned by this user
|
|
133
|
+
for rt in pool:
|
|
134
|
+
# Extra safety: never reuse a session from another user
|
|
135
|
+
owner = getattr(rt, "_owner_user_id", None)
|
|
136
|
+
if owner is not None and owner != user_id:
|
|
137
|
+
continue
|
|
138
|
+
lock = getattr(rt, "_in_use_lock", None)
|
|
139
|
+
if lock is None:
|
|
140
|
+
lock = asyncio.Lock()
|
|
141
|
+
setattr(rt, "_in_use_lock", lock)
|
|
142
|
+
if not lock.locked():
|
|
143
|
+
if await self._try_acquire_lock(lock):
|
|
144
|
+
return rt
|
|
145
|
+
# None free: create a new session
|
|
146
|
+
opts = RealtimeSessionOptions(
|
|
147
|
+
model="gpt-realtime",
|
|
148
|
+
voice=rt_voice,
|
|
149
|
+
vad_enabled=False,
|
|
150
|
+
input_rate_hz=24000,
|
|
151
|
+
output_rate_hz=24000,
|
|
152
|
+
input_mime="audio/pcm",
|
|
153
|
+
output_mime="audio/pcm",
|
|
154
|
+
output_modalities=rt_output_modalities,
|
|
155
|
+
tools=initial_tools or None,
|
|
156
|
+
tool_choice="auto",
|
|
157
|
+
)
|
|
158
|
+
try:
|
|
159
|
+
opts.instructions = final_instructions
|
|
160
|
+
opts.voice = rt_voice
|
|
161
|
+
except Exception:
|
|
162
|
+
pass
|
|
163
|
+
conv_session = OpenAIRealtimeWebSocketSession(api_key=api_key, options=opts)
|
|
164
|
+
transcoder = FFmpegTranscoder() if (encode_in or encode_out) else None
|
|
165
|
+
from solana_agent.services.realtime import RealtimeService
|
|
166
|
+
|
|
167
|
+
rt = RealtimeService(
|
|
168
|
+
session=conv_session,
|
|
169
|
+
options=opts,
|
|
170
|
+
transcoder=transcoder,
|
|
171
|
+
accept_compressed_input=encode_in,
|
|
172
|
+
client_input_mime=_mime_from(audio_input_format),
|
|
173
|
+
encode_output=encode_out,
|
|
174
|
+
client_output_mime=_mime_from(audio_output_format),
|
|
175
|
+
)
|
|
176
|
+
# Tag ownership to prevent any cross-user reuse
|
|
177
|
+
setattr(rt, "_owner_user_id", user_id)
|
|
178
|
+
setattr(rt, "_in_use_lock", asyncio.Lock())
|
|
179
|
+
# Mark busy
|
|
180
|
+
await getattr(rt, "_in_use_lock").acquire()
|
|
181
|
+
pool.append(rt)
|
|
182
|
+
self._rt_services[user_id] = pool
|
|
183
|
+
return rt
|
|
184
|
+
|
|
185
|
+
def _get_sticky_agent(self, user_id: str) -> Optional[str]:
|
|
186
|
+
sess = self._sticky_sessions.get(user_id)
|
|
187
|
+
return sess.get("agent") if isinstance(sess, dict) else None
|
|
188
|
+
|
|
189
|
+
def _set_sticky_agent(
|
|
190
|
+
self, user_id: str, agent_name: str, required_complete: bool = False
|
|
191
|
+
) -> None:
|
|
192
|
+
self._sticky_sessions[user_id] = {
|
|
193
|
+
"agent": agent_name,
|
|
194
|
+
"started_at": time.time(),
|
|
195
|
+
"last_updated": time.time(),
|
|
196
|
+
"required_complete": required_complete,
|
|
197
|
+
}
|
|
198
|
+
|
|
199
|
+
def _update_sticky_required_complete(
|
|
200
|
+
self, user_id: str, required_complete: bool
|
|
201
|
+
) -> None:
|
|
202
|
+
if user_id in self._sticky_sessions:
|
|
203
|
+
self._sticky_sessions[user_id]["required_complete"] = required_complete
|
|
204
|
+
self._sticky_sessions[user_id]["last_updated"] = time.time()
|
|
205
|
+
|
|
206
|
+
def _clear_sticky_agent(self, user_id: str) -> None:
|
|
207
|
+
if user_id in self._sticky_sessions:
|
|
208
|
+
try:
|
|
209
|
+
del self._sticky_sessions[user_id]
|
|
210
|
+
except Exception:
|
|
211
|
+
pass
|
|
212
|
+
|
|
213
|
+
async def _build_combined_context(
|
|
214
|
+
self,
|
|
215
|
+
user_id: str,
|
|
216
|
+
user_text: str,
|
|
217
|
+
agent_name: str,
|
|
218
|
+
capture_name: Optional[str] = None,
|
|
219
|
+
capture_schema: Optional[Dict[str, Any]] = None,
|
|
220
|
+
prev_assistant: str = "",
|
|
221
|
+
) -> Tuple[str, bool]:
|
|
222
|
+
"""Build combined context string and return required_complete flag."""
|
|
223
|
+
# Memory context
|
|
224
|
+
memory_context = ""
|
|
225
|
+
if self.memory_provider:
|
|
226
|
+
try:
|
|
227
|
+
memory_context = await self.memory_provider.retrieve(user_id)
|
|
228
|
+
except Exception:
|
|
229
|
+
memory_context = ""
|
|
230
|
+
|
|
231
|
+
# KB context
|
|
232
|
+
kb_context = ""
|
|
233
|
+
if self.knowledge_base:
|
|
234
|
+
try:
|
|
235
|
+
kb_results = await self.knowledge_base.query(
|
|
236
|
+
query_text=user_text,
|
|
237
|
+
top_k=self.kb_results_count,
|
|
238
|
+
include_content=True,
|
|
239
|
+
include_metadata=False,
|
|
240
|
+
)
|
|
241
|
+
if kb_results:
|
|
242
|
+
kb_lines = [
|
|
243
|
+
"**KNOWLEDGE BASE (CRITICAL: MAKE THIS INFORMATION THE TOP PRIORITY):**"
|
|
244
|
+
]
|
|
245
|
+
for i, r in enumerate(kb_results, 1):
|
|
246
|
+
kb_lines.append(f"[{i}] {r.get('content', '').strip()}\n")
|
|
247
|
+
kb_context = "\n".join(kb_lines)
|
|
248
|
+
except Exception:
|
|
249
|
+
kb_context = ""
|
|
250
|
+
|
|
251
|
+
# Capture context
|
|
252
|
+
capture_context = ""
|
|
253
|
+
required_complete = False
|
|
254
|
+
|
|
255
|
+
active_capture_name = capture_name
|
|
256
|
+
active_capture_schema = capture_schema
|
|
257
|
+
if not active_capture_name or not active_capture_schema:
|
|
258
|
+
try:
|
|
259
|
+
cap_cfg = self.agent_service.get_agent_capture(agent_name)
|
|
260
|
+
if cap_cfg:
|
|
261
|
+
active_capture_name = active_capture_name or cap_cfg.get("name")
|
|
262
|
+
active_capture_schema = active_capture_schema or cap_cfg.get(
|
|
263
|
+
"schema"
|
|
264
|
+
)
|
|
265
|
+
except Exception:
|
|
266
|
+
pass
|
|
267
|
+
|
|
268
|
+
if active_capture_name and isinstance(active_capture_schema, dict):
|
|
269
|
+
latest_by_name: Dict[str, Dict[str, Any]] = {}
|
|
270
|
+
if self.memory_provider:
|
|
271
|
+
try:
|
|
272
|
+
docs = self.memory_provider.find(
|
|
273
|
+
collection="captures",
|
|
274
|
+
query={"user_id": user_id},
|
|
275
|
+
sort=[("timestamp", -1)],
|
|
276
|
+
limit=100,
|
|
277
|
+
)
|
|
278
|
+
for d in docs or []:
|
|
279
|
+
name = (d or {}).get("capture_name")
|
|
280
|
+
if not name or name in latest_by_name:
|
|
281
|
+
continue
|
|
282
|
+
latest_by_name[name] = {
|
|
283
|
+
"data": (d or {}).get("data", {}) or {},
|
|
284
|
+
"mode": (d or {}).get("mode", "once"),
|
|
285
|
+
"agent": (d or {}).get("agent_name"),
|
|
286
|
+
}
|
|
287
|
+
except Exception:
|
|
288
|
+
pass
|
|
289
|
+
|
|
290
|
+
active_data = (latest_by_name.get(active_capture_name, {}) or {}).get(
|
|
291
|
+
"data", {}
|
|
292
|
+
) or {}
|
|
293
|
+
|
|
294
|
+
def _non_empty(v: Any) -> bool:
|
|
295
|
+
if v is None:
|
|
296
|
+
return False
|
|
297
|
+
if isinstance(v, str):
|
|
298
|
+
s = v.strip().lower()
|
|
299
|
+
return s not in {"", "null", "none", "n/a", "na", "undefined", "."}
|
|
300
|
+
if isinstance(v, (list, dict, tuple, set)):
|
|
301
|
+
return len(v) > 0
|
|
302
|
+
return True
|
|
303
|
+
|
|
304
|
+
props = (active_capture_schema or {}).get("properties", {})
|
|
305
|
+
required_fields = list(
|
|
306
|
+
(active_capture_schema or {}).get("required", []) or []
|
|
307
|
+
)
|
|
308
|
+
all_fields = list(props.keys())
|
|
309
|
+
optional_fields = [f for f in all_fields if f not in set(required_fields)]
|
|
310
|
+
|
|
311
|
+
def _missing_from(data: Dict[str, Any], fields: List[str]) -> List[str]:
|
|
312
|
+
return [f for f in fields if not _non_empty(data.get(f))]
|
|
313
|
+
|
|
314
|
+
missing_required = _missing_from(active_data, required_fields)
|
|
315
|
+
missing_optional = _missing_from(active_data, optional_fields)
|
|
316
|
+
|
|
317
|
+
required_complete = len(missing_required) == 0 and len(required_fields) > 0
|
|
318
|
+
|
|
319
|
+
lines: List[str] = []
|
|
320
|
+
lines.append(
|
|
321
|
+
"CAPTURED FORM STATE (Authoritative; do not re-ask filled values):"
|
|
322
|
+
)
|
|
323
|
+
lines.append(f"- form_name: {active_capture_name}")
|
|
324
|
+
|
|
325
|
+
if active_data:
|
|
326
|
+
pairs = [f"{k}: {v}" for k, v in active_data.items() if _non_empty(v)]
|
|
327
|
+
lines.append(
|
|
328
|
+
f"- filled_fields: {', '.join(pairs) if pairs else '(none)'}"
|
|
329
|
+
)
|
|
330
|
+
else:
|
|
331
|
+
lines.append("- filled_fields: (none)")
|
|
332
|
+
|
|
333
|
+
lines.append(
|
|
334
|
+
f"- missing_required_fields: {', '.join(missing_required) if missing_required else '(none)'}"
|
|
335
|
+
)
|
|
336
|
+
lines.append(
|
|
337
|
+
f"- missing_optional_fields: {', '.join(missing_optional) if missing_optional else '(none)'}"
|
|
338
|
+
)
|
|
339
|
+
lines.append("")
|
|
340
|
+
|
|
341
|
+
if latest_by_name:
|
|
342
|
+
lines.append("OTHER CAPTURED USER DATA (for reference):")
|
|
343
|
+
for cname, info in latest_by_name.items():
|
|
344
|
+
if cname == active_capture_name:
|
|
345
|
+
continue
|
|
346
|
+
data = info.get("data", {}) or {}
|
|
347
|
+
if data:
|
|
348
|
+
pairs = [f"{k}: {v}" for k, v in data.items() if _non_empty(v)]
|
|
349
|
+
lines.append(
|
|
350
|
+
f"- {cname}: {', '.join(pairs) if pairs else '(none)'}"
|
|
351
|
+
)
|
|
352
|
+
else:
|
|
353
|
+
lines.append(f"- {cname}: (none)")
|
|
354
|
+
|
|
355
|
+
if lines:
|
|
356
|
+
capture_context = "\n".join(lines) + "\n\n"
|
|
357
|
+
|
|
358
|
+
# Merge contexts
|
|
359
|
+
combined_context = ""
|
|
360
|
+
if capture_context:
|
|
361
|
+
combined_context += capture_context
|
|
362
|
+
if memory_context:
|
|
363
|
+
combined_context += f"CONVERSATION HISTORY (Use for continuity; not authoritative for facts):\n{memory_context}\n\n"
|
|
364
|
+
if kb_context:
|
|
365
|
+
combined_context += kb_context + "\n"
|
|
366
|
+
|
|
367
|
+
guide = (
|
|
368
|
+
"PRIORITIZATION GUIDE:\n"
|
|
369
|
+
"- Prefer Captured User Data for user-specific fields.\n"
|
|
370
|
+
"- Prefer KB/tools for facts.\n"
|
|
371
|
+
"- History is for tone and continuity.\n\n"
|
|
372
|
+
"FORM FLOW RULES:\n"
|
|
373
|
+
"- Ask exactly one field per turn.\n"
|
|
374
|
+
"- If any required fields are missing, ask the next missing required field.\n"
|
|
375
|
+
"- If all required fields are filled but optional fields are missing, ask the next missing optional field.\n"
|
|
376
|
+
"- Do NOT re-ask or verify values present in Captured User Data (auto-saved, authoritative).\n"
|
|
377
|
+
"- Do NOT provide summaries until no required or optional fields are missing.\n\n"
|
|
378
|
+
)
|
|
379
|
+
|
|
380
|
+
if combined_context:
|
|
381
|
+
combined_context += guide
|
|
382
|
+
else:
|
|
383
|
+
# Diagnostics for why the context is empty
|
|
384
|
+
try:
|
|
385
|
+
logger.debug(
|
|
386
|
+
"_build_combined_context: empty sources — memory_provider=%s, knowledge_base=%s, active_capture=%s",
|
|
387
|
+
bool(self.memory_provider),
|
|
388
|
+
bool(self.knowledge_base),
|
|
389
|
+
bool(
|
|
390
|
+
active_capture_name and isinstance(active_capture_schema, dict)
|
|
391
|
+
),
|
|
392
|
+
)
|
|
393
|
+
except Exception:
|
|
394
|
+
pass
|
|
395
|
+
# Provide minimal guide so realtime instructions are not blank
|
|
396
|
+
combined_context = guide
|
|
397
|
+
|
|
398
|
+
return combined_context, required_complete
|
|
399
|
+
|
|
400
|
+
# LLM-backed switch intent detection (gpt-4.1-mini)
|
|
401
|
+
class _SwitchIntentModel(BaseModel):
|
|
402
|
+
switch: bool = False
|
|
403
|
+
target_agent: Optional[str] = None
|
|
404
|
+
start_new: bool = False
|
|
405
|
+
|
|
406
|
+
async def _detect_switch_intent(
|
|
407
|
+
self, text: str, available_agents: List[str]
|
|
408
|
+
) -> Tuple[bool, Optional[str], bool]:
|
|
409
|
+
"""Detect if the user is asking to switch agents or start a new conversation.
|
|
410
|
+
|
|
411
|
+
Returns: (switch_requested, target_agent_name_or_none, start_new_conversation)
|
|
412
|
+
Implemented as an LLM call to gpt-4.1-mini with structured output.
|
|
413
|
+
"""
|
|
414
|
+
if not text:
|
|
415
|
+
return (False, None, False)
|
|
416
|
+
|
|
417
|
+
# Instruction and user prompt for the classifier
|
|
418
|
+
instruction = (
|
|
419
|
+
"You are a strict intent classifier for agent routing. "
|
|
420
|
+
"Decide if the user's message requests switching to another agent or starting a new conversation. "
|
|
421
|
+
"Only return JSON with keys: switch (bool), target_agent (string|null), start_new (bool). "
|
|
422
|
+
"If a target agent is mentioned, it MUST be one of the provided agent names (case-insensitive). "
|
|
423
|
+
"If none clearly applies, set switch=false and start_new=false and target_agent=null."
|
|
424
|
+
)
|
|
425
|
+
user_prompt = (
|
|
426
|
+
f"Available agents (choose only from these if a target is specified): {available_agents}\n\n"
|
|
427
|
+
f"User message:\n{text}\n\n"
|
|
428
|
+
'Return JSON only, like: {"switch": true|false, "target_agent": "<one_of_available_or_null>", "start_new": true|false}'
|
|
429
|
+
)
|
|
430
|
+
|
|
431
|
+
# Primary: use llm_provider.parse_structured_output
|
|
432
|
+
try:
|
|
433
|
+
if hasattr(self.agent_service.llm_provider, "parse_structured_output"):
|
|
434
|
+
try:
|
|
435
|
+
result = (
|
|
436
|
+
await self.agent_service.llm_provider.parse_structured_output(
|
|
437
|
+
prompt=user_prompt,
|
|
438
|
+
system_prompt=instruction,
|
|
439
|
+
model_class=QueryService._SwitchIntentModel,
|
|
440
|
+
model="gpt-4.1-mini",
|
|
441
|
+
)
|
|
442
|
+
)
|
|
443
|
+
except TypeError:
|
|
444
|
+
# Provider may not accept 'model' kwarg
|
|
445
|
+
result = (
|
|
446
|
+
await self.agent_service.llm_provider.parse_structured_output(
|
|
447
|
+
prompt=user_prompt,
|
|
448
|
+
system_prompt=instruction,
|
|
449
|
+
model_class=QueryService._SwitchIntentModel,
|
|
450
|
+
)
|
|
451
|
+
)
|
|
452
|
+
switch = bool(getattr(result, "switch", False))
|
|
453
|
+
target = getattr(result, "target_agent", None)
|
|
454
|
+
start_new = bool(getattr(result, "start_new", False))
|
|
455
|
+
# Normalize target to available agent name
|
|
456
|
+
if target:
|
|
457
|
+
target_lower = target.lower()
|
|
458
|
+
norm = None
|
|
459
|
+
for a in available_agents:
|
|
460
|
+
if a.lower() == target_lower or target_lower in a.lower():
|
|
461
|
+
norm = a
|
|
462
|
+
break
|
|
463
|
+
target = norm
|
|
464
|
+
if not switch:
|
|
465
|
+
target = None
|
|
466
|
+
return (switch, target, start_new)
|
|
467
|
+
except Exception as e:
|
|
468
|
+
logger.debug(f"LLM switch intent parse_structured_output failed: {e}")
|
|
469
|
+
|
|
470
|
+
# Fallback: generate_response with output_model
|
|
471
|
+
try:
|
|
472
|
+
async for r in self.agent_service.generate_response(
|
|
473
|
+
agent_name="default",
|
|
474
|
+
user_id="router",
|
|
475
|
+
query="",
|
|
476
|
+
images=None,
|
|
477
|
+
memory_context="",
|
|
478
|
+
output_format="text",
|
|
479
|
+
prompt=f"{instruction}\n\n{user_prompt}",
|
|
480
|
+
output_model=QueryService._SwitchIntentModel,
|
|
481
|
+
):
|
|
482
|
+
result = r
|
|
483
|
+
switch = False
|
|
484
|
+
target = None
|
|
485
|
+
start_new = False
|
|
486
|
+
try:
|
|
487
|
+
switch = bool(result.switch) # type: ignore[attr-defined]
|
|
488
|
+
target = result.target_agent # type: ignore[attr-defined]
|
|
489
|
+
start_new = bool(result.start_new) # type: ignore[attr-defined]
|
|
490
|
+
except Exception:
|
|
491
|
+
try:
|
|
492
|
+
d = result.model_dump()
|
|
493
|
+
switch = bool(d.get("switch", False))
|
|
494
|
+
target = d.get("target_agent")
|
|
495
|
+
start_new = bool(d.get("start_new", False))
|
|
496
|
+
except Exception:
|
|
497
|
+
pass
|
|
498
|
+
if target:
|
|
499
|
+
target_lower = str(target).lower()
|
|
500
|
+
norm = None
|
|
501
|
+
for a in available_agents:
|
|
502
|
+
if a.lower() == target_lower or target_lower in a.lower():
|
|
503
|
+
norm = a
|
|
504
|
+
break
|
|
505
|
+
target = norm
|
|
506
|
+
if not switch:
|
|
507
|
+
target = None
|
|
508
|
+
return (switch, target, start_new)
|
|
509
|
+
except Exception as e:
|
|
510
|
+
logger.debug(f"LLM switch intent generate_response failed: {e}")
|
|
511
|
+
|
|
512
|
+
# Last resort: no switch
|
|
513
|
+
return (False, None, False)
|
|
36
514
|
|
|
37
515
|
async def process(
|
|
38
516
|
self,
|
|
39
517
|
user_id: str,
|
|
40
518
|
query: Union[str, bytes],
|
|
519
|
+
images: Optional[List[Union[str, bytes]]] = None,
|
|
41
520
|
output_format: Literal["text", "audio"] = "text",
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
521
|
+
realtime: bool = False,
|
|
522
|
+
# Realtime minimal controls (voice/format come from audio_* args)
|
|
523
|
+
vad: Optional[bool] = None,
|
|
524
|
+
rt_encode_input: bool = False,
|
|
525
|
+
rt_encode_output: bool = False,
|
|
526
|
+
rt_output_modalities: Optional[List[Literal["audio", "text"]]] = None,
|
|
527
|
+
rt_voice: Literal[
|
|
528
|
+
"alloy",
|
|
529
|
+
"ash",
|
|
530
|
+
"ballad",
|
|
531
|
+
"cedar",
|
|
532
|
+
"coral",
|
|
533
|
+
"echo",
|
|
534
|
+
"marin",
|
|
535
|
+
"sage",
|
|
536
|
+
"shimmer",
|
|
537
|
+
"verse",
|
|
538
|
+
] = "marin",
|
|
539
|
+
# Realtime transcription configuration (new)
|
|
540
|
+
rt_transcription_model: Optional[str] = None,
|
|
541
|
+
rt_transcription_language: Optional[str] = None,
|
|
542
|
+
rt_transcription_prompt: Optional[str] = None,
|
|
543
|
+
rt_transcription_noise_reduction: Optional[bool] = None,
|
|
544
|
+
rt_transcription_include_logprobs: bool = False,
|
|
545
|
+
audio_voice: Literal[
|
|
546
|
+
"alloy",
|
|
547
|
+
"ash",
|
|
548
|
+
"ballad",
|
|
549
|
+
"coral",
|
|
550
|
+
"echo",
|
|
551
|
+
"fable",
|
|
552
|
+
"onyx",
|
|
553
|
+
"nova",
|
|
554
|
+
"sage",
|
|
555
|
+
"shimmer",
|
|
556
|
+
] = "nova",
|
|
557
|
+
audio_output_format: Literal[
|
|
558
|
+
"mp3", "opus", "aac", "flac", "wav", "pcm"
|
|
559
|
+
] = "aac",
|
|
47
560
|
audio_input_format: Literal[
|
|
48
561
|
"flac", "mp3", "mp4", "mpeg", "mpga", "m4a", "ogg", "wav", "webm"
|
|
49
562
|
] = "mp4",
|
|
50
563
|
prompt: Optional[str] = None,
|
|
51
564
|
router: Optional[RoutingServiceInterface] = None,
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
query: Text query or audio bytes
|
|
58
|
-
output_format: Response format ("text" or "audio")
|
|
59
|
-
audio_voice: Voice for TTS (text-to-speech)
|
|
60
|
-
audio_instructions: Optional instructions for TTS
|
|
61
|
-
audio_output_format: Audio output format
|
|
62
|
-
audio_input_format: Audio input format
|
|
63
|
-
prompt: Optional prompt for the agent
|
|
64
|
-
router: Optional routing service for processing
|
|
65
|
-
|
|
66
|
-
Yields:
|
|
67
|
-
Response chunks (text strings or audio bytes)
|
|
68
|
-
"""
|
|
565
|
+
output_model: Optional[Type[BaseModel]] = None,
|
|
566
|
+
capture_schema: Optional[Dict[str, Any]] = None,
|
|
567
|
+
capture_name: Optional[str] = None,
|
|
568
|
+
) -> AsyncGenerator[Union[str, bytes, BaseModel], None]: # pragma: no cover
|
|
569
|
+
"""Process the user request and generate a response."""
|
|
69
570
|
try:
|
|
70
|
-
#
|
|
571
|
+
# Realtime request: HTTP STT for user + single WS for assistant audio
|
|
572
|
+
if realtime:
|
|
573
|
+
# 1) Determine if input is audio bytes. We now ALWAYS skip HTTP STT in realtime mode.
|
|
574
|
+
# The realtime websocket session (optionally with built-in transcription) is authoritative.
|
|
575
|
+
is_audio_bytes = isinstance(query, (bytes, bytearray))
|
|
576
|
+
user_text = "" if is_audio_bytes else str(query)
|
|
577
|
+
# Provide a sensible default realtime transcription model when audio supplied
|
|
578
|
+
if is_audio_bytes and not rt_transcription_model:
|
|
579
|
+
rt_transcription_model = "gpt-4o-mini-transcribe"
|
|
580
|
+
|
|
581
|
+
# 2) Single agent selection (no multi-agent routing in realtime path)
|
|
582
|
+
agent_name = self._get_sticky_agent(user_id)
|
|
583
|
+
if not agent_name:
|
|
584
|
+
try:
|
|
585
|
+
agents = self.agent_service.get_all_ai_agents() or {}
|
|
586
|
+
agent_name = next(iter(agents.keys())) if agents else "default"
|
|
587
|
+
except Exception:
|
|
588
|
+
agent_name = "default"
|
|
589
|
+
prev_assistant = ""
|
|
590
|
+
if self.memory_provider:
|
|
591
|
+
try:
|
|
592
|
+
prev_docs = self.memory_provider.find(
|
|
593
|
+
collection="conversations",
|
|
594
|
+
query={"user_id": user_id},
|
|
595
|
+
sort=[("timestamp", -1)],
|
|
596
|
+
limit=1,
|
|
597
|
+
)
|
|
598
|
+
if prev_docs:
|
|
599
|
+
prev_assistant = (prev_docs[0] or {}).get(
|
|
600
|
+
"assistant_message", ""
|
|
601
|
+
) or ""
|
|
602
|
+
except Exception:
|
|
603
|
+
pass
|
|
604
|
+
|
|
605
|
+
# 3) Build context + tools
|
|
606
|
+
combined_ctx = ""
|
|
607
|
+
required_complete = False
|
|
608
|
+
try:
|
|
609
|
+
(
|
|
610
|
+
combined_ctx,
|
|
611
|
+
required_complete,
|
|
612
|
+
) = await self._build_combined_context(
|
|
613
|
+
user_id=user_id,
|
|
614
|
+
user_text=(user_text if not is_audio_bytes else ""),
|
|
615
|
+
agent_name=agent_name,
|
|
616
|
+
capture_name=capture_name,
|
|
617
|
+
capture_schema=capture_schema,
|
|
618
|
+
prev_assistant=prev_assistant,
|
|
619
|
+
)
|
|
620
|
+
try:
|
|
621
|
+
self._update_sticky_required_complete(
|
|
622
|
+
user_id, required_complete
|
|
623
|
+
)
|
|
624
|
+
except Exception:
|
|
625
|
+
pass
|
|
626
|
+
except Exception:
|
|
627
|
+
combined_ctx = ""
|
|
628
|
+
try:
|
|
629
|
+
# GA Realtime expects flattened tool definitions (no nested "function" object)
|
|
630
|
+
initial_tools = [
|
|
631
|
+
{
|
|
632
|
+
"type": "function",
|
|
633
|
+
"name": t["name"],
|
|
634
|
+
"description": t.get("description", ""),
|
|
635
|
+
"parameters": t.get("parameters", {}),
|
|
636
|
+
"strict": True,
|
|
637
|
+
}
|
|
638
|
+
for t in self.agent_service.get_agent_tools(agent_name)
|
|
639
|
+
]
|
|
640
|
+
except Exception:
|
|
641
|
+
initial_tools = []
|
|
642
|
+
|
|
643
|
+
# Build realtime instructions: include full agent system prompt, context, and optional prompt (no user_text)
|
|
644
|
+
system_prompt = ""
|
|
645
|
+
try:
|
|
646
|
+
system_prompt = self.agent_service.get_agent_system_prompt(
|
|
647
|
+
agent_name
|
|
648
|
+
)
|
|
649
|
+
except Exception:
|
|
650
|
+
system_prompt = ""
|
|
651
|
+
|
|
652
|
+
parts: List[str] = []
|
|
653
|
+
if system_prompt:
|
|
654
|
+
parts.append(system_prompt)
|
|
655
|
+
if combined_ctx:
|
|
656
|
+
parts.append(combined_ctx)
|
|
657
|
+
if prompt:
|
|
658
|
+
parts.append(str(prompt))
|
|
659
|
+
final_instructions = "\n\n".join([p for p in parts if p])
|
|
660
|
+
|
|
661
|
+
# 4) Open a single WS session for assistant audio
|
|
662
|
+
# Realtime imports handled inside allocator helper
|
|
663
|
+
|
|
664
|
+
api_key = None
|
|
665
|
+
try:
|
|
666
|
+
api_key = self.agent_service.llm_provider.get_api_key() # type: ignore[attr-defined]
|
|
667
|
+
except Exception:
|
|
668
|
+
pass
|
|
669
|
+
if not api_key:
|
|
670
|
+
raise ValueError("OpenAI API key is required for realtime")
|
|
671
|
+
|
|
672
|
+
# Per-user persistent WS (single session)
|
|
673
|
+
def _mime_from(fmt: str) -> str:
|
|
674
|
+
f = (fmt or "").lower()
|
|
675
|
+
return {
|
|
676
|
+
"aac": "audio/aac",
|
|
677
|
+
"mp3": "audio/mpeg",
|
|
678
|
+
"mp4": "audio/mp4",
|
|
679
|
+
"m4a": "audio/mp4",
|
|
680
|
+
"mpeg": "audio/mpeg",
|
|
681
|
+
"mpga": "audio/mpeg",
|
|
682
|
+
"wav": "audio/wav",
|
|
683
|
+
"flac": "audio/flac",
|
|
684
|
+
"opus": "audio/opus",
|
|
685
|
+
"ogg": "audio/ogg",
|
|
686
|
+
"webm": "audio/webm",
|
|
687
|
+
"pcm": "audio/pcm",
|
|
688
|
+
}.get(f, "audio/pcm")
|
|
689
|
+
|
|
690
|
+
# Choose output encoding automatically when non-PCM output is requested
|
|
691
|
+
encode_out = bool(
|
|
692
|
+
rt_encode_output or (audio_output_format.lower() != "pcm")
|
|
693
|
+
)
|
|
694
|
+
# If caller explicitly requests text-only realtime, disable output encoding entirely
|
|
695
|
+
if (
|
|
696
|
+
rt_output_modalities is not None
|
|
697
|
+
and "audio" not in rt_output_modalities
|
|
698
|
+
):
|
|
699
|
+
if encode_out:
|
|
700
|
+
logger.debug(
|
|
701
|
+
"Realtime(QueryService): forcing encode_out False for text-only modalities=%s",
|
|
702
|
+
rt_output_modalities,
|
|
703
|
+
)
|
|
704
|
+
encode_out = False
|
|
705
|
+
# Choose input transcoding when compressed input is provided (or explicitly requested)
|
|
706
|
+
is_audio_bytes = isinstance(query, (bytes, bytearray))
|
|
707
|
+
encode_in = bool(
|
|
708
|
+
rt_encode_input
|
|
709
|
+
or (is_audio_bytes and audio_input_format.lower() != "pcm")
|
|
710
|
+
)
|
|
711
|
+
|
|
712
|
+
# Allocate or reuse a realtime session for this specific request/user.
|
|
713
|
+
# (Transcription options may be applied below; if they change after allocate we will reconfigure.)
|
|
714
|
+
rt = await self._alloc_realtime_session(
|
|
715
|
+
user_id,
|
|
716
|
+
api_key=api_key,
|
|
717
|
+
rt_voice=rt_voice,
|
|
718
|
+
final_instructions=final_instructions,
|
|
719
|
+
initial_tools=initial_tools,
|
|
720
|
+
encode_in=encode_in,
|
|
721
|
+
encode_out=encode_out,
|
|
722
|
+
audio_input_format=audio_input_format,
|
|
723
|
+
audio_output_format=audio_output_format,
|
|
724
|
+
rt_output_modalities=rt_output_modalities,
|
|
725
|
+
)
|
|
726
|
+
# Ensure lock is released no matter what
|
|
727
|
+
try:
|
|
728
|
+
# --- Apply realtime transcription config BEFORE connecting (new) ---
|
|
729
|
+
if rt_transcription_model and hasattr(rt, "_options"):
|
|
730
|
+
try:
|
|
731
|
+
setattr(
|
|
732
|
+
rt._options,
|
|
733
|
+
"transcription_model",
|
|
734
|
+
rt_transcription_model,
|
|
735
|
+
)
|
|
736
|
+
if rt_transcription_language is not None:
|
|
737
|
+
setattr(
|
|
738
|
+
rt._options,
|
|
739
|
+
"transcription_language",
|
|
740
|
+
rt_transcription_language,
|
|
741
|
+
)
|
|
742
|
+
if rt_transcription_prompt is not None:
|
|
743
|
+
setattr(
|
|
744
|
+
rt._options,
|
|
745
|
+
"transcription_prompt",
|
|
746
|
+
rt_transcription_prompt,
|
|
747
|
+
)
|
|
748
|
+
if rt_transcription_noise_reduction is not None:
|
|
749
|
+
setattr(
|
|
750
|
+
rt._options,
|
|
751
|
+
"transcription_noise_reduction",
|
|
752
|
+
rt_transcription_noise_reduction,
|
|
753
|
+
)
|
|
754
|
+
if rt_transcription_include_logprobs:
|
|
755
|
+
setattr(
|
|
756
|
+
rt._options, "transcription_include_logprobs", True
|
|
757
|
+
)
|
|
758
|
+
except Exception:
|
|
759
|
+
logger.debug(
|
|
760
|
+
"Failed pre-connect transcription option assignment",
|
|
761
|
+
exc_info=True,
|
|
762
|
+
)
|
|
763
|
+
|
|
764
|
+
# Tool executor
|
|
765
|
+
async def _exec(
|
|
766
|
+
tool_name: str, args: Dict[str, Any]
|
|
767
|
+
) -> Dict[str, Any]:
|
|
768
|
+
try:
|
|
769
|
+
return await self.agent_service.execute_tool(
|
|
770
|
+
agent_name, tool_name, args or {}
|
|
771
|
+
)
|
|
772
|
+
except Exception as e:
|
|
773
|
+
return {"status": "error", "message": str(e)}
|
|
774
|
+
|
|
775
|
+
# If possible, set on underlying session
|
|
776
|
+
try:
|
|
777
|
+
if hasattr(rt, "_session"):
|
|
778
|
+
getattr(rt, "_session").set_tool_executor(_exec) # type: ignore[attr-defined]
|
|
779
|
+
except Exception:
|
|
780
|
+
pass
|
|
781
|
+
|
|
782
|
+
# Connect/configure
|
|
783
|
+
if not getattr(rt, "_connected", False):
|
|
784
|
+
await rt.start()
|
|
785
|
+
await rt.configure(
|
|
786
|
+
voice=rt_voice,
|
|
787
|
+
vad_enabled=bool(vad) if vad is not None else False,
|
|
788
|
+
instructions=final_instructions,
|
|
789
|
+
tools=initial_tools or None,
|
|
790
|
+
tool_choice="auto",
|
|
791
|
+
)
|
|
792
|
+
|
|
793
|
+
# Ensure clean input buffers for this turn
|
|
794
|
+
try:
|
|
795
|
+
await rt.clear_input()
|
|
796
|
+
except Exception:
|
|
797
|
+
pass
|
|
798
|
+
# Also reset any leftover output audio so new turn doesn't replay old chunks
|
|
799
|
+
try:
|
|
800
|
+
if hasattr(rt, "reset_output_stream"):
|
|
801
|
+
rt.reset_output_stream()
|
|
802
|
+
except Exception:
|
|
803
|
+
pass
|
|
804
|
+
|
|
805
|
+
# Begin streaming turn (defer user transcript persistence until final to avoid duplicates)
|
|
806
|
+
turn_id = await self.realtime_begin_turn(user_id)
|
|
807
|
+
# We'll buffer the full user transcript (text input or realtime audio transcription) and persist exactly once.
|
|
808
|
+
# Initialize empty; we'll build it strictly from realtime transcript segments to avoid
|
|
809
|
+
# accidental duplication with pre-supplied user_text or prior buffers.
|
|
810
|
+
final_user_tr: str = ""
|
|
811
|
+
user_persisted = False
|
|
812
|
+
|
|
813
|
+
# Feed audio into WS if audio bytes provided and audio modality requested; else treat as text
|
|
814
|
+
wants_audio = (
|
|
815
|
+
(
|
|
816
|
+
getattr(rt, "_options", None)
|
|
817
|
+
and getattr(rt, "_options").output_modalities
|
|
818
|
+
)
|
|
819
|
+
and "audio" in getattr(rt, "_options").output_modalities # type: ignore[attr-defined]
|
|
820
|
+
) or (
|
|
821
|
+
rt_output_modalities is None
|
|
822
|
+
or (rt_output_modalities and "audio" in rt_output_modalities)
|
|
823
|
+
)
|
|
824
|
+
# Determine if realtime transcription should be enabled (always skip HTTP STT regardless)
|
|
825
|
+
# realtime_transcription_enabled now implicit (options set before connect)
|
|
826
|
+
|
|
827
|
+
if is_audio_bytes and not wants_audio:
|
|
828
|
+
# Feed audio solely for transcription (no audio output requested)
|
|
829
|
+
bq = bytes(query)
|
|
830
|
+
logger.info(
|
|
831
|
+
"Realtime: appending input audio for transcription only, len=%d, fmt=%s",
|
|
832
|
+
len(bq),
|
|
833
|
+
audio_input_format,
|
|
834
|
+
)
|
|
835
|
+
await rt.append_audio(bq)
|
|
836
|
+
vad_enabled_value = bool(vad) if vad is not None else False
|
|
837
|
+
if not vad_enabled_value:
|
|
838
|
+
await rt.commit_input()
|
|
839
|
+
# Request only text response
|
|
840
|
+
await rt.create_response({"modalities": ["text"]})
|
|
841
|
+
else:
|
|
842
|
+
logger.debug(
|
|
843
|
+
"Realtime: VAD enabled (text-only output) — skipping manual response.create"
|
|
844
|
+
)
|
|
845
|
+
if is_audio_bytes and wants_audio:
|
|
846
|
+
bq = bytes(query)
|
|
847
|
+
logger.info(
|
|
848
|
+
"Realtime: appending input audio to WS via FFmpeg, len=%d, fmt=%s",
|
|
849
|
+
len(bq),
|
|
850
|
+
audio_input_format,
|
|
851
|
+
)
|
|
852
|
+
await rt.append_audio(bq)
|
|
853
|
+
vad_enabled_value = bool(vad) if vad is not None else False
|
|
854
|
+
if not vad_enabled_value:
|
|
855
|
+
await rt.commit_input()
|
|
856
|
+
# Manually trigger response when VAD is disabled
|
|
857
|
+
await rt.create_response({})
|
|
858
|
+
else:
|
|
859
|
+
# With server VAD enabled, the model will auto-create a response at end of speech
|
|
860
|
+
logger.debug(
|
|
861
|
+
"Realtime: VAD enabled — skipping manual response.create"
|
|
862
|
+
)
|
|
863
|
+
else: # Text-only path OR caller excluded audio modality
|
|
864
|
+
# For text input, create conversation item first, then response
|
|
865
|
+
await rt.create_conversation_item(
|
|
866
|
+
{
|
|
867
|
+
"type": "message",
|
|
868
|
+
"role": "user",
|
|
869
|
+
"content": [
|
|
870
|
+
{"type": "input_text", "text": user_text or ""}
|
|
871
|
+
],
|
|
872
|
+
}
|
|
873
|
+
)
|
|
874
|
+
# Determine effective modalities (fall back to provided override or text only)
|
|
875
|
+
if rt_output_modalities is not None:
|
|
876
|
+
modalities = rt_output_modalities or ["text"]
|
|
877
|
+
else:
|
|
878
|
+
mo = getattr(
|
|
879
|
+
rt, "_options", RealtimeSessionOptions()
|
|
880
|
+
).output_modalities
|
|
881
|
+
modalities = mo if mo else ["audio"]
|
|
882
|
+
if "audio" not in modalities:
|
|
883
|
+
# Ensure we do not accidentally request audio generation
|
|
884
|
+
modalities = [m for m in modalities if m == "text"] or [
|
|
885
|
+
"text"
|
|
886
|
+
]
|
|
887
|
+
await rt.create_response(
|
|
888
|
+
{
|
|
889
|
+
"modalities": modalities,
|
|
890
|
+
}
|
|
891
|
+
)
|
|
892
|
+
|
|
893
|
+
# Collect audio and transcripts
|
|
894
|
+
user_tr = "" # Accumulates realtime input transcript segments (audio path)
|
|
895
|
+
asst_tr = ""
|
|
896
|
+
|
|
897
|
+
input_segments: List[str] = []
|
|
898
|
+
|
|
899
|
+
async def _drain_in_tr():
|
|
900
|
+
"""Accumulate realtime input transcript segments, de-duplicating cumulative repeats.
|
|
901
|
+
|
|
902
|
+
Some realtime providers emit growing cumulative transcripts (e.g. "Hel", "Hello") or
|
|
903
|
+
may occasionally resend the full final transcript. Previous logic naively concatenated
|
|
904
|
+
every segment which could yield duplicated text ("HelloHello") if cumulative or repeated
|
|
905
|
+
finals were received. This routine keeps a canonical buffer (user_tr) and only appends
|
|
906
|
+
the non-overlapping suffix of each new segment.
|
|
907
|
+
"""
|
|
908
|
+
nonlocal user_tr
|
|
909
|
+
async for t in rt.iter_input_transcript():
|
|
910
|
+
if not t:
|
|
911
|
+
continue
|
|
912
|
+
# Track raw segment for optional debugging
|
|
913
|
+
input_segments.append(t)
|
|
914
|
+
if not user_tr:
|
|
915
|
+
user_tr = t
|
|
916
|
+
continue
|
|
917
|
+
if t == user_tr:
|
|
918
|
+
# Exact duplicate of current buffer; skip
|
|
919
|
+
continue
|
|
920
|
+
if t.startswith(user_tr):
|
|
921
|
+
# Cumulative growth; append only the new suffix
|
|
922
|
+
user_tr += t[len(user_tr) :]
|
|
923
|
+
continue
|
|
924
|
+
# General case: find largest overlap between end of user_tr and start of t
|
|
925
|
+
# to avoid duplicated middle content (e.g., user_tr="My name is", t="name is John")
|
|
926
|
+
overlap = 0
|
|
927
|
+
max_check = min(len(user_tr), len(t))
|
|
928
|
+
for k in range(max_check, 0, -1):
|
|
929
|
+
if user_tr.endswith(t[:k]):
|
|
930
|
+
overlap = k
|
|
931
|
+
break
|
|
932
|
+
user_tr += t[overlap:]
|
|
933
|
+
|
|
934
|
+
# Check if we need both audio and text modalities
|
|
935
|
+
modalities = getattr(
|
|
936
|
+
rt, "_options", RealtimeSessionOptions()
|
|
937
|
+
).output_modalities or ["audio"]
|
|
938
|
+
use_combined_stream = "audio" in modalities and "text" in modalities
|
|
939
|
+
|
|
940
|
+
if use_combined_stream and wants_audio:
|
|
941
|
+
# Use combined stream for both modalities
|
|
942
|
+
async def _drain_out_tr():
|
|
943
|
+
nonlocal asst_tr
|
|
944
|
+
async for t in rt.iter_output_transcript():
|
|
945
|
+
if t:
|
|
946
|
+
asst_tr += t
|
|
947
|
+
|
|
948
|
+
in_task = asyncio.create_task(_drain_in_tr())
|
|
949
|
+
out_task = asyncio.create_task(_drain_out_tr())
|
|
950
|
+
try:
|
|
951
|
+
# Check if the service has iter_output_combined method
|
|
952
|
+
if hasattr(rt, "iter_output_combined"):
|
|
953
|
+
async for chunk in rt.iter_output_combined():
|
|
954
|
+
# Adapt output based on caller's requested output_format
|
|
955
|
+
if output_format == "text":
|
|
956
|
+
# Only yield text modalities as plain strings
|
|
957
|
+
if getattr(chunk, "modality", None) == "text":
|
|
958
|
+
yield chunk.data # type: ignore[attr-defined]
|
|
959
|
+
continue
|
|
960
|
+
# Audio streaming path
|
|
961
|
+
if getattr(chunk, "modality", None) == "audio":
|
|
962
|
+
# Yield raw bytes if data present
|
|
963
|
+
yield getattr(chunk, "data", b"")
|
|
964
|
+
elif (
|
|
965
|
+
getattr(chunk, "modality", None) == "text"
|
|
966
|
+
and output_format == "audio"
|
|
967
|
+
):
|
|
968
|
+
# Optionally ignore or log text while audio requested
|
|
969
|
+
continue
|
|
970
|
+
else:
|
|
971
|
+
# Fallback: ignore unknown modalities for now
|
|
972
|
+
continue
|
|
973
|
+
else:
|
|
974
|
+
# Fallback: yield audio chunks as RealtimeChunk objects
|
|
975
|
+
async for audio_chunk in rt.iter_output_audio_encoded():
|
|
976
|
+
if output_format == "text":
|
|
977
|
+
# Ignore audio when text requested
|
|
978
|
+
continue
|
|
979
|
+
# output_format audio: provide raw bytes
|
|
980
|
+
if hasattr(audio_chunk, "modality"):
|
|
981
|
+
if (
|
|
982
|
+
getattr(audio_chunk, "modality", None)
|
|
983
|
+
== "audio"
|
|
984
|
+
):
|
|
985
|
+
yield getattr(audio_chunk, "data", b"")
|
|
986
|
+
else:
|
|
987
|
+
yield audio_chunk
|
|
988
|
+
finally:
|
|
989
|
+
# Allow transcript drain tasks to finish to capture user/asst text before persistence
|
|
990
|
+
try:
|
|
991
|
+
await asyncio.wait_for(in_task, timeout=0.05)
|
|
992
|
+
except Exception:
|
|
993
|
+
in_task.cancel()
|
|
994
|
+
try:
|
|
995
|
+
await asyncio.wait_for(out_task, timeout=0.05)
|
|
996
|
+
except Exception:
|
|
997
|
+
out_task.cancel()
|
|
998
|
+
# HTTP STT path removed: realtime audio input transcript (if any) is authoritative
|
|
999
|
+
# Persist transcripts after combined streaming completes
|
|
1000
|
+
if turn_id:
|
|
1001
|
+
try:
|
|
1002
|
+
effective_user_tr = user_tr or ("".join(input_segments))
|
|
1003
|
+
try:
|
|
1004
|
+
setattr(
|
|
1005
|
+
self,
|
|
1006
|
+
"_last_realtime_user_transcript",
|
|
1007
|
+
effective_user_tr,
|
|
1008
|
+
)
|
|
1009
|
+
except Exception:
|
|
1010
|
+
pass
|
|
1011
|
+
if effective_user_tr:
|
|
1012
|
+
final_user_tr = effective_user_tr
|
|
1013
|
+
elif (
|
|
1014
|
+
isinstance(query, str)
|
|
1015
|
+
and query
|
|
1016
|
+
and not input_segments
|
|
1017
|
+
and not user_tr
|
|
1018
|
+
):
|
|
1019
|
+
final_user_tr = query
|
|
1020
|
+
if asst_tr:
|
|
1021
|
+
await self.realtime_update_assistant(
|
|
1022
|
+
user_id, turn_id, asst_tr
|
|
1023
|
+
)
|
|
1024
|
+
except Exception:
|
|
1025
|
+
pass
|
|
1026
|
+
if final_user_tr and not user_persisted:
|
|
1027
|
+
try:
|
|
1028
|
+
await self.realtime_update_user(
|
|
1029
|
+
user_id, turn_id, final_user_tr
|
|
1030
|
+
)
|
|
1031
|
+
user_persisted = True
|
|
1032
|
+
except Exception:
|
|
1033
|
+
pass
|
|
1034
|
+
try:
|
|
1035
|
+
await self.realtime_finalize_turn(user_id, turn_id)
|
|
1036
|
+
except Exception:
|
|
1037
|
+
pass
|
|
1038
|
+
if final_user_tr and not user_persisted:
|
|
1039
|
+
try:
|
|
1040
|
+
await self.realtime_update_user(
|
|
1041
|
+
user_id, turn_id, final_user_tr
|
|
1042
|
+
)
|
|
1043
|
+
user_persisted = True
|
|
1044
|
+
except Exception:
|
|
1045
|
+
pass
|
|
1046
|
+
elif wants_audio:
|
|
1047
|
+
# Use separate streams (legacy behavior)
|
|
1048
|
+
async def _drain_out_tr():
|
|
1049
|
+
nonlocal asst_tr
|
|
1050
|
+
async for t in rt.iter_output_transcript():
|
|
1051
|
+
if t:
|
|
1052
|
+
asst_tr += t
|
|
1053
|
+
|
|
1054
|
+
in_task = asyncio.create_task(_drain_in_tr())
|
|
1055
|
+
out_task = asyncio.create_task(_drain_out_tr())
|
|
1056
|
+
try:
|
|
1057
|
+
async for audio_chunk in rt.iter_output_audio_encoded():
|
|
1058
|
+
if output_format == "text":
|
|
1059
|
+
# Skip audio when caller wants text only
|
|
1060
|
+
continue
|
|
1061
|
+
# output_format audio: yield raw bytes
|
|
1062
|
+
if hasattr(audio_chunk, "modality"):
|
|
1063
|
+
if (
|
|
1064
|
+
getattr(audio_chunk, "modality", None)
|
|
1065
|
+
== "audio"
|
|
1066
|
+
):
|
|
1067
|
+
yield getattr(audio_chunk, "data", b"")
|
|
1068
|
+
else:
|
|
1069
|
+
yield audio_chunk
|
|
1070
|
+
finally:
|
|
1071
|
+
try:
|
|
1072
|
+
await asyncio.wait_for(in_task, timeout=0.05)
|
|
1073
|
+
except Exception:
|
|
1074
|
+
in_task.cancel()
|
|
1075
|
+
try:
|
|
1076
|
+
await asyncio.wait_for(out_task, timeout=0.05)
|
|
1077
|
+
except Exception:
|
|
1078
|
+
out_task.cancel()
|
|
1079
|
+
# HTTP STT path removed
|
|
1080
|
+
# Persist transcripts after audio-only streaming
|
|
1081
|
+
if turn_id:
|
|
1082
|
+
try:
|
|
1083
|
+
effective_user_tr = user_tr or ("".join(input_segments))
|
|
1084
|
+
try:
|
|
1085
|
+
setattr(
|
|
1086
|
+
self,
|
|
1087
|
+
"_last_realtime_user_transcript",
|
|
1088
|
+
effective_user_tr,
|
|
1089
|
+
)
|
|
1090
|
+
except Exception:
|
|
1091
|
+
pass
|
|
1092
|
+
# Buffer final transcript for single persistence
|
|
1093
|
+
if effective_user_tr:
|
|
1094
|
+
final_user_tr = effective_user_tr
|
|
1095
|
+
elif (
|
|
1096
|
+
isinstance(query, str)
|
|
1097
|
+
and query
|
|
1098
|
+
and not input_segments
|
|
1099
|
+
and not user_tr
|
|
1100
|
+
):
|
|
1101
|
+
final_user_tr = query
|
|
1102
|
+
if asst_tr:
|
|
1103
|
+
await self.realtime_update_assistant(
|
|
1104
|
+
user_id, turn_id, asst_tr
|
|
1105
|
+
)
|
|
1106
|
+
except Exception:
|
|
1107
|
+
pass
|
|
1108
|
+
if final_user_tr and not user_persisted:
|
|
1109
|
+
try:
|
|
1110
|
+
await self.realtime_update_user(
|
|
1111
|
+
user_id, turn_id, final_user_tr
|
|
1112
|
+
)
|
|
1113
|
+
user_persisted = True
|
|
1114
|
+
except Exception:
|
|
1115
|
+
pass
|
|
1116
|
+
try:
|
|
1117
|
+
await self.realtime_finalize_turn(user_id, turn_id)
|
|
1118
|
+
except Exception:
|
|
1119
|
+
pass
|
|
1120
|
+
# If no WS input transcript was captured, fall back to HTTP STT result
|
|
1121
|
+
else:
|
|
1122
|
+
# Text-only: just stream assistant transcript if available (no audio iteration)
|
|
1123
|
+
# If original input was audio bytes but caller only wants text output (no audio modality),
|
|
1124
|
+
# we still need to drain the input transcript stream to build user_tr.
|
|
1125
|
+
in_task_audio_only = None
|
|
1126
|
+
if is_audio_bytes:
|
|
1127
|
+
in_task_audio_only = asyncio.create_task(_drain_in_tr())
|
|
1128
|
+
|
|
1129
|
+
async def _drain_out_tr_text():
|
|
1130
|
+
nonlocal asst_tr
|
|
1131
|
+
async for t in rt.iter_output_transcript():
|
|
1132
|
+
if t:
|
|
1133
|
+
asst_tr += t
|
|
1134
|
+
yield t # Yield incremental text chunks directly
|
|
1135
|
+
|
|
1136
|
+
async for t in _drain_out_tr_text():
|
|
1137
|
+
# Provide plain text to caller
|
|
1138
|
+
yield t
|
|
1139
|
+
# Wait for input transcript (if any) before persistence
|
|
1140
|
+
if "in_task_audio_only" in locals() and in_task_audio_only:
|
|
1141
|
+
try:
|
|
1142
|
+
await asyncio.wait_for(in_task_audio_only, timeout=0.1)
|
|
1143
|
+
except Exception:
|
|
1144
|
+
in_task_audio_only.cancel()
|
|
1145
|
+
# No HTTP STT fallback
|
|
1146
|
+
if turn_id:
|
|
1147
|
+
try:
|
|
1148
|
+
effective_user_tr = user_tr or ("".join(input_segments))
|
|
1149
|
+
try:
|
|
1150
|
+
setattr(
|
|
1151
|
+
self,
|
|
1152
|
+
"_last_realtime_user_transcript",
|
|
1153
|
+
effective_user_tr,
|
|
1154
|
+
)
|
|
1155
|
+
except Exception:
|
|
1156
|
+
pass
|
|
1157
|
+
# For text-only modality but audio-origin (cumulative segments captured), persist user transcript
|
|
1158
|
+
if effective_user_tr:
|
|
1159
|
+
final_user_tr = effective_user_tr
|
|
1160
|
+
elif (
|
|
1161
|
+
isinstance(query, str)
|
|
1162
|
+
and query
|
|
1163
|
+
and not input_segments
|
|
1164
|
+
and not user_tr
|
|
1165
|
+
):
|
|
1166
|
+
final_user_tr = query
|
|
1167
|
+
if asst_tr:
|
|
1168
|
+
await self.realtime_update_assistant(
|
|
1169
|
+
user_id, turn_id, asst_tr
|
|
1170
|
+
)
|
|
1171
|
+
except Exception:
|
|
1172
|
+
pass
|
|
1173
|
+
if final_user_tr and not user_persisted:
|
|
1174
|
+
try:
|
|
1175
|
+
await self.realtime_update_user(
|
|
1176
|
+
user_id, turn_id, final_user_tr
|
|
1177
|
+
)
|
|
1178
|
+
user_persisted = True
|
|
1179
|
+
except Exception:
|
|
1180
|
+
pass
|
|
1181
|
+
try:
|
|
1182
|
+
await self.realtime_finalize_turn(user_id, turn_id)
|
|
1183
|
+
except Exception:
|
|
1184
|
+
pass
|
|
1185
|
+
# Input transcript task already awaited above
|
|
1186
|
+
# Clear input buffer for next turn reuse
|
|
1187
|
+
try:
|
|
1188
|
+
await rt.clear_input()
|
|
1189
|
+
except Exception:
|
|
1190
|
+
pass
|
|
1191
|
+
finally:
|
|
1192
|
+
# Always release the session for reuse by other concurrent requests/devices
|
|
1193
|
+
try:
|
|
1194
|
+
lock = getattr(rt, "_in_use_lock", None)
|
|
1195
|
+
if lock and lock.locked():
|
|
1196
|
+
lock.release()
|
|
1197
|
+
except Exception:
|
|
1198
|
+
pass
|
|
1199
|
+
return
|
|
1200
|
+
|
|
1201
|
+
# 1) Acquire user_text (transcribe audio or direct text) for non-realtime path
|
|
71
1202
|
user_text = ""
|
|
72
1203
|
if not isinstance(query, str):
|
|
73
|
-
|
|
74
|
-
|
|
1204
|
+
try:
|
|
1205
|
+
logger.info(
|
|
1206
|
+
f"Received audio input, transcribing format: {audio_input_format}"
|
|
1207
|
+
)
|
|
1208
|
+
async for tpart in self.agent_service.llm_provider.transcribe_audio( # type: ignore[attr-defined]
|
|
1209
|
+
query, audio_input_format
|
|
1210
|
+
):
|
|
1211
|
+
user_text += tpart
|
|
1212
|
+
except Exception:
|
|
1213
|
+
user_text = ""
|
|
75
1214
|
else:
|
|
76
1215
|
user_text = query
|
|
77
1216
|
|
|
78
|
-
#
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
voice=audio_voice,
|
|
85
|
-
response_format=audio_output_format
|
|
86
|
-
):
|
|
87
|
-
yield chunk
|
|
88
|
-
else:
|
|
89
|
-
yield response
|
|
90
|
-
|
|
91
|
-
# Store simple interaction in memory
|
|
92
|
-
if self.memory_provider:
|
|
93
|
-
await self._store_conversation(user_id, user_text, response)
|
|
94
|
-
return
|
|
1217
|
+
# 2) Input guardrails
|
|
1218
|
+
for guardrail in self.input_guardrails:
|
|
1219
|
+
try:
|
|
1220
|
+
user_text = await guardrail.process(user_text)
|
|
1221
|
+
except Exception as e:
|
|
1222
|
+
logger.debug(f"Guardrail error: {e}")
|
|
95
1223
|
|
|
96
|
-
#
|
|
1224
|
+
# 3) Memory context (conversation history)
|
|
97
1225
|
memory_context = ""
|
|
98
1226
|
if self.memory_provider:
|
|
99
|
-
|
|
1227
|
+
try:
|
|
1228
|
+
memory_context = await self.memory_provider.retrieve(user_id)
|
|
1229
|
+
except Exception:
|
|
1230
|
+
memory_context = ""
|
|
1231
|
+
|
|
1232
|
+
# 4) Knowledge base context
|
|
1233
|
+
kb_context = ""
|
|
1234
|
+
if self.knowledge_base:
|
|
1235
|
+
try:
|
|
1236
|
+
kb_results = await self.knowledge_base.query(
|
|
1237
|
+
query_text=user_text,
|
|
1238
|
+
top_k=self.kb_results_count,
|
|
1239
|
+
include_content=True,
|
|
1240
|
+
include_metadata=False,
|
|
1241
|
+
)
|
|
1242
|
+
if kb_results:
|
|
1243
|
+
kb_lines = [
|
|
1244
|
+
"**KNOWLEDGE BASE (CRITICAL: MAKE THIS INFORMATION THE TOP PRIORITY):**"
|
|
1245
|
+
]
|
|
1246
|
+
for i, r in enumerate(kb_results, 1):
|
|
1247
|
+
kb_lines.append(f"[{i}] {r.get('content', '').strip()}\n")
|
|
1248
|
+
kb_context = "\n".join(kb_lines)
|
|
1249
|
+
except Exception:
|
|
1250
|
+
kb_context = ""
|
|
100
1251
|
|
|
101
|
-
#
|
|
102
|
-
|
|
103
|
-
|
|
1252
|
+
# 5) Determine agent (sticky session aware; allow explicit switch/new conversation)
|
|
1253
|
+
agent_name = "default"
|
|
1254
|
+
prev_assistant = ""
|
|
1255
|
+
routing_input = user_text
|
|
1256
|
+
if self.memory_provider:
|
|
1257
|
+
try:
|
|
1258
|
+
prev_docs = self.memory_provider.find(
|
|
1259
|
+
collection="conversations",
|
|
1260
|
+
query={"user_id": user_id},
|
|
1261
|
+
sort=[("timestamp", -1)],
|
|
1262
|
+
limit=1,
|
|
1263
|
+
)
|
|
1264
|
+
if prev_docs:
|
|
1265
|
+
prev_user_msg = (prev_docs[0] or {}).get(
|
|
1266
|
+
"user_message", ""
|
|
1267
|
+
) or ""
|
|
1268
|
+
prev_assistant = (prev_docs[0] or {}).get(
|
|
1269
|
+
"assistant_message", ""
|
|
1270
|
+
) or ""
|
|
1271
|
+
if prev_user_msg:
|
|
1272
|
+
routing_input = f"previous_user_message: {prev_user_msg}\ncurrent_user_message: {user_text}"
|
|
1273
|
+
except Exception:
|
|
1274
|
+
pass
|
|
1275
|
+
|
|
1276
|
+
# Get available agents first so the LLM can select a valid target
|
|
1277
|
+
agents = self.agent_service.get_all_ai_agents() or {}
|
|
1278
|
+
available_agent_names = list(agents.keys())
|
|
1279
|
+
|
|
1280
|
+
# Fast path: if only one agent, skip all routing logic entirely
|
|
1281
|
+
if len(available_agent_names) == 1:
|
|
1282
|
+
agent_name = available_agent_names[0]
|
|
1283
|
+
self._set_sticky_agent(user_id, agent_name, required_complete=False)
|
|
104
1284
|
else:
|
|
105
|
-
|
|
1285
|
+
# LLM detects switch intent (only needed with multiple agents)
|
|
1286
|
+
(
|
|
1287
|
+
switch_requested,
|
|
1288
|
+
requested_agent_raw,
|
|
1289
|
+
start_new,
|
|
1290
|
+
) = await self._detect_switch_intent(user_text, available_agent_names)
|
|
1291
|
+
|
|
1292
|
+
# Normalize requested agent to an exact available key
|
|
1293
|
+
requested_agent = None
|
|
1294
|
+
if requested_agent_raw:
|
|
1295
|
+
raw_lower = requested_agent_raw.lower()
|
|
1296
|
+
for a in available_agent_names:
|
|
1297
|
+
if a.lower() == raw_lower or raw_lower in a.lower():
|
|
1298
|
+
requested_agent = a
|
|
1299
|
+
break
|
|
1300
|
+
|
|
1301
|
+
sticky_agent = self._get_sticky_agent(user_id)
|
|
1302
|
+
|
|
1303
|
+
if sticky_agent and not switch_requested:
|
|
1304
|
+
agent_name = sticky_agent
|
|
1305
|
+
else:
|
|
1306
|
+
try:
|
|
1307
|
+
if start_new:
|
|
1308
|
+
# Start fresh
|
|
1309
|
+
self._clear_sticky_agent(user_id)
|
|
1310
|
+
if requested_agent:
|
|
1311
|
+
agent_name = requested_agent
|
|
1312
|
+
else:
|
|
1313
|
+
# Route if no explicit target
|
|
1314
|
+
if router:
|
|
1315
|
+
agent_name = await router.route_query(routing_input)
|
|
1316
|
+
else:
|
|
1317
|
+
agent_name = await self.routing_service.route_query(
|
|
1318
|
+
routing_input
|
|
1319
|
+
)
|
|
1320
|
+
except Exception:
|
|
1321
|
+
agent_name = next(iter(agents.keys())) if agents else "default"
|
|
1322
|
+
self._set_sticky_agent(user_id, agent_name, required_complete=False)
|
|
1323
|
+
|
|
1324
|
+
# 7) Captured data context + incremental save using previous assistant message
|
|
1325
|
+
capture_context = ""
|
|
1326
|
+
# Two completion flags:
|
|
1327
|
+
required_complete = False
|
|
1328
|
+
form_complete = False # required + optional
|
|
1329
|
+
|
|
1330
|
+
# Helpers
|
|
1331
|
+
def _non_empty(v: Any) -> bool:
|
|
1332
|
+
if v is None:
|
|
1333
|
+
return False
|
|
1334
|
+
if isinstance(v, str):
|
|
1335
|
+
s = v.strip().lower()
|
|
1336
|
+
return s not in {"", "null", "none", "n/a", "na", "undefined", "."}
|
|
1337
|
+
if isinstance(v, (list, dict, tuple, set)):
|
|
1338
|
+
return len(v) > 0
|
|
1339
|
+
return True
|
|
1340
|
+
|
|
1341
|
+
def _parse_numbers_list(s: str) -> List[str]:
|
|
1342
|
+
nums = re.findall(r"\b(\d+)\b", s)
|
|
1343
|
+
seen, out = set(), []
|
|
1344
|
+
for n in nums:
|
|
1345
|
+
if n not in seen:
|
|
1346
|
+
seen.add(n)
|
|
1347
|
+
out.append(n)
|
|
1348
|
+
return out
|
|
1349
|
+
|
|
1350
|
+
def _extract_numbered_options(text: str) -> Dict[str, str]:
|
|
1351
|
+
"""Parse previous assistant message for lines like:
|
|
1352
|
+
'1) Foo', '1. Foo', '- 1) Foo', '* 1. Foo' -> {'1': 'Foo'}"""
|
|
1353
|
+
options: Dict[str, str] = {}
|
|
1354
|
+
if not text:
|
|
1355
|
+
return options
|
|
1356
|
+
for raw in text.splitlines():
|
|
1357
|
+
line = raw.strip()
|
|
1358
|
+
if not line:
|
|
1359
|
+
continue
|
|
1360
|
+
m = re.match(r"^(?:[-*]\s*)?(\d+)[\.)]?\s+(.*)$", line)
|
|
1361
|
+
if m:
|
|
1362
|
+
idx, label = m.group(1), m.group(2).strip().rstrip()
|
|
1363
|
+
if len(label) >= 1:
|
|
1364
|
+
options[idx] = label
|
|
1365
|
+
return options
|
|
1366
|
+
|
|
1367
|
+
# LLM-backed field detection (gpt-4.1-mini) with graceful fallbacks
|
|
1368
|
+
class _FieldDetect(BaseModel):
|
|
1369
|
+
field: Optional[str] = None
|
|
1370
|
+
|
|
1371
|
+
async def _detect_field_from_prev_question(
|
|
1372
|
+
prev_text: str, schema: Optional[Dict[str, Any]]
|
|
1373
|
+
) -> Optional[str]:
|
|
1374
|
+
if not prev_text or not isinstance(schema, dict):
|
|
1375
|
+
return None
|
|
1376
|
+
props = list((schema.get("properties") or {}).keys())
|
|
1377
|
+
if not props:
|
|
1378
|
+
return None
|
|
1379
|
+
|
|
1380
|
+
question = prev_text.strip()
|
|
1381
|
+
instruction = (
|
|
1382
|
+
"You are a strict classifier. Given the assistant's last question and a list of "
|
|
1383
|
+
"permitted schema field keys, choose exactly one key that the question is asking the user to answer. "
|
|
1384
|
+
"If none apply, return null."
|
|
1385
|
+
)
|
|
1386
|
+
user_prompt = (
|
|
1387
|
+
f"Schema field keys (choose exactly one of these): {props}\n"
|
|
1388
|
+
f"Assistant question:\n{question}\n\n"
|
|
1389
|
+
'Return strictly JSON like: {"field": "<one_of_the_keys_or_null>"}'
|
|
1390
|
+
)
|
|
1391
|
+
|
|
1392
|
+
# Try llm_provider.parse_structured_output with mini
|
|
1393
|
+
try:
|
|
1394
|
+
if hasattr(
|
|
1395
|
+
self.agent_service.llm_provider, "parse_structured_output"
|
|
1396
|
+
):
|
|
1397
|
+
try:
|
|
1398
|
+
result = await self.agent_service.llm_provider.parse_structured_output(
|
|
1399
|
+
prompt=user_prompt,
|
|
1400
|
+
system_prompt=instruction,
|
|
1401
|
+
model_class=_FieldDetect,
|
|
1402
|
+
model="gpt-4.1-mini",
|
|
1403
|
+
)
|
|
1404
|
+
except TypeError:
|
|
1405
|
+
# Provider may not accept 'model' kwarg
|
|
1406
|
+
result = await self.agent_service.llm_provider.parse_structured_output(
|
|
1407
|
+
prompt=user_prompt,
|
|
1408
|
+
system_prompt=instruction,
|
|
1409
|
+
model_class=_FieldDetect,
|
|
1410
|
+
)
|
|
1411
|
+
sel = None
|
|
1412
|
+
try:
|
|
1413
|
+
sel = getattr(result, "field", None)
|
|
1414
|
+
except Exception:
|
|
1415
|
+
sel = None
|
|
1416
|
+
if sel is None:
|
|
1417
|
+
try:
|
|
1418
|
+
d = result.model_dump()
|
|
1419
|
+
sel = d.get("field")
|
|
1420
|
+
except Exception:
|
|
1421
|
+
sel = None
|
|
1422
|
+
if sel in props:
|
|
1423
|
+
return sel
|
|
1424
|
+
except Exception as e:
|
|
1425
|
+
logger.debug(
|
|
1426
|
+
f"LLM parse_structured_output field detection failed: {e}"
|
|
1427
|
+
)
|
|
1428
|
+
|
|
1429
|
+
# Fallback: use generate_response with output_model=_FieldDetect
|
|
1430
|
+
try:
|
|
1431
|
+
async for r in self.agent_service.generate_response(
|
|
1432
|
+
agent_name=agent_name,
|
|
1433
|
+
user_id=user_id,
|
|
1434
|
+
query=user_text,
|
|
1435
|
+
images=images,
|
|
1436
|
+
memory_context="",
|
|
1437
|
+
output_format="text",
|
|
1438
|
+
prompt=f"{instruction}\n\n{user_prompt}",
|
|
1439
|
+
output_model=_FieldDetect,
|
|
1440
|
+
):
|
|
1441
|
+
fd = r
|
|
1442
|
+
sel = None
|
|
1443
|
+
try:
|
|
1444
|
+
sel = fd.field # type: ignore[attr-defined]
|
|
1445
|
+
except Exception:
|
|
1446
|
+
try:
|
|
1447
|
+
d = fd.model_dump()
|
|
1448
|
+
sel = d.get("field")
|
|
1449
|
+
except Exception:
|
|
1450
|
+
sel = None
|
|
1451
|
+
if sel in props:
|
|
1452
|
+
return sel
|
|
1453
|
+
break
|
|
1454
|
+
except Exception as e:
|
|
1455
|
+
logger.debug(f"LLM generate_response field detection failed: {e}")
|
|
1456
|
+
|
|
1457
|
+
# Final heuristic fallback (keeps system working if LLM unavailable)
|
|
1458
|
+
t = question.lower()
|
|
1459
|
+
for key in props:
|
|
1460
|
+
if key in t:
|
|
1461
|
+
return key
|
|
1462
|
+
return None
|
|
1463
|
+
|
|
1464
|
+
# Resolve active capture from args or agent config
|
|
1465
|
+
active_capture_name = capture_name
|
|
1466
|
+
active_capture_schema = capture_schema
|
|
1467
|
+
if not active_capture_name or not active_capture_schema:
|
|
1468
|
+
try:
|
|
1469
|
+
cap_cfg = self.agent_service.get_agent_capture(agent_name)
|
|
1470
|
+
if cap_cfg:
|
|
1471
|
+
active_capture_name = active_capture_name or cap_cfg.get("name")
|
|
1472
|
+
active_capture_schema = active_capture_schema or cap_cfg.get(
|
|
1473
|
+
"schema"
|
|
1474
|
+
)
|
|
1475
|
+
except Exception:
|
|
1476
|
+
pass
|
|
1477
|
+
|
|
1478
|
+
latest_by_name: Dict[str, Dict[str, Any]] = {}
|
|
1479
|
+
if self.memory_provider:
|
|
1480
|
+
try:
|
|
1481
|
+
docs = self.memory_provider.find(
|
|
1482
|
+
collection="captures",
|
|
1483
|
+
query={"user_id": user_id},
|
|
1484
|
+
sort=[("timestamp", -1)],
|
|
1485
|
+
limit=100,
|
|
1486
|
+
)
|
|
1487
|
+
for d in docs or []:
|
|
1488
|
+
name = (d or {}).get("capture_name")
|
|
1489
|
+
if not name or name in latest_by_name:
|
|
1490
|
+
continue
|
|
1491
|
+
latest_by_name[name] = {
|
|
1492
|
+
"data": (d or {}).get("data", {}) or {},
|
|
1493
|
+
"mode": (d or {}).get("mode", "once"),
|
|
1494
|
+
"agent": (d or {}).get("agent_name"),
|
|
1495
|
+
}
|
|
1496
|
+
except Exception:
|
|
1497
|
+
pass
|
|
1498
|
+
|
|
1499
|
+
# Incremental save: use prev_assistant's numbered list to map numeric reply -> labels
|
|
1500
|
+
incremental: Dict[str, Any] = {}
|
|
1501
|
+
try:
|
|
1502
|
+
if (
|
|
1503
|
+
self.memory_provider
|
|
1504
|
+
and active_capture_name
|
|
1505
|
+
and isinstance(active_capture_schema, dict)
|
|
1506
|
+
):
|
|
1507
|
+
props = (active_capture_schema or {}).get("properties", {})
|
|
1508
|
+
required_fields = list(
|
|
1509
|
+
(active_capture_schema or {}).get("required", []) or []
|
|
1510
|
+
)
|
|
1511
|
+
all_fields = list(props.keys())
|
|
1512
|
+
optional_fields = [
|
|
1513
|
+
f for f in all_fields if f not in set(required_fields)
|
|
1514
|
+
]
|
|
1515
|
+
|
|
1516
|
+
active_data_existing = (
|
|
1517
|
+
latest_by_name.get(active_capture_name, {}) or {}
|
|
1518
|
+
).get("data", {}) or {}
|
|
1519
|
+
|
|
1520
|
+
def _missing(fields: List[str]) -> List[str]:
|
|
1521
|
+
return [
|
|
1522
|
+
f
|
|
1523
|
+
for f in fields
|
|
1524
|
+
if not _non_empty(active_data_existing.get(f))
|
|
1525
|
+
]
|
|
1526
|
+
|
|
1527
|
+
missing_required = _missing(required_fields)
|
|
1528
|
+
missing_optional = _missing(optional_fields)
|
|
1529
|
+
|
|
1530
|
+
target_field: Optional[
|
|
1531
|
+
str
|
|
1532
|
+
] = await _detect_field_from_prev_question(
|
|
1533
|
+
prev_assistant, active_capture_schema
|
|
1534
|
+
)
|
|
1535
|
+
if not target_field:
|
|
1536
|
+
# If exactly one required missing, target it; else if none required missing and exactly one optional missing, target it.
|
|
1537
|
+
if len(missing_required) == 1:
|
|
1538
|
+
target_field = missing_required[0]
|
|
1539
|
+
elif len(missing_required) == 0 and len(missing_optional) == 1:
|
|
1540
|
+
target_field = missing_optional[0]
|
|
1541
|
+
|
|
1542
|
+
if target_field and target_field in props:
|
|
1543
|
+
f_schema = props.get(target_field, {}) or {}
|
|
1544
|
+
f_type = f_schema.get("type")
|
|
1545
|
+
number_to_label = _extract_numbered_options(prev_assistant)
|
|
1546
|
+
|
|
1547
|
+
if number_to_label:
|
|
1548
|
+
nums = _parse_numbers_list(user_text)
|
|
1549
|
+
labels = [
|
|
1550
|
+
number_to_label[n] for n in nums if n in number_to_label
|
|
1551
|
+
]
|
|
1552
|
+
if labels:
|
|
1553
|
+
if f_type == "array":
|
|
1554
|
+
incremental[target_field] = labels
|
|
1555
|
+
else:
|
|
1556
|
+
incremental[target_field] = labels[0]
|
|
1557
|
+
|
|
1558
|
+
if target_field not in incremental:
|
|
1559
|
+
if f_type == "number":
|
|
1560
|
+
m = re.search(r"\b([0-9]+(?:\.[0-9]+)?)\b", user_text)
|
|
1561
|
+
if m:
|
|
1562
|
+
try:
|
|
1563
|
+
incremental[target_field] = float(m.group(1))
|
|
1564
|
+
except Exception:
|
|
1565
|
+
pass
|
|
1566
|
+
elif f_type == "array":
|
|
1567
|
+
parts = [
|
|
1568
|
+
p.strip()
|
|
1569
|
+
for p in re.split(r"[,\n;]+", user_text)
|
|
1570
|
+
if p.strip()
|
|
1571
|
+
]
|
|
1572
|
+
if parts:
|
|
1573
|
+
incremental[target_field] = parts
|
|
1574
|
+
else:
|
|
1575
|
+
if user_text.strip():
|
|
1576
|
+
incremental[target_field] = user_text.strip()
|
|
1577
|
+
|
|
1578
|
+
if incremental:
|
|
1579
|
+
cleaned = {
|
|
1580
|
+
k: v for k, v in incremental.items() if _non_empty(v)
|
|
1581
|
+
}
|
|
1582
|
+
if cleaned:
|
|
1583
|
+
try:
|
|
1584
|
+
await self.memory_provider.save_capture(
|
|
1585
|
+
user_id=user_id,
|
|
1586
|
+
capture_name=active_capture_name,
|
|
1587
|
+
agent_name=agent_name,
|
|
1588
|
+
data=cleaned,
|
|
1589
|
+
schema=active_capture_schema,
|
|
1590
|
+
)
|
|
1591
|
+
except Exception as se:
|
|
1592
|
+
logger.error(f"Error saving incremental capture: {se}")
|
|
1593
|
+
|
|
1594
|
+
except Exception as e:
|
|
1595
|
+
logger.debug(f"Incremental extraction skipped: {e}")
|
|
1596
|
+
|
|
1597
|
+
# Build capture context, merging in incremental immediately (avoid read lag)
|
|
1598
|
+
def _get_active_data(name: Optional[str]) -> Dict[str, Any]:
|
|
1599
|
+
if not name:
|
|
1600
|
+
return {}
|
|
1601
|
+
base = (latest_by_name.get(name, {}) or {}).get("data", {}) or {}
|
|
1602
|
+
if incremental:
|
|
1603
|
+
base = {**base, **incremental}
|
|
1604
|
+
return base
|
|
1605
|
+
|
|
1606
|
+
lines: List[str] = []
|
|
1607
|
+
if active_capture_name and isinstance(active_capture_schema, dict):
|
|
1608
|
+
props = (active_capture_schema or {}).get("properties", {})
|
|
1609
|
+
required_fields = list(
|
|
1610
|
+
(active_capture_schema or {}).get("required", []) or []
|
|
1611
|
+
)
|
|
1612
|
+
all_fields = list(props.keys())
|
|
1613
|
+
optional_fields = [
|
|
1614
|
+
f for f in all_fields if f not in set(required_fields)
|
|
1615
|
+
]
|
|
1616
|
+
|
|
1617
|
+
active_data = _get_active_data(active_capture_name)
|
|
1618
|
+
|
|
1619
|
+
def _missing_from(data: Dict[str, Any], fields: List[str]) -> List[str]:
|
|
1620
|
+
return [f for f in fields if not _non_empty(data.get(f))]
|
|
1621
|
+
|
|
1622
|
+
missing_required = _missing_from(active_data, required_fields)
|
|
1623
|
+
missing_optional = _missing_from(active_data, optional_fields)
|
|
1624
|
+
|
|
1625
|
+
required_complete = (
|
|
1626
|
+
len(missing_required) == 0 and len(required_fields) > 0
|
|
1627
|
+
)
|
|
1628
|
+
form_complete = required_complete and len(missing_optional) == 0
|
|
1629
|
+
|
|
1630
|
+
lines.append(
|
|
1631
|
+
"CAPTURED FORM STATE (Authoritative; do not re-ask filled values):"
|
|
1632
|
+
)
|
|
1633
|
+
lines.append(f"- form_name: {active_capture_name}")
|
|
1634
|
+
|
|
1635
|
+
if active_data:
|
|
1636
|
+
pairs = [
|
|
1637
|
+
f"{k}: {v}" for k, v in active_data.items() if _non_empty(v)
|
|
1638
|
+
]
|
|
1639
|
+
lines.append(
|
|
1640
|
+
f"- filled_fields: {', '.join(pairs) if pairs else '(none)'}"
|
|
1641
|
+
)
|
|
1642
|
+
else:
|
|
1643
|
+
lines.append("- filled_fields: (none)")
|
|
1644
|
+
|
|
1645
|
+
lines.append(
|
|
1646
|
+
f"- missing_required_fields: {', '.join(missing_required) if missing_required else '(none)'}"
|
|
1647
|
+
)
|
|
1648
|
+
lines.append(
|
|
1649
|
+
f"- missing_optional_fields: {', '.join(missing_optional) if missing_optional else '(none)'}"
|
|
1650
|
+
)
|
|
1651
|
+
lines.append("")
|
|
106
1652
|
|
|
107
|
-
|
|
1653
|
+
if latest_by_name:
|
|
1654
|
+
lines.append("OTHER CAPTURED USER DATA (for reference):")
|
|
1655
|
+
for cname, info in latest_by_name.items():
|
|
1656
|
+
if cname == active_capture_name:
|
|
1657
|
+
continue
|
|
1658
|
+
data = info.get("data", {}) or {}
|
|
1659
|
+
if data:
|
|
1660
|
+
pairs = [f"{k}: {v}" for k, v in data.items() if _non_empty(v)]
|
|
1661
|
+
lines.append(
|
|
1662
|
+
f"- {cname}: {', '.join(pairs) if pairs else '(none)'}"
|
|
1663
|
+
)
|
|
1664
|
+
else:
|
|
1665
|
+
lines.append(f"- {cname}: (none)")
|
|
1666
|
+
|
|
1667
|
+
if lines:
|
|
1668
|
+
capture_context = "\n".join(lines) + "\n\n"
|
|
1669
|
+
# Update sticky session completion flag
|
|
1670
|
+
try:
|
|
1671
|
+
self._update_sticky_required_complete(user_id, required_complete)
|
|
1672
|
+
except Exception:
|
|
1673
|
+
pass
|
|
1674
|
+
|
|
1675
|
+
# Merge contexts + flow rules
|
|
1676
|
+
combined_context = ""
|
|
1677
|
+
if capture_context:
|
|
1678
|
+
combined_context += capture_context
|
|
1679
|
+
if memory_context:
|
|
1680
|
+
combined_context += f"CONVERSATION HISTORY (Use for continuity; not authoritative for facts):\n{memory_context}\n\n"
|
|
1681
|
+
if kb_context:
|
|
1682
|
+
combined_context += kb_context + "\n"
|
|
1683
|
+
if combined_context:
|
|
1684
|
+
combined_context += (
|
|
1685
|
+
"PRIORITIZATION GUIDE:\n"
|
|
1686
|
+
"- Prefer Captured User Data for user-specific fields.\n"
|
|
1687
|
+
"- Prefer KB/tools for facts.\n"
|
|
1688
|
+
"- History is for tone and continuity.\n\n"
|
|
1689
|
+
"FORM FLOW RULES:\n"
|
|
1690
|
+
"- Ask exactly one field per turn.\n"
|
|
1691
|
+
"- If any required fields are missing, ask the next missing required field.\n"
|
|
1692
|
+
"- If all required fields are filled but optional fields are missing, ask the next missing optional field.\n"
|
|
1693
|
+
"- Do NOT re-ask or verify values present in Captured User Data (auto-saved, authoritative).\n"
|
|
1694
|
+
"- Do NOT provide summaries until no required or optional fields are missing.\n\n"
|
|
1695
|
+
)
|
|
108
1696
|
|
|
109
|
-
# Generate response
|
|
1697
|
+
# 8) Generate response
|
|
110
1698
|
if output_format == "audio":
|
|
111
1699
|
async for audio_chunk in self.agent_service.generate_response(
|
|
112
1700
|
agent_name=agent_name,
|
|
113
1701
|
user_id=user_id,
|
|
114
|
-
query=
|
|
115
|
-
|
|
1702
|
+
query=user_text,
|
|
1703
|
+
images=images,
|
|
1704
|
+
memory_context=combined_context,
|
|
116
1705
|
output_format="audio",
|
|
117
1706
|
audio_voice=audio_voice,
|
|
118
|
-
audio_input_format=audio_input_format,
|
|
119
1707
|
audio_output_format=audio_output_format,
|
|
120
|
-
audio_instructions=audio_instructions,
|
|
121
1708
|
prompt=prompt,
|
|
122
1709
|
):
|
|
123
1710
|
yield audio_chunk
|
|
124
|
-
|
|
125
1711
|
if self.memory_provider:
|
|
126
1712
|
await self._store_conversation(
|
|
127
1713
|
user_id=user_id,
|
|
@@ -130,78 +1716,130 @@ class QueryService(QueryServiceInterface):
|
|
|
130
1716
|
)
|
|
131
1717
|
else:
|
|
132
1718
|
full_text_response = ""
|
|
1719
|
+
capture_data: Optional[BaseModel] = None
|
|
1720
|
+
|
|
1721
|
+
# Resolve agent capture if not provided
|
|
1722
|
+
if not capture_schema or not capture_name:
|
|
1723
|
+
try:
|
|
1724
|
+
cap = self.agent_service.get_agent_capture(agent_name)
|
|
1725
|
+
if cap:
|
|
1726
|
+
capture_name = cap.get("name")
|
|
1727
|
+
capture_schema = cap.get("schema")
|
|
1728
|
+
except Exception:
|
|
1729
|
+
pass
|
|
1730
|
+
|
|
1731
|
+
# Only run final structured output when no required or optional fields are missing
|
|
1732
|
+
if capture_schema and capture_name and form_complete:
|
|
1733
|
+
try:
|
|
1734
|
+
DynamicModel = self._build_model_from_json_schema(
|
|
1735
|
+
capture_name, capture_schema
|
|
1736
|
+
)
|
|
1737
|
+
async for result in self.agent_service.generate_response(
|
|
1738
|
+
agent_name=agent_name,
|
|
1739
|
+
user_id=user_id,
|
|
1740
|
+
query=user_text,
|
|
1741
|
+
images=images,
|
|
1742
|
+
memory_context=combined_context,
|
|
1743
|
+
output_format="text",
|
|
1744
|
+
prompt=(
|
|
1745
|
+
(
|
|
1746
|
+
prompt
|
|
1747
|
+
+ "\n\nUsing the captured user data above, return only the JSON for the requested schema. Do not invent values."
|
|
1748
|
+
)
|
|
1749
|
+
if prompt
|
|
1750
|
+
else "Using the captured user data above, return only the JSON for the requested schema. Do not invent values."
|
|
1751
|
+
),
|
|
1752
|
+
output_model=DynamicModel,
|
|
1753
|
+
):
|
|
1754
|
+
capture_data = result # type: ignore
|
|
1755
|
+
break
|
|
1756
|
+
except Exception as e:
|
|
1757
|
+
logger.error(f"Error during capture structured output: {e}")
|
|
1758
|
+
|
|
133
1759
|
async for chunk in self.agent_service.generate_response(
|
|
134
1760
|
agent_name=agent_name,
|
|
135
1761
|
user_id=user_id,
|
|
136
1762
|
query=user_text,
|
|
137
|
-
|
|
1763
|
+
images=images,
|
|
1764
|
+
memory_context=combined_context,
|
|
138
1765
|
output_format="text",
|
|
139
1766
|
prompt=prompt,
|
|
1767
|
+
output_model=output_model,
|
|
140
1768
|
):
|
|
141
1769
|
yield chunk
|
|
142
|
-
|
|
1770
|
+
if output_model is None:
|
|
1771
|
+
full_text_response += chunk
|
|
143
1772
|
|
|
144
1773
|
if self.memory_provider and full_text_response:
|
|
145
1774
|
await self._store_conversation(
|
|
146
1775
|
user_id=user_id,
|
|
147
1776
|
user_message=user_text,
|
|
148
|
-
assistant_message=full_text_response
|
|
1777
|
+
assistant_message=full_text_response,
|
|
149
1778
|
)
|
|
150
1779
|
|
|
1780
|
+
# Save final capture data if the model returned it
|
|
1781
|
+
if (
|
|
1782
|
+
self.memory_provider
|
|
1783
|
+
and capture_schema
|
|
1784
|
+
and capture_name
|
|
1785
|
+
and capture_data is not None
|
|
1786
|
+
):
|
|
1787
|
+
try:
|
|
1788
|
+
data_dict = (
|
|
1789
|
+
capture_data.model_dump()
|
|
1790
|
+
if hasattr(capture_data, "model_dump")
|
|
1791
|
+
else capture_data.dict()
|
|
1792
|
+
)
|
|
1793
|
+
await self.memory_provider.save_capture(
|
|
1794
|
+
user_id=user_id,
|
|
1795
|
+
capture_name=capture_name,
|
|
1796
|
+
agent_name=agent_name,
|
|
1797
|
+
data=data_dict,
|
|
1798
|
+
schema=capture_schema,
|
|
1799
|
+
)
|
|
1800
|
+
except Exception as e:
|
|
1801
|
+
logger.error(f"Error saving capture: {e}")
|
|
1802
|
+
|
|
151
1803
|
except Exception as e:
|
|
152
|
-
|
|
1804
|
+
import traceback
|
|
1805
|
+
|
|
1806
|
+
error_msg = (
|
|
1807
|
+
"I apologize for the technical difficulty. Please try again later."
|
|
1808
|
+
)
|
|
1809
|
+
logger.error(f"Error in query processing: {e}\n{traceback.format_exc()}")
|
|
1810
|
+
|
|
153
1811
|
if output_format == "audio":
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
1812
|
+
try:
|
|
1813
|
+
async for chunk in self.agent_service.llm_provider.tts(
|
|
1814
|
+
text=error_msg,
|
|
1815
|
+
voice=audio_voice,
|
|
1816
|
+
response_format=audio_output_format,
|
|
1817
|
+
):
|
|
1818
|
+
yield chunk
|
|
1819
|
+
except Exception as tts_e:
|
|
1820
|
+
logger.error(f"Error during TTS for error message: {tts_e}")
|
|
1821
|
+
yield error_msg + f" (TTS Error: {tts_e})"
|
|
160
1822
|
else:
|
|
161
1823
|
yield error_msg
|
|
162
1824
|
|
|
163
|
-
print(f"Error in query processing: {str(e)}")
|
|
164
|
-
import traceback
|
|
165
|
-
print(traceback.format_exc())
|
|
166
|
-
|
|
167
1825
|
async def delete_user_history(self, user_id: str) -> None:
|
|
168
|
-
"""Delete all conversation history for a user.
|
|
169
|
-
|
|
170
|
-
Args:
|
|
171
|
-
user_id: User ID
|
|
172
|
-
"""
|
|
1826
|
+
"""Delete all conversation history for a user."""
|
|
173
1827
|
if self.memory_provider:
|
|
174
1828
|
try:
|
|
175
1829
|
await self.memory_provider.delete(user_id)
|
|
176
1830
|
except Exception as e:
|
|
177
|
-
|
|
1831
|
+
logger.error(f"Error deleting user history for {user_id}: {e}")
|
|
1832
|
+
else:
|
|
1833
|
+
logger.debug("No memory provider; skip delete_user_history")
|
|
178
1834
|
|
|
179
1835
|
async def get_user_history(
|
|
180
1836
|
self,
|
|
181
1837
|
user_id: str,
|
|
182
1838
|
page_num: int = 1,
|
|
183
1839
|
page_size: int = 20,
|
|
184
|
-
sort_order: str = "desc"
|
|
1840
|
+
sort_order: str = "desc",
|
|
185
1841
|
) -> Dict[str, Any]:
|
|
186
|
-
"""Get paginated message history for a user.
|
|
187
|
-
|
|
188
|
-
Args:
|
|
189
|
-
user_id: User ID
|
|
190
|
-
page_num: Page number (starting from 1)
|
|
191
|
-
page_size: Number of messages per page
|
|
192
|
-
sort_order: Sort order ("asc" or "desc")
|
|
193
|
-
|
|
194
|
-
Returns:
|
|
195
|
-
Dictionary with paginated results and metadata:
|
|
196
|
-
{
|
|
197
|
-
"data": List of conversation entries,
|
|
198
|
-
"total": Total number of entries,
|
|
199
|
-
"page": Current page number,
|
|
200
|
-
"page_size": Number of items per page,
|
|
201
|
-
"total_pages": Total number of pages,
|
|
202
|
-
"error": Error message if any
|
|
203
|
-
}
|
|
204
|
-
"""
|
|
1842
|
+
"""Get paginated message history for a user."""
|
|
205
1843
|
if not self.memory_provider:
|
|
206
1844
|
return {
|
|
207
1845
|
"data": [],
|
|
@@ -209,85 +1847,152 @@ class QueryService(QueryServiceInterface):
|
|
|
209
1847
|
"page": page_num,
|
|
210
1848
|
"page_size": page_size,
|
|
211
1849
|
"total_pages": 0,
|
|
212
|
-
"error": "Memory provider not available"
|
|
1850
|
+
"error": "Memory provider not available",
|
|
213
1851
|
}
|
|
214
|
-
|
|
215
1852
|
try:
|
|
216
|
-
# Calculate skip and limit for pagination
|
|
217
1853
|
skip = (page_num - 1) * page_size
|
|
218
|
-
|
|
219
|
-
# Get total count of documents
|
|
220
1854
|
total = self.memory_provider.count_documents(
|
|
221
|
-
collection="conversations",
|
|
222
|
-
query={"user_id": user_id}
|
|
1855
|
+
collection="conversations", query={"user_id": user_id}
|
|
223
1856
|
)
|
|
1857
|
+
total_pages = (total + page_size - 1) // page_size if total > 0 else 0
|
|
224
1858
|
|
|
225
|
-
# Calculate total pages
|
|
226
|
-
total_pages = (total + page_size - 1) // page_size
|
|
227
|
-
|
|
228
|
-
# Get paginated results
|
|
229
1859
|
conversations = self.memory_provider.find(
|
|
230
1860
|
collection="conversations",
|
|
231
1861
|
query={"user_id": user_id},
|
|
232
1862
|
sort=[("timestamp", 1 if sort_order == "asc" else -1)],
|
|
233
1863
|
skip=skip,
|
|
234
|
-
limit=page_size
|
|
1864
|
+
limit=page_size,
|
|
235
1865
|
)
|
|
236
1866
|
|
|
237
|
-
|
|
238
|
-
formatted_conversations = []
|
|
1867
|
+
formatted: List[Dict[str, Any]] = []
|
|
239
1868
|
for conv in conversations:
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
1869
|
+
ts = conv.get("timestamp")
|
|
1870
|
+
ts_epoch = int(ts.timestamp()) if ts else None
|
|
1871
|
+
formatted.append(
|
|
1872
|
+
{
|
|
1873
|
+
"id": str(conv.get("_id")),
|
|
1874
|
+
"user_message": conv.get("user_message"),
|
|
1875
|
+
"assistant_message": conv.get("assistant_message"),
|
|
1876
|
+
"timestamp": ts_epoch,
|
|
1877
|
+
}
|
|
1878
|
+
)
|
|
250
1879
|
|
|
251
1880
|
return {
|
|
252
|
-
"data":
|
|
1881
|
+
"data": formatted,
|
|
253
1882
|
"total": total,
|
|
254
1883
|
"page": page_num,
|
|
255
1884
|
"page_size": page_size,
|
|
256
1885
|
"total_pages": total_pages,
|
|
257
|
-
"error": None
|
|
1886
|
+
"error": None,
|
|
258
1887
|
}
|
|
259
|
-
|
|
260
1888
|
except Exception as e:
|
|
261
|
-
print(f"Error retrieving user history: {str(e)}")
|
|
262
1889
|
import traceback
|
|
263
|
-
|
|
1890
|
+
|
|
1891
|
+
logger.error(
|
|
1892
|
+
f"Error retrieving user history for {user_id}: {e}\n{traceback.format_exc()}"
|
|
1893
|
+
)
|
|
264
1894
|
return {
|
|
265
1895
|
"data": [],
|
|
266
1896
|
"total": 0,
|
|
267
1897
|
"page": page_num,
|
|
268
1898
|
"page_size": page_size,
|
|
269
1899
|
"total_pages": 0,
|
|
270
|
-
"error": f"Error retrieving history: {str(e)}"
|
|
1900
|
+
"error": f"Error retrieving history: {str(e)}",
|
|
271
1901
|
}
|
|
272
1902
|
|
|
273
1903
|
async def _store_conversation(
|
|
274
1904
|
self, user_id: str, user_message: str, assistant_message: str
|
|
275
1905
|
) -> None:
|
|
276
|
-
"""Store conversation history in memory provider.
|
|
1906
|
+
"""Store conversation history in memory provider."""
|
|
1907
|
+
if not self.memory_provider:
|
|
1908
|
+
return
|
|
1909
|
+
try:
|
|
1910
|
+
await self.memory_provider.store(
|
|
1911
|
+
user_id,
|
|
1912
|
+
[
|
|
1913
|
+
{"role": "user", "content": user_message},
|
|
1914
|
+
{"role": "assistant", "content": assistant_message},
|
|
1915
|
+
],
|
|
1916
|
+
)
|
|
1917
|
+
except Exception as e:
|
|
1918
|
+
logger.error(f"Store conversation error for {user_id}: {e}")
|
|
277
1919
|
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
1920
|
+
# --- Realtime persistence helpers (used by client/server using realtime service) ---
|
|
1921
|
+
async def realtime_begin_turn(
|
|
1922
|
+
self, user_id: str
|
|
1923
|
+
) -> Optional[str]: # pragma: no cover
|
|
1924
|
+
if not self.memory_provider:
|
|
1925
|
+
return None
|
|
1926
|
+
if not hasattr(self.memory_provider, "begin_stream_turn"):
|
|
1927
|
+
return None
|
|
1928
|
+
return await self.memory_provider.begin_stream_turn(user_id) # type: ignore[attr-defined]
|
|
1929
|
+
|
|
1930
|
+
async def realtime_update_user(
|
|
1931
|
+
self, user_id: str, turn_id: str, delta: str
|
|
1932
|
+
) -> None: # pragma: no cover
|
|
1933
|
+
if not self.memory_provider:
|
|
1934
|
+
return
|
|
1935
|
+
if not hasattr(self.memory_provider, "update_stream_user"):
|
|
1936
|
+
return
|
|
1937
|
+
await self.memory_provider.update_stream_user(user_id, turn_id, delta) # type: ignore[attr-defined]
|
|
1938
|
+
|
|
1939
|
+
async def realtime_update_assistant(
|
|
1940
|
+
self, user_id: str, turn_id: str, delta: str
|
|
1941
|
+
) -> None: # pragma: no cover
|
|
1942
|
+
if not self.memory_provider:
|
|
1943
|
+
return
|
|
1944
|
+
if not hasattr(self.memory_provider, "update_stream_assistant"):
|
|
1945
|
+
return
|
|
1946
|
+
await self.memory_provider.update_stream_assistant(user_id, turn_id, delta) # type: ignore[attr-defined]
|
|
1947
|
+
|
|
1948
|
+
async def realtime_finalize_turn(
|
|
1949
|
+
self, user_id: str, turn_id: str
|
|
1950
|
+
) -> None: # pragma: no cover
|
|
1951
|
+
if not self.memory_provider:
|
|
1952
|
+
return
|
|
1953
|
+
if not hasattr(self.memory_provider, "finalize_stream_turn"):
|
|
1954
|
+
return
|
|
1955
|
+
await self.memory_provider.finalize_stream_turn(user_id, turn_id) # type: ignore[attr-defined]
|
|
1956
|
+
|
|
1957
|
+
def _build_model_from_json_schema(
|
|
1958
|
+
self, name: str, schema: Dict[str, Any]
|
|
1959
|
+
) -> Type[BaseModel]:
|
|
1960
|
+
"""Create a Pydantic model dynamically from a JSON Schema subset."""
|
|
1961
|
+
from pydantic import create_model
|
|
1962
|
+
|
|
1963
|
+
def py_type(js: Dict[str, Any]):
|
|
1964
|
+
t = js.get("type")
|
|
1965
|
+
if isinstance(t, list):
|
|
1966
|
+
non_null = [x for x in t if x != "null"]
|
|
1967
|
+
if not non_null:
|
|
1968
|
+
return Optional[Any]
|
|
1969
|
+
base = py_type({"type": non_null[0]})
|
|
1970
|
+
return Optional[base]
|
|
1971
|
+
if t == "string":
|
|
1972
|
+
return str
|
|
1973
|
+
if t == "integer":
|
|
1974
|
+
return int
|
|
1975
|
+
if t == "number":
|
|
1976
|
+
return float
|
|
1977
|
+
if t == "boolean":
|
|
1978
|
+
return bool
|
|
1979
|
+
if t == "array":
|
|
1980
|
+
items = js.get("items", {"type": "string"})
|
|
1981
|
+
return List[py_type(items)]
|
|
1982
|
+
if t == "object":
|
|
1983
|
+
return Dict[str, Any]
|
|
1984
|
+
return Any
|
|
1985
|
+
|
|
1986
|
+
properties: Dict[str, Any] = schema.get("properties", {})
|
|
1987
|
+
required = set(schema.get("required", []))
|
|
1988
|
+
fields: Dict[str, Any] = {}
|
|
1989
|
+
for field_name, field_schema in properties.items():
|
|
1990
|
+
typ = py_type(field_schema)
|
|
1991
|
+
default = field_schema.get("default")
|
|
1992
|
+
if field_name in required and default is None:
|
|
1993
|
+
fields[field_name] = (typ, ...)
|
|
1994
|
+
else:
|
|
1995
|
+
fields[field_name] = (typ, default)
|
|
1996
|
+
|
|
1997
|
+
Model = create_model(name, **fields) # type: ignore
|
|
1998
|
+
return Model
|