swarms 7.6.0__py3-none-any.whl → 7.6.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- swarms/agents/__init__.py +9 -2
- swarms/agents/agent_judge.py +119 -0
- swarms/agents/flexion_agent.py +625 -0
- swarms/agents/gkp_agent.py +581 -0
- swarms/agents/reasoning_agents.py +32 -0
- swarms/prompts/agent_judge_prompt.py +38 -0
- swarms/structs/__init__.py +7 -2
- swarms/structs/agent.py +37 -8
- swarms/structs/agent_builder.py +6 -8
- swarms/structs/concurrent_workflow.py +1 -1
- swarms/structs/deep_research_swarm.py +482 -0
- swarms/structs/dynamic_conversational_swarm.py +226 -0
- swarms/structs/hiearchical_swarm.py +1 -1
- swarms/structs/hybrid_hiearchical_peer_swarm.py +273 -0
- swarms/structs/majority_voting.py +1 -1
- swarms/structs/mixture_of_agents.py +1 -1
- swarms/structs/multi_agent_orchestrator.py +1 -1
- swarms/structs/output_types.py +3 -0
- swarms/structs/rearrange.py +1 -1
- swarms/structs/sequential_workflow.py +1 -1
- swarms/structs/swarm_router.py +12 -1
- swarms/structs/swarms_api.py +1 -1
- swarms/telemetry/main.py +7 -3
- swarms/tools/mcp_integration.py +554 -0
- swarms/tools/tool_schema_base_model.py +57 -0
- {swarms-7.6.0.dist-info → swarms-7.6.2.dist-info}/METADATA +79 -1
- {swarms-7.6.0.dist-info → swarms-7.6.2.dist-info}/RECORD +30 -24
- swarms/structs/agent_security.py +0 -318
- swarms/structs/airflow_swarm.py +0 -430
- swarms/structs/output_type.py +0 -18
- {swarms-7.6.0.dist-info → swarms-7.6.2.dist-info}/LICENSE +0 -0
- {swarms-7.6.0.dist-info → swarms-7.6.2.dist-info}/WHEEL +0 -0
- {swarms-7.6.0.dist-info → swarms-7.6.2.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,554 @@
|
|
1
|
+
from contextlib import AsyncExitStack
|
2
|
+
from types import TracebackType
|
3
|
+
from typing import (
|
4
|
+
Any,
|
5
|
+
Callable,
|
6
|
+
Coroutine,
|
7
|
+
List,
|
8
|
+
Literal,
|
9
|
+
Optional,
|
10
|
+
TypedDict,
|
11
|
+
cast,
|
12
|
+
)
|
13
|
+
|
14
|
+
from mcp import ClientSession, StdioServerParameters
|
15
|
+
from mcp.client.sse import sse_client
|
16
|
+
from mcp.client.stdio import stdio_client
|
17
|
+
from mcp.types import (
|
18
|
+
CallToolResult,
|
19
|
+
EmbeddedResource,
|
20
|
+
ImageContent,
|
21
|
+
PromptMessage,
|
22
|
+
TextContent,
|
23
|
+
)
|
24
|
+
from mcp.types import (
|
25
|
+
Tool as MCPTool,
|
26
|
+
)
|
27
|
+
|
28
|
+
|
29
|
+
def convert_mcp_prompt_message_to_message(
|
30
|
+
message: PromptMessage,
|
31
|
+
) -> str:
|
32
|
+
"""Convert an MCP prompt message to a string message.
|
33
|
+
|
34
|
+
Args:
|
35
|
+
message: MCP prompt message to convert
|
36
|
+
|
37
|
+
Returns:
|
38
|
+
a string message
|
39
|
+
"""
|
40
|
+
if message.content.type == "text":
|
41
|
+
if message.role == "user":
|
42
|
+
return str(message.content.text)
|
43
|
+
elif message.role == "assistant":
|
44
|
+
return str(
|
45
|
+
message.content.text
|
46
|
+
) # Fixed attribute name from str to text
|
47
|
+
else:
|
48
|
+
raise ValueError(
|
49
|
+
f"Unsupported prompt message role: {message.role}"
|
50
|
+
)
|
51
|
+
|
52
|
+
raise ValueError(
|
53
|
+
f"Unsupported prompt message content type: {message.content.type}"
|
54
|
+
)
|
55
|
+
|
56
|
+
|
57
|
+
async def load_mcp_prompt(
|
58
|
+
session: ClientSession,
|
59
|
+
name: str,
|
60
|
+
arguments: Optional[dict[str, Any]] = None,
|
61
|
+
) -> List[str]:
|
62
|
+
"""Load MCP prompt and convert to messages."""
|
63
|
+
response = await session.get_prompt(name, arguments)
|
64
|
+
|
65
|
+
return [
|
66
|
+
convert_mcp_prompt_message_to_message(message)
|
67
|
+
for message in response.messages
|
68
|
+
]
|
69
|
+
|
70
|
+
|
71
|
+
DEFAULT_ENCODING = "utf-8"
|
72
|
+
DEFAULT_ENCODING_ERROR_HANDLER = "strict"
|
73
|
+
|
74
|
+
DEFAULT_HTTP_TIMEOUT = 5
|
75
|
+
DEFAULT_SSE_READ_TIMEOUT = 60 * 5
|
76
|
+
|
77
|
+
|
78
|
+
class StdioConnection(TypedDict):
|
79
|
+
transport: Literal["stdio"]
|
80
|
+
|
81
|
+
command: str
|
82
|
+
"""The executable to run to start the server."""
|
83
|
+
|
84
|
+
args: list[str]
|
85
|
+
"""Command line arguments to pass to the executable."""
|
86
|
+
|
87
|
+
env: dict[str, str] | None
|
88
|
+
"""The environment to use when spawning the process."""
|
89
|
+
|
90
|
+
encoding: str
|
91
|
+
"""The text encoding used when sending/receiving messages to the server."""
|
92
|
+
|
93
|
+
encoding_error_handler: Literal["strict", "ignore", "replace"]
|
94
|
+
"""
|
95
|
+
The text encoding error handler.
|
96
|
+
|
97
|
+
See https://docs.python.org/3/library/codecs.html#codec-base-classes for
|
98
|
+
explanations of possible values
|
99
|
+
"""
|
100
|
+
|
101
|
+
|
102
|
+
class SSEConnection(TypedDict):
|
103
|
+
transport: Literal["sse"]
|
104
|
+
|
105
|
+
url: str
|
106
|
+
"""The URL of the SSE endpoint to connect to."""
|
107
|
+
|
108
|
+
headers: dict[str, Any] | None
|
109
|
+
"""HTTP headers to send to the SSE endpoint"""
|
110
|
+
|
111
|
+
timeout: float
|
112
|
+
"""HTTP timeout"""
|
113
|
+
|
114
|
+
sse_read_timeout: float
|
115
|
+
"""SSE read timeout"""
|
116
|
+
|
117
|
+
|
118
|
+
NonTextContent = ImageContent | EmbeddedResource
|
119
|
+
|
120
|
+
|
121
|
+
def _convert_call_tool_result(
|
122
|
+
call_tool_result: CallToolResult,
|
123
|
+
) -> tuple[str | list[str], list[NonTextContent] | None]:
|
124
|
+
text_contents: list[TextContent] = []
|
125
|
+
non_text_contents = []
|
126
|
+
for content in call_tool_result.content:
|
127
|
+
if isinstance(content, TextContent):
|
128
|
+
text_contents.append(content)
|
129
|
+
else:
|
130
|
+
non_text_contents.append(content)
|
131
|
+
|
132
|
+
tool_content: str | list[str] = [
|
133
|
+
content.text for content in text_contents
|
134
|
+
]
|
135
|
+
if len(text_contents) == 1:
|
136
|
+
tool_content = tool_content[0]
|
137
|
+
|
138
|
+
if call_tool_result.isError:
|
139
|
+
raise ValueError("Error calling tool")
|
140
|
+
|
141
|
+
return tool_content, non_text_contents or None
|
142
|
+
|
143
|
+
|
144
|
+
def convert_mcp_tool_to_function(
|
145
|
+
session: ClientSession,
|
146
|
+
tool: MCPTool,
|
147
|
+
) -> Callable[
|
148
|
+
...,
|
149
|
+
Coroutine[
|
150
|
+
Any, Any, tuple[str | list[str], list[NonTextContent] | None]
|
151
|
+
],
|
152
|
+
]:
|
153
|
+
"""Convert an MCP tool to a callable function.
|
154
|
+
|
155
|
+
NOTE: this tool can be executed only in a context of an active MCP client session.
|
156
|
+
|
157
|
+
Args:
|
158
|
+
session: MCP client session
|
159
|
+
tool: MCP tool to convert
|
160
|
+
|
161
|
+
Returns:
|
162
|
+
a callable function
|
163
|
+
"""
|
164
|
+
|
165
|
+
async def call_tool(
|
166
|
+
**arguments: dict[str, Any],
|
167
|
+
) -> tuple[str | list[str], list[NonTextContent] | None]:
|
168
|
+
"""Execute the tool with the given arguments."""
|
169
|
+
call_tool_result = await session.call_tool(
|
170
|
+
tool.name, arguments
|
171
|
+
)
|
172
|
+
return _convert_call_tool_result(call_tool_result)
|
173
|
+
|
174
|
+
# Add metadata as attributes to the function
|
175
|
+
call_tool.__name__ = tool.name
|
176
|
+
call_tool.__doc__ = tool.description or ""
|
177
|
+
call_tool.schema = tool.inputSchema
|
178
|
+
|
179
|
+
return call_tool
|
180
|
+
|
181
|
+
|
182
|
+
async def load_mcp_tools(session: ClientSession) -> list[Callable]:
|
183
|
+
"""Load all available MCP tools and convert them to callable functions."""
|
184
|
+
tools = await session.list_tools()
|
185
|
+
return [
|
186
|
+
convert_mcp_tool_to_function(session, tool)
|
187
|
+
for tool in tools.tools
|
188
|
+
]
|
189
|
+
|
190
|
+
|
191
|
+
class MultiServerMCPClient:
|
192
|
+
"""Client for connecting to multiple MCP servers and loading tools from them."""
|
193
|
+
|
194
|
+
def __init__(
|
195
|
+
self,
|
196
|
+
connections: dict[
|
197
|
+
str, StdioConnection | SSEConnection
|
198
|
+
] = None,
|
199
|
+
) -> None:
|
200
|
+
"""Initialize a MultiServerMCPClient with MCP servers connections.
|
201
|
+
|
202
|
+
Args:
|
203
|
+
connections: A dictionary mapping server names to connection configurations.
|
204
|
+
Each configuration can be either a StdioConnection or SSEConnection.
|
205
|
+
If None, no initial connections are established.
|
206
|
+
|
207
|
+
Example:
|
208
|
+
|
209
|
+
```python
|
210
|
+
async with MultiServerMCPClient(
|
211
|
+
{
|
212
|
+
"math": {
|
213
|
+
"command": "python",
|
214
|
+
# Make sure to update to the full absolute path to your math_server.py file
|
215
|
+
"args": ["/path/to/math_server.py"],
|
216
|
+
"transport": "stdio",
|
217
|
+
},
|
218
|
+
"weather": {
|
219
|
+
# make sure you start your weather server on port 8000
|
220
|
+
"url": "http://localhost:8000/sse",
|
221
|
+
"transport": "sse",
|
222
|
+
}
|
223
|
+
}
|
224
|
+
) as client:
|
225
|
+
all_tools = client.get_tools()
|
226
|
+
...
|
227
|
+
```
|
228
|
+
"""
|
229
|
+
self.connections = connections
|
230
|
+
self.exit_stack = AsyncExitStack()
|
231
|
+
self.sessions: dict[str, ClientSession] = {}
|
232
|
+
self.server_name_to_tools: dict[str, list[Callable]] = {}
|
233
|
+
|
234
|
+
async def _initialize_session_and_load_tools(
|
235
|
+
self, server_name: str, session: ClientSession
|
236
|
+
) -> None:
|
237
|
+
"""Initialize a session and load tools from it.
|
238
|
+
|
239
|
+
Args:
|
240
|
+
server_name: Name to identify this server connection
|
241
|
+
session: The ClientSession to initialize
|
242
|
+
"""
|
243
|
+
# Initialize the session
|
244
|
+
await session.initialize()
|
245
|
+
self.sessions[server_name] = session
|
246
|
+
|
247
|
+
# Load tools from this server
|
248
|
+
server_tools = await load_mcp_tools(session)
|
249
|
+
self.server_name_to_tools[server_name] = server_tools
|
250
|
+
|
251
|
+
async def connect_to_server(
|
252
|
+
self,
|
253
|
+
server_name: str,
|
254
|
+
*,
|
255
|
+
transport: Literal["stdio", "sse"] = "stdio",
|
256
|
+
**kwargs,
|
257
|
+
) -> None:
|
258
|
+
"""Connect to an MCP server using either stdio or SSE.
|
259
|
+
|
260
|
+
This is a generic method that calls either connect_to_server_via_stdio or connect_to_server_via_sse
|
261
|
+
based on the provided transport parameter.
|
262
|
+
|
263
|
+
Args:
|
264
|
+
server_name: Name to identify this server connection
|
265
|
+
transport: Type of transport to use ("stdio" or "sse"), defaults to "stdio"
|
266
|
+
**kwargs: Additional arguments to pass to the specific connection method
|
267
|
+
|
268
|
+
Raises:
|
269
|
+
ValueError: If transport is not recognized
|
270
|
+
ValueError: If required parameters for the specified transport are missing
|
271
|
+
"""
|
272
|
+
if transport == "sse":
|
273
|
+
if "url" not in kwargs:
|
274
|
+
raise ValueError(
|
275
|
+
"'url' parameter is required for SSE connection"
|
276
|
+
)
|
277
|
+
await self.connect_to_server_via_sse(
|
278
|
+
server_name,
|
279
|
+
url=kwargs["url"],
|
280
|
+
headers=kwargs.get("headers"),
|
281
|
+
timeout=kwargs.get("timeout", DEFAULT_HTTP_TIMEOUT),
|
282
|
+
sse_read_timeout=kwargs.get(
|
283
|
+
"sse_read_timeout", DEFAULT_SSE_READ_TIMEOUT
|
284
|
+
),
|
285
|
+
)
|
286
|
+
elif transport == "stdio":
|
287
|
+
if "command" not in kwargs:
|
288
|
+
raise ValueError(
|
289
|
+
"'command' parameter is required for stdio connection"
|
290
|
+
)
|
291
|
+
if "args" not in kwargs:
|
292
|
+
raise ValueError(
|
293
|
+
"'args' parameter is required for stdio connection"
|
294
|
+
)
|
295
|
+
await self.connect_to_server_via_stdio(
|
296
|
+
server_name,
|
297
|
+
command=kwargs["command"],
|
298
|
+
args=kwargs["args"],
|
299
|
+
env=kwargs.get("env"),
|
300
|
+
encoding=kwargs.get("encoding", DEFAULT_ENCODING),
|
301
|
+
encoding_error_handler=kwargs.get(
|
302
|
+
"encoding_error_handler",
|
303
|
+
DEFAULT_ENCODING_ERROR_HANDLER,
|
304
|
+
),
|
305
|
+
)
|
306
|
+
else:
|
307
|
+
raise ValueError(
|
308
|
+
f"Unsupported transport: {transport}. Must be 'stdio' or 'sse'"
|
309
|
+
)
|
310
|
+
|
311
|
+
async def connect_to_server_via_stdio(
|
312
|
+
self,
|
313
|
+
server_name: str,
|
314
|
+
*,
|
315
|
+
command: str,
|
316
|
+
args: list[str],
|
317
|
+
env: dict[str, str] | None = None,
|
318
|
+
encoding: str = DEFAULT_ENCODING,
|
319
|
+
encoding_error_handler: Literal[
|
320
|
+
"strict", "ignore", "replace"
|
321
|
+
] = DEFAULT_ENCODING_ERROR_HANDLER,
|
322
|
+
) -> None:
|
323
|
+
"""Connect to a specific MCP server using stdio
|
324
|
+
|
325
|
+
Args:
|
326
|
+
server_name: Name to identify this server connection
|
327
|
+
command: Command to execute
|
328
|
+
args: Arguments for the command
|
329
|
+
env: Environment variables for the command
|
330
|
+
encoding: Character encoding
|
331
|
+
encoding_error_handler: How to handle encoding errors
|
332
|
+
"""
|
333
|
+
server_params = StdioServerParameters(
|
334
|
+
command=command,
|
335
|
+
args=args,
|
336
|
+
env=env,
|
337
|
+
encoding=encoding,
|
338
|
+
encoding_error_handler=encoding_error_handler,
|
339
|
+
)
|
340
|
+
|
341
|
+
# Create and store the connection
|
342
|
+
stdio_transport = await self.exit_stack.enter_async_context(
|
343
|
+
stdio_client(server_params)
|
344
|
+
)
|
345
|
+
read, write = stdio_transport
|
346
|
+
session = cast(
|
347
|
+
ClientSession,
|
348
|
+
await self.exit_stack.enter_async_context(
|
349
|
+
ClientSession(read, write)
|
350
|
+
),
|
351
|
+
)
|
352
|
+
|
353
|
+
await self._initialize_session_and_load_tools(
|
354
|
+
server_name, session
|
355
|
+
)
|
356
|
+
|
357
|
+
async def connect_to_server_via_sse(
|
358
|
+
self,
|
359
|
+
server_name: str,
|
360
|
+
*,
|
361
|
+
url: str,
|
362
|
+
headers: dict[str, Any] | None = None,
|
363
|
+
timeout: float = DEFAULT_HTTP_TIMEOUT,
|
364
|
+
sse_read_timeout: float = DEFAULT_SSE_READ_TIMEOUT,
|
365
|
+
) -> None:
|
366
|
+
"""Connect to a specific MCP server using SSE
|
367
|
+
|
368
|
+
Args:
|
369
|
+
server_name: Name to identify this server connection
|
370
|
+
url: URL of the SSE server
|
371
|
+
headers: HTTP headers to send to the SSE endpoint
|
372
|
+
timeout: HTTP timeout
|
373
|
+
sse_read_timeout: SSE read timeout
|
374
|
+
"""
|
375
|
+
# Create and store the connection
|
376
|
+
sse_transport = await self.exit_stack.enter_async_context(
|
377
|
+
sse_client(url, headers, timeout, sse_read_timeout)
|
378
|
+
)
|
379
|
+
read, write = sse_transport
|
380
|
+
session = cast(
|
381
|
+
ClientSession,
|
382
|
+
await self.exit_stack.enter_async_context(
|
383
|
+
ClientSession(read, write)
|
384
|
+
),
|
385
|
+
)
|
386
|
+
|
387
|
+
await self._initialize_session_and_load_tools(
|
388
|
+
server_name, session
|
389
|
+
)
|
390
|
+
|
391
|
+
def get_tools(self) -> list[Callable]:
|
392
|
+
"""Get a list of all tools from all connected servers."""
|
393
|
+
all_tools: list[Callable] = []
|
394
|
+
for server_tools in self.server_name_to_tools.values():
|
395
|
+
all_tools.extend(server_tools)
|
396
|
+
return all_tools
|
397
|
+
|
398
|
+
async def get_prompt(
|
399
|
+
self,
|
400
|
+
server_name: str,
|
401
|
+
prompt_name: str,
|
402
|
+
arguments: Optional[dict[str, Any]] = None,
|
403
|
+
) -> List[str]:
|
404
|
+
"""Get a prompt from a given MCP server."""
|
405
|
+
session = self.sessions[server_name]
|
406
|
+
return await load_mcp_prompt(session, prompt_name, arguments)
|
407
|
+
|
408
|
+
async def __aenter__(self) -> "MultiServerMCPClient":
|
409
|
+
try:
|
410
|
+
connections = self.connections or {}
|
411
|
+
for server_name, connection in connections.items():
|
412
|
+
connection_dict = connection.copy()
|
413
|
+
transport = connection_dict.pop("transport")
|
414
|
+
if transport == "stdio":
|
415
|
+
await self.connect_to_server_via_stdio(
|
416
|
+
server_name, **connection_dict
|
417
|
+
)
|
418
|
+
elif transport == "sse":
|
419
|
+
await self.connect_to_server_via_sse(
|
420
|
+
server_name, **connection_dict
|
421
|
+
)
|
422
|
+
else:
|
423
|
+
raise ValueError(
|
424
|
+
f"Unsupported transport: {transport}. Must be 'stdio' or 'sse'"
|
425
|
+
)
|
426
|
+
return self
|
427
|
+
except Exception:
|
428
|
+
await self.exit_stack.aclose()
|
429
|
+
raise
|
430
|
+
|
431
|
+
async def __aexit__(
|
432
|
+
self,
|
433
|
+
exc_type: type[BaseException] | None,
|
434
|
+
exc_val: BaseException | None,
|
435
|
+
exc_tb: TracebackType | None,
|
436
|
+
) -> None:
|
437
|
+
await self.exit_stack.aclose()
|
438
|
+
|
439
|
+
|
440
|
+
# #!/usr/bin/env python3
|
441
|
+
# import asyncio
|
442
|
+
# import os
|
443
|
+
# import json
|
444
|
+
# from typing import List, Any, Callable
|
445
|
+
|
446
|
+
# # # Import our MCP client module
|
447
|
+
# # from mcp_client import MultiServerMCPClient
|
448
|
+
|
449
|
+
|
450
|
+
# async def main():
|
451
|
+
# """Test script for demonstrating MCP client usage."""
|
452
|
+
# print("Starting MCP Client test...")
|
453
|
+
|
454
|
+
# # Create a connection to multiple MCP servers
|
455
|
+
# # You'll need to update these paths to match your setup
|
456
|
+
# async with MultiServerMCPClient(
|
457
|
+
# {
|
458
|
+
# "math": {
|
459
|
+
# "transport": "stdio",
|
460
|
+
# "command": "python",
|
461
|
+
# "args": ["/path/to/math_server.py"],
|
462
|
+
# "env": {"DEBUG": "1"},
|
463
|
+
# },
|
464
|
+
# "search": {
|
465
|
+
# "transport": "sse",
|
466
|
+
# "url": "http://localhost:8000/sse",
|
467
|
+
# "headers": {
|
468
|
+
# "Authorization": f"Bearer {os.environ.get('API_KEY', '')}"
|
469
|
+
# },
|
470
|
+
# },
|
471
|
+
# }
|
472
|
+
# ) as client:
|
473
|
+
# # Get all available tools
|
474
|
+
# tools = client.get_tools()
|
475
|
+
# print(f"Found {len(tools)} tools across all servers")
|
476
|
+
|
477
|
+
# # Print tool information
|
478
|
+
# for i, tool in enumerate(tools):
|
479
|
+
# print(f"\nTool {i+1}: {tool.__name__}")
|
480
|
+
# print(f" Description: {tool.__doc__}")
|
481
|
+
# if hasattr(tool, "schema") and tool.schema:
|
482
|
+
# print(
|
483
|
+
# f" Schema: {json.dumps(tool.schema, indent=2)[:100]}..."
|
484
|
+
# )
|
485
|
+
|
486
|
+
# # Example: Use a specific tool if available
|
487
|
+
# calculator_tool = next(
|
488
|
+
# (t for t in tools if t.__name__ == "calculator"), None
|
489
|
+
# )
|
490
|
+
# if calculator_tool:
|
491
|
+
# print("\n\nTesting calculator tool:")
|
492
|
+
# try:
|
493
|
+
# # Call the tool as an async function
|
494
|
+
# result, artifacts = await calculator_tool(
|
495
|
+
# expression="2 + 2 * 3"
|
496
|
+
# )
|
497
|
+
# print(f" Calculator result: {result}")
|
498
|
+
# if artifacts:
|
499
|
+
# print(
|
500
|
+
# f" With {len(artifacts)} additional artifacts"
|
501
|
+
# )
|
502
|
+
# except Exception as e:
|
503
|
+
# print(f" Error using calculator: {e}")
|
504
|
+
|
505
|
+
# # Example: Load a prompt from a server
|
506
|
+
# try:
|
507
|
+
# print("\n\nTesting prompt loading:")
|
508
|
+
# prompt_messages = await client.get_prompt(
|
509
|
+
# "math",
|
510
|
+
# "calculation_introduction",
|
511
|
+
# {"user_name": "Test User"},
|
512
|
+
# )
|
513
|
+
# print(
|
514
|
+
# f" Loaded prompt with {len(prompt_messages)} messages:"
|
515
|
+
# )
|
516
|
+
# for i, msg in enumerate(prompt_messages):
|
517
|
+
# print(f" Message {i+1}: {msg[:50]}...")
|
518
|
+
# except Exception as e:
|
519
|
+
# print(f" Error loading prompt: {e}")
|
520
|
+
|
521
|
+
|
522
|
+
# async def create_custom_tool():
|
523
|
+
# """Example of creating a custom tool function."""
|
524
|
+
|
525
|
+
# # Define a tool function with metadata
|
526
|
+
# async def add_numbers(a: float, b: float) -> tuple[str, None]:
|
527
|
+
# """Add two numbers together."""
|
528
|
+
# result = a + b
|
529
|
+
# return f"The sum of {a} and {b} is {result}", None
|
530
|
+
|
531
|
+
# # Add metadata to the function
|
532
|
+
# add_numbers.__name__ = "add_numbers"
|
533
|
+
# add_numbers.__doc__ = (
|
534
|
+
# "Add two numbers together and return the result."
|
535
|
+
# )
|
536
|
+
# add_numbers.schema = {
|
537
|
+
# "type": "object",
|
538
|
+
# "properties": {
|
539
|
+
# "a": {"type": "number", "description": "First number"},
|
540
|
+
# "b": {"type": "number", "description": "Second number"},
|
541
|
+
# },
|
542
|
+
# "required": ["a", "b"],
|
543
|
+
# }
|
544
|
+
|
545
|
+
# # Use the tool
|
546
|
+
# result, _ = await add_numbers(a=5, b=7)
|
547
|
+
# print(f"\nCustom tool result: {result}")
|
548
|
+
|
549
|
+
|
550
|
+
# if __name__ == "__main__":
|
551
|
+
# # Run both examples
|
552
|
+
# loop = asyncio.get_event_loop()
|
553
|
+
# loop.run_until_complete(main())
|
554
|
+
# loop.run_until_complete(create_custom_tool())
|
@@ -0,0 +1,57 @@
|
|
1
|
+
from typing import Any, Dict, List, Optional
|
2
|
+
from pydantic import BaseModel
|
3
|
+
|
4
|
+
|
5
|
+
class PropertySchema(BaseModel):
|
6
|
+
type: str
|
7
|
+
description: Optional[str] = None
|
8
|
+
enum: Optional[List[str]] = None
|
9
|
+
items: Optional[Dict[str, Any]] = None
|
10
|
+
properties: Optional[Dict[str, "PropertySchema"]] = None
|
11
|
+
required: Optional[List[str]] = None
|
12
|
+
|
13
|
+
|
14
|
+
class ParameterSchema(BaseModel):
|
15
|
+
type: str
|
16
|
+
properties: Dict[str, PropertySchema]
|
17
|
+
required: Optional[List[str]] = None
|
18
|
+
|
19
|
+
|
20
|
+
class FunctionDefinition(BaseModel):
|
21
|
+
name: str
|
22
|
+
description: str
|
23
|
+
parameters: ParameterSchema
|
24
|
+
|
25
|
+
|
26
|
+
class Tool(BaseModel):
|
27
|
+
type: str
|
28
|
+
function: FunctionDefinition
|
29
|
+
|
30
|
+
|
31
|
+
class ToolSet(BaseModel):
|
32
|
+
tools: List[Tool]
|
33
|
+
|
34
|
+
|
35
|
+
# model = ToolSet(
|
36
|
+
# tools=[
|
37
|
+
# Tool(
|
38
|
+
# type="function",
|
39
|
+
# function=FunctionDefinition(
|
40
|
+
# name="test",
|
41
|
+
# description="test",
|
42
|
+
# parameters=ParameterSchema(
|
43
|
+
# type="object",
|
44
|
+
# properties={
|
45
|
+
# "weather_tool": PropertySchema(
|
46
|
+
# type="string",
|
47
|
+
# description="Get the weather in a given location",
|
48
|
+
# )
|
49
|
+
# },
|
50
|
+
# required=["weather_tool"],
|
51
|
+
# ),
|
52
|
+
# ),
|
53
|
+
# ),
|
54
|
+
# ]
|
55
|
+
# )
|
56
|
+
|
57
|
+
# print(model.model_dump_json(indent=4))
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: swarms
|
3
|
-
Version: 7.6.
|
3
|
+
Version: 7.6.2
|
4
4
|
Summary: Swarms - TGSC
|
5
5
|
Home-page: https://github.com/kyegomez/swarms
|
6
6
|
License: MIT
|
@@ -365,6 +365,84 @@ agent.run(
|
|
365
365
|
```
|
366
366
|
|
367
367
|
|
368
|
+
## Structured Outputs
|
369
|
+
|
370
|
+
1. Create a tool schema
|
371
|
+
2. Create a function schema
|
372
|
+
3. Create a tool list dictionary
|
373
|
+
4. Initialize the agent
|
374
|
+
5. Run the agent
|
375
|
+
6. Print the output
|
376
|
+
7. Convert the output to a dictionary
|
377
|
+
|
378
|
+
```python
|
379
|
+
|
380
|
+
from dotenv import load_dotenv
|
381
|
+
|
382
|
+
from swarms import Agent
|
383
|
+
from swarms.prompts.finance_agent_sys_prompt import (
|
384
|
+
FINANCIAL_AGENT_SYS_PROMPT,
|
385
|
+
)
|
386
|
+
from swarms.utils.str_to_dict import str_to_dict
|
387
|
+
|
388
|
+
load_dotenv()
|
389
|
+
|
390
|
+
tools = [
|
391
|
+
{
|
392
|
+
"type": "function",
|
393
|
+
"function": {
|
394
|
+
"name": "get_stock_price",
|
395
|
+
"description": "Retrieve the current stock price and related information for a specified company.",
|
396
|
+
"parameters": {
|
397
|
+
"type": "object",
|
398
|
+
"properties": {
|
399
|
+
"ticker": {
|
400
|
+
"type": "string",
|
401
|
+
"description": "The stock ticker symbol of the company, e.g. AAPL for Apple Inc.",
|
402
|
+
},
|
403
|
+
"include_history": {
|
404
|
+
"type": "boolean",
|
405
|
+
"description": "Indicates whether to include historical price data along with the current price.",
|
406
|
+
},
|
407
|
+
"time": {
|
408
|
+
"type": "string",
|
409
|
+
"format": "date-time",
|
410
|
+
"description": "Optional parameter to specify the time for which the stock data is requested, in ISO 8601 format.",
|
411
|
+
},
|
412
|
+
},
|
413
|
+
"required": [
|
414
|
+
"ticker",
|
415
|
+
"include_history",
|
416
|
+
"time",
|
417
|
+
],
|
418
|
+
},
|
419
|
+
},
|
420
|
+
}
|
421
|
+
]
|
422
|
+
|
423
|
+
|
424
|
+
# Initialize the agent
|
425
|
+
agent = Agent(
|
426
|
+
agent_name="Financial-Analysis-Agent",
|
427
|
+
agent_description="Personal finance advisor agent",
|
428
|
+
system_prompt=FINANCIAL_AGENT_SYS_PROMPT,
|
429
|
+
max_loops=1,
|
430
|
+
tools_list_dictionary=tools,
|
431
|
+
)
|
432
|
+
|
433
|
+
out = agent.run(
|
434
|
+
"What is the current stock price for Apple Inc. (AAPL)? Include historical price data.",
|
435
|
+
)
|
436
|
+
|
437
|
+
print(out)
|
438
|
+
|
439
|
+
print(type(out))
|
440
|
+
|
441
|
+
print(str_to_dict(out))
|
442
|
+
|
443
|
+
print(type(str_to_dict(out)))
|
444
|
+
```
|
445
|
+
|
368
446
|
-------
|
369
447
|
|
370
448
|
### Misc Agent Settings
|