open-swarm 0.1.1743070217__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (89) hide show
  1. open_swarm-0.1.1743070217.dist-info/METADATA +258 -0
  2. open_swarm-0.1.1743070217.dist-info/RECORD +89 -0
  3. open_swarm-0.1.1743070217.dist-info/WHEEL +5 -0
  4. open_swarm-0.1.1743070217.dist-info/entry_points.txt +3 -0
  5. open_swarm-0.1.1743070217.dist-info/licenses/LICENSE +21 -0
  6. open_swarm-0.1.1743070217.dist-info/top_level.txt +1 -0
  7. swarm/__init__.py +3 -0
  8. swarm/agent/__init__.py +7 -0
  9. swarm/agent/agent.py +49 -0
  10. swarm/apps.py +53 -0
  11. swarm/auth.py +56 -0
  12. swarm/consumers.py +141 -0
  13. swarm/core.py +326 -0
  14. swarm/extensions/__init__.py +1 -0
  15. swarm/extensions/blueprint/__init__.py +36 -0
  16. swarm/extensions/blueprint/agent_utils.py +45 -0
  17. swarm/extensions/blueprint/blueprint_base.py +562 -0
  18. swarm/extensions/blueprint/blueprint_discovery.py +112 -0
  19. swarm/extensions/blueprint/blueprint_utils.py +17 -0
  20. swarm/extensions/blueprint/common_utils.py +12 -0
  21. swarm/extensions/blueprint/django_utils.py +203 -0
  22. swarm/extensions/blueprint/interactive_mode.py +102 -0
  23. swarm/extensions/blueprint/modes/rest_mode.py +37 -0
  24. swarm/extensions/blueprint/output_utils.py +95 -0
  25. swarm/extensions/blueprint/spinner.py +91 -0
  26. swarm/extensions/cli/__init__.py +0 -0
  27. swarm/extensions/cli/blueprint_runner.py +251 -0
  28. swarm/extensions/cli/cli_args.py +88 -0
  29. swarm/extensions/cli/commands/__init__.py +0 -0
  30. swarm/extensions/cli/commands/blueprint_management.py +31 -0
  31. swarm/extensions/cli/commands/config_management.py +15 -0
  32. swarm/extensions/cli/commands/edit_config.py +77 -0
  33. swarm/extensions/cli/commands/list_blueprints.py +22 -0
  34. swarm/extensions/cli/commands/validate_env.py +57 -0
  35. swarm/extensions/cli/commands/validate_envvars.py +39 -0
  36. swarm/extensions/cli/interactive_shell.py +41 -0
  37. swarm/extensions/cli/main.py +36 -0
  38. swarm/extensions/cli/selection.py +43 -0
  39. swarm/extensions/cli/utils/discover_commands.py +32 -0
  40. swarm/extensions/cli/utils/env_setup.py +15 -0
  41. swarm/extensions/cli/utils.py +105 -0
  42. swarm/extensions/config/__init__.py +6 -0
  43. swarm/extensions/config/config_loader.py +208 -0
  44. swarm/extensions/config/config_manager.py +258 -0
  45. swarm/extensions/config/server_config.py +49 -0
  46. swarm/extensions/config/setup_wizard.py +103 -0
  47. swarm/extensions/config/utils/__init__.py +0 -0
  48. swarm/extensions/config/utils/logger.py +36 -0
  49. swarm/extensions/launchers/__init__.py +1 -0
  50. swarm/extensions/launchers/build_launchers.py +14 -0
  51. swarm/extensions/launchers/build_swarm_wrapper.py +12 -0
  52. swarm/extensions/launchers/swarm_api.py +68 -0
  53. swarm/extensions/launchers/swarm_cli.py +304 -0
  54. swarm/extensions/launchers/swarm_wrapper.py +29 -0
  55. swarm/extensions/mcp/__init__.py +1 -0
  56. swarm/extensions/mcp/cache_utils.py +36 -0
  57. swarm/extensions/mcp/mcp_client.py +341 -0
  58. swarm/extensions/mcp/mcp_constants.py +7 -0
  59. swarm/extensions/mcp/mcp_tool_provider.py +110 -0
  60. swarm/llm/chat_completion.py +195 -0
  61. swarm/messages.py +132 -0
  62. swarm/migrations/0010_initial_chat_models.py +51 -0
  63. swarm/migrations/__init__.py +0 -0
  64. swarm/models.py +45 -0
  65. swarm/repl/__init__.py +1 -0
  66. swarm/repl/repl.py +87 -0
  67. swarm/serializers.py +12 -0
  68. swarm/settings.py +189 -0
  69. swarm/tool_executor.py +239 -0
  70. swarm/types.py +126 -0
  71. swarm/urls.py +89 -0
  72. swarm/util.py +124 -0
  73. swarm/utils/color_utils.py +40 -0
  74. swarm/utils/context_utils.py +272 -0
  75. swarm/utils/general_utils.py +162 -0
  76. swarm/utils/logger.py +61 -0
  77. swarm/utils/logger_setup.py +25 -0
  78. swarm/utils/message_sequence.py +173 -0
  79. swarm/utils/message_utils.py +95 -0
  80. swarm/utils/redact.py +68 -0
  81. swarm/views/__init__.py +41 -0
  82. swarm/views/api_views.py +46 -0
  83. swarm/views/chat_views.py +76 -0
  84. swarm/views/core_views.py +118 -0
  85. swarm/views/message_views.py +40 -0
  86. swarm/views/model_views.py +135 -0
  87. swarm/views/utils.py +457 -0
  88. swarm/views/web_views.py +149 -0
  89. swarm/wsgi.py +16 -0
