open-swarm 0.1.1743070217__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- open_swarm-0.1.1743070217.dist-info/METADATA +258 -0
- open_swarm-0.1.1743070217.dist-info/RECORD +89 -0
- open_swarm-0.1.1743070217.dist-info/WHEEL +5 -0
- open_swarm-0.1.1743070217.dist-info/entry_points.txt +3 -0
- open_swarm-0.1.1743070217.dist-info/licenses/LICENSE +21 -0
- open_swarm-0.1.1743070217.dist-info/top_level.txt +1 -0
- swarm/__init__.py +3 -0
- swarm/agent/__init__.py +7 -0
- swarm/agent/agent.py +49 -0
- swarm/apps.py +53 -0
- swarm/auth.py +56 -0
- swarm/consumers.py +141 -0
- swarm/core.py +326 -0
- swarm/extensions/__init__.py +1 -0
- swarm/extensions/blueprint/__init__.py +36 -0
- swarm/extensions/blueprint/agent_utils.py +45 -0
- swarm/extensions/blueprint/blueprint_base.py +562 -0
- swarm/extensions/blueprint/blueprint_discovery.py +112 -0
- swarm/extensions/blueprint/blueprint_utils.py +17 -0
- swarm/extensions/blueprint/common_utils.py +12 -0
- swarm/extensions/blueprint/django_utils.py +203 -0
- swarm/extensions/blueprint/interactive_mode.py +102 -0
- swarm/extensions/blueprint/modes/rest_mode.py +37 -0
- swarm/extensions/blueprint/output_utils.py +95 -0
- swarm/extensions/blueprint/spinner.py +91 -0
- swarm/extensions/cli/__init__.py +0 -0
- swarm/extensions/cli/blueprint_runner.py +251 -0
- swarm/extensions/cli/cli_args.py +88 -0
- swarm/extensions/cli/commands/__init__.py +0 -0
- swarm/extensions/cli/commands/blueprint_management.py +31 -0
- swarm/extensions/cli/commands/config_management.py +15 -0
- swarm/extensions/cli/commands/edit_config.py +77 -0
- swarm/extensions/cli/commands/list_blueprints.py +22 -0
- swarm/extensions/cli/commands/validate_env.py +57 -0
- swarm/extensions/cli/commands/validate_envvars.py +39 -0
- swarm/extensions/cli/interactive_shell.py +41 -0
- swarm/extensions/cli/main.py +36 -0
- swarm/extensions/cli/selection.py +43 -0
- swarm/extensions/cli/utils/discover_commands.py +32 -0
- swarm/extensions/cli/utils/env_setup.py +15 -0
- swarm/extensions/cli/utils.py +105 -0
- swarm/extensions/config/__init__.py +6 -0
- swarm/extensions/config/config_loader.py +208 -0
- swarm/extensions/config/config_manager.py +258 -0
- swarm/extensions/config/server_config.py +49 -0
- swarm/extensions/config/setup_wizard.py +103 -0
- swarm/extensions/config/utils/__init__.py +0 -0
- swarm/extensions/config/utils/logger.py +36 -0
- swarm/extensions/launchers/__init__.py +1 -0
- swarm/extensions/launchers/build_launchers.py +14 -0
- swarm/extensions/launchers/build_swarm_wrapper.py +12 -0
- swarm/extensions/launchers/swarm_api.py +68 -0
- swarm/extensions/launchers/swarm_cli.py +304 -0
- swarm/extensions/launchers/swarm_wrapper.py +29 -0
- swarm/extensions/mcp/__init__.py +1 -0
- swarm/extensions/mcp/cache_utils.py +36 -0
- swarm/extensions/mcp/mcp_client.py +341 -0
- swarm/extensions/mcp/mcp_constants.py +7 -0
- swarm/extensions/mcp/mcp_tool_provider.py +110 -0
- swarm/llm/chat_completion.py +195 -0
- swarm/messages.py +132 -0
- swarm/migrations/0010_initial_chat_models.py +51 -0
- swarm/migrations/__init__.py +0 -0
- swarm/models.py +45 -0
- swarm/repl/__init__.py +1 -0
- swarm/repl/repl.py +87 -0
- swarm/serializers.py +12 -0
- swarm/settings.py +189 -0
- swarm/tool_executor.py +239 -0
- swarm/types.py +126 -0
- swarm/urls.py +89 -0
- swarm/util.py +124 -0
- swarm/utils/color_utils.py +40 -0
- swarm/utils/context_utils.py +272 -0
- swarm/utils/general_utils.py +162 -0
- swarm/utils/logger.py +61 -0
- swarm/utils/logger_setup.py +25 -0
- swarm/utils/message_sequence.py +173 -0
- swarm/utils/message_utils.py +95 -0
- swarm/utils/redact.py +68 -0
- swarm/views/__init__.py +41 -0
- swarm/views/api_views.py +46 -0
- swarm/views/chat_views.py +76 -0
- swarm/views/core_views.py +118 -0
- swarm/views/message_views.py +40 -0
- swarm/views/model_views.py +135 -0
- swarm/views/utils.py +457 -0
- swarm/views/web_views.py +149 -0
- swarm/wsgi.py +16 -0
@@ -0,0 +1,341 @@
|
|
1
|
+
"""
|
2
|
+
MCP Client Module
|
3
|
+
|
4
|
+
Manages connections and interactions with MCP servers using the MCP Python SDK.
|
5
|
+
Redirects MCP server stderr to log files unless debug mode is enabled.
|
6
|
+
"""
|
7
|
+
|
8
|
+
import asyncio
|
9
|
+
import logging
|
10
|
+
import os
|
11
|
+
from typing import Any, Dict, List, Callable
|
12
|
+
from contextlib import contextmanager
|
13
|
+
import sys
|
14
|
+
import json # Added for result parsing
|
15
|
+
|
16
|
+
# Attempt to import mcp types carefully
|
17
|
+
try:
|
18
|
+
from mcp import ClientSession, StdioServerParameters # type: ignore
|
19
|
+
from mcp.client.stdio import stdio_client # type: ignore
|
20
|
+
MCP_AVAILABLE = True
|
21
|
+
except ImportError:
|
22
|
+
MCP_AVAILABLE = False
|
23
|
+
# Define dummy classes if mcp is not installed
|
24
|
+
class ClientSession: pass
|
25
|
+
class StdioServerParameters: pass
|
26
|
+
def stdio_client(*args, **kwargs): raise ImportError("mcp library not installed")
|
27
|
+
|
28
|
+
from ...types import Tool # Use Tool from swarm.types
|
29
|
+
from .cache_utils import get_cache
|
30
|
+
from ...settings import Settings # Import Swarm settings
|
31
|
+
|
32
|
+
# Use Swarm's settings for logging configuration
|
33
|
+
swarm_settings = Settings()
|
34
|
+
logger = logging.getLogger(__name__)
|
35
|
+
logger.setLevel(swarm_settings.log_level.upper()) # Use log level from settings
|
36
|
+
# Ensure handler is added only if needed, respecting potential global config
|
37
|
+
if not logger.handlers and not logging.getLogger('swarm').handlers:
|
38
|
+
handler = logging.StreamHandler()
|
39
|
+
# Use log format from settings
|
40
|
+
formatter = logging.Formatter(swarm_settings.log_format.value)
|
41
|
+
handler.setFormatter(formatter)
|
42
|
+
logger.addHandler(handler)
|
43
|
+
|
44
|
+
class MCPClient:
|
45
|
+
"""
|
46
|
+
Manages connections and interactions with MCP servers using the MCP Python SDK.
|
47
|
+
"""
|
48
|
+
|
49
|
+
def __init__(self, server_config: Dict[str, Any], timeout: int = 15, debug: bool = False):
|
50
|
+
"""
|
51
|
+
Initialize the MCPClient with server configuration.
|
52
|
+
|
53
|
+
Args:
|
54
|
+
server_config (dict): Configuration dictionary for the MCP server.
|
55
|
+
timeout (int): Timeout for operations in seconds.
|
56
|
+
debug (bool): If True, MCP server stderr goes to console; otherwise, suppressed.
|
57
|
+
"""
|
58
|
+
if not MCP_AVAILABLE:
|
59
|
+
raise ImportError("The 'mcp-client' library is required for MCP functionality but is not installed.")
|
60
|
+
|
61
|
+
self.command = server_config.get("command", "npx")
|
62
|
+
self.args = server_config.get("args", [])
|
63
|
+
self.env = {**os.environ.copy(), **server_config.get("env", {})}
|
64
|
+
self.timeout = timeout
|
65
|
+
self.debug = debug or swarm_settings.debug # Use instance debug or global debug
|
66
|
+
self._tool_cache: Dict[str, Tool] = {}
|
67
|
+
self.cache = get_cache()
|
68
|
+
|
69
|
+
# Validate command and args types
|
70
|
+
if not isinstance(self.command, str):
|
71
|
+
raise TypeError(f"MCP server command must be a string, got {type(self.command)}")
|
72
|
+
if not isinstance(self.args, list) or not all(isinstance(a, str) for a in self.args):
|
73
|
+
raise TypeError(f"MCP server args must be a list of strings, got {self.args}")
|
74
|
+
|
75
|
+
|
76
|
+
logger.info(f"Initialized MCPClient with command={self.command}, args={self.args}, debug={self.debug}")
|
77
|
+
|
78
|
+
@contextmanager
|
79
|
+
def _redirect_stderr(self):
|
80
|
+
"""Redirects stderr to /dev/null if not in debug mode."""
|
81
|
+
if not self.debug:
|
82
|
+
original_stderr = sys.stderr
|
83
|
+
devnull = None
|
84
|
+
try:
|
85
|
+
devnull = open(os.devnull, "w")
|
86
|
+
sys.stderr = devnull
|
87
|
+
yield
|
88
|
+
except Exception:
|
89
|
+
# Restore stderr even if there was an error opening /dev/null or during yield
|
90
|
+
if devnull: devnull.close()
|
91
|
+
sys.stderr = original_stderr
|
92
|
+
raise # Re-raise the exception
|
93
|
+
finally:
|
94
|
+
if devnull: devnull.close()
|
95
|
+
sys.stderr = original_stderr
|
96
|
+
else:
|
97
|
+
# If debug is True, don't redirect
|
98
|
+
yield
|
99
|
+
|
100
|
+
async def list_tools(self) -> List[Tool]:
|
101
|
+
"""
|
102
|
+
Discover tools from the MCP server and cache their schemas.
|
103
|
+
|
104
|
+
Returns:
|
105
|
+
List[Tool]: A list of discovered tools with schemas.
|
106
|
+
"""
|
107
|
+
logger.debug(f"Entering list_tools for command={self.command}, args={self.args}")
|
108
|
+
|
109
|
+
# Attempt to retrieve tools from cache
|
110
|
+
# Create a more robust cache key
|
111
|
+
args_string = json.dumps(self.args, sort_keys=True) # Serialize args consistently
|
112
|
+
cache_key = f"mcp_tools_{self.command}_{args_string}"
|
113
|
+
cached_tools_data = self.cache.get(cache_key)
|
114
|
+
|
115
|
+
if cached_tools_data:
|
116
|
+
logger.debug("Retrieved tools data from cache")
|
117
|
+
tools = []
|
118
|
+
for tool_data in cached_tools_data:
|
119
|
+
tool_name = tool_data["name"]
|
120
|
+
# Create Tool instance, ensuring func is a callable wrapper
|
121
|
+
tool = Tool(
|
122
|
+
name=tool_name,
|
123
|
+
description=tool_data["description"],
|
124
|
+
input_schema=tool_data.get("input_schema", {"type": "object", "properties": {}}),
|
125
|
+
func=self._create_tool_callable(tool_name), # Use the factory method
|
126
|
+
)
|
127
|
+
self._tool_cache[tool_name] = tool # Store in instance cache too
|
128
|
+
tools.append(tool)
|
129
|
+
logger.debug(f"Returning {len(tools)} cached tools")
|
130
|
+
return tools
|
131
|
+
|
132
|
+
# If not in cache, discover from server
|
133
|
+
server_params = StdioServerParameters(command=self.command, args=self.args, env=self.env)
|
134
|
+
logger.debug("Opening stdio_client connection")
|
135
|
+
try:
|
136
|
+
async with stdio_client(server_params) as (read, write):
|
137
|
+
logger.debug("Opening ClientSession")
|
138
|
+
async with ClientSession(read, write) as session:
|
139
|
+
logger.info("Initializing session for tool discovery")
|
140
|
+
await asyncio.wait_for(session.initialize(), timeout=self.timeout)
|
141
|
+
logger.info("Requesting tool list from MCP server...")
|
142
|
+
tools_response = await asyncio.wait_for(session.list_tools(), timeout=self.timeout)
|
143
|
+
logger.debug(f"Tool list received: {tools_response}")
|
144
|
+
|
145
|
+
if not hasattr(tools_response, 'tools') or not isinstance(tools_response.tools, list):
|
146
|
+
logger.error(f"Invalid tool list response from MCP server: {tools_response}")
|
147
|
+
return []
|
148
|
+
|
149
|
+
serialized_tools = []
|
150
|
+
tools = []
|
151
|
+
for tool_proto in tools_response.tools:
|
152
|
+
if not hasattr(tool_proto, 'name') or not tool_proto.name:
|
153
|
+
logger.warning(f"Skipping tool with missing name in response: {tool_proto}")
|
154
|
+
continue
|
155
|
+
|
156
|
+
# Ensure inputSchema exists and is a dict, default if not
|
157
|
+
input_schema = getattr(tool_proto, 'inputSchema', None)
|
158
|
+
if not isinstance(input_schema, dict):
|
159
|
+
input_schema = {"type": "object", "properties": {}}
|
160
|
+
|
161
|
+
description = getattr(tool_proto, 'description', "") or "" # Ensure description is string
|
162
|
+
|
163
|
+
serialized_tool_data = {
|
164
|
+
'name': tool_proto.name,
|
165
|
+
'description': description,
|
166
|
+
'input_schema': input_schema,
|
167
|
+
}
|
168
|
+
serialized_tools.append(serialized_tool_data)
|
169
|
+
|
170
|
+
# Create Tool instance for returning
|
171
|
+
discovered_tool = Tool(
|
172
|
+
name=tool_proto.name,
|
173
|
+
description=description,
|
174
|
+
input_schema=input_schema,
|
175
|
+
func=self._create_tool_callable(tool_proto.name),
|
176
|
+
)
|
177
|
+
self._tool_cache[tool_proto.name] = discovered_tool # Cache instance
|
178
|
+
tools.append(discovered_tool)
|
179
|
+
logger.debug(f"Discovered tool: {tool_proto.name} with schema: {input_schema}")
|
180
|
+
|
181
|
+
# Cache the serialized data
|
182
|
+
self.cache.set(cache_key, serialized_tools, 3600)
|
183
|
+
logger.debug(f"Cached {len(serialized_tools)} tools.")
|
184
|
+
|
185
|
+
logger.debug(f"Returning {len(tools)} tools from MCP server")
|
186
|
+
return tools
|
187
|
+
|
188
|
+
except asyncio.TimeoutError:
|
189
|
+
logger.error(f"Timeout after {self.timeout}s waiting for tool list")
|
190
|
+
raise RuntimeError("Tool list request timed out")
|
191
|
+
except Exception as e:
|
192
|
+
logger.error(f"Error listing tools: {e}", exc_info=True)
|
193
|
+
raise RuntimeError(f"Failed to list tools: {e}") from e
|
194
|
+
|
195
|
+
async def _do_list_resources(self) -> Any:
|
196
|
+
"""Internal method to list resources with timeout."""
|
197
|
+
server_params = StdioServerParameters(command=self.command, args=self.args, env=self.env)
|
198
|
+
logger.debug("Opening stdio_client connection for resources")
|
199
|
+
try:
|
200
|
+
async with stdio_client(server_params) as (read, write):
|
201
|
+
logger.debug("Opening ClientSession for resources")
|
202
|
+
async with ClientSession(read, write) as session:
|
203
|
+
with self._redirect_stderr(): # Suppress stderr if not debugging
|
204
|
+
logger.debug("Initializing session before listing resources")
|
205
|
+
await asyncio.wait_for(session.initialize(), timeout=self.timeout)
|
206
|
+
logger.info("Requesting resource list from MCP server...")
|
207
|
+
resources_response = await asyncio.wait_for(session.list_resources(), timeout=self.timeout)
|
208
|
+
logger.debug("Resource list received from MCP server")
|
209
|
+
return resources_response
|
210
|
+
except asyncio.TimeoutError:
|
211
|
+
logger.error(f"Timeout listing resources after {self.timeout}s")
|
212
|
+
raise RuntimeError("Resource list request timed out")
|
213
|
+
except Exception as e:
|
214
|
+
logger.error(f"Error listing resources: {e}", exc_info=True)
|
215
|
+
raise RuntimeError(f"Failed to list resources: {e}") from e
|
216
|
+
|
217
|
+
def _create_tool_callable(self, tool_name: str) -> Callable[..., Any]:
|
218
|
+
"""
|
219
|
+
Dynamically create an async callable function for the specified tool.
|
220
|
+
This callable will establish a connection and execute the tool on demand.
|
221
|
+
"""
|
222
|
+
async def dynamic_tool_func(**kwargs) -> Any:
|
223
|
+
logger.debug(f"Creating tool callable for '{tool_name}'")
|
224
|
+
server_params = StdioServerParameters(command=self.command, args=self.args, env=self.env)
|
225
|
+
try:
|
226
|
+
async with stdio_client(server_params) as (read, write):
|
227
|
+
async with ClientSession(read, write) as session:
|
228
|
+
# Initialize session first
|
229
|
+
logger.debug(f"Initializing session for tool '{tool_name}'")
|
230
|
+
await asyncio.wait_for(session.initialize(), timeout=self.timeout)
|
231
|
+
|
232
|
+
# Validate input if schema is available in instance cache
|
233
|
+
if tool_name in self._tool_cache:
|
234
|
+
tool = self._tool_cache[tool_name]
|
235
|
+
self._validate_input_schema(tool.input_schema, kwargs)
|
236
|
+
else:
|
237
|
+
logger.warning(f"Schema for tool '{tool_name}' not found in cache for validation.")
|
238
|
+
|
239
|
+
logger.info(f"Calling tool '{tool_name}' with arguments: {kwargs}")
|
240
|
+
# Execute the tool call
|
241
|
+
result_proto = await asyncio.wait_for(session.call_tool(tool_name, kwargs), timeout=self.timeout)
|
242
|
+
|
243
|
+
# Process result (assuming result_proto has a 'result' attribute)
|
244
|
+
result_data = getattr(result_proto, 'result', None)
|
245
|
+
if result_data is None:
|
246
|
+
logger.warning(f"Tool '{tool_name}' executed but returned no result data.")
|
247
|
+
return None # Or raise error?
|
248
|
+
|
249
|
+
# Attempt to parse if it looks like JSON, otherwise return as is
|
250
|
+
if isinstance(result_data, str):
|
251
|
+
try:
|
252
|
+
parsed_result = json.loads(result_data)
|
253
|
+
logger.info(f"Tool '{tool_name}' executed successfully (result parsed as JSON).")
|
254
|
+
return parsed_result
|
255
|
+
except json.JSONDecodeError:
|
256
|
+
logger.info(f"Tool '{tool_name}' executed successfully (result returned as string).")
|
257
|
+
return result_data # Return raw string if not JSON
|
258
|
+
else:
|
259
|
+
logger.info(f"Tool '{tool_name}' executed successfully (result type: {type(result_data)}).")
|
260
|
+
return result_data # Return non-string result directly
|
261
|
+
|
262
|
+
except asyncio.TimeoutError:
|
263
|
+
logger.error(f"Timeout after {self.timeout}s executing tool '{tool_name}'")
|
264
|
+
raise RuntimeError(f"Tool '{tool_name}' execution timed out")
|
265
|
+
except Exception as e:
|
266
|
+
logger.error(f"Failed to execute tool '{tool_name}': {e}", exc_info=True)
|
267
|
+
raise RuntimeError(f"Tool execution failed: {e}") from e
|
268
|
+
|
269
|
+
return dynamic_tool_func
|
270
|
+
|
271
|
+
def _validate_input_schema(self, schema: Dict[str, Any], kwargs: Dict[str, Any]):
|
272
|
+
"""
|
273
|
+
Validate the provided arguments against the input schema.
|
274
|
+
"""
|
275
|
+
# Ensure schema is a dictionary, default to no-op if not
|
276
|
+
if not isinstance(schema, dict):
|
277
|
+
logger.warning(f"Invalid schema format for validation: {type(schema)}. Skipping.")
|
278
|
+
return
|
279
|
+
|
280
|
+
required_params = schema.get("required", [])
|
281
|
+
# Ensure required_params is a list
|
282
|
+
if not isinstance(required_params, list):
|
283
|
+
logger.warning(f"Invalid 'required' list in schema: {type(required_params)}. Skipping requirement check.")
|
284
|
+
required_params = []
|
285
|
+
|
286
|
+
for param in required_params:
|
287
|
+
if param not in kwargs:
|
288
|
+
raise ValueError(f"Missing required parameter: '{param}'")
|
289
|
+
|
290
|
+
# Optional: Add type validation based on schema['properties'][param]['type']
|
291
|
+
properties = schema.get("properties", {})
|
292
|
+
if isinstance(properties, dict):
|
293
|
+
for key, value in kwargs.items():
|
294
|
+
if key in properties:
|
295
|
+
expected_type = properties[key].get("type")
|
296
|
+
# Basic type mapping (add more as needed)
|
297
|
+
type_map = {"string": str, "integer": int, "number": (int, float), "boolean": bool, "array": list, "object": dict}
|
298
|
+
if expected_type in type_map:
|
299
|
+
if not isinstance(value, type_map[expected_type]):
|
300
|
+
logger.warning(f"Type mismatch for parameter '{key}'. Expected '{expected_type}', got '{type(value).__name__}'. Attempting to proceed.")
|
301
|
+
# Allow proceeding but log warning, or raise ValueError for strict validation
|
302
|
+
|
303
|
+
logger.debug(f"Validated input against schema: {schema} with arguments: {kwargs}")
|
304
|
+
|
305
|
+
async def list_resources(self) -> Any:
|
306
|
+
"""
|
307
|
+
Discover resources from the MCP server using the internal method with enforced timeout.
|
308
|
+
"""
|
309
|
+
return await self._do_list_resources() # Timeout handled in _do_list_resources
|
310
|
+
|
311
|
+
async def get_resource(self, resource_uri: str) -> Any:
|
312
|
+
"""
|
313
|
+
Retrieve a specific resource from the MCP server.
|
314
|
+
|
315
|
+
Args:
|
316
|
+
resource_uri (str): The URI of the resource to retrieve.
|
317
|
+
|
318
|
+
Returns:
|
319
|
+
Any: The resource retrieval response.
|
320
|
+
"""
|
321
|
+
server_params = StdioServerParameters(command=self.command, args=self.args, env=self.env)
|
322
|
+
logger.debug("Opening stdio_client connection for resource retrieval")
|
323
|
+
try:
|
324
|
+
async with stdio_client(server_params) as (read, write):
|
325
|
+
logger.debug("Opening ClientSession for resource retrieval")
|
326
|
+
async with ClientSession(read, write) as session:
|
327
|
+
with self._redirect_stderr(): # Suppress stderr if not debugging
|
328
|
+
logger.debug(f"Initializing session for resource retrieval of {resource_uri}")
|
329
|
+
await asyncio.wait_for(session.initialize(), timeout=self.timeout)
|
330
|
+
logger.info(f"Retrieving resource '{resource_uri}' from MCP server")
|
331
|
+
response = await asyncio.wait_for(session.read_resource(resource_uri), timeout=self.timeout)
|
332
|
+
logger.info(f"Resource '{resource_uri}' retrieved successfully")
|
333
|
+
# Process response if needed (e.g., getattr(response, 'content', None))
|
334
|
+
return response
|
335
|
+
except asyncio.TimeoutError:
|
336
|
+
logger.error(f"Timeout retrieving resource '{resource_uri}' after {self.timeout}s")
|
337
|
+
raise RuntimeError(f"Resource '{resource_uri}' retrieval timed out")
|
338
|
+
except Exception as e:
|
339
|
+
logger.error(f"Failed to retrieve resource '{resource_uri}': {e}", exc_info=True)
|
340
|
+
raise RuntimeError(f"Resource retrieval failed: {e}") from e
|
341
|
+
|
@@ -0,0 +1,110 @@
|
|
1
|
+
"""
|
2
|
+
MCPToolProvider Module for Open-Swarm
|
3
|
+
|
4
|
+
This module is responsible for discovering tools from MCP (Model Context Protocol) servers
|
5
|
+
and integrating them into the Open-Swarm framework as `Tool` instances.
|
6
|
+
"""
|
7
|
+
|
8
|
+
import logging
|
9
|
+
import json
|
10
|
+
import re # Standard library for regular expressions
|
11
|
+
from typing import List, Dict, Any
|
12
|
+
|
13
|
+
from ...settings import Settings # Use Swarm settings
|
14
|
+
from ...types import Tool, Agent
|
15
|
+
from .mcp_client import MCPClient
|
16
|
+
from .cache_utils import get_cache
|
17
|
+
|
18
|
+
# Use Swarm's settings for logging configuration
|
19
|
+
swarm_settings = Settings()
|
20
|
+
logger = logging.getLogger(__name__)
|
21
|
+
logger.setLevel(swarm_settings.log_level.upper())
|
22
|
+
# Ensure handler is added only if needed
|
23
|
+
if not logger.handlers and not logging.getLogger('swarm').handlers:
|
24
|
+
handler = logging.StreamHandler()
|
25
|
+
formatter = logging.Formatter(swarm_settings.log_format.value)
|
26
|
+
handler.setFormatter(formatter)
|
27
|
+
logger.addHandler(handler)
|
28
|
+
|
29
|
+
|
30
|
+
class MCPToolProvider:
|
31
|
+
"""
|
32
|
+
MCPToolProvider discovers tools from an MCP server and converts them into `Tool` instances.
|
33
|
+
Uses caching to avoid repeated discovery.
|
34
|
+
"""
|
35
|
+
_instances: Dict[str, "MCPToolProvider"] = {}
|
36
|
+
|
37
|
+
@classmethod
|
38
|
+
def get_instance(cls, server_name: str, server_config: Dict[str, Any], timeout: int = 15, debug: bool = False) -> "MCPToolProvider":
|
39
|
+
"""Get or create an instance for the given server name."""
|
40
|
+
config_key = json.dumps(server_config, sort_keys=True)
|
41
|
+
instance_key = f"{server_name}_{config_key}_{timeout}_{debug}"
|
42
|
+
|
43
|
+
if instance_key not in cls._instances:
|
44
|
+
logger.debug(f"Creating new MCPToolProvider instance for key: {instance_key}")
|
45
|
+
cls._instances[instance_key] = cls(server_name, server_config, timeout, debug)
|
46
|
+
else:
|
47
|
+
logger.debug(f"Reusing existing MCPToolProvider instance for key: {instance_key}")
|
48
|
+
return cls._instances[instance_key]
|
49
|
+
|
50
|
+
def __init__(self, server_name: str, server_config: Dict[str, Any], timeout: int = 15, debug: bool = False):
|
51
|
+
"""
|
52
|
+
Initialize an MCPToolProvider instance. Use get_instance() for shared instances.
|
53
|
+
"""
|
54
|
+
self.server_name = server_name
|
55
|
+
effective_debug = debug or swarm_settings.debug
|
56
|
+
try:
|
57
|
+
self.client = MCPClient(server_config=server_config, timeout=timeout, debug=effective_debug)
|
58
|
+
except ImportError as e:
|
59
|
+
logger.error(f"Failed to initialize MCPClient for '{server_name}': {e}. MCP features will be unavailable.")
|
60
|
+
self.client = None
|
61
|
+
except Exception as e:
|
62
|
+
logger.error(f"Error initializing MCPClient for '{server_name}': {e}", exc_info=True)
|
63
|
+
self.client = None
|
64
|
+
|
65
|
+
self.cache = get_cache()
|
66
|
+
logger.debug(f"Initialized MCPToolProvider for server '{self.server_name}' with timeout {timeout}s.")
|
67
|
+
|
68
|
+
async def discover_tools(self, agent: Agent) -> List[Tool]:
|
69
|
+
"""
|
70
|
+
Discover tools from the MCP server using the MCPClient.
|
71
|
+
|
72
|
+
Args:
|
73
|
+
agent (Agent): The agent for which tools are being discovered.
|
74
|
+
|
75
|
+
Returns:
|
76
|
+
List[Tool]: A list of discovered `Tool` instances with prefixed names.
|
77
|
+
"""
|
78
|
+
if not self.client:
|
79
|
+
logger.warning(f"MCPClient for '{self.server_name}' not initialized. Cannot discover tools.")
|
80
|
+
return []
|
81
|
+
|
82
|
+
logger.debug(f"Starting tool discovery via MCPClient for server '{self.server_name}'.")
|
83
|
+
try:
|
84
|
+
tools = await self.client.list_tools()
|
85
|
+
logger.debug(f"Discovered {len(tools)} tools from MCP server '{self.server_name}'.")
|
86
|
+
|
87
|
+
separator = "__"
|
88
|
+
prefixed_tools = []
|
89
|
+
for tool in tools:
|
90
|
+
prefixed_name = f"{self.server_name}{separator}{tool.name}"
|
91
|
+
# Validate prefixed name against OpenAI pattern
|
92
|
+
if not re.match(r"^[a-zA-Z0-9_-]{1,64}$", prefixed_name):
|
93
|
+
logger.warning(f"Generated MCP tool name '{prefixed_name}' might violate OpenAI pattern. Skipping.")
|
94
|
+
continue
|
95
|
+
|
96
|
+
prefixed_tool = Tool(
|
97
|
+
name=prefixed_name,
|
98
|
+
description=tool.description,
|
99
|
+
input_schema=tool.input_schema,
|
100
|
+
func=tool.func # Callable already targets this client/tool
|
101
|
+
)
|
102
|
+
prefixed_tools.append(prefixed_tool)
|
103
|
+
logger.debug(f"Added prefixed tool: {prefixed_tool.name}")
|
104
|
+
|
105
|
+
return prefixed_tools
|
106
|
+
|
107
|
+
except Exception as e:
|
108
|
+
logger.error(f"Failed to discover tools from MCP server '{self.server_name}': {e}", exc_info=True)
|
109
|
+
return []
|
110
|
+
|
@@ -0,0 +1,195 @@
|
|
1
|
+
"""
|
2
|
+
Chat Completion Module
|
3
|
+
|
4
|
+
This module handles chat completion logic for the Swarm framework, including message preparation,
|
5
|
+
tool call repair, and interaction with the OpenAI API. Located in llm/ for LLM-specific functionality.
|
6
|
+
"""
|
7
|
+
|
8
|
+
import os
|
9
|
+
import json
|
10
|
+
import logging
|
11
|
+
from typing import List, Optional, Dict, Any, Union, AsyncGenerator # Added AsyncGenerator
|
12
|
+
from collections import defaultdict
|
13
|
+
|
14
|
+
import asyncio
|
15
|
+
from openai import AsyncOpenAI, OpenAIError
|
16
|
+
# Make sure ChatCompletionMessage is correctly imported if it's defined elsewhere
|
17
|
+
# Assuming it might be part of the base model or a common types module
|
18
|
+
# For now, let's assume it's implicitly handled or use a dict directly
|
19
|
+
# from ..types import ChatCompletionMessage, Agent # If defined in types
|
20
|
+
from ..types import Agent # Import Agent
|
21
|
+
from ..utils.redact import redact_sensitive_data
|
22
|
+
from ..utils.general_utils import serialize_datetime
|
23
|
+
from ..utils.message_utils import filter_duplicate_system_messages, update_null_content
|
24
|
+
from ..utils.context_utils import get_token_count, truncate_message_history
|
25
|
+
from ..utils.message_sequence import repair_message_payload
|
26
|
+
|
27
|
+
# Configure module-level logging
|
28
|
+
logger = logging.getLogger(__name__)
|
29
|
+
logger.setLevel(logging.DEBUG)
|
30
|
+
if not logger.handlers:
|
31
|
+
stream_handler = logging.StreamHandler()
|
32
|
+
formatter = logging.Formatter("[%(levelname)s] %(asctime)s - %(name)s - %(message)s")
|
33
|
+
stream_handler.setFormatter(formatter)
|
34
|
+
logger.addHandler(stream_handler)
|
35
|
+
|
36
|
+
|
37
|
+
async def get_chat_completion(
|
38
|
+
client: AsyncOpenAI,
|
39
|
+
agent: Agent,
|
40
|
+
history: List[Dict[str, Any]],
|
41
|
+
context_variables: dict,
|
42
|
+
current_llm_config: Dict[str, Any],
|
43
|
+
max_context_tokens: int,
|
44
|
+
max_context_messages: int,
|
45
|
+
tools: Optional[List[Dict[str, Any]]] = None, # <-- Added tools parameter
|
46
|
+
tool_choice: Optional[str] = "auto", # <-- Added tool_choice parameter
|
47
|
+
model_override: Optional[str] = None,
|
48
|
+
stream: bool = False,
|
49
|
+
debug: bool = False
|
50
|
+
) -> Union[Dict[str, Any], AsyncGenerator[Any, None]]: # Adjusted return type hint
|
51
|
+
"""
|
52
|
+
Retrieve a chat completion from the LLM for the given agent and history.
|
53
|
+
|
54
|
+
Args:
|
55
|
+
client: AsyncOpenAI client instance.
|
56
|
+
agent: The agent processing the completion.
|
57
|
+
history: List of previous messages in the conversation.
|
58
|
+
context_variables: Variables to include in the agent's context.
|
59
|
+
current_llm_config: Current LLM configuration dictionary.
|
60
|
+
max_context_tokens: Maximum token limit for context.
|
61
|
+
max_context_messages: Maximum message limit for context.
|
62
|
+
tools: Optional list of tools in OpenAI format.
|
63
|
+
tool_choice: Tool choice mode (e.g., "auto", "none").
|
64
|
+
model_override: Optional model to use instead of default.
|
65
|
+
stream: If True, stream the response; otherwise, return complete.
|
66
|
+
debug: If True, log detailed debugging information.
|
67
|
+
|
68
|
+
Returns:
|
69
|
+
Union[Dict[str, Any], AsyncGenerator[Any, None]]: The LLM's response message (as dict) or stream.
|
70
|
+
"""
|
71
|
+
if not agent:
|
72
|
+
logger.error("Cannot generate chat completion: Agent is None")
|
73
|
+
raise ValueError("Agent is required")
|
74
|
+
|
75
|
+
logger.debug(f"Generating chat completion for agent '{agent.name}'")
|
76
|
+
active_model = model_override or current_llm_config.get("model", "default")
|
77
|
+
client_kwargs = {
|
78
|
+
"api_key": current_llm_config.get("api_key"),
|
79
|
+
"base_url": current_llm_config.get("base_url")
|
80
|
+
}
|
81
|
+
client_kwargs = {k: v for k, v in client_kwargs.items() if v is not None}
|
82
|
+
redacted_kwargs = redact_sensitive_data(client_kwargs, sensitive_keys=["api_key"])
|
83
|
+
logger.debug(f"Using client with model='{active_model}', base_url='{client_kwargs.get('base_url', 'default')}', api_key={redacted_kwargs['api_key']}")
|
84
|
+
|
85
|
+
context_variables = defaultdict(str, context_variables)
|
86
|
+
instructions = agent.instructions(context_variables) if callable(agent.instructions) else agent.instructions
|
87
|
+
if not isinstance(instructions, str):
|
88
|
+
logger.warning(f"Invalid instructions type for '{agent.name}': {type(instructions)}. Converting to string.")
|
89
|
+
instructions = str(instructions)
|
90
|
+
messages = repair_message_payload([{"role": "system", "content": instructions}], debug=debug)
|
91
|
+
|
92
|
+
if not isinstance(history, list):
|
93
|
+
logger.error(f"Invalid history type for '{agent.name}': {type(history)}. Expected list.")
|
94
|
+
history = []
|
95
|
+
seen_ids = set()
|
96
|
+
for msg in history:
|
97
|
+
msg_id = msg.get("id", hash(json.dumps(msg, sort_keys=True, default=serialize_datetime)))
|
98
|
+
if msg_id not in seen_ids:
|
99
|
+
seen_ids.add(msg_id)
|
100
|
+
if "tool_calls" in msg and msg["tool_calls"] is not None and not isinstance(msg["tool_calls"], list):
|
101
|
+
logger.warning(f"Invalid tool_calls in history for '{msg.get('sender', 'unknown')}': {msg['tool_calls']}. Setting to None.")
|
102
|
+
msg["tool_calls"] = None
|
103
|
+
# Ensure content: None becomes content: "" for API compatibility
|
104
|
+
if "content" in msg and msg["content"] is None:
|
105
|
+
msg["content"] = ""
|
106
|
+
messages.append(msg)
|
107
|
+
messages = filter_duplicate_system_messages(messages)
|
108
|
+
messages = truncate_message_history(messages, active_model, max_context_tokens, max_context_messages)
|
109
|
+
messages = repair_message_payload(messages, debug=debug) # Ensure tool calls are paired post-truncation
|
110
|
+
# Final content None -> "" check after repair
|
111
|
+
messages = update_null_content(messages)
|
112
|
+
|
113
|
+
logger.debug(f"Prepared {len(messages)} messages for '{agent.name}'")
|
114
|
+
if debug:
|
115
|
+
logger.debug(f"Messages: {json.dumps(messages, indent=2, default=str)}")
|
116
|
+
|
117
|
+
create_params = {
|
118
|
+
"model": active_model,
|
119
|
+
"messages": messages,
|
120
|
+
"stream": stream,
|
121
|
+
"temperature": current_llm_config.get("temperature", 0.7),
|
122
|
+
# --- Pass tools and tool_choice ---
|
123
|
+
"tools": tools if tools else None,
|
124
|
+
"tool_choice": tool_choice if tools else None, # Only set tool_choice if tools are provided
|
125
|
+
}
|
126
|
+
if getattr(agent, "response_format", None):
|
127
|
+
create_params["response_format"] = agent.response_format
|
128
|
+
create_params = {k: v for k, v in create_params.items() if v is not None} # Clean None values
|
129
|
+
|
130
|
+
tool_info_log = f", tools_count={len(tools)}" if tools else ", tools=None"
|
131
|
+
logger.debug(f"Chat completion params: model='{active_model}', messages_count={len(messages)}, stream={stream}{tool_info_log}, tool_choice={create_params.get('tool_choice')}")
|
132
|
+
|
133
|
+
try:
|
134
|
+
logger.debug(f"Calling OpenAI API for '{agent.name}' with model='{active_model}'")
|
135
|
+
# Temporary workaround for potential env var conflicts if client doesn't isolate well
|
136
|
+
prev_openai_api_key = os.environ.pop("OPENAI_API_KEY", None)
|
137
|
+
try:
|
138
|
+
completion = await client.chat.completions.create(**create_params)
|
139
|
+
if stream:
|
140
|
+
return completion # Return stream object directly
|
141
|
+
|
142
|
+
# --- Handle Non-Streaming Response ---
|
143
|
+
if completion.choices and len(completion.choices) > 0 and completion.choices[0].message:
|
144
|
+
message_dict = completion.choices[0].message.model_dump(exclude_none=True)
|
145
|
+
log_msg = message_dict.get("content", "No content")[:50] if message_dict.get("content") else "No content"
|
146
|
+
if message_dict.get("tool_calls"): log_msg += f" (+{len(message_dict['tool_calls'])} tool calls)"
|
147
|
+
logger.debug(f"OpenAI completion received for '{agent.name}': {log_msg}...")
|
148
|
+
return message_dict # Return the message dictionary
|
149
|
+
else:
|
150
|
+
logger.warning(f"No valid message in completion for '{agent.name}'")
|
151
|
+
return {"role": "assistant", "content": "No response generated"} # Return dict
|
152
|
+
finally:
|
153
|
+
if prev_openai_api_key is not None:
|
154
|
+
os.environ["OPENAI_API_KEY"] = prev_openai_api_key
|
155
|
+
except OpenAIError as e:
|
156
|
+
logger.error(f"Chat completion failed for '{agent.name}': {e}")
|
157
|
+
raise
|
158
|
+
except Exception as e: # Catch broader errors during API call
|
159
|
+
logger.error(f"Unexpected error during chat completion for '{agent.name}': {e}", exc_info=True)
|
160
|
+
raise # Re-raise
|
161
|
+
|
162
|
+
|
163
|
+
async def get_chat_completion_message(
|
164
|
+
client: AsyncOpenAI,
|
165
|
+
agent: Agent,
|
166
|
+
history: List[Dict[str, Any]],
|
167
|
+
context_variables: dict,
|
168
|
+
current_llm_config: Dict[str, Any],
|
169
|
+
max_context_tokens: int,
|
170
|
+
max_context_messages: int,
|
171
|
+
tools: Optional[List[Dict[str, Any]]] = None, # <-- Added tools
|
172
|
+
tool_choice: Optional[str] = "auto", # <-- Added tool_choice
|
173
|
+
model_override: Optional[str] = None,
|
174
|
+
stream: bool = False,
|
175
|
+
debug: bool = False
|
176
|
+
) -> Union[Dict[str, Any], AsyncGenerator[Any, None]]: # Return dict or stream
|
177
|
+
"""
|
178
|
+
Wrapper to retrieve and validate a chat completion message (returns dict or stream).
|
179
|
+
|
180
|
+
Args:
|
181
|
+
Same as get_chat_completion.
|
182
|
+
|
183
|
+
Returns:
|
184
|
+
Union[Dict[str, Any], AsyncGenerator[Any, None]]: Validated LLM response message as dict or the stream.
|
185
|
+
"""
|
186
|
+
logger.debug(f"Fetching chat completion message for '{agent.name}'")
|
187
|
+
completion_result = await get_chat_completion(
|
188
|
+
client, agent, history, context_variables, current_llm_config,
|
189
|
+
max_context_tokens, max_context_messages,
|
190
|
+
tools=tools, tool_choice=tool_choice, # Pass through
|
191
|
+
model_override=model_override, stream=stream, debug=debug
|
192
|
+
)
|
193
|
+
# If streaming, completion_result is already the generator
|
194
|
+
# If not streaming, it's the message dictionary
|
195
|
+
return completion_result
|