solana-agent 31.1.4__tar.gz → 31.1.6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. {solana_agent-31.1.4 → solana_agent-31.1.6}/PKG-INFO +4 -5
  2. {solana_agent-31.1.4 → solana_agent-31.1.6}/README.md +1 -2
  3. {solana_agent-31.1.4 → solana_agent-31.1.6}/pyproject.toml +7 -7
  4. {solana_agent-31.1.4 → solana_agent-31.1.6}/solana_agent/adapters/openai_adapter.py +71 -0
  5. {solana_agent-31.1.4 → solana_agent-31.1.6}/solana_agent/factories/agent_factory.py +1 -12
  6. {solana_agent-31.1.4 → solana_agent-31.1.6}/solana_agent/interfaces/providers/llm.py +17 -0
  7. {solana_agent-31.1.4 → solana_agent-31.1.6}/solana_agent/repositories/memory.py +30 -51
  8. {solana_agent-31.1.4 → solana_agent-31.1.6}/solana_agent/services/agent.py +168 -116
  9. {solana_agent-31.1.4 → solana_agent-31.1.6}/solana_agent/services/query.py +208 -15
  10. {solana_agent-31.1.4 → solana_agent-31.1.6}/LICENSE +0 -0
  11. {solana_agent-31.1.4 → solana_agent-31.1.6}/solana_agent/__init__.py +0 -0
  12. {solana_agent-31.1.4 → solana_agent-31.1.6}/solana_agent/adapters/__init__.py +0 -0
  13. {solana_agent-31.1.4 → solana_agent-31.1.6}/solana_agent/adapters/mongodb_adapter.py +0 -0
  14. {solana_agent-31.1.4 → solana_agent-31.1.6}/solana_agent/adapters/pinecone_adapter.py +0 -0
  15. {solana_agent-31.1.4 → solana_agent-31.1.6}/solana_agent/cli.py +0 -0
  16. {solana_agent-31.1.4 → solana_agent-31.1.6}/solana_agent/client/__init__.py +0 -0
  17. {solana_agent-31.1.4 → solana_agent-31.1.6}/solana_agent/client/solana_agent.py +0 -0
  18. {solana_agent-31.1.4 → solana_agent-31.1.6}/solana_agent/domains/__init__.py +0 -0
  19. {solana_agent-31.1.4 → solana_agent-31.1.6}/solana_agent/domains/agent.py +0 -0
  20. {solana_agent-31.1.4 → solana_agent-31.1.6}/solana_agent/domains/routing.py +0 -0
  21. {solana_agent-31.1.4 → solana_agent-31.1.6}/solana_agent/factories/__init__.py +0 -0
  22. {solana_agent-31.1.4 → solana_agent-31.1.6}/solana_agent/guardrails/pii.py +0 -0
  23. {solana_agent-31.1.4 → solana_agent-31.1.6}/solana_agent/interfaces/__init__.py +0 -0
  24. {solana_agent-31.1.4 → solana_agent-31.1.6}/solana_agent/interfaces/client/client.py +0 -0
  25. {solana_agent-31.1.4 → solana_agent-31.1.6}/solana_agent/interfaces/guardrails/guardrails.py +0 -0
  26. {solana_agent-31.1.4 → solana_agent-31.1.6}/solana_agent/interfaces/plugins/plugins.py +0 -0
  27. {solana_agent-31.1.4 → solana_agent-31.1.6}/solana_agent/interfaces/providers/data_storage.py +0 -0
  28. {solana_agent-31.1.4 → solana_agent-31.1.6}/solana_agent/interfaces/providers/memory.py +0 -0
  29. {solana_agent-31.1.4 → solana_agent-31.1.6}/solana_agent/interfaces/providers/vector_storage.py +0 -0
  30. {solana_agent-31.1.4 → solana_agent-31.1.6}/solana_agent/interfaces/services/agent.py +0 -0
  31. {solana_agent-31.1.4 → solana_agent-31.1.6}/solana_agent/interfaces/services/knowledge_base.py +0 -0
  32. {solana_agent-31.1.4 → solana_agent-31.1.6}/solana_agent/interfaces/services/query.py +0 -0
  33. {solana_agent-31.1.4 → solana_agent-31.1.6}/solana_agent/interfaces/services/routing.py +0 -0
  34. {solana_agent-31.1.4 → solana_agent-31.1.6}/solana_agent/plugins/__init__.py +0 -0
  35. {solana_agent-31.1.4 → solana_agent-31.1.6}/solana_agent/plugins/manager.py +0 -0
  36. {solana_agent-31.1.4 → solana_agent-31.1.6}/solana_agent/plugins/registry.py +0 -0
  37. {solana_agent-31.1.4 → solana_agent-31.1.6}/solana_agent/plugins/tools/__init__.py +0 -0
  38. {solana_agent-31.1.4 → solana_agent-31.1.6}/solana_agent/plugins/tools/auto_tool.py +0 -0
  39. {solana_agent-31.1.4 → solana_agent-31.1.6}/solana_agent/repositories/__init__.py +0 -0
  40. {solana_agent-31.1.4 → solana_agent-31.1.6}/solana_agent/services/__init__.py +0 -0
  41. {solana_agent-31.1.4 → solana_agent-31.1.6}/solana_agent/services/knowledge_base.py +0 -0
  42. {solana_agent-31.1.4 → solana_agent-31.1.6}/solana_agent/services/routing.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: solana-agent
