open-swarm 0.1.1743070217__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- open_swarm-0.1.1743070217.dist-info/METADATA +258 -0
- open_swarm-0.1.1743070217.dist-info/RECORD +89 -0
- open_swarm-0.1.1743070217.dist-info/WHEEL +5 -0
- open_swarm-0.1.1743070217.dist-info/entry_points.txt +3 -0
- open_swarm-0.1.1743070217.dist-info/licenses/LICENSE +21 -0
- open_swarm-0.1.1743070217.dist-info/top_level.txt +1 -0
- swarm/__init__.py +3 -0
- swarm/agent/__init__.py +7 -0
- swarm/agent/agent.py +49 -0
- swarm/apps.py +53 -0
- swarm/auth.py +56 -0
- swarm/consumers.py +141 -0
- swarm/core.py +326 -0
- swarm/extensions/__init__.py +1 -0
- swarm/extensions/blueprint/__init__.py +36 -0
- swarm/extensions/blueprint/agent_utils.py +45 -0
- swarm/extensions/blueprint/blueprint_base.py +562 -0
- swarm/extensions/blueprint/blueprint_discovery.py +112 -0
- swarm/extensions/blueprint/blueprint_utils.py +17 -0
- swarm/extensions/blueprint/common_utils.py +12 -0
- swarm/extensions/blueprint/django_utils.py +203 -0
- swarm/extensions/blueprint/interactive_mode.py +102 -0
- swarm/extensions/blueprint/modes/rest_mode.py +37 -0
- swarm/extensions/blueprint/output_utils.py +95 -0
- swarm/extensions/blueprint/spinner.py +91 -0
- swarm/extensions/cli/__init__.py +0 -0
- swarm/extensions/cli/blueprint_runner.py +251 -0
- swarm/extensions/cli/cli_args.py +88 -0
- swarm/extensions/cli/commands/__init__.py +0 -0
- swarm/extensions/cli/commands/blueprint_management.py +31 -0
- swarm/extensions/cli/commands/config_management.py +15 -0
- swarm/extensions/cli/commands/edit_config.py +77 -0
- swarm/extensions/cli/commands/list_blueprints.py +22 -0
- swarm/extensions/cli/commands/validate_env.py +57 -0
- swarm/extensions/cli/commands/validate_envvars.py +39 -0
- swarm/extensions/cli/interactive_shell.py +41 -0
- swarm/extensions/cli/main.py +36 -0
- swarm/extensions/cli/selection.py +43 -0
- swarm/extensions/cli/utils/discover_commands.py +32 -0
- swarm/extensions/cli/utils/env_setup.py +15 -0
- swarm/extensions/cli/utils.py +105 -0
- swarm/extensions/config/__init__.py +6 -0
- swarm/extensions/config/config_loader.py +208 -0
- swarm/extensions/config/config_manager.py +258 -0
- swarm/extensions/config/server_config.py +49 -0
- swarm/extensions/config/setup_wizard.py +103 -0
- swarm/extensions/config/utils/__init__.py +0 -0
- swarm/extensions/config/utils/logger.py +36 -0
- swarm/extensions/launchers/__init__.py +1 -0
- swarm/extensions/launchers/build_launchers.py +14 -0
- swarm/extensions/launchers/build_swarm_wrapper.py +12 -0
- swarm/extensions/launchers/swarm_api.py +68 -0
- swarm/extensions/launchers/swarm_cli.py +304 -0
- swarm/extensions/launchers/swarm_wrapper.py +29 -0
- swarm/extensions/mcp/__init__.py +1 -0
- swarm/extensions/mcp/cache_utils.py +36 -0
- swarm/extensions/mcp/mcp_client.py +341 -0
- swarm/extensions/mcp/mcp_constants.py +7 -0
- swarm/extensions/mcp/mcp_tool_provider.py +110 -0
- swarm/llm/chat_completion.py +195 -0
- swarm/messages.py +132 -0
- swarm/migrations/0010_initial_chat_models.py +51 -0
- swarm/migrations/__init__.py +0 -0
- swarm/models.py +45 -0
- swarm/repl/__init__.py +1 -0
- swarm/repl/repl.py +87 -0
- swarm/serializers.py +12 -0
- swarm/settings.py +189 -0
- swarm/tool_executor.py +239 -0
- swarm/types.py +126 -0
- swarm/urls.py +89 -0
- swarm/util.py +124 -0
- swarm/utils/color_utils.py +40 -0
- swarm/utils/context_utils.py +272 -0
- swarm/utils/general_utils.py +162 -0
- swarm/utils/logger.py +61 -0
- swarm/utils/logger_setup.py +25 -0
- swarm/utils/message_sequence.py +173 -0
- swarm/utils/message_utils.py +95 -0
- swarm/utils/redact.py +68 -0
- swarm/views/__init__.py +41 -0
- swarm/views/api_views.py +46 -0
- swarm/views/chat_views.py +76 -0
- swarm/views/core_views.py +118 -0
- swarm/views/message_views.py +40 -0
- swarm/views/model_views.py +135 -0
- swarm/views/utils.py +457 -0
- swarm/views/web_views.py +149 -0
- swarm/wsgi.py +16 -0
swarm/consumers.py
ADDED
@@ -0,0 +1,141 @@
|
|
1
|
+
import json
|
2
|
+
import os
|
3
|
+
import uuid
|
4
|
+
from channels.generic.websocket import AsyncWebsocketConsumer
|
5
|
+
from openai import AsyncOpenAI
|
6
|
+
from django.template.loader import render_to_string
|
7
|
+
from channels.db import database_sync_to_async
|
8
|
+
from swarm.models import ChatConversation, ChatMessage
|
9
|
+
|
10
|
+
# In-memory conversation storage (populated lazily)
|
11
|
+
IN_MEMORY_CONVERSATIONS = {}
|
12
|
+
|
13
|
+
class DjangoChatConsumer(AsyncWebsocketConsumer):
|
14
|
+
async def connect(self):
|
15
|
+
self.user = self.scope["user"]
|
16
|
+
self.conversation_id = self.scope['url_route']['kwargs']['conversation_id']
|
17
|
+
|
18
|
+
if self.user.is_authenticated:
|
19
|
+
self.messages = await self.fetch_conversation(self.conversation_id)
|
20
|
+
await self.accept()
|
21
|
+
else:
|
22
|
+
await self.close()
|
23
|
+
|
24
|
+
async def disconnect(self, close_code):
|
25
|
+
if self.user.is_authenticated:
|
26
|
+
await self.save_conversation(self.conversation_id, self.messages)
|
27
|
+
|
28
|
+
# Delete conversation from DB and memory if empty
|
29
|
+
if not self.messages:
|
30
|
+
await self.delete_conversation(self.conversation_id)
|
31
|
+
|
32
|
+
# Clean up in-memory cache to avoid leaks
|
33
|
+
if self.conversation_id in IN_MEMORY_CONVERSATIONS:
|
34
|
+
del IN_MEMORY_CONVERSATIONS[self.conversation_id]
|
35
|
+
|
36
|
+
async def receive(self, text_data):
|
37
|
+
text_data_json = json.loads(text_data)
|
38
|
+
message_text = text_data_json["message"]
|
39
|
+
|
40
|
+
if not message_text.strip():
|
41
|
+
return
|
42
|
+
|
43
|
+
self.messages.append(
|
44
|
+
{
|
45
|
+
"role": "user",
|
46
|
+
"content": message_text,
|
47
|
+
}
|
48
|
+
)
|
49
|
+
|
50
|
+
user_message_html = render_to_string(
|
51
|
+
"websocket_partials/user_message.html",
|
52
|
+
{"message_text": message_text},
|
53
|
+
)
|
54
|
+
await self.send(text_data=user_message_html)
|
55
|
+
|
56
|
+
message_id = uuid.uuid4().hex
|
57
|
+
contents_div_id = f"message-response-{message_id}"
|
58
|
+
system_message_html = render_to_string(
|
59
|
+
"websocket_partials/system_message.html",
|
60
|
+
{"contents_div_id": contents_div_id},
|
61
|
+
)
|
62
|
+
await self.send(text_data=system_message_html)
|
63
|
+
|
64
|
+
client = AsyncOpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
65
|
+
stream = await client.chat.completions.create(
|
66
|
+
model=os.getenv("OPENAI_MODEL"),
|
67
|
+
messages=self.messages,
|
68
|
+
stream=True,
|
69
|
+
)
|
70
|
+
|
71
|
+
full_message = ""
|
72
|
+
async for chunk in stream:
|
73
|
+
message_chunk = chunk.choices[0].delta.content
|
74
|
+
if message_chunk:
|
75
|
+
full_message += message_chunk
|
76
|
+
chunk_html = f'<div hx-swap-oob="beforeend:#{contents_div_id}">{message_chunk}</div>'
|
77
|
+
await self.send(text_data=chunk_html)
|
78
|
+
|
79
|
+
self.messages.append(
|
80
|
+
{
|
81
|
+
"role": "assistant",
|
82
|
+
"content": full_message,
|
83
|
+
}
|
84
|
+
)
|
85
|
+
|
86
|
+
final_message = render_to_string(
|
87
|
+
"websocket_partials/final_system_message.html",
|
88
|
+
{
|
89
|
+
"contents_div_id": contents_div_id,
|
90
|
+
"message": full_message,
|
91
|
+
},
|
92
|
+
)
|
93
|
+
await client.close()
|
94
|
+
await self.send(text_data=final_message)
|
95
|
+
|
96
|
+
@database_sync_to_async
|
97
|
+
def fetch_conversation(self, conversation_id):
|
98
|
+
"""
|
99
|
+
Fetch conversation messages from memory or DB. If missing from memory, load from DB.
|
100
|
+
"""
|
101
|
+
if conversation_id in IN_MEMORY_CONVERSATIONS:
|
102
|
+
return IN_MEMORY_CONVERSATIONS[conversation_id]
|
103
|
+
|
104
|
+
try:
|
105
|
+
chat = ChatConversation.objects.get(conversation_id=conversation_id, user=self.user)
|
106
|
+
messages = list(chat.messages.values("sender", "content", "timestamp"))
|
107
|
+
IN_MEMORY_CONVERSATIONS[conversation_id] = messages # Cache it
|
108
|
+
return messages
|
109
|
+
except ChatConversation.DoesNotExist:
|
110
|
+
return []
|
111
|
+
|
112
|
+
@database_sync_to_async
|
113
|
+
def save_conversation(self, conversation_id, new_messages):
|
114
|
+
"""
|
115
|
+
Save messages to the DB and update in-memory cache.
|
116
|
+
"""
|
117
|
+
chat, _ = ChatConversation.objects.get_or_create(conversation_id=conversation_id, user=self.user)
|
118
|
+
|
119
|
+
for message in new_messages:
|
120
|
+
ChatMessage.objects.create(
|
121
|
+
conversation=chat,
|
122
|
+
sender=message["role"],
|
123
|
+
content=message["content"]
|
124
|
+
)
|
125
|
+
|
126
|
+
# Sync in-memory store
|
127
|
+
IN_MEMORY_CONVERSATIONS[conversation_id] = new_messages
|
128
|
+
|
129
|
+
@database_sync_to_async
|
130
|
+
def delete_conversation(self, conversation_id):
|
131
|
+
"""
|
132
|
+
Delete the conversation from DB if empty.
|
133
|
+
"""
|
134
|
+
try:
|
135
|
+
chat = ChatConversation.objects.get(conversation_id=conversation_id, user=self.user)
|
136
|
+
if not chat.messages.exists(): # Check if there are any messages before deleting
|
137
|
+
chat.delete()
|
138
|
+
if conversation_id in IN_MEMORY_CONVERSATIONS:
|
139
|
+
del IN_MEMORY_CONVERSATIONS[conversation_id] # Cleanup memory cache
|
140
|
+
except ChatConversation.DoesNotExist:
|
141
|
+
pass
|
swarm/core.py
ADDED
@@ -0,0 +1,326 @@
|
|
1
|
+
import os
|
2
|
+
import json
|
3
|
+
import logging
|
4
|
+
import asyncio
|
5
|
+
import re # Import re for tool name validation in provider
|
6
|
+
from typing import List, Dict, Optional, Union, AsyncGenerator, Any, Callable
|
7
|
+
from openai import AsyncOpenAI, OpenAIError
|
8
|
+
import uuid
|
9
|
+
|
10
|
+
from .types import Agent, LLMConfig, Response, ToolCall, ToolResult, ChatMessage, Tool
|
11
|
+
from .settings import Settings
|
12
|
+
from .extensions.config.config_loader import load_server_config, load_llm_config, get_server_params # Import load_server_config
|
13
|
+
from .utils.redact import redact_sensitive_data
|
14
|
+
from .llm.chat_completion import get_chat_completion_message
|
15
|
+
from .extensions.mcp.mcp_tool_provider import MCPToolProvider
|
16
|
+
from .utils.context_utils import get_token_count
|
17
|
+
|
18
|
+
settings = Settings()
|
19
|
+
logger = logging.getLogger(__name__)
|
20
|
+
logger.setLevel(settings.log_level.upper())
|
21
|
+
if not logger.handlers and not logging.getLogger().handlers:
|
22
|
+
log_handler = logging.StreamHandler()
|
23
|
+
formatter = logging.Formatter(settings.log_format.value)
|
24
|
+
log_handler.setFormatter(formatter)
|
25
|
+
logger.addHandler(log_handler)
|
26
|
+
|
27
|
+
logger.debug(f"Swarm Core initialized with log level: {settings.log_level.upper()}")
|
28
|
+
|
29
|
+
# --- FIX: Define correct separator ---
|
30
|
+
MCP_TOOL_SEPARATOR = "__"
|
31
|
+
|
32
|
+
# --- Helper Function: Discover and Merge Agent Tools ---
|
33
|
+
async def discover_and_merge_agent_tools(agent: Agent, config: Dict[str, Any], timeout: int, debug: bool) -> List[Tool]:
|
34
|
+
"""
|
35
|
+
Discovers tools from MCP servers listed in agent.mcp_servers and merges
|
36
|
+
them with the agent's static functions. Returns a list of Tool objects.
|
37
|
+
"""
|
38
|
+
merged_tools: Dict[str, Tool] = {}
|
39
|
+
|
40
|
+
# 1. Process static functions
|
41
|
+
if hasattr(agent, 'functions') and agent.functions:
|
42
|
+
for func in agent.functions:
|
43
|
+
if isinstance(func, Tool):
|
44
|
+
if func.name in merged_tools: logger.warning(f"Duplicate tool name '{func.name}'. Overwriting.")
|
45
|
+
merged_tools[func.name] = func
|
46
|
+
elif callable(func):
|
47
|
+
tool_name = getattr(func, '__name__', f'callable_{uuid.uuid4().hex[:6]}')
|
48
|
+
if not re.match(r"^[a-zA-Z0-9_-]{1,64}$", tool_name): # Validate static tool name
|
49
|
+
logger.warning(f"Static function name '{tool_name}' violates OpenAI pattern. Skipping.")
|
50
|
+
continue
|
51
|
+
if tool_name in merged_tools: logger.warning(f"Duplicate static tool name '{tool_name}'. Overwriting.")
|
52
|
+
|
53
|
+
docstring = getattr(func, '__doc__', None)
|
54
|
+
description = docstring.strip() if docstring else f"Executes the {tool_name} function."
|
55
|
+
|
56
|
+
input_schema = {"type": "object", "properties": {}}
|
57
|
+
merged_tools[tool_name] = Tool(name=tool_name, func=func, description=description, input_schema=input_schema)
|
58
|
+
else: logger.warning(f"Ignoring non-callable item in agent functions list: {func}")
|
59
|
+
logger.debug(f"Agent '{agent.name}': Processed {len(merged_tools)} static tools.")
|
60
|
+
|
61
|
+
# 2. Discover tools from MCP servers
|
62
|
+
if agent.mcp_servers:
|
63
|
+
mcp_server_configs = config.get("mcpServers", {})
|
64
|
+
discovery_tasks = []
|
65
|
+
for server_name in agent.mcp_servers:
|
66
|
+
if server_name not in mcp_server_configs:
|
67
|
+
logger.warning(f"Config for MCP server '{server_name}' for agent '{agent.name}' not found. Skipping.")
|
68
|
+
continue
|
69
|
+
server_config = mcp_server_configs[server_name]
|
70
|
+
if not get_server_params(server_config, server_name):
|
71
|
+
logger.error(f"Invalid config for MCP server '{server_name}'. Cannot discover.")
|
72
|
+
continue
|
73
|
+
try:
|
74
|
+
provider = MCPToolProvider.get_instance(server_name=server_name, server_config=server_config, timeout=timeout, debug=debug)
|
75
|
+
if provider.client: discovery_tasks.append(provider.discover_tools(agent))
|
76
|
+
else: logger.error(f"MCPClient failed init for '{server_name}'.")
|
77
|
+
except Exception as e: logger.error(f"Error getting MCP instance for '{server_name}': {e}", exc_info=True)
|
78
|
+
|
79
|
+
if discovery_tasks:
|
80
|
+
logger.debug(f"Awaiting discovery from {len(discovery_tasks)} MCP providers.")
|
81
|
+
results = await asyncio.gather(*discovery_tasks, return_exceptions=True)
|
82
|
+
for result in results:
|
83
|
+
if isinstance(result, Exception): logger.error(f"MCP discovery error: {result}")
|
84
|
+
elif isinstance(result, list):
|
85
|
+
for mcp_tool in result:
|
86
|
+
if mcp_tool.name in merged_tools: logger.warning(f"Duplicate tool name '{mcp_tool.name}' (MCP vs Static/Other MCP). Overwriting.")
|
87
|
+
merged_tools[mcp_tool.name] = mcp_tool # Name already prefixed by provider
|
88
|
+
else: logger.warning(f"Unexpected result type during MCP discovery: {type(result)}")
|
89
|
+
|
90
|
+
final_tool_list = list(merged_tools.values())
|
91
|
+
logger.info(f"Agent '{agent.name}': Final merged tool count: {len(final_tool_list)}")
|
92
|
+
if debug: logger.debug(f"Agent '{agent.name}': Final tools: {[t.name for t in final_tool_list]}")
|
93
|
+
return final_tool_list
|
94
|
+
|
95
|
+
# --- Helper Function: Format Tools for LLM ---
|
96
|
+
def format_tools_for_llm(tools: List[Tool]) -> List[Dict[str, Any]]:
|
97
|
+
"""Formats the Tool list into the structure expected by OpenAI API."""
|
98
|
+
if not tools: return []
|
99
|
+
formatted = []
|
100
|
+
for tool in tools:
|
101
|
+
parameters = tool.input_schema or {"type": "object", "properties": {}}
|
102
|
+
if not isinstance(parameters, dict) or "type" not in parameters:
|
103
|
+
logger.warning(f"Invalid schema for tool '{tool.name}'. Using default. Schema: {parameters}")
|
104
|
+
parameters = {"type": "object", "properties": {}}
|
105
|
+
elif parameters.get("type") == "object" and "properties" not in parameters:
|
106
|
+
parameters["properties"] = {}
|
107
|
+
|
108
|
+
# Validate tool name again before formatting
|
109
|
+
if not re.match(r"^[a-zA-Z0-9_-]{1,64}$", tool.name):
|
110
|
+
logger.error(f"Tool name '{tool.name}' is invalid for OpenAI API. Skipping.")
|
111
|
+
continue
|
112
|
+
|
113
|
+
formatted.append({
|
114
|
+
"type": "function",
|
115
|
+
"function": {
|
116
|
+
"name": tool.name,
|
117
|
+
"description": tool.description or f"Executes the {tool.name} tool.",
|
118
|
+
"parameters": parameters,
|
119
|
+
},
|
120
|
+
})
|
121
|
+
return formatted
|
122
|
+
|
123
|
+
# --- Swarm Class ---
|
124
|
+
class Swarm:
|
125
|
+
def __init__(
|
126
|
+
self,
|
127
|
+
llm_profile: str = "default",
|
128
|
+
config: Optional[dict] = None,
|
129
|
+
api_key: Optional[str] = None,
|
130
|
+
base_url: Optional[str] = None,
|
131
|
+
model: Optional[str] = None,
|
132
|
+
agents: Optional[Dict[str, Agent]] = None,
|
133
|
+
max_context_tokens: int = 8000,
|
134
|
+
max_context_messages: int = 50,
|
135
|
+
max_tool_response_tokens: int = 4096,
|
136
|
+
max_total_tool_response_tokens: int = 16384,
|
137
|
+
max_tool_calls_per_turn: int = 10,
|
138
|
+
tool_execution_timeout: int = 120,
|
139
|
+
tool_discovery_timeout: int = 15,
|
140
|
+
debug: bool = False,
|
141
|
+
):
|
142
|
+
self.debug = debug or settings.debug
|
143
|
+
if self.debug: logger.setLevel(logging.DEBUG); [h.setLevel(logging.DEBUG) for h in logging.getLogger().handlers if hasattr(h, 'setLevel')] ; logger.debug("Debug mode enabled.")
|
144
|
+
self.tool_execution_timeout = tool_execution_timeout
|
145
|
+
self.tool_discovery_timeout = tool_discovery_timeout
|
146
|
+
self.agents = agents or {}; logger.debug(f"Initial agents: {list(self.agents.keys())}")
|
147
|
+
# Load config if not provided
|
148
|
+
self.config = config if config is not None else load_server_config()
|
149
|
+
logger.debug(f"INIT START: Received api_key arg: {'****' if api_key else 'None'}")
|
150
|
+
|
151
|
+
llm_profile_name = os.getenv("DEFAULT_LLM", llm_profile)
|
152
|
+
logger.debug(f"INIT: Using LLM profile name: '{llm_profile_name}'")
|
153
|
+
try:
|
154
|
+
loaded_config_dict = load_llm_config(self.config, llm_profile_name)
|
155
|
+
except Exception as e: logger.critical(f"INIT: Failed to load config for profile '{llm_profile_name}': {e}", exc_info=True); raise
|
156
|
+
|
157
|
+
final_config = loaded_config_dict.copy(); log_key_source = final_config.get("_log_key_source", "load_llm_config")
|
158
|
+
if api_key is not None: final_config['api_key'] = api_key; log_key_source = "__init__ arg"
|
159
|
+
if base_url is not None: final_config['base_url'] = base_url
|
160
|
+
if model is not None: final_config['model'] = model
|
161
|
+
self.current_llm_config = final_config; self.model = self.current_llm_config.get("model"); self.provider = self.current_llm_config.get("provider")
|
162
|
+
|
163
|
+
self.max_context_tokens=max_context_tokens; self.max_context_messages=max_context_messages
|
164
|
+
self.max_tool_response_tokens=max_tool_response_tokens; self.max_total_tool_response_tokens=max_total_tool_response_tokens
|
165
|
+
self.max_tool_calls_per_turn=max_tool_calls_per_turn
|
166
|
+
|
167
|
+
client_kwargs = {"api_key": self.current_llm_config.get("api_key"), "base_url": self.current_llm_config.get("base_url")}
|
168
|
+
client_kwargs = {k: v for k, v in client_kwargs.items() if v is not None}
|
169
|
+
try:
|
170
|
+
self.client = AsyncOpenAI(**client_kwargs)
|
171
|
+
final_api_key_used = self.current_llm_config.get("api_key")
|
172
|
+
logger.info(f"Swarm initialized. LLM Profile: '{llm_profile_name}', Model: '{self.model}', Key Source: {log_key_source}, Key Used: {'****' if final_api_key_used else 'None'}")
|
173
|
+
if self.debug: logger.debug(f"AsyncOpenAI client kwargs: {redact_sensitive_data(client_kwargs)}")
|
174
|
+
except Exception as e: logger.critical(f"Failed to initialize OpenAI client: {e}", exc_info=True); raise
|
175
|
+
self._agent_tools: Dict[str, List[Tool]] = {}
|
176
|
+
|
177
|
+
def register_agent(self, agent: Agent):
|
178
|
+
if agent.name in self.agents: logger.warning(f"Agent '{agent.name}' already registered. Overwriting.")
|
179
|
+
self.agents[agent.name] = agent; logger.info(f"Agent '{agent.name}' registered.")
|
180
|
+
if agent.name in self._agent_tools: del self._agent_tools[agent.name]
|
181
|
+
if self.debug: logger.debug(f"Agent details: {agent}")
|
182
|
+
|
183
|
+
async def _get_agent_tools(self, agent: Agent) -> List[Tool]:
|
184
|
+
if agent.name not in self._agent_tools:
|
185
|
+
logger.debug(f"Tools cache miss for agent '{agent.name}'. Discovering...")
|
186
|
+
self._agent_tools[agent.name] = await discover_and_merge_agent_tools(agent, self.config, self.tool_discovery_timeout, self.debug)
|
187
|
+
return self._agent_tools[agent.name]
|
188
|
+
|
189
|
+
async def _execute_tool_call(self, agent: Agent, tool_call: ToolCall, context_variables: Dict[str, Any]) -> ToolResult:
|
190
|
+
"""Executes a single tool call, handling static and MCP tools."""
|
191
|
+
function_name = tool_call.function.name # This is the name LLM used (could be prefixed)
|
192
|
+
tool_call_id = tool_call.id
|
193
|
+
logger.info(f"Executing tool call '{function_name}' (ID: {tool_call_id}) for agent '{agent.name}'.")
|
194
|
+
arguments = {}
|
195
|
+
content = f"Error: Tool '{function_name}' execution failed."
|
196
|
+
|
197
|
+
try:
|
198
|
+
args_raw = tool_call.function.arguments
|
199
|
+
arguments = json.loads(args_raw) if isinstance(args_raw, str) else args_raw
|
200
|
+
if not isinstance(arguments, dict):
|
201
|
+
logger.error(f"Parsed tool args for {function_name} not dict: {type(arguments)}. Args: {args_raw}")
|
202
|
+
raise ValueError("Tool arguments must be a JSON object.")
|
203
|
+
except json.JSONDecodeError as e:
|
204
|
+
logger.error(f"JSONDecodeError parsing args for {function_name}: {e}. Args: {args_raw}")
|
205
|
+
content = f"Error: Invalid JSON args for '{function_name}': {e}"
|
206
|
+
except ValueError as e: # Catch the explicit error from above
|
207
|
+
content = str(e)
|
208
|
+
except Exception as e:
|
209
|
+
logger.error(f"Error processing args for {function_name}: {e}", exc_info=True)
|
210
|
+
content = f"Error processing args for '{function_name}'."
|
211
|
+
|
212
|
+
tool_executed = False
|
213
|
+
if isinstance(arguments, dict): # Proceed only if args are valid
|
214
|
+
agent_tools = await self._get_agent_tools(agent)
|
215
|
+
target_tool: Optional[Tool] = next((t for t in agent_tools if t.name == function_name), None)
|
216
|
+
|
217
|
+
if target_tool and callable(target_tool.func):
|
218
|
+
tool_executed = True
|
219
|
+
logger.debug(f"Found tool '{function_name}'. Executing...")
|
220
|
+
try:
|
221
|
+
if asyncio.iscoroutinefunction(target_tool.func):
|
222
|
+
result = await asyncio.wait_for(target_tool.func(**arguments), timeout=self.tool_execution_timeout)
|
223
|
+
else:
|
224
|
+
# Consider running sync functions in threadpool executor
|
225
|
+
result = target_tool.func(**arguments)
|
226
|
+
|
227
|
+
# Process result
|
228
|
+
if isinstance(result, Agent):
|
229
|
+
logger.info(f"Handoff signal: Result is Agent '{result.name}'.")
|
230
|
+
# --- FIX: Use correct separator ---
|
231
|
+
content = f"HANDOFF{MCP_TOOL_SEPARATOR}{result.name}"
|
232
|
+
elif isinstance(result, (dict, list, tuple)): content = json.dumps(result, default=str)
|
233
|
+
elif result is None: content = "Tool executed successfully with no return value."
|
234
|
+
else: content = str(result)
|
235
|
+
logger.debug(f"Tool '{function_name}' executed. Raw result type: {type(result)}")
|
236
|
+
|
237
|
+
except asyncio.TimeoutError:
|
238
|
+
logger.error(f"Timeout executing tool '{function_name}'.")
|
239
|
+
content = f"Error: Tool '{function_name}' timed out ({self.tool_execution_timeout}s)."
|
240
|
+
except Exception as e:
|
241
|
+
logger.error(f"Error executing tool {function_name}: {e}", exc_info=True)
|
242
|
+
content = f"Error: Tool '{function_name}' failed: {e}"
|
243
|
+
# else: Tool not found error handled below
|
244
|
+
|
245
|
+
if not tool_executed and isinstance(arguments, dict):
|
246
|
+
logger.error(f"Tool '{function_name}' not found for agent '{agent.name}'. Available: {[t.name for t in await self._get_agent_tools(agent)]}")
|
247
|
+
content = f"Error: Tool '{function_name}' not available for agent '{agent.name}'."
|
248
|
+
|
249
|
+
# Truncation
|
250
|
+
# --- FIX: Use correct separator ---
|
251
|
+
if isinstance(content, str) and not content.startswith(f"HANDOFF{MCP_TOOL_SEPARATOR}"):
|
252
|
+
token_count = get_token_count(content, self.current_llm_config.get("model"))
|
253
|
+
if token_count > self.max_tool_response_tokens:
|
254
|
+
logger.warning(f"Truncating tool response '{function_name}'. Size: {token_count} > Limit: {self.max_tool_response_tokens}")
|
255
|
+
content = content[:self.max_tool_response_tokens * 4] + "... (truncated)"
|
256
|
+
|
257
|
+
return ToolResult(tool_call_id=tool_call_id, name=function_name, content=content)
|
258
|
+
|
259
|
+
async def _run_non_streaming(self, agent: Agent, messages: List[Dict[str, Any]], context_variables: Optional[Dict[str, Any]] = None, max_turns: int = 10, debug: bool = False) -> Response:
|
260
|
+
current_agent = agent; history = list(messages); context_vars = context_variables.copy() if context_variables else {}; turn = 0
|
261
|
+
while turn < max_turns:
|
262
|
+
turn += 1; logger.debug(f"Turn {turn} starting with agent '{current_agent.name}'.")
|
263
|
+
agent_tools = await self._get_agent_tools(current_agent); formatted_tools = format_tools_for_llm(agent_tools)
|
264
|
+
if debug and formatted_tools: logger.debug(f"Tools for '{current_agent.name}': {[t['function']['name'] for t in formatted_tools]}")
|
265
|
+
try:
|
266
|
+
ai_message_dict = await get_chat_completion_message(client=self.client, agent=current_agent, history=history, context_variables=context_vars, current_llm_config=self.current_llm_config, max_context_tokens=self.max_context_tokens, max_context_messages=self.max_context_messages, tools=formatted_tools or None, tool_choice="auto" if formatted_tools else None, stream=False, debug=debug)
|
267
|
+
ai_message_dict["sender"] = current_agent.name; history.append(ai_message_dict)
|
268
|
+
tool_calls_raw = ai_message_dict.get("tool_calls")
|
269
|
+
if tool_calls_raw:
|
270
|
+
if not isinstance(tool_calls_raw, list): tool_calls_raw = []
|
271
|
+
logger.info(f"Agent '{current_agent.name}' requested {len(tool_calls_raw)} tool calls.")
|
272
|
+
tool_calls_to_execute = []
|
273
|
+
for tc_raw in tool_calls_raw[:self.max_tool_calls_per_turn]:
|
274
|
+
try:
|
275
|
+
if isinstance(tc_raw, dict) and 'function' in tc_raw and isinstance(tc_raw['function'], dict) and 'name' in tc_raw['function'] and 'arguments' in tc_raw['function']: tool_calls_to_execute.append(ToolCall(**tc_raw))
|
276
|
+
else: logger.warning(f"Skipping malformed tool call: {tc_raw}")
|
277
|
+
except Exception as p_err: logger.warning(f"Skipping tool call validation error: {p_err}. Raw: {tc_raw}")
|
278
|
+
if len(tool_calls_raw) > self.max_tool_calls_per_turn: logger.warning(f"Clamping tool calls to {self.max_tool_calls_per_turn}.")
|
279
|
+
|
280
|
+
tool_tasks = [self._execute_tool_call(current_agent, tc, context_vars) for tc in tool_calls_to_execute]
|
281
|
+
tool_results: List[ToolResult] = await asyncio.gather(*tool_tasks)
|
282
|
+
next_agent_name_from_handoff = None; total_tool_response_tokens = 0
|
283
|
+
for result in tool_results:
|
284
|
+
history.append(result.model_dump(exclude_none=True)); content = result.content
|
285
|
+
if isinstance(content, str):
|
286
|
+
# --- FIX: Use correct separator ---
|
287
|
+
if content.startswith(f"HANDOFF{MCP_TOOL_SEPARATOR}"):
|
288
|
+
parts = content.split(MCP_TOOL_SEPARATOR, 1); potential_next_agent = parts[1] if len(parts) > 1 else None
|
289
|
+
if potential_next_agent and potential_next_agent in self.agents:
|
290
|
+
if not next_agent_name_from_handoff: next_agent_name_from_handoff = potential_next_agent; logger.info(f"Handoff to '{next_agent_name_from_handoff}' confirmed.")
|
291
|
+
elif next_agent_name_from_handoff != potential_next_agent: logger.warning(f"Multiple handoffs requested. Using first '{next_agent_name_from_handoff}'.")
|
292
|
+
else: logger.warning(f"Handoff to unknown agent '{potential_next_agent}'. Ignoring.")
|
293
|
+
else: total_tool_response_tokens += get_token_count(content, self.current_llm_config.get("model"))
|
294
|
+
if total_tool_response_tokens > self.max_total_tool_response_tokens: logger.error(f"Total tool tokens ({total_tool_response_tokens}) exceeded limit. Ending run."); history.append({"role": "assistant", "sender": "System", "content": "[System Error: Tool responses token limit exceeded.]"}); break
|
295
|
+
if next_agent_name_from_handoff: current_agent = self.agents[next_agent_name_from_handoff]; context_vars["active_agent_name"] = current_agent.name; logger.debug(f"Activating agent '{current_agent.name}'."); continue
|
296
|
+
else: continue
|
297
|
+
else: break # No tool calls, end interaction
|
298
|
+
except OpenAIError as e: logger.error(f"API error turn {turn} for '{current_agent.name}': {e}", exc_info=True); history.append({"role": "assistant", "sender": "System", "content": f"[System Error: API call failed]"}); break
|
299
|
+
except Exception as e: logger.error(f"Unexpected error turn {turn} for '{current_agent.name}': {e}", exc_info=True); history.append({"role": "assistant", "sender": "System", "content": f"[System Error: Unexpected error]"}); break
|
300
|
+
if turn >= max_turns: logger.warning(f"Reached max turns ({max_turns}).")
|
301
|
+
logger.debug(f"Non-streaming run completed. Turns={turn}, History Messages={len(history)}.")
|
302
|
+
final_messages_raw = history[len(messages):]; final_messages_typed = [ChatMessage(**msg) for msg in final_messages_raw if isinstance(msg, dict)]
|
303
|
+
response_id = f"response-{uuid.uuid4()}"
|
304
|
+
return Response(id=response_id, messages=final_messages_typed, agent=current_agent, context_variables=context_vars)
|
305
|
+
|
306
|
+
async def _run_streaming(self, agent: Agent, messages: List[Dict[str, Any]], context_variables: Optional[Dict[str, Any]] = None, max_turns: int = 10, debug: bool = False) -> AsyncGenerator[Dict[str, Any], None]:
|
307
|
+
current_agent = agent; history = list(messages); context_vars = context_variables.copy() if context_variables else {}; logger.debug(f"Streaming run starting for '{current_agent.name}'. (Tool exec/handoff N/A)")
|
308
|
+
agent_tools = await self._get_agent_tools(current_agent); formatted_tools = format_tools_for_llm(agent_tools)
|
309
|
+
if debug and formatted_tools: logger.debug(f"Tools for '{current_agent.name}' (streaming): {[t['function']['name'] for t in formatted_tools]}")
|
310
|
+
try:
|
311
|
+
stream_generator = get_chat_completion_message(client=self.client, agent=current_agent, history=history, context_variables=context_vars, current_llm_config=self.current_llm_config, max_context_tokens=self.max_context_tokens, max_context_messages=self.max_context_messages, tools=formatted_tools or None, tool_choice="auto" if formatted_tools else None, stream=True, debug=debug)
|
312
|
+
async for chunk in stream_generator: yield chunk
|
313
|
+
logger.warning("Tool calls/handoffs not processed in streaming.")
|
314
|
+
except OpenAIError as e: logger.error(f"API error stream for '{current_agent.name}': {e}", exc_info=True); yield {"error": f"API call failed: {str(e)}"}
|
315
|
+
except Exception as e: logger.error(f"Error stream for '{current_agent.name}': {e}", exc_info=True); yield {"error": f"Unexpected error: {str(e)}"}
|
316
|
+
logger.debug(f"Streaming run finished for '{current_agent.name}'.")
|
317
|
+
|
318
|
+
async def run(self, agent: Agent, messages: List[Dict[str, Any]], context_variables: Optional[Dict[str, Any]] = None, max_turns: int = 10, stream: bool = False, debug: bool = False) -> Union[Response, AsyncGenerator[Dict[str, Any], None]]:
|
319
|
+
effective_debug = debug or self.debug
|
320
|
+
if effective_debug != logger.isEnabledFor(logging.DEBUG):
|
321
|
+
new_level = logging.DEBUG if effective_debug else settings.log_level.upper(); logger.setLevel(new_level); [h.setLevel(new_level) for h in logger.handlers]; logger.debug(f"Log level set to {new_level}.")
|
322
|
+
if not agent: raise ValueError("Agent cannot be None")
|
323
|
+
if not isinstance(messages, list): raise TypeError("Messages must be a list")
|
324
|
+
logger.info(f"Starting {'STREAMING' if stream else 'NON-STREAMING'} run with agent '{agent.name}'")
|
325
|
+
if stream: return self._run_streaming(agent, messages, context_variables, max_turns, effective_debug)
|
326
|
+
else: return await self._run_non_streaming(agent, messages, context_variables, max_turns, effective_debug)
|
@@ -0,0 +1 @@
|
|
1
|
+
# This is the __init__.py for the 'extensions' package.
|
@@ -0,0 +1,36 @@
|
|
1
|
+
"""
|
2
|
+
Blueprint discovery and management utilities.
|
3
|
+
"""
|
4
|
+
|
5
|
+
from .blueprint_base import BlueprintBase
|
6
|
+
from .blueprint_discovery import discover_blueprints
|
7
|
+
from .blueprint_utils import filter_blueprints
|
8
|
+
|
9
|
+
# Re-export the necessary message utilities from their new locations
|
10
|
+
# Note: The specific truncation functions like truncate_preserve_pairs might have been
|
11
|
+
# consolidated into truncate_message_history. Adjust if needed.
|
12
|
+
try:
|
13
|
+
from swarm.utils.message_sequence import repair_message_payload, validate_message_sequence
|
14
|
+
from swarm.utils.context_utils import truncate_message_history
|
15
|
+
# If specific old truncation functions are truly needed, they'd have to be
|
16
|
+
# re-implemented or their callers refactored to use truncate_message_history.
|
17
|
+
# Assuming truncate_message_history is the intended replacement for now.
|
18
|
+
# Define aliases if old names are required by downstream code:
|
19
|
+
# truncate_preserve_pairs = truncate_message_history # Example if needed
|
20
|
+
except ImportError as e:
|
21
|
+
# Log an error or warning if imports fail, helpful for debugging setup issues
|
22
|
+
import logging
|
23
|
+
logging.getLogger(__name__).error(f"Failed to import core message utilities: {e}")
|
24
|
+
# Define dummy functions or raise error if critical
|
25
|
+
def repair_message_payload(m, **kwargs): return m
|
26
|
+
def validate_message_sequence(m): return m
|
27
|
+
def truncate_message_history(m, *args, **kwargs): return m
|
28
|
+
|
29
|
+
__all__ = [
|
30
|
+
"BlueprintBase",
|
31
|
+
"discover_blueprints",
|
32
|
+
"filter_blueprints",
|
33
|
+
"repair_message_payload",
|
34
|
+
"validate_message_sequence",
|
35
|
+
"truncate_message_history",
|
36
|
+
]
|
@@ -0,0 +1,45 @@
|
|
1
|
+
"""
|
2
|
+
Agent utility functions for Swarm blueprints
|
3
|
+
"""
|
4
|
+
|
5
|
+
import logging
|
6
|
+
import os
|
7
|
+
from typing import Dict, List, Any, Callable, Optional
|
8
|
+
import asyncio
|
9
|
+
from swarm.types import Agent
|
10
|
+
|
11
|
+
logger = logging.getLogger(__name__)
|
12
|
+
|
13
|
+
def get_agent_name(agent: Agent) -> str:
|
14
|
+
"""Extract an agent's name, defaulting to its class name if not explicitly set."""
|
15
|
+
return getattr(agent, 'name', agent.__class__.__name__)
|
16
|
+
|
17
|
+
async def discover_tools_for_agent(agent: Agent, blueprint: Any) -> List[Any]:
|
18
|
+
"""Asynchronously discover tools available for an agent within a blueprint."""
|
19
|
+
return getattr(blueprint, '_discovered_tools', {}).get(get_agent_name(agent), [])
|
20
|
+
|
21
|
+
async def discover_resources_for_agent(agent: Agent, blueprint: Any) -> List[Any]:
|
22
|
+
"""Asynchronously discover resources available for an agent within a blueprint."""
|
23
|
+
return getattr(blueprint, '_discovered_resources', {}).get(get_agent_name(agent), [])
|
24
|
+
|
25
|
+
def initialize_agents(blueprint: Any) -> None:
|
26
|
+
"""Initialize agents defined in the blueprint's create_agents method."""
|
27
|
+
if not callable(getattr(blueprint, 'create_agents', None)):
|
28
|
+
logger.error(f"Blueprint {blueprint.__class__.__name__} has no callable create_agents method.")
|
29
|
+
return
|
30
|
+
|
31
|
+
agents = blueprint.create_agents()
|
32
|
+
if not isinstance(agents, dict):
|
33
|
+
logger.error(f"Blueprint {blueprint.__class__.__name__}.create_agents must return a dict, got {type(agents)}")
|
34
|
+
return
|
35
|
+
|
36
|
+
if hasattr(blueprint, 'swarm') and hasattr(blueprint.swarm, 'agents'):
|
37
|
+
blueprint.swarm.agents.update(agents)
|
38
|
+
else:
|
39
|
+
logger.error("Blueprint or its swarm instance lacks an 'agents' attribute to update.")
|
40
|
+
return
|
41
|
+
|
42
|
+
if not blueprint.starting_agent and agents:
|
43
|
+
first_agent_name = next(iter(agents.keys()))
|
44
|
+
blueprint.starting_agent = agents[first_agent_name]
|
45
|
+
logger.debug(f"Set default starting agent: {first_agent_name}")
|