@@ -0,0 +1,341 @@
1
+ """
2
+ MCP Client Module
3
+
4
+ Manages connections and interactions with MCP servers using the MCP Python SDK.
5
+ Redirects MCP server stderr to log files unless debug mode is enabled.
6
+ """
7
+
8
+ import asyncio
9
+ import logging
10
+ import os
11
+ from typing import Any, Dict, List, Callable
12
+ from contextlib import contextmanager
13
+ import sys
14
+ import json # Added for result parsing
15
+
16
+ # Attempt to import mcp types carefully
17
+ try:
18
+ from mcp import ClientSession, StdioServerParameters # type: ignore
19
+ from mcp.client.stdio import stdio_client # type: ignore
20
+ MCP_AVAILABLE = True
21
+ except ImportError:
22
+ MCP_AVAILABLE = False
23
+ # Define dummy classes if mcp is not installed
24
+ class ClientSession: pass
25
+ class StdioServerParameters: pass
26
+ def stdio_client(*args, **kwargs): raise ImportError("mcp library not installed")
27
+
28
+ from ...types import Tool # Use Tool from swarm.types
29
+ from .cache_utils import get_cache
30
+ from ...settings import Settings # Import Swarm settings
31
+
32
+ # Use Swarm's settings for logging configuration
33
+ swarm_settings = Settings()
34
+ logger = logging.getLogger(__name__)
35
+ logger.setLevel(swarm_settings.log_level.upper()) # Use log level from settings
36
+ # Ensure handler is added only if needed, respecting potential global config
37
+ if not logger.handlers and not logging.getLogger('swarm').handlers:
38
+ handler = logging.StreamHandler()
39
+ # Use log format from settings
40
+ formatter = logging.Formatter(swarm_settings.log_format.value)
41
+ handler.setFormatter(formatter)
42
+ logger.addHandler(handler)
43
+
44
+ class MCPClient:
45
+ """
46
+ Manages connections and interactions with MCP servers using the MCP Python SDK.
47
+ """
48
+
49
+ def __init__(self, server_config: Dict[str, Any], timeout: int = 15, debug: bool = False):
50
+ """
51
+ Initialize the MCPClient with server configuration.
52
+
53
+ Args:
54
+ server_config (dict): Configuration dictionary for the MCP server.
55
+ timeout (int): Timeout for operations in seconds.
56
+ debug (bool): If True, MCP server stderr goes to console; otherwise, suppressed.
57
+ """
58
+ if not MCP_AVAILABLE:
59
+ raise ImportError("The 'mcp-client' library is required for MCP functionality but is not installed.")
60
+
61
+ self.command = server_config.get("command", "npx")
62
+ self.args = server_config.get("args", [])
63
+ self.env = {**os.environ.copy(), **server_config.get("env", {})}
64
+ self.timeout = timeout
65
+ self.debug = debug or swarm_settings.debug # Use instance debug or global debug
66
+ self._tool_cache: Dict[str, Tool] = {}
67
+ self.cache = get_cache()
68
+
69
+ # Validate command and args types
70
+ if not isinstance(self.command, str):
71
+ raise TypeError(f"MCP server command must be a string, got {type(self.command)}")
72
+ if not isinstance(self.args, list) or not all(isinstance(a, str) for a in self.args):
73
+ raise TypeError(f"MCP server args must be a list of strings, got {self.args}")
74
+
75
+
76
+ logger.info(f"Initialized MCPClient with command={self.command}, args={self.args}, debug={self.debug}")
77
+
78
+ @contextmanager
79
+ def _redirect_stderr(self):
80
+ """Redirects stderr to /dev/null if not in debug mode."""
81
+ if not self.debug:
82
+ original_stderr = sys.stderr
83
+ devnull = None
84
+ try:
85
+ devnull = open(os.devnull, "w")
86
+ sys.stderr = devnull
87
+ yield
88
+ except Exception:
89
+ # Restore stderr even if there was an error opening /dev/null or during yield
90
+ if devnull: devnull.close()
91
+ sys.stderr = original_stderr
92
+ raise # Re-raise the exception
93
+ finally:
94
+ if devnull: devnull.close()
95
+ sys.stderr = original_stderr
96
+ else:
97
+ # If debug is True, don't redirect
98
+ yield
99
+
100
+ async def list_tools(self) -> List[Tool]:
101
+ """
102
+ Discover tools from the MCP server and cache their schemas.
103
+
104
+ Returns:
105
+ List[Tool]: A list of discovered tools with schemas.
106
+ """
107
+ logger.debug(f"Entering list_tools for command={self.command}, args={self.args}")
108
+
109
+ # Attempt to retrieve tools from cache
110
+ # Create a more robust cache key
111
+ args_string = json.dumps(self.args, sort_keys=True) # Serialize args consistently
112
+ cache_key = f"mcp_tools_{self.command}_{args_string}"
113
+ cached_tools_data = self.cache.get(cache_key)
114
+
115
+ if cached_tools_data:
116
+ logger.debug("Retrieved tools data from cache")
117
+ tools = []
118
+ for tool_data in cached_tools_data:
119
+ tool_name = tool_data["name"]
120
+ # Create Tool instance, ensuring func is a callable wrapper
121
+ tool = Tool(
122
+ name=tool_name,
123
+ description=tool_data["description"],
124
+ input_schema=tool_data.get("input_schema", {"type": "object", "properties": {}}),
125
+ func=self._create_tool_callable(tool_name), # Use the factory method
126
+ )
127
+ self._tool_cache[tool_name] = tool # Store in instance cache too
128
+ tools.append(tool)
129
+ logger.debug(f"Returning {len(tools)} cached tools")
130
+ return tools
131
+
132
+ # If not in cache, discover from server
133
+ server_params = StdioServerParameters(command=self.command, args=self.args, env=self.env)
134
+ logger.debug("Opening stdio_client connection")
135
+ try:
136
+ async with stdio_client(server_params) as (read, write):
137
+ logger.debug("Opening ClientSession")
138
+ async with ClientSession(read, write) as session:
139
+ logger.info("Initializing session for tool discovery")
140
+ await asyncio.wait_for(session.initialize(), timeout=self.timeout)
141
+ logger.info("Requesting tool list from MCP server...")
142
+ tools_response = await asyncio.wait_for(session.list_tools(), timeout=self.timeout)
143
+ logger.debug(f"Tool list received: {tools_response}")
144
+
145
+ if not hasattr(tools_response, 'tools') or not isinstance(tools_response.tools, list):
146
+ logger.error(f"Invalid tool list response from MCP server: {tools_response}")
147
+ return []
148
+
149
+ serialized_tools = []
150
+ tools = []
151
+ for tool_proto in tools_response.tools:
152
+ if not hasattr(tool_proto, 'name') or not tool_proto.name:
153
+ logger.warning(f"Skipping tool with missing name in response: {tool_proto}")
154
+ continue
155
+
156
+ # Ensure inputSchema exists and is a dict, default if not
157
+ input_schema = getattr(tool_proto, 'inputSchema', None)
158
+ if not isinstance(input_schema, dict):
159
+ input_schema = {"type": "object", "properties": {}}
160
+
161
+ description = getattr(tool_proto, 'description', "") or "" # Ensure description is string
162
+
163
+ serialized_tool_data = {
164
+ 'name': tool_proto.name,
165
+ 'description': description,
166
+ 'input_schema': input_schema,
167
+ }
168
+ serialized_tools.append(serialized_tool_data)
169
+
170
+ # Create Tool instance for returning
171
+ discovered_tool = Tool(
172
+ name=tool_proto.name,
173
+ description=description,
174
+ input_schema=input_schema,
175
+ func=self._create_tool_callable(tool_proto.name),
176
+ )
177
+ self._tool_cache[tool_proto.name] = discovered_tool # Cache instance
178
+ tools.append(discovered_tool)
179
+ logger.debug(f"Discovered tool: {tool_proto.name} with schema: {input_schema}")
180
+
181
+ # Cache the serialized data
182
+ self.cache.set(cache_key, serialized_tools, 3600)
183
+ logger.debug(f"Cached {len(serialized_tools)} tools.")
184
+
185
+ logger.debug(f"Returning {len(tools)} tools from MCP server")
186
+ return tools
187
+
188
+ except asyncio.TimeoutError:
189
+ logger.error(f"Timeout after {self.timeout}s waiting for tool list")
190
+ raise RuntimeError("Tool list request timed out")
191
+ except Exception as e:
192
+ logger.error(f"Error listing tools: {e}", exc_info=True)
193
+ raise RuntimeError(f"Failed to list tools: {e}") from e
194
+
195
+ async def _do_list_resources(self) -> Any:
196
+ """Internal method to list resources with timeout."""
197
+ server_params = StdioServerParameters(command=self.command, args=self.args, env=self.env)
198
+ logger.debug("Opening stdio_client connection for resources")
199
+ try:
200
+ async with stdio_client(server_params) as (read, write):
201
+ logger.debug("Opening ClientSession for resources")
202
+ async with ClientSession(read, write) as session:
203
+ with self._redirect_stderr(): # Suppress stderr if not debugging
204
+ logger.debug("Initializing session before listing resources")
205
+ await asyncio.wait_for(session.initialize(), timeout=self.timeout)
206
+ logger.info("Requesting resource list from MCP server...")
207
+ resources_response = await asyncio.wait_for(session.list_resources(), timeout=self.timeout)
208
+ logger.debug("Resource list received from MCP server")
209
+ return resources_response
210
+ except asyncio.TimeoutError:
211
+ logger.error(f"Timeout listing resources after {self.timeout}s")
212
+ raise RuntimeError("Resource list request timed out")
213
+ except Exception as e:
214
+ logger.error(f"Error listing resources: {e}", exc_info=True)
215
+ raise RuntimeError(f"Failed to list resources: {e}") from e
216
+
217
+ def _create_tool_callable(self, tool_name: str) -> Callable[..., Any]:
218
+ """
219
+ Dynamically create an async callable function for the specified tool.
220
+ This callable will establish a connection and execute the tool on demand.
221
+ """
222
+ async def dynamic_tool_func(**kwargs) -> Any:
223
+ logger.debug(f"Creating tool callable for '{tool_name}'")
224
+ server_params = StdioServerParameters(command=self.command, args=self.args, env=self.env)
225
+ try:
226
+ async with stdio_client(server_params) as (read, write):
227
+ async with ClientSession(read, write) as session:
228
+ # Initialize session first
229
+ logger.debug(f"Initializing session for tool '{tool_name}'")
230
+ await asyncio.wait_for(session.initialize(), timeout=self.timeout)
231
+
232
+ # Validate input if schema is available in instance cache
233
+ if tool_name in self._tool_cache:
234
+ tool = self._tool_cache[tool_name]
235
+ self._validate_input_schema(tool.input_schema, kwargs)
236
+ else:
237
+ logger.warning(f"Schema for tool '{tool_name}' not found in cache for validation.")
238
+
239
+ logger.info(f"Calling tool '{tool_name}' with arguments: {kwargs}")
240
+ # Execute the tool call
241
+ result_proto = await asyncio.wait_for(session.call_tool(tool_name, kwargs), timeout=self.timeout)
242
+
243
+ # Process result (assuming result_proto has a 'result' attribute)
244
+ result_data = getattr(result_proto, 'result', None)
245
+ if result_data is None:
246
+ logger.warning(f"Tool '{tool_name}' executed but returned no result data.")
247
+ return None # Or raise error?
248
+
249
+ # Attempt to parse if it looks like JSON, otherwise return as is
250
+ if isinstance(result_data, str):
251
+ try:
252
+ parsed_result = json.loads(result_data)
253
+ logger.info(f"Tool '{tool_name}' executed successfully (result parsed as JSON).")
254
+ return parsed_result
255
+ except json.JSONDecodeError:
256
+ logger.info(f"Tool '{tool_name}' executed successfully (result returned as string).")
257
+ return result_data # Return raw string if not JSON
258
+ else:
259
+ logger.info(f"Tool '{tool_name}' executed successfully (result type: {type(result_data)}).")
260
+ return result_data # Return non-string result directly
261
+
262
+ except asyncio.TimeoutError:
263
+ logger.error(f"Timeout after {self.timeout}s executing tool '{tool_name}'")
264
+ raise RuntimeError(f"Tool '{tool_name}' execution timed out")
265
+ except Exception as e:
266
+ logger.error(f"Failed to execute tool '{tool_name}': {e}", exc_info=True)
267
+ raise RuntimeError(f"Tool execution failed: {e}") from e
268
+
269
+ return dynamic_tool_func
270
+
271
+ def _validate_input_schema(self, schema: Dict[str, Any], kwargs: Dict[str, Any]):
272
+ """
273
+ Validate the provided arguments against the input schema.
274
+ """
275
+ # Ensure schema is a dictionary, default to no-op if not
276
+ if not isinstance(schema, dict):
277
+ logger.warning(f"Invalid schema format for validation: {type(schema)}. Skipping.")
278
+ return
279
+
280
+ required_params = schema.get("required", [])
281
+ # Ensure required_params is a list
282
+ if not isinstance(required_params, list):
283
+ logger.warning(f"Invalid 'required' list in schema: {type(required_params)}. Skipping requirement check.")
284
+ required_params = []
285
+
286
+ for param in required_params:
287
+ if param not in kwargs:
288
+ raise ValueError(f"Missing required parameter: '{param}'")
289
+
290
+ # Optional: Add type validation based on schema['properties'][param]['type']
291
+ properties = schema.get("properties", {})
292
+ if isinstance(properties, dict):
293
+ for key, value in kwargs.items():
294
+ if key in properties:
295
+ expected_type = properties[key].get("type")
296
+ # Basic type mapping (add more as needed)
297
+ type_map = {"string": str, "integer": int, "number": (int, float), "boolean": bool, "array": list, "object": dict}
298
+ if expected_type in type_map:
299
+ if not isinstance(value, type_map[expected_type]):
300
+ logger.warning(f"Type mismatch for parameter '{key}'. Expected '{expected_type}', got '{type(value).__name__}'. Attempting to proceed.")
301
+ # Allow proceeding but log warning, or raise ValueError for strict validation
302
+
303
+ logger.debug(f"Validated input against schema: {schema} with arguments: {kwargs}")
304
+
305
+ async def list_resources(self) -> Any:
306
+ """
307
+ Discover resources from the MCP server using the internal method with enforced timeout.
308
+ """
309
+ return await self._do_list_resources() # Timeout handled in _do_list_resources
310
+
311
+ async def get_resource(self, resource_uri: str) -> Any:
312
+ """
313
+ Retrieve a specific resource from the MCP server.
314
+
315
+ Args:
316
+ resource_uri (str): The URI of the resource to retrieve.
317
+
318
+ Returns:
319
+ Any: The resource retrieval response.
320
+ """
321
+ server_params = StdioServerParameters(command=self.command, args=self.args, env=self.env)
322
+ logger.debug("Opening stdio_client connection for resource retrieval")
323
+ try:
324
+ async with stdio_client(server_params) as (read, write):
325
+ logger.debug("Opening ClientSession for resource retrieval")
326
+ async with ClientSession(read, write) as session:
327
+ with self._redirect_stderr(): # Suppress stderr if not debugging
328
+ logger.debug(f"Initializing session for resource retrieval of {resource_uri}")
329
+ await asyncio.wait_for(session.initialize(), timeout=self.timeout)
330
+ logger.info(f"Retrieving resource '{resource_uri}' from MCP server")
331
+ response = await asyncio.wait_for(session.read_resource(resource_uri), timeout=self.timeout)
332
+ logger.info(f"Resource '{resource_uri}' retrieved successfully")
333
+ # Process response if needed (e.g., getattr(response, 'content', None))
334
+ return response
335
+ except asyncio.TimeoutError:
336
+ logger.error(f"Timeout retrieving resource '{resource_uri}' after {self.timeout}s")
337
+ raise RuntimeError(f"Resource '{resource_uri}' retrieval timed out")
338
+ except Exception as e:
339
+ logger.error(f"Failed to retrieve resource '{resource_uri}': {e}", exc_info=True)
340
+ raise RuntimeError(f"Resource retrieval failed: {e}") from e
341
+
@@ -0,0 +1,7 @@
1
+ """
2
+ Constants specific to MCP interactions.
3
+ """
4
+
5
+ # Separator used in tool results to signal agent handoff
6
+ MCP_SEPARATOR = ":::"
7
+
@@ -0,0 +1,110 @@
1
+ """
2
+ MCPToolProvider Module for Open-Swarm
3
+
4
+ This module is responsible for discovering tools from MCP (Model Context Protocol) servers
5
+ and integrating them into the Open-Swarm framework as `Tool` instances.
6
+ """
7
+
8
+ import logging
9
+ import json
10
+ import re # Standard library for regular expressions
11
+ from typing import List, Dict, Any
12
+
13
+ from ...settings import Settings # Use Swarm settings
14
+ from ...types import Tool, Agent
15
+ from .mcp_client import MCPClient
16
+ from .cache_utils import get_cache
17
+
18
+ # Use Swarm's settings for logging configuration
19
+ swarm_settings = Settings()
20
+ logger = logging.getLogger(__name__)
21
+ logger.setLevel(swarm_settings.log_level.upper())
22
+ # Ensure handler is added only if needed
23
+ if not logger.handlers and not logging.getLogger('swarm').handlers:
24
+ handler = logging.StreamHandler()
25
+ formatter = logging.Formatter(swarm_settings.log_format.value)
26
+ handler.setFormatter(formatter)
27
+ logger.addHandler(handler)
28
+
29
+
30
+ class MCPToolProvider:
31
+ """
32
+ MCPToolProvider discovers tools from an MCP server and converts them into `Tool` instances.
33
+ Uses caching to avoid repeated discovery.
34
+ """
35
+ _instances: Dict[str, "MCPToolProvider"] = {}
36
+
37
+ @classmethod
38
+ def get_instance(cls, server_name: str, server_config: Dict[str, Any], timeout: int = 15, debug: bool = False) -> "MCPToolProvider":
39
+ """Get or create an instance for the given server name."""
40
+ config_key = json.dumps(server_config, sort_keys=True)
41
+ instance_key = f"{server_name}_{config_key}_{timeout}_{debug}"
42
+
43
+ if instance_key not in cls._instances:
44
+ logger.debug(f"Creating new MCPToolProvider instance for key: {instance_key}")
45
+ cls._instances[instance_key] = cls(server_name, server_config, timeout, debug)
46
+ else:
47
+ logger.debug(f"Reusing existing MCPToolProvider instance for key: {instance_key}")
48
+ return cls._instances[instance_key]
49
+
50
+ def __init__(self, server_name: str, server_config: Dict[str, Any], timeout: int = 15, debug: bool = False):
51
+ """
52
+ Initialize an MCPToolProvider instance. Use get_instance() for shared instances.
53
+ """
54
+ self.server_name = server_name
55
+ effective_debug = debug or swarm_settings.debug
56
+ try:
57
+ self.client = MCPClient(server_config=server_config, timeout=timeout, debug=effective_debug)
58
+ except ImportError as e:
59
+ logger.error(f"Failed to initialize MCPClient for '{server_name}': {e}. MCP features will be unavailable.")
60
+ self.client = None
61
+ except Exception as e:
62
+ logger.error(f"Error initializing MCPClient for '{server_name}': {e}", exc_info=True)
63
+ self.client = None
64
+
65
+ self.cache = get_cache()
66
+ logger.debug(f"Initialized MCPToolProvider for server '{self.server_name}' with timeout {timeout}s.")
67
+
68
+ async def discover_tools(self, agent: Agent) -> List[Tool]:
69
+ """
70
+ Discover tools from the MCP server using the MCPClient.
71
+
72
+ Args:
73
+ agent (Agent): The agent for which tools are being discovered.
74
+
75
+ Returns:
76
+ List[Tool]: A list of discovered `Tool` instances with prefixed names.
77
+ """
78
+ if not self.client:
79
+ logger.warning(f"MCPClient for '{self.server_name}' not initialized. Cannot discover tools.")
80
+ return []
81
+
82
+ logger.debug(f"Starting tool discovery via MCPClient for server '{self.server_name}'.")
83
+ try:
84
+ tools = await self.client.list_tools()
85
+ logger.debug(f"Discovered {len(tools)} tools from MCP server '{self.server_name}'.")
86
+
87
+ separator = "__"
88
+ prefixed_tools = []
89
+ for tool in tools:
90
+ prefixed_name = f"{self.server_name}{separator}{tool.name}"
91
+ # Validate prefixed name against OpenAI pattern
92
+ if not re.match(r"^[a-zA-Z0-9_-]{1,64}$", prefixed_name):
93
+ logger.warning(f"Generated MCP tool name '{prefixed_name}' might violate OpenAI pattern. Skipping.")
94
+ continue
95
+
96
+ prefixed_tool = Tool(
97
+ name=prefixed_name,
98
+ description=tool.description,
99
+ input_schema=tool.input_schema,
100
+ func=tool.func # Callable already targets this client/tool
101
+ )
102
+ prefixed_tools.append(prefixed_tool)
103
+ logger.debug(f"Added prefixed tool: {prefixed_tool.name}")
104
+
105
+ return prefixed_tools
106
+
107
+ except Exception as e:
108
+ logger.error(f"Failed to discover tools from MCP server '{self.server_name}': {e}", exc_info=True)
109
+ return []
110
+
@@ -0,0 +1,195 @@
1
+ """
2
+ Chat Completion Module
3
+
4
+ This module handles chat completion logic for the Swarm framework, including message preparation,
5
+ tool call repair, and interaction with the OpenAI API. Located in llm/ for LLM-specific functionality.
6
+ """
7
+
8
+ import os
9
+ import json
10
+ import logging
11
+ from typing import List, Optional, Dict, Any, Union, AsyncGenerator # Added AsyncGenerator
12
+ from collections import defaultdict
13
+
14
+ import asyncio
15
+ from openai import AsyncOpenAI, OpenAIError
16
+ # Make sure ChatCompletionMessage is correctly imported if it's defined elsewhere
17
+ # Assuming it might be part of the base model or a common types module
18
+ # For now, let's assume it's implicitly handled or use a dict directly
19
+ # from ..types import ChatCompletionMessage, Agent # If defined in types
20
+ from ..types import Agent # Import Agent
21
+ from ..utils.redact import redact_sensitive_data
22
+ from ..utils.general_utils import serialize_datetime
23
+ from ..utils.message_utils import filter_duplicate_system_messages, update_null_content
24
+ from ..utils.context_utils import get_token_count, truncate_message_history
25
+ from ..utils.message_sequence import repair_message_payload
26
+
27
+ # Configure module-level logging
28
+ logger = logging.getLogger(__name__)
29
+ logger.setLevel(logging.DEBUG)
30
+ if not logger.handlers:
31
+ stream_handler = logging.StreamHandler()
32
+ formatter = logging.Formatter("[%(levelname)s] %(asctime)s - %(name)s - %(message)s")
33
+ stream_handler.setFormatter(formatter)
34
+ logger.addHandler(stream_handler)
35
+
36
+
37
+ async def get_chat_completion(
38
+ client: AsyncOpenAI,
39
+ agent: Agent,
40
+ history: List[Dict[str, Any]],
41
+ context_variables: dict,
42
+ current_llm_config: Dict[str, Any],
43
+ max_context_tokens: int,
44
+ max_context_messages: int,
45
+ tools: Optional[List[Dict[str, Any]]] = None, # <-- Added tools parameter
46
+ tool_choice: Optional[str] = "auto", # <-- Added tool_choice parameter
47
+ model_override: Optional[str] = None,
48
+ stream: bool = False,
49
+ debug: bool = False
50
+ ) -> Union[Dict[str, Any], AsyncGenerator[Any, None]]: # Adjusted return type hint
51
+ """
52
+ Retrieve a chat completion from the LLM for the given agent and history.
53
+
54
+ Args:
55
+ client: AsyncOpenAI client instance.
56
+ agent: The agent processing the completion.
57
+ history: List of previous messages in the conversation.
58
+ context_variables: Variables to include in the agent's context.
59
+ current_llm_config: Current LLM configuration dictionary.
60
+ max_context_tokens: Maximum token limit for context.
61
+ max_context_messages: Maximum message limit for context.
62
+ tools: Optional list of tools in OpenAI format.
63
+ tool_choice: Tool choice mode (e.g., "auto", "none").
64
+ model_override: Optional model to use instead of default.
65
+ stream: If True, stream the response; otherwise, return complete.
66
+ debug: If True, log detailed debugging information.
67
+
68
+ Returns:
69
+ Union[Dict[str, Any], AsyncGenerator[Any, None]]: The LLM's response message (as dict) or stream.
70
+ """
71
+ if not agent:
72
+ logger.error("Cannot generate chat completion: Agent is None")
73
+ raise ValueError("Agent is required")
74
+
75
+ logger.debug(f"Generating chat completion for agent '{agent.name}'")
76
+ active_model = model_override or current_llm_config.get("model", "default")
77
+ client_kwargs = {
78
+ "api_key": current_llm_config.get("api_key"),
79
+ "base_url": current_llm_config.get("base_url")
80
+ }
81
+ client_kwargs = {k: v for k, v in client_kwargs.items() if v is not None}
82
+ redacted_kwargs = redact_sensitive_data(client_kwargs, sensitive_keys=["api_key"])
83
+ logger.debug(f"Using client with model='{active_model}', base_url='{client_kwargs.get('base_url', 'default')}', api_key={redacted_kwargs['api_key']}")
84
+
85
+ context_variables = defaultdict(str, context_variables)
86
+ instructions = agent.instructions(context_variables) if callable(agent.instructions) else agent.instructions
87
+ if not isinstance(instructions, str):
88
+ logger.warning(f"Invalid instructions type for '{agent.name}': {type(instructions)}. Converting to string.")
89
+ instructions = str(instructions)
90
+ messages = repair_message_payload([{"role": "system", "content": instructions}], debug=debug)
91
+
92
+ if not isinstance(history, list):
93
+ logger.error(f"Invalid history type for '{agent.name}': {type(history)}. Expected list.")
94
+ history = []
95
+ seen_ids = set()
96
+ for msg in history:
97
+ msg_id = msg.get("id", hash(json.dumps(msg, sort_keys=True, default=serialize_datetime)))
98
+ if msg_id not in seen_ids:
99
+ seen_ids.add(msg_id)
100
+ if "tool_calls" in msg and msg["tool_calls"] is not None and not isinstance(msg["tool_calls"], list):
101
+ logger.warning(f"Invalid tool_calls in history for '{msg.get('sender', 'unknown')}': {msg['tool_calls']}. Setting to None.")
102
+ msg["tool_calls"] = None
103
+ # Ensure content: None becomes content: "" for API compatibility
104
+ if "content" in msg and msg["content"] is None:
105
+ msg["content"] = ""
106
+ messages.append(msg)
107
+ messages = filter_duplicate_system_messages(messages)
108
+ messages = truncate_message_history(messages, active_model, max_context_tokens, max_context_messages)
109
+ messages = repair_message_payload(messages, debug=debug) # Ensure tool calls are paired post-truncation
110
+ # Final content None -> "" check after repair
111
+ messages = update_null_content(messages)
112
+
113
+ logger.debug(f"Prepared {len(messages)} messages for '{agent.name}'")
114
+ if debug:
115
+ logger.debug(f"Messages: {json.dumps(messages, indent=2, default=str)}")
116
+
117
+ create_params = {
118
+ "model": active_model,
119
+ "messages": messages,
120
+ "stream": stream,
121
+ "temperature": current_llm_config.get("temperature", 0.7),
122
+ # --- Pass tools and tool_choice ---
123
+ "tools": tools if tools else None,
124
+ "tool_choice": tool_choice if tools else None, # Only set tool_choice if tools are provided
125
+ }
126
+ if getattr(agent, "response_format", None):
127
+ create_params["response_format"] = agent.response_format
128
+ create_params = {k: v for k, v in create_params.items() if v is not None} # Clean None values
129
+
130
+ tool_info_log = f", tools_count={len(tools)}" if tools else ", tools=None"
131
+ logger.debug(f"Chat completion params: model='{active_model}', messages_count={len(messages)}, stream={stream}{tool_info_log}, tool_choice={create_params.get('tool_choice')}")
132
+
133
+ try:
134
+ logger.debug(f"Calling OpenAI API for '{agent.name}' with model='{active_model}'")
135
+ # Temporary workaround for potential env var conflicts if client doesn't isolate well
136
+ prev_openai_api_key = os.environ.pop("OPENAI_API_KEY", None)
137
+ try:
138
+ completion = await client.chat.completions.create(**create_params)
139
+ if stream:
140
+ return completion # Return stream object directly
141
+
142
+ # --- Handle Non-Streaming Response ---
143
+ if completion.choices and len(completion.choices) > 0 and completion.choices[0].message:
144
+ message_dict = completion.choices[0].message.model_dump(exclude_none=True)
145
+ log_msg = message_dict.get("content", "No content")[:50] if message_dict.get("content") else "No content"
146
+ if message_dict.get("tool_calls"): log_msg += f" (+{len(message_dict['tool_calls'])} tool calls)"
147
+ logger.debug(f"OpenAI completion received for '{agent.name}': {log_msg}...")
148
+ return message_dict # Return the message dictionary
149
+ else:
150
+ logger.warning(f"No valid message in completion for '{agent.name}'")
151
+ return {"role": "assistant", "content": "No response generated"} # Return dict
152
+ finally:
153
+ if prev_openai_api_key is not None:
154
+ os.environ["OPENAI_API_KEY"] = prev_openai_api_key
155
+ except OpenAIError as e:
156
+ logger.error(f"Chat completion failed for '{agent.name}': {e}")
157
+ raise
158
+ except Exception as e: # Catch broader errors during API call
159
+ logger.error(f"Unexpected error during chat completion for '{agent.name}': {e}", exc_info=True)
160
+ raise # Re-raise
161
+
162
+
163
+ async def get_chat_completion_message(
164
+ client: AsyncOpenAI,
165
+ agent: Agent,
166
+ history: List[Dict[str, Any]],
167
+ context_variables: dict,
168
+ current_llm_config: Dict[str, Any],
169
+ max_context_tokens: int,
170
+ max_context_messages: int,
171
+ tools: Optional[List[Dict[str, Any]]] = None, # <-- Added tools
172
+ tool_choice: Optional[str] = "auto", # <-- Added tool_choice
173
+ model_override: Optional[str] = None,
174
+ stream: bool = False,
175
+ debug: bool = False
176
+ ) -> Union[Dict[str, Any], AsyncGenerator[Any, None]]: # Return dict or stream
177
+ """
178
+ Wrapper to retrieve and validate a chat completion message (returns dict or stream).
179
+
180
+ Args:
181
+ Same as get_chat_completion.
182
+
183
+ Returns:
184
+ Union[Dict[str, Any], AsyncGenerator[Any, None]]: Validated LLM response message as dict or the stream.
185
+ """
186
+ logger.debug(f"Fetching chat completion message for '{agent.name}'")
187
+ completion_result = await get_chat_completion(
188
+ client, agent, history, context_variables, current_llm_config,
189
+ max_context_tokens, max_context_messages,
190
+ tools=tools, tool_choice=tool_choice, # Pass through
191
+ model_override=model_override, stream=stream, debug=debug
192
+ )
193
+ # If streaming, completion_result is already the generator
194
+ # If not streaming, it's the message dictionary
195
+ return completion_result