3
- Version: 31.1.4
3
+ Version: 31.1.6
4
4
  Summary: AI Agents for Solana
5
5
  License: MIT
6
6
  Keywords: solana,solana ai,solana agent,ai,ai agent,ai agents
@@ -15,10 +15,10 @@ Classifier: Programming Language :: Python :: 3.12
15
15
  Classifier: Programming Language :: Python :: 3.13
16
16
  Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
17
17
  Requires-Dist: instructor (==1.11.2)
18
- Requires-Dist: llama-index-core (==0.13.3)
18
+ Requires-Dist: llama-index-core (==0.13.5)
19
19
  Requires-Dist: llama-index-embeddings-openai (==0.5.0)
20
20
  Requires-Dist: logfire (==4.3.6)
21
- Requires-Dist: openai (==1.102.0)
21
+ Requires-Dist: openai (==1.106.1)
22
22
  Requires-Dist: pillow (==11.3.0)
23
23
  Requires-Dist: pinecone[asyncio] (==7.3.0)
24
24
  Requires-Dist: pydantic (>=2)
@@ -52,7 +52,7 @@ Build your AI agents in three lines of code!
52
52
  ## Why?
53
53
  * Three lines of code setup
54
54
  * Simple Agent Definition
55
- * Fast Responses
55
+ * Fast & Streaming Responses
56
56
  * Solana Integration
57
57
  * Multi-Agent Swarm
58
58
  * Multi-Modal (Images & Audio & Text)
