solana-agent 31.0.0__py3-none-any.whl → 31.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- solana_agent/client/solana_agent.py +4 -0
- solana_agent/domains/agent.py +7 -1
- solana_agent/factories/agent_factory.py +37 -6
- solana_agent/interfaces/providers/memory.py +12 -0
- solana_agent/interfaces/services/query.py +2 -0
- solana_agent/repositories/memory.py +150 -90
- solana_agent/services/agent.py +33 -1
- solana_agent/services/query.py +473 -190
- solana_agent/services/routing.py +19 -13
- {solana_agent-31.0.0.dist-info → solana_agent-31.1.1.dist-info}/METADATA +32 -27
- {solana_agent-31.0.0.dist-info → solana_agent-31.1.1.dist-info}/RECORD +14 -14
- {solana_agent-31.0.0.dist-info → solana_agent-31.1.1.dist-info}/LICENSE +0 -0
- {solana_agent-31.0.0.dist-info → solana_agent-31.1.1.dist-info}/WHEEL +0 -0
- {solana_agent-31.0.0.dist-info → solana_agent-31.1.1.dist-info}/entry_points.txt +0 -0
solana_agent/services/query.py
CHANGED
|
@@ -7,6 +7,7 @@ clean separation of concerns.
|
|
|
7
7
|
"""
|
|
8
8
|
|
|
9
9
|
import logging
|
|
10
|
+
import re
|
|
10
11
|
from typing import Any, AsyncGenerator, Dict, List, Literal, Optional, Type, Union
|
|
11
12
|
|
|
12
13
|
from pydantic import BaseModel
|
|
@@ -22,9 +23,7 @@ from solana_agent.interfaces.providers.memory import (
|
|
|
22
23
|
from solana_agent.interfaces.services.knowledge_base import (
|
|
23
24
|
KnowledgeBaseService as KnowledgeBaseInterface,
|
|
24
25
|
)
|
|
25
|
-
from solana_agent.interfaces.guardrails.guardrails import
|
|
26
|
-
InputGuardrail,
|
|
27
|
-
)
|
|
26
|
+
from solana_agent.interfaces.guardrails.guardrails import InputGuardrail
|
|
28
27
|
|
|
29
28
|
from solana_agent.services.agent import AgentService
|
|
30
29
|
from solana_agent.services.routing import RoutingService
|
|
@@ -44,16 +43,7 @@ class QueryService(QueryServiceInterface):
|
|
|
44
43
|
kb_results_count: int = 3,
|
|
45
44
|
input_guardrails: List[InputGuardrail] = None,
|
|
46
45
|
):
|
|
47
|
-
"""Initialize the query service.
|
|
48
|
-
|
|
49
|
-
Args:
|
|
50
|
-
agent_service: Service for AI agent management
|
|
51
|
-
routing_service: Service for routing queries to appropriate agents
|
|
52
|
-
memory_provider: Optional provider for memory storage and retrieval
|
|
53
|
-
knowledge_base: Optional provider for knowledge base interactions
|
|
54
|
-
kb_results_count: Number of results to retrieve from knowledge base
|
|
55
|
-
input_guardrails: List of input guardrail instances
|
|
56
|
-
"""
|
|
46
|
+
"""Initialize the query service."""
|
|
57
47
|
self.agent_service = agent_service
|
|
58
48
|
self.routing_service = routing_service
|
|
59
49
|
self.memory_provider = memory_provider
|
|
@@ -89,27 +79,12 @@ class QueryService(QueryServiceInterface):
|
|
|
89
79
|
prompt: Optional[str] = None,
|
|
90
80
|
router: Optional[RoutingServiceInterface] = None,
|
|
91
81
|
output_model: Optional[Type[BaseModel]] = None,
|
|
82
|
+
capture_schema: Optional[Dict[str, Any]] = None,
|
|
83
|
+
capture_name: Optional[str] = None,
|
|
92
84
|
) -> AsyncGenerator[Union[str, bytes, BaseModel], None]: # pragma: no cover
|
|
93
|
-
"""Process the user request
|
|
94
|
-
|
|
95
|
-
Args:
|
|
96
|
-
user_id: User ID
|
|
97
|
-
query: Text query or audio bytes
|
|
98
|
-
images: Optional list of image URLs (str) or image bytes.
|
|
99
|
-
output_format: Response format ("text" or "audio")
|
|
100
|
-
audio_voice: Voice for TTS (text-to-speech)
|
|
101
|
-
audio_instructions: Audio voice instructions
|
|
102
|
-
audio_output_format: Audio output format
|
|
103
|
-
audio_input_format: Audio input format
|
|
104
|
-
prompt: Optional prompt for the agent
|
|
105
|
-
router: Optional routing service for processing
|
|
106
|
-
output_model: Optional Pydantic model for structured output
|
|
107
|
-
|
|
108
|
-
Yields:
|
|
109
|
-
Response chunks (text strings or audio bytes)
|
|
110
|
-
"""
|
|
85
|
+
"""Process the user request and generate a response."""
|
|
111
86
|
try:
|
|
112
|
-
#
|
|
87
|
+
# 1) Transcribe audio or accept text
|
|
113
88
|
user_text = ""
|
|
114
89
|
if not isinstance(query, str):
|
|
115
90
|
logger.info(
|
|
@@ -126,123 +101,386 @@ class QueryService(QueryServiceInterface):
|
|
|
126
101
|
user_text = query
|
|
127
102
|
logger.info(f"Received text input length: {len(user_text)}")
|
|
128
103
|
|
|
129
|
-
#
|
|
104
|
+
# 2) Input guardrails
|
|
130
105
|
original_text = user_text
|
|
131
|
-
processed_text = user_text
|
|
132
106
|
for guardrail in self.input_guardrails:
|
|
133
107
|
try:
|
|
134
|
-
|
|
135
|
-
logger.debug(
|
|
136
|
-
f"Applied input guardrail: {guardrail.__class__.__name__}"
|
|
137
|
-
)
|
|
108
|
+
user_text = await guardrail.process(user_text)
|
|
138
109
|
except Exception as e:
|
|
139
|
-
logger.error
|
|
140
|
-
|
|
141
|
-
exc_info=True,
|
|
142
|
-
)
|
|
143
|
-
if processed_text != original_text:
|
|
110
|
+
logger.debug(f"Guardrail error: {e}")
|
|
111
|
+
if user_text != original_text:
|
|
144
112
|
logger.info(
|
|
145
|
-
f"Input guardrails modified user text. Original length: {len(original_text)}, New length: {len(
|
|
113
|
+
f"Input guardrails modified user text. Original length: {len(original_text)}, New length: {len(user_text)}"
|
|
146
114
|
)
|
|
147
|
-
user_text = processed_text # Use the processed text going forward
|
|
148
|
-
# --- End Apply Input Guardrails ---
|
|
149
115
|
|
|
150
|
-
#
|
|
151
|
-
|
|
152
|
-
if not images and user_text.strip().lower() in [
|
|
153
|
-
"test",
|
|
154
|
-
"hello",
|
|
116
|
+
# 3) Greetings shortcut
|
|
117
|
+
if not images and user_text.strip().lower() in {
|
|
155
118
|
"hi",
|
|
119
|
+
"hello",
|
|
156
120
|
"hey",
|
|
157
121
|
"ping",
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
122
|
+
"test",
|
|
123
|
+
}:
|
|
124
|
+
greeting = "Hello! How can I help you today?"
|
|
161
125
|
if output_format == "audio":
|
|
162
126
|
async for chunk in self.agent_service.llm_provider.tts(
|
|
163
|
-
text=
|
|
127
|
+
text=greeting,
|
|
164
128
|
voice=audio_voice,
|
|
165
129
|
response_format=audio_output_format,
|
|
166
130
|
instructions=audio_instructions,
|
|
167
131
|
):
|
|
168
132
|
yield chunk
|
|
169
133
|
else:
|
|
170
|
-
yield
|
|
171
|
-
|
|
172
|
-
# Store simple interaction in memory (using processed user_text)
|
|
134
|
+
yield greeting
|
|
173
135
|
if self.memory_provider:
|
|
174
|
-
await self._store_conversation(user_id,
|
|
136
|
+
await self._store_conversation(user_id, original_text, greeting)
|
|
175
137
|
return
|
|
176
138
|
|
|
177
|
-
#
|
|
139
|
+
# 4) Memory context (conversation history)
|
|
178
140
|
memory_context = ""
|
|
179
141
|
if self.memory_provider:
|
|
180
142
|
try:
|
|
181
143
|
memory_context = await self.memory_provider.retrieve(user_id)
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
)
|
|
185
|
-
except Exception as e:
|
|
186
|
-
logger.error(f"Error retrieving memory context: {e}", exc_info=True)
|
|
144
|
+
except Exception:
|
|
145
|
+
memory_context = ""
|
|
187
146
|
|
|
188
|
-
#
|
|
147
|
+
# 5) Knowledge base context
|
|
189
148
|
kb_context = ""
|
|
190
149
|
if self.knowledge_base:
|
|
191
150
|
try:
|
|
192
|
-
# Use processed user_text for KB query
|
|
193
151
|
kb_results = await self.knowledge_base.query(
|
|
194
152
|
query_text=user_text,
|
|
195
153
|
top_k=self.kb_results_count,
|
|
196
154
|
include_content=True,
|
|
197
|
-
include_metadata=False,
|
|
155
|
+
include_metadata=False,
|
|
198
156
|
)
|
|
199
|
-
|
|
200
157
|
if kb_results:
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
158
|
+
kb_lines = [
|
|
159
|
+
"**KNOWLEDGE BASE (CRITICAL: MAKE THIS INFORMATION THE TOP PRIORITY):**"
|
|
160
|
+
]
|
|
161
|
+
for i, r in enumerate(kb_results, 1):
|
|
162
|
+
kb_lines.append(f"[{i}] {r.get('content', '').strip()}\n")
|
|
163
|
+
kb_context = "\n".join(kb_lines)
|
|
164
|
+
except Exception:
|
|
165
|
+
kb_context = ""
|
|
166
|
+
|
|
167
|
+
# 6) Route query (and fetch previous assistant message)
|
|
168
|
+
agent_name = "default"
|
|
169
|
+
prev_assistant = ""
|
|
170
|
+
routing_input = user_text
|
|
171
|
+
if self.memory_provider:
|
|
172
|
+
try:
|
|
173
|
+
prev_docs = self.memory_provider.find(
|
|
174
|
+
collection="conversations",
|
|
175
|
+
query={"user_id": user_id},
|
|
176
|
+
sort=[("timestamp", -1)],
|
|
177
|
+
limit=1,
|
|
178
|
+
)
|
|
179
|
+
if prev_docs:
|
|
180
|
+
prev_user_msg = (prev_docs[0] or {}).get(
|
|
181
|
+
"user_message", ""
|
|
182
|
+
) or ""
|
|
183
|
+
prev_assistant = (prev_docs[0] or {}).get(
|
|
184
|
+
"assistant_message", ""
|
|
185
|
+
) or ""
|
|
186
|
+
if prev_user_msg:
|
|
187
|
+
routing_input = (
|
|
188
|
+
f"previous_user_message: {prev_user_msg}\n"
|
|
189
|
+
f"current_user_message: {user_text}"
|
|
190
|
+
)
|
|
191
|
+
except Exception:
|
|
192
|
+
pass
|
|
215
193
|
try:
|
|
216
|
-
# Use processed user_text for routing (images generally don't affect routing logic here)
|
|
217
194
|
if router:
|
|
218
|
-
agent_name = await router.route_query(
|
|
195
|
+
agent_name = await router.route_query(routing_input)
|
|
219
196
|
else:
|
|
220
|
-
agent_name = await self.routing_service.route_query(
|
|
221
|
-
|
|
197
|
+
agent_name = await self.routing_service.route_query(routing_input)
|
|
198
|
+
except Exception:
|
|
199
|
+
agent_name = "default"
|
|
200
|
+
|
|
201
|
+
# 7) Captured data context + incremental save using previous assistant message
|
|
202
|
+
capture_context = ""
|
|
203
|
+
form_complete = False
|
|
204
|
+
|
|
205
|
+
# Helpers
|
|
206
|
+
def _non_empty(v: Any) -> bool:
|
|
207
|
+
if v is None:
|
|
208
|
+
return False
|
|
209
|
+
if isinstance(v, str):
|
|
210
|
+
s = v.strip().lower()
|
|
211
|
+
return s not in {"", "null", "none", "n/a", "na", "undefined", "."}
|
|
212
|
+
if isinstance(v, (list, dict, tuple, set)):
|
|
213
|
+
return len(v) > 0
|
|
214
|
+
return True
|
|
215
|
+
|
|
216
|
+
def _parse_numbers_list(s: str) -> List[str]:
|
|
217
|
+
nums = re.findall(r"\b(\d+)\b", s)
|
|
218
|
+
# dedupe keep order
|
|
219
|
+
seen, out = set(), []
|
|
220
|
+
for n in nums:
|
|
221
|
+
if n not in seen:
|
|
222
|
+
seen.add(n)
|
|
223
|
+
out.append(n)
|
|
224
|
+
return out
|
|
225
|
+
|
|
226
|
+
def _extract_numbered_options(text: str) -> Dict[str, str]:
|
|
227
|
+
"""Parse previous assistant message for lines like:
|
|
228
|
+
'1) Foo', '2. Bar', '- 3) Baz', '* 4. Buzz'
|
|
229
|
+
Returns mapping '1' -> 'Foo', etc.
|
|
230
|
+
"""
|
|
231
|
+
options: Dict[str, str] = {}
|
|
232
|
+
if not text:
|
|
233
|
+
return options
|
|
234
|
+
for raw in text.splitlines():
|
|
235
|
+
line = raw.strip()
|
|
236
|
+
if not line:
|
|
237
|
+
continue
|
|
238
|
+
# Common Markdown patterns: "1. Label", "1) Label", "- 1) Label", "* 1. Label"
|
|
239
|
+
m = re.match(r"^(?:[-*]\s*)?(\d+)[\.)]?\s+(.*)$", line)
|
|
240
|
+
if m:
|
|
241
|
+
idx, label = m.group(1), m.group(2).strip()
|
|
242
|
+
# Strip trailing markdown soft-break spaces
|
|
243
|
+
label = label.rstrip()
|
|
244
|
+
# Ignore labels that are too short or look like continuations
|
|
245
|
+
if len(label) >= 1:
|
|
246
|
+
options[idx] = label
|
|
247
|
+
return options
|
|
248
|
+
|
|
249
|
+
def _detect_field_from_prev_question(
|
|
250
|
+
prev_text: str, schema: Optional[Dict[str, Any]]
|
|
251
|
+
) -> Optional[str]:
|
|
252
|
+
if not prev_text or not isinstance(schema, dict):
|
|
253
|
+
return None
|
|
254
|
+
t = prev_text.lower()
|
|
255
|
+
# Heuristic synonyms for your onboarding schema
|
|
256
|
+
patterns = [
|
|
257
|
+
("ideas", ["which ideas attract you", "ideas"]),
|
|
258
|
+
("description", ["please describe yourself", "describe yourself"]),
|
|
259
|
+
("myself", ["tell us about yourself", "about yourself"]),
|
|
260
|
+
("questions", ["do you have any questions"]),
|
|
261
|
+
("rating", ["rating", "1 to 5", "how satisfied", "how happy"]),
|
|
262
|
+
("email", ["email"]),
|
|
263
|
+
("phone", ["phone"]),
|
|
264
|
+
("name", ["name"]),
|
|
265
|
+
("city", ["city"]),
|
|
266
|
+
("state", ["state"]),
|
|
267
|
+
]
|
|
268
|
+
candidates = set((schema.get("properties") or {}).keys())
|
|
269
|
+
for field, keys in patterns:
|
|
270
|
+
if field in candidates and any(key in t for key in keys):
|
|
271
|
+
return field
|
|
272
|
+
# Fallback: property name appears directly
|
|
273
|
+
for field in candidates:
|
|
274
|
+
if field in t:
|
|
275
|
+
return field
|
|
276
|
+
return None
|
|
277
|
+
|
|
278
|
+
# Resolve active capture from args or agent config
|
|
279
|
+
active_capture_name = capture_name
|
|
280
|
+
active_capture_schema = capture_schema
|
|
281
|
+
if not active_capture_name or not active_capture_schema:
|
|
282
|
+
try:
|
|
283
|
+
cap_cfg = self.agent_service.get_agent_capture(agent_name)
|
|
284
|
+
if cap_cfg:
|
|
285
|
+
active_capture_name = active_capture_name or cap_cfg.get("name")
|
|
286
|
+
active_capture_schema = active_capture_schema or cap_cfg.get(
|
|
287
|
+
"schema"
|
|
288
|
+
)
|
|
289
|
+
except Exception:
|
|
290
|
+
pass
|
|
291
|
+
|
|
292
|
+
latest_by_name: Dict[str, Dict[str, Any]] = {}
|
|
293
|
+
if self.memory_provider:
|
|
294
|
+
try:
|
|
295
|
+
docs = self.memory_provider.find(
|
|
296
|
+
collection="captures",
|
|
297
|
+
query={"user_id": user_id},
|
|
298
|
+
sort=[("timestamp", -1)],
|
|
299
|
+
limit=100,
|
|
300
|
+
)
|
|
301
|
+
for d in docs or []:
|
|
302
|
+
name = (d or {}).get("capture_name")
|
|
303
|
+
if not name or name in latest_by_name:
|
|
304
|
+
continue
|
|
305
|
+
latest_by_name[name] = {
|
|
306
|
+
"data": (d or {}).get("data", {}) or {},
|
|
307
|
+
"mode": (d or {}).get("mode", "once"),
|
|
308
|
+
"agent": (d or {}).get("agent_name"),
|
|
309
|
+
}
|
|
310
|
+
except Exception:
|
|
311
|
+
pass
|
|
312
|
+
|
|
313
|
+
# Incremental save: use prev_assistant's numbered list to map numeric reply -> labels
|
|
314
|
+
incremental: Dict[str, Any] = {}
|
|
315
|
+
try:
|
|
316
|
+
if (
|
|
317
|
+
self.memory_provider
|
|
318
|
+
and active_capture_name
|
|
319
|
+
and isinstance(active_capture_schema, dict)
|
|
320
|
+
):
|
|
321
|
+
props = (active_capture_schema or {}).get("properties", {})
|
|
322
|
+
required_fields = list(
|
|
323
|
+
(active_capture_schema or {}).get("required", []) or []
|
|
324
|
+
)
|
|
325
|
+
# Prefer a field detected from prev assistant; else if exactly one required missing, use it
|
|
326
|
+
target_field: Optional[str] = _detect_field_from_prev_question(
|
|
327
|
+
prev_assistant, active_capture_schema
|
|
328
|
+
)
|
|
329
|
+
active_data_existing = (
|
|
330
|
+
latest_by_name.get(active_capture_name, {}) or {}
|
|
331
|
+
).get("data", {}) or {}
|
|
332
|
+
|
|
333
|
+
def _missing_required() -> List[str]:
|
|
334
|
+
return [
|
|
335
|
+
f
|
|
336
|
+
for f in required_fields
|
|
337
|
+
if not _non_empty(active_data_existing.get(f))
|
|
338
|
+
]
|
|
339
|
+
|
|
340
|
+
if not target_field:
|
|
341
|
+
missing = _missing_required()
|
|
342
|
+
if len(missing) == 1:
|
|
343
|
+
target_field = missing[0]
|
|
344
|
+
|
|
345
|
+
if target_field:
|
|
346
|
+
f_schema = props.get(target_field, {}) or {}
|
|
347
|
+
f_type = f_schema.get("type")
|
|
348
|
+
number_to_label = _extract_numbered_options(prev_assistant)
|
|
349
|
+
|
|
350
|
+
if number_to_label:
|
|
351
|
+
# Map any numbers in user's reply to their labels
|
|
352
|
+
nums = _parse_numbers_list(user_text)
|
|
353
|
+
labels = [
|
|
354
|
+
number_to_label[n] for n in nums if n in number_to_label
|
|
355
|
+
]
|
|
356
|
+
if labels:
|
|
357
|
+
if f_type == "array":
|
|
358
|
+
incremental[target_field] = labels
|
|
359
|
+
else:
|
|
360
|
+
incremental[target_field] = labels[0]
|
|
361
|
+
|
|
362
|
+
# If we didn't map via options, fallback to type-based parse
|
|
363
|
+
if target_field not in incremental:
|
|
364
|
+
if f_type == "number":
|
|
365
|
+
m = re.search(r"\b([0-9]+(?:\.[0-9]+)?)\b", user_text)
|
|
366
|
+
if m:
|
|
367
|
+
try:
|
|
368
|
+
incremental[target_field] = float(m.group(1))
|
|
369
|
+
except Exception:
|
|
370
|
+
pass
|
|
371
|
+
elif f_type == "array":
|
|
372
|
+
# Accept CSV-style input as array of strings
|
|
373
|
+
parts = [
|
|
374
|
+
p.strip()
|
|
375
|
+
for p in re.split(r"[,\n;]+", user_text)
|
|
376
|
+
if p.strip()
|
|
377
|
+
]
|
|
378
|
+
if parts:
|
|
379
|
+
incremental[target_field] = parts
|
|
380
|
+
else: # string/default
|
|
381
|
+
if user_text.strip():
|
|
382
|
+
incremental[target_field] = user_text.strip()
|
|
383
|
+
|
|
384
|
+
# Filter out empty junk and save
|
|
385
|
+
if incremental:
|
|
386
|
+
cleaned = {
|
|
387
|
+
k: v for k, v in incremental.items() if _non_empty(v)
|
|
388
|
+
}
|
|
389
|
+
if cleaned:
|
|
390
|
+
try:
|
|
391
|
+
await self.memory_provider.save_capture(
|
|
392
|
+
user_id=user_id,
|
|
393
|
+
capture_name=active_capture_name,
|
|
394
|
+
agent_name=agent_name,
|
|
395
|
+
data=cleaned,
|
|
396
|
+
schema=active_capture_schema,
|
|
397
|
+
)
|
|
398
|
+
except Exception as se:
|
|
399
|
+
logger.error(f"Error saving incremental capture: {se}")
|
|
222
400
|
except Exception as e:
|
|
223
|
-
logger.
|
|
224
|
-
|
|
225
|
-
|
|
401
|
+
logger.debug(f"Incremental extraction skipped: {e}")
|
|
402
|
+
|
|
403
|
+
# Build capture context, merging in incremental immediately (avoid read lag)
|
|
404
|
+
def _get_active_data(name: Optional[str]) -> Dict[str, Any]:
|
|
405
|
+
if not name:
|
|
406
|
+
return {}
|
|
407
|
+
base = (latest_by_name.get(name, {}) or {}).get("data", {}) or {}
|
|
408
|
+
if incremental:
|
|
409
|
+
base = {**base, **incremental}
|
|
410
|
+
return base
|
|
411
|
+
|
|
412
|
+
lines: List[str] = []
|
|
413
|
+
if active_capture_name and isinstance(active_capture_schema, dict):
|
|
414
|
+
active_data = _get_active_data(active_capture_name)
|
|
415
|
+
required_fields = list(
|
|
416
|
+
(active_capture_schema or {}).get("required", []) or []
|
|
417
|
+
)
|
|
418
|
+
missing = [
|
|
419
|
+
f for f in required_fields if not _non_empty(active_data.get(f))
|
|
420
|
+
]
|
|
421
|
+
form_complete = len(missing) == 0 and len(required_fields) > 0
|
|
422
|
+
|
|
423
|
+
lines.append(
|
|
424
|
+
"CAPTURED FORM STATE (Authoritative; do not re-ask filled values):"
|
|
425
|
+
)
|
|
426
|
+
lines.append(f"- form_name: {active_capture_name}")
|
|
427
|
+
if active_data:
|
|
428
|
+
pairs = [
|
|
429
|
+
f"{k}: {v}" for k, v in active_data.items() if _non_empty(v)
|
|
430
|
+
]
|
|
431
|
+
lines.append(
|
|
432
|
+
f"- filled_fields: {', '.join(pairs) if pairs else '(none)'}"
|
|
433
|
+
)
|
|
434
|
+
else:
|
|
435
|
+
lines.append("- filled_fields: (none)")
|
|
436
|
+
lines.append(
|
|
437
|
+
f"- missing_required_fields: {', '.join(missing) if missing else '(none)'}"
|
|
226
438
|
)
|
|
439
|
+
lines.append("")
|
|
440
|
+
|
|
441
|
+
if latest_by_name:
|
|
442
|
+
lines.append("OTHER CAPTURED USER DATA (for reference):")
|
|
443
|
+
for cname, info in latest_by_name.items():
|
|
444
|
+
if cname == active_capture_name:
|
|
445
|
+
continue
|
|
446
|
+
data = info.get("data", {}) or {}
|
|
447
|
+
if data:
|
|
448
|
+
pairs = [f"{k}: {v}" for k, v in data.items() if _non_empty(v)]
|
|
449
|
+
lines.append(
|
|
450
|
+
f"- {cname}: {', '.join(pairs) if pairs else '(none)'}"
|
|
451
|
+
)
|
|
452
|
+
else:
|
|
453
|
+
lines.append(f"- {cname}: (none)")
|
|
227
454
|
|
|
228
|
-
|
|
455
|
+
if lines:
|
|
456
|
+
capture_context = "\n".join(lines) + "\n\n"
|
|
457
|
+
|
|
458
|
+
# Merge contexts
|
|
229
459
|
combined_context = ""
|
|
460
|
+
if capture_context:
|
|
461
|
+
combined_context += capture_context
|
|
230
462
|
if memory_context:
|
|
231
|
-
combined_context += f"CONVERSATION HISTORY (Use for
|
|
463
|
+
combined_context += f"CONVERSATION HISTORY (Use for continuity; not authoritative for facts):\n{memory_context}\n\n"
|
|
232
464
|
if kb_context:
|
|
233
|
-
combined_context +=
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
465
|
+
combined_context += kb_context + "\n"
|
|
466
|
+
if combined_context:
|
|
467
|
+
combined_context += (
|
|
468
|
+
"PRIORITIZATION GUIDE:\n"
|
|
469
|
+
"- Prefer Captured User Data for user-specific fields.\n"
|
|
470
|
+
"- Prefer KB/tools for facts.\n"
|
|
471
|
+
"- History is for tone and continuity.\n\n"
|
|
472
|
+
"FORM FLOW RULES:\n"
|
|
473
|
+
"- Ask exactly one missing required field per turn.\n"
|
|
474
|
+
"- Do NOT re-ask or verify values present in Captured User Data (auto-saved, authoritative).\n"
|
|
475
|
+
"- If no required fields are missing, proceed without further capture questions.\n\n"
|
|
476
|
+
)
|
|
238
477
|
|
|
239
|
-
#
|
|
240
|
-
# Pass the processed user_text and images to the agent service
|
|
478
|
+
# 8) Generate response
|
|
241
479
|
if output_format == "audio":
|
|
242
480
|
async for audio_chunk in self.agent_service.generate_response(
|
|
243
481
|
agent_name=agent_name,
|
|
244
482
|
user_id=user_id,
|
|
245
|
-
query=user_text,
|
|
483
|
+
query=user_text,
|
|
246
484
|
images=images,
|
|
247
485
|
memory_context=combined_context,
|
|
248
486
|
output_format="audio",
|
|
@@ -252,22 +490,59 @@ class QueryService(QueryServiceInterface):
|
|
|
252
490
|
prompt=prompt,
|
|
253
491
|
):
|
|
254
492
|
yield audio_chunk
|
|
255
|
-
|
|
256
|
-
# Store conversation using processed user_text
|
|
257
|
-
# Note: Storing images in history is not directly supported by current memory provider interface
|
|
258
493
|
if self.memory_provider:
|
|
259
494
|
await self._store_conversation(
|
|
260
495
|
user_id=user_id,
|
|
261
|
-
user_message=user_text,
|
|
496
|
+
user_message=user_text,
|
|
262
497
|
assistant_message=self.agent_service.last_text_response,
|
|
263
498
|
)
|
|
264
499
|
else:
|
|
265
500
|
full_text_response = ""
|
|
501
|
+
capture_data: Optional[BaseModel] = None
|
|
502
|
+
|
|
503
|
+
# Resolve agent capture if not provided
|
|
504
|
+
if not capture_schema or not capture_name:
|
|
505
|
+
try:
|
|
506
|
+
cap = self.agent_service.get_agent_capture(agent_name)
|
|
507
|
+
if cap:
|
|
508
|
+
capture_name = cap.get("name")
|
|
509
|
+
capture_schema = cap.get("schema")
|
|
510
|
+
except Exception:
|
|
511
|
+
pass
|
|
512
|
+
|
|
513
|
+
# If form is complete, ask for structured output JSON
|
|
514
|
+
if capture_schema and capture_name and form_complete:
|
|
515
|
+
try:
|
|
516
|
+
DynamicModel = self._build_model_from_json_schema(
|
|
517
|
+
capture_name, capture_schema
|
|
518
|
+
)
|
|
519
|
+
async for result in self.agent_service.generate_response(
|
|
520
|
+
agent_name=agent_name,
|
|
521
|
+
user_id=user_id,
|
|
522
|
+
query=user_text,
|
|
523
|
+
images=images,
|
|
524
|
+
memory_context=combined_context,
|
|
525
|
+
output_format="text",
|
|
526
|
+
prompt=(
|
|
527
|
+
(
|
|
528
|
+
prompt
|
|
529
|
+
+ "\n\nUsing the captured user data above, return only the JSON for the requested schema. Do not invent values."
|
|
530
|
+
)
|
|
531
|
+
if prompt
|
|
532
|
+
else "Using the captured user data above, return only the JSON for the requested schema. Do not invent values."
|
|
533
|
+
),
|
|
534
|
+
output_model=DynamicModel,
|
|
535
|
+
):
|
|
536
|
+
capture_data = result # type: ignore
|
|
537
|
+
break
|
|
538
|
+
except Exception as e:
|
|
539
|
+
logger.error(f"Error during capture structured output: {e}")
|
|
540
|
+
|
|
266
541
|
async for chunk in self.agent_service.generate_response(
|
|
267
542
|
agent_name=agent_name,
|
|
268
543
|
user_id=user_id,
|
|
269
|
-
query=user_text,
|
|
270
|
-
images=images,
|
|
544
|
+
query=user_text,
|
|
545
|
+
images=images,
|
|
271
546
|
memory_context=combined_context,
|
|
272
547
|
output_format="text",
|
|
273
548
|
prompt=prompt,
|
|
@@ -277,15 +552,36 @@ class QueryService(QueryServiceInterface):
|
|
|
277
552
|
if output_model is None:
|
|
278
553
|
full_text_response += chunk
|
|
279
554
|
|
|
280
|
-
# Store conversation using processed user_text
|
|
281
|
-
# Note: Storing images in history is not directly supported by current memory provider interface
|
|
282
555
|
if self.memory_provider and full_text_response:
|
|
283
556
|
await self._store_conversation(
|
|
284
557
|
user_id=user_id,
|
|
285
|
-
user_message=user_text,
|
|
558
|
+
user_message=user_text,
|
|
286
559
|
assistant_message=full_text_response,
|
|
287
560
|
)
|
|
288
561
|
|
|
562
|
+
# Save final capture data if the model returned it
|
|
563
|
+
if (
|
|
564
|
+
self.memory_provider
|
|
565
|
+
and capture_schema
|
|
566
|
+
and capture_name
|
|
567
|
+
and capture_data is not None
|
|
568
|
+
):
|
|
569
|
+
try:
|
|
570
|
+
data_dict = (
|
|
571
|
+
capture_data.model_dump()
|
|
572
|
+
if hasattr(capture_data, "model_dump")
|
|
573
|
+
else capture_data.dict()
|
|
574
|
+
)
|
|
575
|
+
await self.memory_provider.save_capture(
|
|
576
|
+
user_id=user_id,
|
|
577
|
+
capture_name=capture_name,
|
|
578
|
+
agent_name=agent_name,
|
|
579
|
+
data=data_dict,
|
|
580
|
+
schema=capture_schema,
|
|
581
|
+
)
|
|
582
|
+
except Exception as e:
|
|
583
|
+
logger.error(f"Error saving capture: {e}")
|
|
584
|
+
|
|
289
585
|
except Exception as e:
|
|
290
586
|
import traceback
|
|
291
587
|
|
|
@@ -304,52 +600,29 @@ class QueryService(QueryServiceInterface):
|
|
|
304
600
|
yield chunk
|
|
305
601
|
except Exception as tts_e:
|
|
306
602
|
logger.error(f"Error during TTS for error message: {tts_e}")
|
|
307
|
-
# Fallback to yielding text error if TTS fails
|
|
308
603
|
yield error_msg + f" (TTS Error: {tts_e})"
|
|
309
604
|
else:
|
|
310
605
|
yield error_msg
|
|
311
606
|
|
|
312
607
|
async def delete_user_history(self, user_id: str) -> None:
|
|
313
|
-
"""Delete all conversation history for a user.
|
|
314
|
-
|
|
315
|
-
Args:
|
|
316
|
-
user_id: User ID
|
|
317
|
-
"""
|
|
608
|
+
"""Delete all conversation history for a user."""
|
|
318
609
|
if self.memory_provider:
|
|
319
610
|
try:
|
|
320
611
|
await self.memory_provider.delete(user_id)
|
|
321
|
-
logger.info(f"Deleted conversation history for user: {user_id}")
|
|
322
612
|
except Exception as e:
|
|
323
|
-
logger.error(
|
|
324
|
-
f"Error deleting user history for {user_id}: {e}", exc_info=True
|
|
325
|
-
)
|
|
613
|
+
logger.error(f"Error deleting user history for {user_id}: {e}")
|
|
326
614
|
else:
|
|
327
|
-
logger.
|
|
328
|
-
"Attempted to delete user history, but no memory provider is configured."
|
|
329
|
-
)
|
|
615
|
+
logger.debug("No memory provider; skip delete_user_history")
|
|
330
616
|
|
|
331
617
|
async def get_user_history(
|
|
332
618
|
self,
|
|
333
619
|
user_id: str,
|
|
334
620
|
page_num: int = 1,
|
|
335
621
|
page_size: int = 20,
|
|
336
|
-
sort_order: str = "desc",
|
|
622
|
+
sort_order: str = "desc",
|
|
337
623
|
) -> Dict[str, Any]:
|
|
338
|
-
"""Get paginated message history for a user.
|
|
339
|
-
|
|
340
|
-
Args:
|
|
341
|
-
user_id: User ID
|
|
342
|
-
page_num: Page number (starting from 1)
|
|
343
|
-
page_size: Number of messages per page
|
|
344
|
-
sort_order: Sort order ("asc" or "desc")
|
|
345
|
-
|
|
346
|
-
Returns:
|
|
347
|
-
Dictionary with paginated results and metadata.
|
|
348
|
-
"""
|
|
624
|
+
"""Get paginated message history for a user."""
|
|
349
625
|
if not self.memory_provider:
|
|
350
|
-
logger.warning(
|
|
351
|
-
"Attempted to get user history, but no memory provider is configured."
|
|
352
|
-
)
|
|
353
626
|
return {
|
|
354
627
|
"data": [],
|
|
355
628
|
"total": 0,
|
|
@@ -358,20 +631,13 @@ class QueryService(QueryServiceInterface):
|
|
|
358
631
|
"total_pages": 0,
|
|
359
632
|
"error": "Memory provider not available",
|
|
360
633
|
}
|
|
361
|
-
|
|
362
634
|
try:
|
|
363
|
-
# Calculate skip and limit for pagination
|
|
364
635
|
skip = (page_num - 1) * page_size
|
|
365
|
-
|
|
366
|
-
# Get total count of documents
|
|
367
636
|
total = self.memory_provider.count_documents(
|
|
368
637
|
collection="conversations", query={"user_id": user_id}
|
|
369
638
|
)
|
|
370
|
-
|
|
371
|
-
# Calculate total pages
|
|
372
639
|
total_pages = (total + page_size - 1) // page_size if total > 0 else 0
|
|
373
640
|
|
|
374
|
-
# Get paginated results
|
|
375
641
|
conversations = self.memory_provider.find(
|
|
376
642
|
collection="conversations",
|
|
377
643
|
query={"user_id": user_id},
|
|
@@ -380,39 +646,27 @@ class QueryService(QueryServiceInterface):
|
|
|
380
646
|
limit=page_size,
|
|
381
647
|
)
|
|
382
648
|
|
|
383
|
-
|
|
384
|
-
formatted_conversations = []
|
|
649
|
+
formatted: List[Dict[str, Any]] = []
|
|
385
650
|
for conv in conversations:
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
else None
|
|
390
|
-
)
|
|
391
|
-
# Assuming the stored format matches what _store_conversation saves
|
|
392
|
-
# (which currently only stores text messages)
|
|
393
|
-
formatted_conversations.append(
|
|
651
|
+
ts = conv.get("timestamp")
|
|
652
|
+
ts_epoch = int(ts.timestamp()) if ts else None
|
|
653
|
+
formatted.append(
|
|
394
654
|
{
|
|
395
655
|
"id": str(conv.get("_id")),
|
|
396
|
-
"user_message": conv.get("user_message"),
|
|
397
|
-
"assistant_message": conv.get(
|
|
398
|
-
|
|
399
|
-
), # Or how it's stored
|
|
400
|
-
"timestamp": timestamp,
|
|
656
|
+
"user_message": conv.get("user_message"),
|
|
657
|
+
"assistant_message": conv.get("assistant_message"),
|
|
658
|
+
"timestamp": ts_epoch,
|
|
401
659
|
}
|
|
402
660
|
)
|
|
403
661
|
|
|
404
|
-
logger.info(
|
|
405
|
-
f"Retrieved page {page_num}/{total_pages} of history for user {user_id}"
|
|
406
|
-
)
|
|
407
662
|
return {
|
|
408
|
-
"data":
|
|
663
|
+
"data": formatted,
|
|
409
664
|
"total": total,
|
|
410
665
|
"page": page_num,
|
|
411
666
|
"page_size": page_size,
|
|
412
667
|
"total_pages": total_pages,
|
|
413
668
|
"error": None,
|
|
414
669
|
}
|
|
415
|
-
|
|
416
670
|
except Exception as e:
|
|
417
671
|
import traceback
|
|
418
672
|
|
|
@@ -431,30 +685,59 @@ class QueryService(QueryServiceInterface):
|
|
|
431
685
|
async def _store_conversation(
|
|
432
686
|
self, user_id: str, user_message: str, assistant_message: str
|
|
433
687
|
) -> None:
|
|
434
|
-
"""Store conversation history in memory provider.
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
# doesn't explicitly handle image data storage in history.
|
|
445
|
-
await self.memory_provider.store(
|
|
446
|
-
user_id,
|
|
447
|
-
[
|
|
448
|
-
{"role": "user", "content": user_message},
|
|
449
|
-
{"role": "assistant", "content": assistant_message},
|
|
450
|
-
],
|
|
451
|
-
)
|
|
452
|
-
logger.info(f"Stored conversation for user {user_id}")
|
|
453
|
-
except Exception as e:
|
|
454
|
-
logger.error(
|
|
455
|
-
f"Error storing conversation for user {user_id}: {e}", exc_info=True
|
|
456
|
-
)
|
|
457
|
-
else:
|
|
458
|
-
logger.debug(
|
|
459
|
-
"Memory provider not configured, skipping conversation storage."
|
|
688
|
+
"""Store conversation history in memory provider."""
|
|
689
|
+
if not self.memory_provider:
|
|
690
|
+
return
|
|
691
|
+
try:
|
|
692
|
+
await self.memory_provider.store(
|
|
693
|
+
user_id,
|
|
694
|
+
[
|
|
695
|
+
{"role": "user", "content": user_message},
|
|
696
|
+
{"role": "assistant", "content": assistant_message},
|
|
697
|
+
],
|
|
460
698
|
)
|
|
699
|
+
except Exception as e:
|
|
700
|
+
logger.error(f"Store conversation error for {user_id}: {e}")
|
|
701
|
+
|
|
702
|
+
def _build_model_from_json_schema(
|
|
703
|
+
self, name: str, schema: Dict[str, Any]
|
|
704
|
+
) -> Type[BaseModel]:
|
|
705
|
+
"""Create a Pydantic model dynamically from a JSON Schema subset."""
|
|
706
|
+
from pydantic import create_model
|
|
707
|
+
|
|
708
|
+
def py_type(js: Dict[str, Any]):
|
|
709
|
+
t = js.get("type")
|
|
710
|
+
if isinstance(t, list):
|
|
711
|
+
non_null = [x for x in t if x != "null"]
|
|
712
|
+
if not non_null:
|
|
713
|
+
return Optional[Any]
|
|
714
|
+
base = py_type({"type": non_null[0]})
|
|
715
|
+
return Optional[base]
|
|
716
|
+
if t == "string":
|
|
717
|
+
return str
|
|
718
|
+
if t == "integer":
|
|
719
|
+
return int
|
|
720
|
+
if t == "number":
|
|
721
|
+
return float
|
|
722
|
+
if t == "boolean":
|
|
723
|
+
return bool
|
|
724
|
+
if t == "array":
|
|
725
|
+
items = js.get("items", {"type": "string"})
|
|
726
|
+
return List[py_type(items)]
|
|
727
|
+
if t == "object":
|
|
728
|
+
return Dict[str, Any]
|
|
729
|
+
return Any
|
|
730
|
+
|
|
731
|
+
properties: Dict[str, Any] = schema.get("properties", {})
|
|
732
|
+
required = set(schema.get("required", []))
|
|
733
|
+
fields: Dict[str, Any] = {}
|
|
734
|
+
for field_name, field_schema in properties.items():
|
|
735
|
+
typ = py_type(field_schema)
|
|
736
|
+
default = field_schema.get("default")
|
|
737
|
+
if field_name in required and default is None:
|
|
738
|
+
fields[field_name] = (typ, ...)
|
|
739
|
+
else:
|
|
740
|
+
fields[field_name] = (typ, default)
|
|
741
|
+
|
|
742
|
+
Model = create_model(name, **fields) # type: ignore
|
|
743
|
+
return Model
|