@@ -361,7 +361,6 @@ config = {
361
361
  "instructions": "You provide friendly, helpful customer support responses.",
362
362
  "specialization": "Customer inquiries",
363
363
  "capture_name": "contact_info",
364
- "capture_mode": "once",
365
364
  "capture_schema": {
366
365
  "type": "object",
367
366
  "properties": {
@@ -17,7 +17,7 @@ Build your AI agents in three lines of code!
17
17
  ## Why?
18
18
  * Three lines of code setup
19
19
  * Simple Agent Definition
20
- * Fast Responses
20
+ * Fast & Streaming Responses
21
21
  * Solana Integration
22
22
  * Multi-Agent Swarm
23
23
  * Multi-Modal (Images & Audio & Text)
@@ -326,7 +326,6 @@ config = {
326
326
  "instructions": "You provide friendly, helpful customer support responses.",
327
327
  "specialization": "Customer inquiries",
328
328
  "capture_name": "contact_info",
329
- "capture_mode": "once",
330
329
  "capture_schema": {
331
330
  "type": "object",
332
331
  "properties": {
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "solana-agent"
3
- version = "31.1.4"
3
+ version = "31.1.6"
4
4
  description = "AI Agents for Solana"
5
5
  authors = ["Bevan Hunt <bevan@bevanhunt.com>"]
6
6
  license = "MIT"
@@ -24,13 +24,13 @@ testpaths = ["tests"]
24
24
 
25
25
  [tool.poetry.dependencies]
26
26
  python = ">=3.12,<4.0"
27
- openai = "1.102.0"
27
+ openai = "1.106.1"
28
28
  pydantic = ">=2"
29
29
  pymongo = "4.14.1"
30
30
  zep-cloud = "3.4.3"
31
31
  instructor = "1.11.2"
32
32
  pinecone = { version = "7.3.0", extras = ["asyncio"] }
33
- llama-index-core = "0.13.3"
33
+ llama-index-core = "0.13.5"
34
34
  llama-index-embeddings-openai = "0.5.0"
35
35
  pypdf = "6.0.0"
36
36
  scrubadub = "2.0.1"
@@ -40,17 +40,17 @@ rich = ">=13,<14.0"
40
40
  pillow = "11.3.0"
41
41
 
42
42
  [tool.poetry.group.dev.dependencies]
43
- pytest = "^8.4.0"
43
+ pytest = "^8.4.2"
44
44
  pytest-cov = "^6.1.1"
45
45
  pytest-asyncio = "^1.1.0"
46
- pytest-mock = "^3.14.0"
46
+ pytest-mock = "^3.15.0"
47
47
  pytest-github-actions-annotate-failures = "^0.3.0"
48
48
  sphinx = "^8.2.3"
49
49
  sphinx-rtd-theme = "^3.0.2"
50
50
  myst-parser = "^4.0.1"
51
- sphinx-autobuild = "^2024.10.3"
51
+ sphinx-autobuild = "^2025.08.25"
52
52
  mongomock = "^4.3.0"
53
- ruff = "^0.12.10"
53
+ ruff = "^0.12.12"
54
54
 
55
55
  [tool.poetry.scripts]
56
56
  solana-agent = "solana_agent.cli:app"
@@ -399,6 +399,77 @@ class OpenAIAdapter(LLMProvider):
399
399
  logger.exception(f"Error in generate_text_with_images: {e}")
400
400
  return f"I apologize, but I encountered an unexpected error: {e}"
401
401
 
402
+ async def chat_stream(
403
+ self,
404
+ messages: List[Dict[str, Any]],
405
+ model: Optional[str] = None,
406
+ tools: Optional[List[Dict[str, Any]]] = None,
407
+ ) -> AsyncGenerator[Dict[str, Any], None]: # pragma: no cover
408
+ """Stream chat completions with optional tool calls, yielding normalized events."""
409
+ try:
410
+ request_params: Dict[str, Any] = {
411
+ "messages": messages,
412
+ "model": model or self.text_model,
413
+ "stream": True,
414
+ }
415
+ if tools:
416
+ request_params["tools"] = tools
417
+
418
+ client = self.client
419
+ if self.logfire:
420
+ logfire.instrument_openai(client)
421
+
422
+ stream = await client.chat.completions.create(**request_params)
423
+ async for chunk in stream:
424
+ try:
425
+ if not chunk or not getattr(chunk, "choices", None):
426
+ continue
427
+ ch = chunk.choices[0]
428
+ delta = getattr(ch, "delta", None)
429
+ if delta is None:
430
+ # Some SDKs use 'message' instead of 'delta'
431
+ delta = getattr(ch, "message", None)
432
+ if delta is None:
433
+ # Finish event
434
+ finish = getattr(ch, "finish_reason", None)
435
+ if finish:
436
+ yield {"type": "message_end", "finish_reason": finish}
437
+ continue
438
+
439
+ # Content delta
440
+ content_piece = getattr(delta, "content", None)
441
+ if content_piece:
442
+ yield {"type": "content", "delta": content_piece}
443
+
444
+ # Tool call deltas
445
+ tool_calls = getattr(delta, "tool_calls", None)
446
+ if tool_calls:
447
+ for idx, tc in enumerate(tool_calls):
448
+ try:
449
+ tc_id = getattr(tc, "id", None)
450
+ func = getattr(tc, "function", None)
451
+ name = getattr(func, "name", None) if func else None
452
+ args_piece = (
453
+ getattr(func, "arguments", "") if func else ""
454
+ )
455
+ yield {
456
+ "type": "tool_call_delta",
457
+ "id": tc_id,
458
+ "index": getattr(tc, "index", idx),
459
+ "name": name,
460
+ "arguments_delta": args_piece or "",
461
+ }
462
+ except Exception:
463
+ continue
464
+ except Exception as parse_err:
465
+ logger.debug(f"Error parsing stream chunk: {parse_err}")
466
+ continue
467
+ # End of stream (SDK may not emit finish event in all cases)
468
+ yield {"type": "message_end", "finish_reason": "end_of_stream"}
469
+ except Exception as e:
470
+ logger.exception(f"Error in chat_stream: {e}")
471
+ yield {"type": "error", "error": str(e)}
472
+
402
473
  async def parse_structured_output(
403
474
  self,
404
475
  prompt: str,
@@ -133,12 +133,7 @@ class SolanaAgentFactory:
133
133
  voice=org_config.get("voice", ""),
134
134
  )
135
135
 
136
- # Build capture modes from agent config if provided
137
- capture_modes: Dict[str, str] = {}
138
- for agent in config.get("agents", []):
139
- mode = agent.get("capture_mode")
140
- if mode in {"once", "multiple"} and agent.get("name"):
141
- capture_modes[agent["name"]] = mode
136
+ # capture_mode removed: repository now always upserts/merges per capture
142
137
 
143
138
  # Create repositories
144
139
  memory_provider = None
@@ -148,22 +143,16 @@ class SolanaAgentFactory:
148
143
  "mongo_adapter": db_adapter,
149
144
  "zep_api_key": config["zep"].get("api_key"),
150
145
  }
151
- if capture_modes: # pragma: no cover
152
- mem_kwargs["capture_modes"] = capture_modes
153
146
  memory_provider = MemoryRepository(**mem_kwargs)
154
147
 
155
148
  if "mongo" in config and "zep" not in config:
156
149
  mem_kwargs = {"mongo_adapter": db_adapter}
157
- if capture_modes:
158
- mem_kwargs["capture_modes"] = capture_modes
159
150
  memory_provider = MemoryRepository(**mem_kwargs)
160
151
 
161
152
  if "zep" in config and "mongo" not in config:
162
153
  if "api_key" not in config["zep"]:
163
154
  raise ValueError("Zep API key is required.")
164
155
  mem_kwargs = {"zep_api_key": config["zep"].get("api_key")}
165
- if capture_modes: # pragma: no cover
166
- mem_kwargs["capture_modes"] = capture_modes
167
156
  memory_provider = MemoryRepository(**mem_kwargs)
168
157
 
169
158
  guardrail_config = config.get("guardrails", {})
@@ -33,6 +33,23 @@ class LLMProvider(ABC):
33
33
  """Generate text from the language model."""
34
34
  pass
35
35
 
36
+ @abstractmethod
37
+ async def chat_stream(
38
+ self,
39
+ messages: List[Dict[str, Any]],
40
+ model: Optional[str] = None,
41
+ tools: Optional[List[Dict[str, Any]]] = None,
42
+ ) -> AsyncGenerator[Dict[str, Any], None]:
43
+ """Stream chat completion deltas and tool call deltas.
44
+
45
+ Yields normalized events:
46
+ - {"type": "content", "delta": str}
47
+ - {"type": "tool_call_delta", "id": Optional[str], "index": Optional[int], "name": Optional[str], "arguments_delta": str}
48
+ - {"type": "message_end", "finish_reason": str}
49
+ - {"type": "error", "error": str}
50
+ """
51
+ pass
52
+
36
53
  @abstractmethod
37
54
  async def parse_structured_output(
38
55
  self,
@@ -19,10 +19,7 @@ class MemoryRepository(MemoryProvider):
19
19
  self,
20
20
  mongo_adapter: Optional[MongoDBAdapter] = None,
21
21
  zep_api_key: Optional[str] = None,
22
- capture_modes: Optional[Dict[str, str]] = None,
23
22
  ):
24
- self.capture_modes: Dict[str, str] = capture_modes or {}
25
-
26
23
  # Mongo setup
27
24
  if not mongo_adapter:
28
25
  self.mongo = None
@@ -46,18 +43,15 @@ class MemoryRepository(MemoryProvider):
46
43
  self.mongo.create_index(self.captures_collection, [("capture_name", 1)])
47
44
  self.mongo.create_index(self.captures_collection, [("agent_name", 1)])
48
45
  self.mongo.create_index(self.captures_collection, [("timestamp", 1)])
49
- # Unique only when mode == 'once'
46
+ # Unique per user/agent/capture combo
50
47
  try:
51
48
  self.mongo.create_index(
52
49
  self.captures_collection,
53
50
  [("user_id", 1), ("agent_name", 1), ("capture_name", 1)],
54
51
  unique=True,
55
- partialFilterExpression={"mode": "once"},
56
52
  )
57
53
  except Exception as e:
58
- logger.error(
59
- f"Error creating partial unique index for captures: {e}"
60
- )
54
+ logger.error(f"Error creating unique index for captures: {e}")
61
55
  except Exception as e:
62
56
  logger.error(f"Error initializing MongoDB captures collection: {e}")
63
57
  self.captures_collection = "captures"
@@ -223,54 +217,39 @@ class MemoryRepository(MemoryProvider):
223
217
  raise ValueError("data must be a dictionary")
224
218
 
225
219
  try:
226
- mode = self.capture_modes.get(agent_name, "once") if agent_name else "once"
227
220
  now = datetime.now(timezone.utc)
228
- if mode == "multiple":
229
- doc = {
221
+ key = {
222
+ "user_id": user_id,
223
+ "agent_name": agent_name,
224
+ "capture_name": capture_name,
225
+ }
226
+ existing = self.mongo.find_one(self.captures_collection, key)
227
+ merged_data: Dict[str, Any] = {}
228
+ if existing and isinstance(existing.get("data"), dict):
229
+ merged_data.update(existing.get("data", {}))
230
+ merged_data.update(data or {})
231
+ update_doc = {
232
+ "$set": {
230
233
  "user_id": user_id,
231
234
  "agent_name": agent_name,
232
235
  "capture_name": capture_name,
233
- "data": data or {},
234
- "schema": schema or {},
235
- "mode": "multiple",
236
+ "data": merged_data,
237
+ "schema": (
238
+ schema
239
+ if schema is not None
240
+ else existing.get("schema")
241
+ if existing
242
+ else {}
243
+ ),
236
244
  "timestamp": now,
237
- "created_at": now,
238
- }
239
- return self.mongo.insert_one(self.captures_collection, doc)
240
- else:
241
- key = {
242
- "user_id": user_id,
243
- "agent_name": agent_name,
244
- "capture_name": capture_name,
245
- }
246
- existing = self.mongo.find_one(self.captures_collection, key)
247
- merged_data: Dict[str, Any] = {}
248
- if existing and isinstance(existing.get("data"), dict):
249
- merged_data.update(existing.get("data", {}))
250
- merged_data.update(data or {})
251
- update_doc = {
252
- "$set": {
253
- "user_id": user_id,
254
- "agent_name": agent_name,
255
- "capture_name": capture_name,
256
- "data": merged_data,
257
- "schema": (
258
- schema
259
- if schema is not None
260
- else existing.get("schema")
261
- if existing
262
- else {}
263
- ),
264
- "mode": "once",
265
- "timestamp": now,
266
- },
267
- "$setOnInsert": {"created_at": now},
268
- }
269
- self.mongo.update_one(
270
- self.captures_collection, key, update_doc, upsert=True
271
- )
272
- doc = self.mongo.find_one(self.captures_collection, key)
273
- return str(doc.get("_id")) if doc and doc.get("_id") else None
245
+ },
246
+ "$setOnInsert": {"created_at": now},
247
+ }
248
+ self.mongo.update_one(
249
+ self.captures_collection, key, update_doc, upsert=True
250
+ )
251
+ doc = self.mongo.find_one(self.captures_collection, key)
252
+ return str(doc.get("_id")) if doc and doc.get("_id") else None
274
253
  except Exception as e: # pragma: no cover
275
254
  logger.error(f"MongoDB save_capture error: {e}")
276
255
  return None
@@ -265,56 +265,57 @@ class AgentService(AgentServiceInterface):
265
265
  prompt: Optional[str] = None,
266
266
  output_model: Optional[Type[BaseModel]] = None,
267
267
  ) -> AsyncGenerator[Union[str, bytes, BaseModel], None]: # pragma: no cover
268
- """Generate a response using OpenAI function calling (tools API) or structured output."""
269
-
270
- agent = next((a for a in self.agents if a.name == agent_name), None)
271
- if not agent:
272
- error_msg = f"Agent '{agent_name}' not found."
273
- logger.warning(error_msg)
274
- if output_format == "audio":
275
- async for chunk in self.llm_provider.tts(
276
- error_msg,
277
- instructions=audio_instructions,
278
- response_format=audio_output_format,
279
- voice=audio_voice,
280
- ):
281
- yield chunk
282
- else:
283
- yield error_msg
284
- return
285
-
286
- # Build system prompt and messages
287
- system_prompt = self.get_agent_system_prompt(agent_name)
288
- user_content = str(query)
289
- if images:
290
- user_content += "\n\n[Images attached]"
291
-
292
- # Compose the prompt for generate_text
293
- full_prompt = ""
294
- if memory_context:
295
- full_prompt += f"CONVERSATION HISTORY:\n{memory_context}\n\n Always use your tools to perform actions and don't rely on your memory!\n\n"
296
- if prompt:
297
- full_prompt += f"ADDITIONAL PROMPT:\n{prompt}\n\n"
298
- full_prompt += user_content
299
- full_prompt += f"USER IDENTIFIER: {user_id}"
300
-
301
- # Get OpenAI function schemas for this agent's tools
302
- tools = [
303
- {
304
- "type": "function",
305
- "function": {
306
- "name": tool["name"],
307
- "description": tool.get("description", ""),
308
- "parameters": tool.get("parameters", {}),
309
- "strict": True,
310
- },
311
- }
312
- for tool in self.get_agent_tools(agent_name)
313
- ]
268
+ """Generate a response using tool-calling with full streaming support."""
314
269
 
315
270
  try:
271
+ # Validate agent
272
+ agent = next((a for a in self.agents if a.name == agent_name), None)
273
+ if not agent:
274
+ error_msg = f"Agent '{agent_name}' not found."
275
+ logger.warning(error_msg)
276
+ if output_format == "audio":
277
+ async for chunk in self.llm_provider.tts(
278
+ error_msg,
279
+ instructions=audio_instructions,
280
+ response_format=audio_output_format,
281
+ voice=audio_voice,
282
+ ):
283
+ yield chunk
284
+ else:
285
+ yield error_msg
286
+ return
287
+
288
+ # Build system prompt and messages
289
+ system_prompt = self.get_agent_system_prompt(agent_name)
290
+ user_content = str(query)
291
+ if images:
292
+ user_content += "\n\n[Images attached]"
293
+
294
+ # Compose the prompt for generate_text
295
+ full_prompt = ""
296
+ if memory_context:
297
+ full_prompt += f"CONVERSATION HISTORY:\n{memory_context}\n\n Always use your tools to perform actions and don't rely on your memory!\n\n"
298
+ if prompt:
299
+ full_prompt += f"ADDITIONAL PROMPT:\n{prompt}\n\n"
300
+ full_prompt += user_content
301
+ full_prompt += f"USER IDENTIFIER: {user_id}"
302
+
303
+ # Get OpenAI function schemas for this agent's tools
304
+ tools = [
305
+ {
306
+ "type": "function",
307
+ "function": {
308
+ "name": tool["name"],
309
+ "description": tool.get("description", ""),
310
+ "parameters": tool.get("parameters", {}),
311
+ "strict": True,
312
+ },
313
+ }
314
+ for tool in self.get_agent_tools(agent_name)
315
+ ]
316
+
317
+ # Structured output path
316
318
  if output_model is not None:
317
- # --- Structured output with tool support ---
318
319
  model_instance = await self.llm_provider.parse_structured_output(
319
320
  prompt=full_prompt,
320
321
  system_prompt=system_prompt,
@@ -327,83 +328,131 @@ class AgentService(AgentServiceInterface):
327
328
  yield model_instance
328
329
  return
329
330
 
330
- # --- Streaming text/audio with tool support (as before) ---
331
- response_text = ""
332
- while True:
333
- if not images:
334
- response = await self.llm_provider.generate_text(
335
- prompt=full_prompt,
336
- system_prompt=system_prompt,
337
- api_key=self.api_key,
338
- base_url=self.base_url,
339
- model=self.model,
340
- tools=tools if tools else None,
341
- )
331
+ # Vision fallback (non-streaming for now)
332
+ if images:
333
+ vision_text = await self.llm_provider.generate_text_with_images(
334
+ prompt=full_prompt, images=images, system_prompt=system_prompt
335
+ )
336
+ if output_format == "audio":
337
+ cleaned_audio_buffer = self._clean_for_audio(vision_text)
338
+ async for audio_chunk in self.llm_provider.tts(
339
+ text=cleaned_audio_buffer,
340
+ voice=audio_voice,
341
+ response_format=audio_output_format,
342
+ instructions=audio_instructions,
343
+ ):
344
+ yield audio_chunk
342
345
  else:
343
- response = await self.llm_provider.generate_text_with_images(
344
- prompt=full_prompt,
345
- system_prompt=system_prompt,
346
- api_key=self.api_key,
347
- base_url=self.base_url,
348
- model=self.model,
349
- tools=tools if tools else None,
350
- images=images,
351
- )
352
- if (
353
- not response
354
- or not hasattr(response, "choices")
355
- or not response.choices
346
+ yield vision_text
347
+ return
348
+
349
+ # Build initial messages for chat streaming
350
+ messages: List[Dict[str, Any]] = []
351
+ if system_prompt:
352
+ messages.append({"role": "system", "content": system_prompt})
353
+ messages.append({"role": "user", "content": full_prompt})
354
+
355
+ accumulated_text = ""
356
+
357
+ # Loop to handle tool calls in streaming mode
358
+ while True:
359
+ # Aggregate tool calls by index and merge late IDs
360
+ tool_calls: Dict[int, Dict[str, Any]] = {}
361
+
362
+ async for event in self.llm_provider.chat_stream(
363
+ messages=messages,
364
+ model=self.model,
365
+ tools=tools if tools else None,
356
366
  ):
357
- logger.error("No response or choices from LLM provider.")
358
- response_text = "I apologize, but I could not generate a response."
359
- break
360
-
361
- choice = response.choices[0]
362
- message = getattr(choice, "message", choice)
363
-
364
- if hasattr(message, "tool_calls") and message.tool_calls:
365
- for tool_call in message.tool_calls:
366
- if tool_call.type == "function":
367
- function_name = tool_call.function.name
368
- arguments = json.loads(tool_call.function.arguments)
369
- logger.info(
370
- f"Model requested tool '{function_name}' with args: {arguments}"
371
- )
372
- # Execute the tool (async)
373
- tool_result = await self.execute_tool(
374
- agent_name, function_name, arguments
375
- )
376
- # Add the tool result to the prompt for the next round
377
- full_prompt += (
378
- f"\n\nTool '{function_name}' was called with arguments {arguments}.\n"
379
- f"Result: {tool_result}\n"
367
+ etype = event.get("type")
368
+ if etype == "content":
369
+ delta = event.get("delta", "")
370
+ accumulated_text += delta
371
+ if output_format == "text":
372
+ yield delta
373
+ elif etype == "tool_call_delta":
374
+ tc_id = event.get("id")
375
+ index_raw = event.get("index")
376
+ try:
377
+ index = int(index_raw) if index_raw is not None else 0
378
+ except Exception:
379
+ index = 0
380
+ name = event.get("name")
381
+ args_piece = event.get("arguments_delta", "")
382
+ entry = tool_calls.setdefault(
383
+ index, {"id": None, "name": None, "arguments": ""}
384
+ )
385
+ if tc_id and not entry.get("id"):
386
+ entry["id"] = tc_id
387
+ if name and not entry.get("name"):
388
+ entry["name"] = name
389
+ entry["arguments"] += args_piece
390
+ elif etype == "message_end":
391
+ _ = event.get("finish_reason")
392
+
393
+ # If tool calls were requested, execute them and continue the loop
394
+ if tool_calls:
395
+ assistant_tool_calls: List[Dict[str, Any]] = []
396
+ call_id_map: Dict[int, str] = {}
397
+ for idx, tc in tool_calls.items():
398
+ name = (tc.get("name") or "").strip()
399
+ if not name:
400
+ logger.warning(
401
+ f"Skipping unnamed tool call at index {idx}; cannot send empty function name."
380
402
  )
381
- continue
403
+ continue
404
+ norm_id = tc.get("id") or f"call_{idx}"
405
+ call_id_map[idx] = norm_id
406
+ assistant_tool_calls.append(
407
+ {
408
+ "id": norm_id,
409
+ "type": "function",
410
+ "function": {
411
+ "name": name,
412
+ "arguments": tc.get("arguments") or "{}",
413
+ },
414
+ }
415
+ )
382
416
 
383
- # Otherwise, it's a normal message (final answer)
384
- response_text = message.content
385
- break
417
+ if assistant_tool_calls:
418
+ messages.append(
419
+ {
420
+ "role": "assistant",
421
+ "content": None,
422
+ "tool_calls": assistant_tool_calls,
423
+ }
424
+ )
386
425
 
387
- # Apply output guardrails if any
388
- processed_final_text = response_text
389
- if self.output_guardrails:
390
- for guardrail in self.output_guardrails:
391
- try:
392
- processed_final_text = await guardrail.process(
393
- processed_final_text
426
+ # Execute each tool and append the tool result messages
427
+ for idx, tc in tool_calls.items():
428
+ func_name = (tc.get("name") or "").strip()
429
+ if not func_name:
430
+ continue
431
+ try:
432
+ args = json.loads(tc.get("arguments") or "{}")
433
+ except Exception:
434
+ args = {}
435
+ logger.info(
436
+ f"Streaming: executing tool '{func_name}' with args: {args}"
394
437
  )
395
- except Exception as e:
396
- logger.error(
397
- f"Error applying output guardrail {guardrail.__class__.__name__}: {e}"
438
+ tool_result = await self.execute_tool(
439
+ agent_name, func_name, args
440
+ )
441
+ messages.append(
442
+ {
443
+ "role": "tool",
444
+ "tool_call_id": call_id_map.get(idx, f"call_{idx}"),
445
+ "content": json.dumps(tool_result),
446
+ }
398
447
  )
399
448
 
400
- self.last_text_response = processed_final_text
449
+ accumulated_text = ""
450
+ continue
401
451
 
402
- if output_format == "text":
403
- yield processed_final_text or ""
404
- elif output_format == "audio":
405
- cleaned_audio_buffer = self._clean_for_audio(processed_final_text)
406
- if cleaned_audio_buffer:
452
+ # No tool calls: we've streamed the final answer
453
+ final_text = accumulated_text
454
+ if output_format == "audio":
455
+ cleaned_audio_buffer = self._clean_for_audio(final_text)
407
456
  async for audio_chunk in self.llm_provider.tts(
408
457
  text=cleaned_audio_buffer,
409
458
  voice=audio_voice,
@@ -412,7 +461,10 @@ class AgentService(AgentServiceInterface):
412
461
  ):
413
462
  yield audio_chunk
414
463
  else:
415
- yield ""
464
+ if not final_text:
465
+ yield ""
466
+ self.last_text_response = final_text
467
+ break
416
468
  except Exception as e:
417
469
  import traceback
418
470
 
@@ -8,7 +8,18 @@ clean separation of concerns.
8
8
 
9
9
  import logging
10
10
  import re
11
- from typing import Any, AsyncGenerator, Dict, List, Literal, Optional, Type, Union
11
+ import time
12
+ from typing import (
13
+ Any,
14
+ AsyncGenerator,
15
+ Dict,
16
+ List,
17
+ Literal,
18
+ Optional,
19
+ Type,
20
+ Union,
21
+ Tuple,
22
+ )
12
23
 
13
24
  from pydantic import BaseModel
14
25
 
@@ -50,6 +61,151 @@ class QueryService(QueryServiceInterface):
50
61
  self.knowledge_base = knowledge_base
51
62
  self.kb_results_count = kb_results_count
52
63
  self.input_guardrails = input_guardrails or []
64
+ # Per-user sticky sessions (in-memory)
65
+ # { user_id: { 'agent': str, 'started_at': float, 'last_updated': float, 'required_complete': bool } }
66
+ self._sticky_sessions: Dict[str, Dict[str, Any]] = {}
67
+
68
+ def _get_sticky_agent(self, user_id: str) -> Optional[str]:
69
+ sess = self._sticky_sessions.get(user_id)
70
+ return sess.get("agent") if isinstance(sess, dict) else None
71
+
72
+ def _set_sticky_agent(
73
+ self, user_id: str, agent_name: str, required_complete: bool = False
74
+ ) -> None:
75
+ self._sticky_sessions[user_id] = {
76
+ "agent": agent_name,
77
+ "started_at": self._sticky_sessions.get(user_id, {}).get(
78
+ "started_at", time.time()
79
+ ),
80
+ "last_updated": time.time(),
81
+ "required_complete": required_complete,
82
+ }
83
+
84
+ def _update_sticky_required_complete(
85
+ self, user_id: str, required_complete: bool
86
+ ) -> None:
87
+ if user_id in self._sticky_sessions:
88
+ self._sticky_sessions[user_id]["required_complete"] = required_complete
89
+ self._sticky_sessions[user_id]["last_updated"] = time.time()
90
+
91
+ def _clear_sticky_agent(self, user_id: str) -> None:
92
+ if user_id in self._sticky_sessions:
93
+ del self._sticky_sessions[user_id]
94
+
95
+ # LLM-backed switch intent detection (gpt-4.1-mini)
96
+ class _SwitchIntentModel(BaseModel):
97
+ switch: bool = False
98
+ target_agent: Optional[str] = None
99
+ start_new: bool = False
100
+
101
+ async def _detect_switch_intent(
102
+ self, text: str, available_agents: List[str]
103
+ ) -> Tuple[bool, Optional[str], bool]:
104
+ """Detect if the user is asking to switch agents or start a new conversation.
105
+
106
+ Returns: (switch_requested, target_agent_name_or_none, start_new_conversation)
107
+ Implemented as an LLM call to gpt-4.1-mini with structured output.
108
+ """
109
+ if not text:
110
+ return (False, None, False)
111
+
112
+ # Instruction and user prompt for the classifier
113
+ instruction = (
114
+ "You are a strict intent classifier for agent routing. "
115
+ "Decide if the user's message requests switching to another agent or starting a new conversation. "
116
+ "Only return JSON with keys: switch (bool), target_agent (string|null), start_new (bool). "
117
+ "If a target agent is mentioned, it MUST be one of the provided agent names (case-insensitive). "
118
+ "If none clearly applies, set switch=false and start_new=false and target_agent=null."
119
+ )
120
+ user_prompt = (
121
+ f"Available agents (choose only from these if a target is specified): {available_agents}\n\n"
122
+ f"User message:\n{text}\n\n"
123
+ 'Return JSON only, like: {"switch": true|false, "target_agent": "<one_of_available_or_null>", "start_new": true|false}'
124
+ )
125
+
126
+ # Primary: use llm_provider.parse_structured_output
127
+ try:
128
+ if hasattr(self.agent_service.llm_provider, "parse_structured_output"):
129
+ try:
130
+ result = (
131
+ await self.agent_service.llm_provider.parse_structured_output(
132
+ prompt=user_prompt,
133
+ system_prompt=instruction,
134
+ model_class=QueryService._SwitchIntentModel,
135
+ model="gpt-4.1-mini",
136
+ )
137
+ )
138
+ except TypeError:
139
+ # Provider may not accept 'model' kwarg
140
+ result = (
141
+ await self.agent_service.llm_provider.parse_structured_output(
142
+ prompt=user_prompt,
143
+ system_prompt=instruction,
144
+ model_class=QueryService._SwitchIntentModel,
145
+ )
146
+ )
147
+ switch = bool(getattr(result, "switch", False))
148
+ target = getattr(result, "target_agent", None)
149
+ start_new = bool(getattr(result, "start_new", False))
150
+ # Normalize target to available agent name
151
+ if target:
152
+ target_lower = target.lower()
153
+ norm = None
154
+ for a in available_agents:
155
+ if a.lower() == target_lower or target_lower in a.lower():
156
+ norm = a
157
+ break
158
+ target = norm
159
+ if not switch:
160
+ target = None
161
+ return (switch, target, start_new)
162
+ except Exception as e:
163
+ logger.debug(f"LLM switch intent parse_structured_output failed: {e}")
164
+
165
+ # Fallback: generate_response with output_model
166
+ try:
167
+ async for r in self.agent_service.generate_response(
168
+ agent_name="default",
169
+ user_id="router",
170
+ query="",
171
+ images=None,
172
+ memory_context="",
173
+ output_format="text",
174
+ prompt=f"{instruction}\n\n{user_prompt}",
175
+ output_model=QueryService._SwitchIntentModel,
176
+ ):
177
+ result = r
178
+ switch = False
179
+ target = None
180
+ start_new = False
181
+ try:
182
+ switch = bool(result.switch) # type: ignore[attr-defined]
183
+ target = result.target_agent # type: ignore[attr-defined]
184
+ start_new = bool(result.start_new) # type: ignore[attr-defined]
185
+ except Exception:
186
+ try:
187
+ d = result.model_dump()
188
+ switch = bool(d.get("switch", False))
189
+ target = d.get("target_agent")
190
+ start_new = bool(d.get("start_new", False))
191
+ except Exception:
192
+ pass
193
+ if target:
194
+ target_lower = str(target).lower()
195
+ norm = None
196
+ for a in available_agents:
197
+ if a.lower() == target_lower or target_lower in a.lower():
198
+ norm = a
199
+ break
200
+ target = norm
201
+ if not switch:
202
+ target = None
203
+ return (switch, target, start_new)
204
+ except Exception as e:
205
+ logger.debug(f"LLM switch intent generate_response failed: {e}")
206
+
207
+ # Last resort: no switch
208
+ return (False, None, False)
53
209
 
54
210
  async def process(
55
211
  self,
@@ -80,7 +236,7 @@ class QueryService(QueryServiceInterface):
80
236
  router: Optional[RoutingServiceInterface] = None,
81
237
  output_model: Optional[Type[BaseModel]] = None,
82
238
  capture_schema: Optional[Dict[str, Any]] = None,
83
- capture_name: Optional[Dict[str, Any]] = None,
239
+ capture_name: Optional[str] = None,
84
240
  ) -> AsyncGenerator[Union[str, bytes, BaseModel], None]: # pragma: no cover
85
241
  """Process the user request and generate a response."""
86
242
  try:
@@ -164,7 +320,7 @@ class QueryService(QueryServiceInterface):
164
320
  except Exception:
165
321
  kb_context = ""
166
322
 
167
- # 6) Route query (and fetch previous assistant message)
323
+ # 6) Determine agent (sticky session aware; allow explicit switch/new conversation)
168
324
  agent_name = "default"
169
325
  prev_assistant = ""
170
326
  routing_input = user_text
@@ -184,19 +340,52 @@ class QueryService(QueryServiceInterface):
184
340
  "assistant_message", ""
185
341
  ) or ""
186
342
  if prev_user_msg:
187
- routing_input = (
188
- f"previous_user_message: {prev_user_msg}\n"
189
- f"current_user_message: {user_text}"
190
- )
343
+ routing_input = f"previous_user_message: {prev_user_msg}\ncurrent_user_message: {user_text}"
191
344
  except Exception:
192
345
  pass
193
- try:
194
- if router:
195
- agent_name = await router.route_query(routing_input)
196
- else:
197
- agent_name = await self.routing_service.route_query(routing_input)
198
- except Exception:
199
- agent_name = "default"
346
+
347
+ # Get available agents first so the LLM can select a valid target
348
+ agents = self.agent_service.get_all_ai_agents() or {}
349
+ available_agent_names = list(agents.keys())
350
+
351
+ # LLM detects switch intent
352
+ (
353
+ switch_requested,
354
+ requested_agent_raw,
355
+ start_new,
356
+ ) = await self._detect_switch_intent(user_text, available_agent_names)
357
+
358
+ # Normalize requested agent to an exact available key
359
+ requested_agent = None
360
+ if requested_agent_raw:
361
+ raw_lower = requested_agent_raw.lower()
362
+ for a in available_agent_names:
363
+ if a.lower() == raw_lower or raw_lower in a.lower():
364
+ requested_agent = a
365
+ break
366
+
367
+ sticky_agent = self._get_sticky_agent(user_id)
368
+
369
+ if sticky_agent and not switch_requested:
370
+ agent_name = sticky_agent
371
+ else:
372
+ try:
373
+ if start_new:
374
+ # Start fresh
375
+ self._clear_sticky_agent(user_id)
376
+ if requested_agent:
377
+ agent_name = requested_agent
378
+ else:
379
+ # Route if no explicit target
380
+ if router:
381
+ agent_name = await router.route_query(routing_input)
382
+ else:
383
+ agent_name = await self.routing_service.route_query(
384
+ routing_input
385
+ )
386
+ except Exception:
387
+ agent_name = next(iter(agents.keys())) if agents else "default"
388
+ self._set_sticky_agent(user_id, agent_name, required_complete=False)
200
389
 
201
390
  # 7) Captured data context + incremental save using previous assistant message
202
391
  capture_context = ""
@@ -285,7 +474,6 @@ class QueryService(QueryServiceInterface):
285
474
  system_prompt=instruction,
286
475
  model_class=_FieldDetect,
287
476
  )
288
- # Read result
289
477
  sel = None
290
478
  try:
291
479
  sel = getattr(result, "field", None)
@@ -544,6 +732,11 @@ class QueryService(QueryServiceInterface):
544
732
 
545
733
  if lines:
546
734
  capture_context = "\n".join(lines) + "\n\n"
735
+ # Update sticky session completion flag
736
+ try:
737
+ self._update_sticky_required_complete(user_id, required_complete)
738
+ except Exception:
739
+ pass
547
740
 
548
741
  # Merge contexts + flow rules
549
742
  combined_context = ""
File without changes