fast-agent-mcp 0.1.9__py3-none-any.whl → 0.1.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,152 @@
1
+ """
2
+ Interface definitions to prevent circular imports.
3
+ This module defines protocols (interfaces) that can be used to break circular dependencies.
4
+ """
5
+
6
+ from contextlib import asynccontextmanager
7
+ from typing import Any, AsyncGenerator, Callable, Generic, List, Optional, Protocol, Type, TypeVar
8
+
9
+ from mcp import ClientSession
10
+ from mcp.types import CreateMessageRequestParams
11
+ from pydantic import Field
12
+
13
+
14
+ class ServerRegistryProtocol(Protocol):
15
+ """
16
+ Protocol defining the minimal interface of ServerRegistry needed by gen_client.
17
+ This allows gen_client to depend on this protocol rather than the full ServerRegistry class.
18
+ """
19
+
20
+ @asynccontextmanager
21
+ async def initialize_server(
22
+ self,
23
+ server_name: str,
24
+ client_session_factory=None,
25
+ init_hook=None,
26
+ ) -> AsyncGenerator[ClientSession, None]:
27
+ """Initialize a server and yield a client session."""
28
+ ...
29
+
30
+ @property
31
+ def connection_manager(self) -> "ConnectionManagerProtocol":
32
+ """Get the connection manager."""
33
+ ...
34
+
35
+
36
+ class ConnectionManagerProtocol(Protocol):
37
+ """
38
+ Protocol defining the minimal interface of ConnectionManager needed.
39
+ """
40
+
41
+ async def get_server(
42
+ self,
43
+ server_name: str,
44
+ client_session_factory=None,
45
+ ):
46
+ """Get a server connection."""
47
+ ...
48
+
49
+ async def disconnect_server(self, server_name: str) -> None:
50
+ """Disconnect from a server."""
51
+ ...
52
+
53
+ async def disconnect_all_servers(self) -> None:
54
+ """Disconnect from all servers."""
55
+ ...
56
+
57
+
58
+ # Type variables for generic protocols
59
+ MessageParamT = TypeVar("MessageParamT")
60
+ """A type representing an input message to an LLM."""
61
+
62
+ MessageT = TypeVar("MessageT")
63
+ """A type representing an output message from an LLM."""
64
+
65
+ ModelT = TypeVar("ModelT")
66
+ """A type representing a structured output message from an LLM."""
67
+
68
+
69
+ class RequestParams(CreateMessageRequestParams):
70
+ """
71
+ Parameters to configure the AugmentedLLM 'generate' requests.
72
+ """
73
+
74
+ messages: None = Field(exclude=True, default=None)
75
+ """
76
+ Ignored. 'messages' are removed from CreateMessageRequestParams
77
+ to avoid confusion with the 'message' parameter on 'generate' method.
78
+ """
79
+
80
+ maxTokens: int = 2048
81
+ """The maximum number of tokens to sample, as requested by the server."""
82
+
83
+ model: str | None = None
84
+ """
85
+ The model to use for the LLM generation.
86
+ If specified, this overrides the 'modelPreferences' selection criteria.
87
+ """
88
+
89
+ use_history: bool = True
90
+ """
91
+ Include the message history in the generate request.
92
+ """
93
+
94
+ max_iterations: int = 10
95
+ """
96
+ The maximum number of iterations to run the LLM for.
97
+ """
98
+
99
+ parallel_tool_calls: bool = True
100
+ """
101
+ Whether to allow multiple tool calls per iteration.
102
+ Also known as multi-step tool use.
103
+ """
104
+
105
+
106
+ class AugmentedLLMProtocol(Protocol, Generic[MessageParamT, MessageT]):
107
+ """Protocol defining the interface for augmented LLMs"""
108
+
109
+ async def generate(
110
+ self,
111
+ message: str | MessageParamT | List[MessageParamT],
112
+ request_params: RequestParams | None = None,
113
+ ) -> List[MessageT]:
114
+ """Request an LLM generation, which may run multiple iterations, and return the result"""
115
+
116
+ async def generate_str(
117
+ self,
118
+ message: str | MessageParamT | List[MessageParamT],
119
+ request_params: RequestParams | None = None,
120
+ ) -> str:
121
+ """Request an LLM generation and return the string representation of the result"""
122
+
123
+ async def generate_structured(
124
+ self,
125
+ message: str | MessageParamT | List[MessageParamT],
126
+ response_model: Type[ModelT],
127
+ request_params: RequestParams | None = None,
128
+ ) -> ModelT:
129
+ """Request a structured LLM generation and return the result as a Pydantic model."""
130
+
131
+
132
+ class ModelFactoryClassProtocol(Protocol):
133
+ """
134
+ Protocol defining the minimal interface of the ModelFactory class needed by sampling.
135
+ This allows sampling.py to depend on this protocol rather than the concrete ModelFactory class.
136
+ """
137
+
138
+ @classmethod
139
+ def create_factory(
140
+ cls, model_string: str, request_params: Optional[RequestParams] = None
141
+ ) -> Callable[..., AugmentedLLMProtocol[Any, Any]]:
142
+ """
143
+ Creates a factory function that can be used to construct an LLM instance.
144
+
145
+ Args:
146
+ model_string: The model specification string
147
+ request_params: Optional parameters to configure LLM behavior
148
+
149
+ Returns:
150
+ A factory function that can create an LLM instance
151
+ """
152
+ ...
@@ -7,7 +7,6 @@ from typing import Optional
7
7
 
8
8
  from mcp import ClientSession
9
9
  from mcp.shared.session import (
10
- RequestResponder,
11
10
  ReceiveResultT,
12
11
  ReceiveNotificationT,
13
12
  RequestId,
@@ -16,26 +15,43 @@ from mcp.shared.session import (
16
15
  SendResultT,
17
16
  )
18
17
  from mcp.types import (
19
- ClientResult,
20
- CreateMessageRequest,
21
- CreateMessageResult,
22
18
  ErrorData,
23
- JSONRPCNotification,
24
- JSONRPCRequest,
25
- ServerRequest,
26
- TextContent,
27
- ListRootsRequest,
28
19
  ListRootsResult,
29
20
  Root,
30
21
  )
22
+ from pydantic import AnyUrl
31
23
 
32
24
  from mcp_agent.config import MCPServerSettings
33
25
  from mcp_agent.context_dependent import ContextDependent
34
26
  from mcp_agent.logging.logger import get_logger
27
+ from mcp_agent.mcp.sampling import sample
35
28
 
36
29
  logger = get_logger(__name__)
37
30
 
38
31
 
32
+ async def list_roots(ctx: ClientSession) -> ListRootsResult:
33
+ """List roots callback that will be called by the MCP library."""
34
+
35
+ roots = []
36
+ if (
37
+ hasattr(ctx, "session")
38
+ and hasattr(ctx.session, "server_config")
39
+ and ctx.session.server_config
40
+ and hasattr(ctx.session.server_config, "roots")
41
+ and ctx.session.server_config.roots
42
+ ):
43
+ roots = [
44
+ Root(
45
+ uri=AnyUrl(
46
+ root.server_uri_alias or root.uri,
47
+ ),
48
+ name=root.name,
49
+ )
50
+ for root in ctx.session.server_config.roots
51
+ ]
52
+ return ListRootsResult(roots=roots or [])
53
+
54
+
39
55
  class MCPAgentClientSession(ClientSession, ContextDependent):
40
56
  """
41
57
  MCP Agent framework acts as a client to the servers providing tools/resources/prompts for the agent workloads.
@@ -48,36 +64,11 @@ class MCPAgentClientSession(ClientSession, ContextDependent):
48
64
  """
49
65
 
50
66
  def __init__(self, *args, **kwargs):
51
- super().__init__(*args, **kwargs)
67
+ super().__init__(
68
+ *args, **kwargs, list_roots_callback=list_roots, sampling_callback=sample
69
+ )
52
70
  self.server_config: Optional[MCPServerSettings] = None
53
71
 
54
- async def _received_request(
55
- self, responder: RequestResponder[ServerRequest, ClientResult]
56
- ) -> None:
57
- logger.debug("Received request:", data=responder.request.model_dump())
58
- request = responder.request.root
59
-
60
- if isinstance(request, CreateMessageRequest):
61
- return await self.handle_sampling_request(request, responder)
62
- elif isinstance(request, ListRootsRequest):
63
- # Handle list_roots request by returning configured roots
64
- if hasattr(self, "server_config") and self.server_config.roots:
65
- roots = [
66
- Root(
67
- uri=root.server_uri_alias or root.uri,
68
- name=root.name,
69
- )
70
- for root in self.server_config.roots
71
- ]
72
-
73
- await responder.respond(ListRootsResult(roots=roots))
74
- else:
75
- await responder.respond(ListRootsResult(roots=[]))
76
- return
77
-
78
- # Handle other requests as usual
79
- await super()._received_request(responder)
80
-
81
72
  async def send_request(
82
73
  self,
83
74
  request: SendRequestT,
@@ -89,7 +80,7 @@ class MCPAgentClientSession(ClientSession, ContextDependent):
89
80
  logger.debug("send_request: response=", data=result.model_dump())
90
81
  return result
91
82
  except Exception as e:
92
- logger.error(f"send_request failed: {e}")
83
+ logger.error(f"send_request failed: {str(e)}")
93
84
  raise
94
85
 
95
86
  async def send_notification(self, notification: SendNotificationT) -> None:
@@ -132,111 +123,4 @@ class MCPAgentClientSession(ClientSession, ContextDependent):
132
123
  )
133
124
  return await super().send_progress_notification(
134
125
  progress_token=progress_token, progress=progress, total=total
135
- )
136
-
137
- async def _receive_loop(self) -> None:
138
- async with (
139
- self._read_stream,
140
- self._write_stream,
141
- self._incoming_message_stream_writer,
142
- ):
143
- async for message in self._read_stream:
144
- if isinstance(message, Exception):
145
- await self._incoming_message_stream_writer.send(message)
146
- elif isinstance(message.root, JSONRPCRequest):
147
- validated_request = self._receive_request_type.model_validate(
148
- message.root.model_dump(
149
- by_alias=True, mode="json", exclude_none=True
150
- )
151
- )
152
- responder = RequestResponder(
153
- request_id=message.root.id,
154
- request_meta=validated_request.root.params.meta
155
- if validated_request.root.params
156
- else None,
157
- request=validated_request,
158
- session=self,
159
- )
160
-
161
- await self._received_request(responder)
162
- if not responder._responded:
163
- await self._incoming_message_stream_writer.send(responder)
164
- elif isinstance(message.root, JSONRPCNotification):
165
- notification = self._receive_notification_type.model_validate(
166
- message.root.model_dump(
167
- by_alias=True, mode="json", exclude_none=True
168
- )
169
- )
170
-
171
- await self._received_notification(notification)
172
- await self._incoming_message_stream_writer.send(notification)
173
- else: # Response or error
174
- stream = self._response_streams.pop(message.root.id, None)
175
- if stream:
176
- await stream.send(message.root)
177
- else:
178
- await self._incoming_message_stream_writer.send(
179
- RuntimeError(
180
- "Received response with an unknown "
181
- f"request ID: {message}"
182
- )
183
- )
184
-
185
- async def handle_sampling_request(
186
- self,
187
- request: CreateMessageRequest,
188
- responder: RequestResponder[ServerRequest, ClientResult],
189
- ):
190
- logger.info("Handling sampling request: %s", request)
191
- config = self.context.config
192
- session = self.context.upstream_session
193
- if session is None:
194
- # TODO: saqadri - consider whether we should be handling the sampling request here as a client
195
- logger.warning(
196
- "Error: No upstream client available for sampling requests. Request:",
197
- data=request,
198
- )
199
- try:
200
- from anthropic import AsyncAnthropic
201
-
202
- client = AsyncAnthropic(api_key=config.anthropic.api_key)
203
-
204
- params = request.params
205
- response = await client.messages.create(
206
- model="claude-3-sonnet-20240229",
207
- max_tokens=params.maxTokens,
208
- messages=[
209
- {
210
- "role": m.role,
211
- "content": m.content.text
212
- if hasattr(m.content, "text")
213
- else m.content.data,
214
- }
215
- for m in params.messages
216
- ],
217
- system=getattr(params, "systemPrompt", None),
218
- temperature=getattr(params, "temperature", 0.7),
219
- stop_sequences=getattr(params, "stopSequences", None),
220
- )
221
-
222
- await responder.respond(
223
- CreateMessageResult(
224
- model="claude-3-sonnet-20240229",
225
- role="assistant",
226
- content=TextContent(type="text", text=response.content[0].text),
227
- )
228
- )
229
- except Exception as e:
230
- logger.error(f"Error handling sampling request: {e}")
231
- await responder.respond(ErrorData(code=-32603, message=str(e)))
232
- else:
233
- try:
234
- # If a session is available, we'll pass-through the sampling request to the upstream client
235
- result = await session.send_request(
236
- request=ServerRequest(request), result_type=CreateMessageResult
237
- )
238
-
239
- # Pass the result from the upstream client back to the server. We just act as a pass-through client here.
240
- await responder.send_result(result)
241
- except Exception as e:
242
- await responder.send_error(code=-32603, message=str(e))
126
+ )
@@ -8,14 +8,15 @@ from typing import (
8
8
  Callable,
9
9
  TypeVar,
10
10
  )
11
- from mcp import GetPromptResult
12
- from pydantic import BaseModel, ConfigDict
11
+ from mcp import GetPromptResult, ReadResourceResult
12
+ from pydantic import AnyUrl, BaseModel, ConfigDict
13
13
  from mcp.client.session import ClientSession
14
14
  from mcp.server.lowlevel.server import Server
15
15
  from mcp.server.stdio import stdio_server
16
16
  from mcp.types import (
17
17
  CallToolResult,
18
18
  ListToolsResult,
19
+ TextContent,
19
20
  Tool,
20
21
  Prompt,
21
22
  )
@@ -210,6 +211,7 @@ class MCPAggregator(ContextDependent):
210
211
  "agent_name": self.agent_name,
211
212
  },
212
213
  )
214
+
213
215
  await self._persistent_connection_manager.get_server(
214
216
  server_name, client_session_factory=MCPAgentClientSession
215
217
  )
@@ -458,7 +460,10 @@ class MCPAggregator(ContextDependent):
458
460
 
459
461
  if server_name is None or local_tool_name is None:
460
462
  logger.error(f"Error: Tool '{name}' not found")
461
- return CallToolResult(isError=True, message=f"Tool '{name}' not found")
463
+ return CallToolResult(
464
+ isError=True,
465
+ content=[TextContent(type="text", text=f"Tool '{name}' not found")]
466
+ )
462
467
 
463
468
  logger.info(
464
469
  "Requesting tool call",
@@ -476,7 +481,10 @@ class MCPAggregator(ContextDependent):
476
481
  operation_name=local_tool_name,
477
482
  method_name="call_tool",
478
483
  method_args={"name": local_tool_name, "arguments": arguments},
479
- error_factory=lambda msg: CallToolResult(isError=True, message=msg),
484
+ error_factory=lambda msg: CallToolResult(
485
+ isError=True,
486
+ content=[TextContent(type="text", text=msg)]
487
+ ),
480
488
  )
481
489
 
482
490
  async def get_prompt(
@@ -821,6 +829,53 @@ class MCPAggregator(ContextDependent):
821
829
  logger.debug(f"Available prompts across servers: {results}")
822
830
  return results
823
831
 
832
+ async def get_resource(
833
+ self, server_name: str, resource_uri: str
834
+ ) -> ReadResourceResult:
835
+ """
836
+ Get a resource directly from an MCP server by URI.
837
+
838
+ Args:
839
+ server_name: Name of the MCP server to retrieve the resource from
840
+ resource_uri: URI of the resource to retrieve
841
+
842
+ Returns:
843
+ ReadResourceResult object containing the resource content
844
+
845
+ Raises:
846
+ ValueError: If the server doesn't exist or the resource couldn't be found
847
+ """
848
+ if not self.initialized:
849
+ await self.load_servers()
850
+
851
+ if server_name not in self.server_names:
852
+ raise ValueError(f"Server '{server_name}' not found")
853
+
854
+ logger.info(
855
+ "Requesting resource",
856
+ data={
857
+ "progress_action": ProgressAction.CALLING_TOOL,
858
+ "resource_uri": resource_uri,
859
+ "server_name": server_name,
860
+ "agent_name": self.agent_name,
861
+ },
862
+ )
863
+
864
+ try:
865
+ uri = AnyUrl(resource_uri)
866
+ except Exception as e:
867
+ raise ValueError(f"Invalid resource URI: {resource_uri}. Error: {e}")
868
+
869
+ # Use the _execute_on_server method to call read_resource on the server
870
+ return await self._execute_on_server(
871
+ server_name=server_name,
872
+ operation_type="resource",
873
+ operation_name=resource_uri,
874
+ method_name="read_resource",
875
+ method_args={"uri": uri},
876
+ error_factory=lambda msg: ValueError(f"Failed to retrieve resource: {msg}"),
877
+ )
878
+
824
879
 
825
880
  class MCPCompoundServer(Server):
826
881
  """
@@ -850,7 +905,10 @@ class MCPCompoundServer(Server):
850
905
  result = await self.aggregator.call_tool(name=name, arguments=arguments)
851
906
  return result.content
852
907
  except Exception as e:
853
- return CallToolResult(isError=True, message=f"Error calling tool: {e}")
908
+ return CallToolResult(
909
+ isError=True,
910
+ content=[TextContent(type="text", text=f"Error calling tool: {e}")]
911
+ )
854
912
 
855
913
  async def _get_prompt(
856
914
  self, name: str = None, arguments: dict[str, str] = None
@@ -163,7 +163,6 @@ async def _server_lifecycle_task(server_conn: ServerConnection) -> None:
163
163
  async with transport_context as (read_stream, write_stream):
164
164
  # try:
165
165
  server_conn.create_session(read_stream, write_stream)
166
- # except FileNotFoundError as e:
167
166
 
168
167
  async with server_conn.session:
169
168
  await server_conn.initialize_session()
@@ -12,6 +12,8 @@ import logging
12
12
  import sys
13
13
  from pathlib import Path
14
14
  from typing import List, Dict, Optional, Callable, Awaitable, Literal, Any
15
+ from mcp.server.fastmcp.resources import FileResource
16
+ from pydantic import AnyUrl
15
17
 
16
18
  from mcp_agent.mcp import mime_utils, resource_utils
17
19
 
@@ -185,19 +187,19 @@ def create_prompt_handler(
185
187
 
186
188
 
187
189
  # Type for resource handler
188
- ResourceHandler = Callable[[], Awaitable[str]]
190
+ ResourceHandler = Callable[[], Awaitable[str | bytes]]
189
191
 
190
192
 
191
193
  def create_resource_handler(resource_path: Path, mime_type: str) -> ResourceHandler:
192
194
  """Create a resource handler function for the given resource"""
193
195
 
194
- async def get_resource() -> str:
196
+ async def get_resource() -> str | bytes:
195
197
  is_binary = mime_utils.is_binary_content(mime_type)
196
198
 
197
199
  if is_binary:
198
200
  # For binary files, read in binary mode and base64 encode
199
201
  with open(resource_path, "rb") as f:
200
- return base64.b64encode(f.read()).decode("utf-8")
202
+ return f.read()
201
203
  else:
202
204
  # For text files, read as utf-8 text
203
205
  with open(resource_path, "r", encoding="utf-8") as f:
@@ -284,15 +286,14 @@ def register_prompt(file_path: Path):
284
286
  exposed_resources[resource_id] = resource_file
285
287
  mime_type = mime_utils.guess_mime_type(str(resource_file))
286
288
 
287
- # Register with the correct resource ID directly with MCP
288
- resource_handler = create_resource_handler(
289
- resource_file, mime_type
289
+ mcp.add_resource(
290
+ FileResource(
291
+ uri=AnyUrl(resource_id),
292
+ path=resource_file,
293
+ mime_type=mime_type,
294
+ is_binary=mime_utils.is_binary_content(mime_type),
295
+ )
290
296
  )
291
- mcp.resource(
292
- resource_id,
293
- description=f"Resource from {file_path.name}",
294
- mime_type=mime_type,
295
- )(resource_handler)
296
297
 
297
298
  logger.info(
298
299
  f"Registered resource: {resource_id} ({resource_file})"
@@ -25,44 +25,6 @@ def find_resource_file(resource_path: str, prompt_files: List[Path]) -> Optional
25
25
  return None
26
26
 
27
27
 
28
- # TODO -- decide how to deal with this. Both Anthropic and OpenAI allow sending URLs in
29
- # input message
30
- # TODO -- used?
31
- # async def fetch_remote_resource(
32
- # url: str, timeout: int = HTTP_TIMEOUT
33
- # ) -> ResourceContent:
34
- # """
35
- # Fetch a remote resource from a URL
36
-
37
- # Returns:
38
- # Tuple of (content, mime_type, is_binary)
39
- # - content: Text content or base64-encoded binary content
40
- # - mime_type: The MIME type of the resource
41
- # - is_binary: Whether the content is binary (and base64-encoded)
42
- # """
43
-
44
- # async with httpx.AsyncClient(timeout=timeout) as client:
45
- # response = await client.get(url)
46
- # response.raise_for_status()
47
-
48
- # # Get the content type or guess from URL
49
- # mime_type = response.headers.get("content-type", "").split(";")[0]
50
- # if not mime_type:
51
- # mime_type = mime_utils.guess_mime_type(url)
52
-
53
- # # Check if this is binary content
54
- # is_binary = mime_utils.is_binary_content(mime_type)
55
-
56
- # if is_binary:
57
- # # For binary responses, get the binary content and base64 encode it
58
- # content = base64.b64encode(response.content).decode("utf-8")
59
- # else:
60
- # # For text responses, just get the text
61
- # content = response.text
62
-
63
- # return content, mime_type, is_binary
64
-
65
-
66
28
  def load_resource_content(
67
29
  resource_path: str, prompt_files: List[Path]
68
30
  ) -> ResourceContent:
@@ -109,6 +71,36 @@ def create_resource_uri(path: str) -> str:
109
71
  return f"resource://fast-agent/{Path(path).name}"
110
72
 
111
73
 
74
+ # Add this to your resource_utils.py module
75
+
76
+
77
+ def create_resource_reference(uri: str, mime_type: str) -> "EmbeddedResource":
78
+ """
79
+ Create a reference to a resource without embedding its content directly.
80
+
81
+ This creates an EmbeddedResource that references another resource URI.
82
+ When the client receives this, it will make a separate request to fetch
83
+ the resource content using the provided URI.
84
+
85
+ Args:
86
+ uri: URI for the resource
87
+ mime_type: MIME type of the resource
88
+
89
+ Returns:
90
+ An EmbeddedResource object
91
+ """
92
+ from mcp.types import EmbeddedResource, TextResourceContents
93
+
94
+ # Create a resource reference
95
+ resource_contents = TextResourceContents(
96
+ uri=uri,
97
+ mimeType=mime_type,
98
+ text="", # Empty text as we're just referencing
99
+ )
100
+
101
+ return EmbeddedResource(type="resource", resource=resource_contents)
102
+
103
+
112
104
  def create_embedded_resource(
113
105
  resource_path: str, content: str, mime_type: str, is_binary: bool = False
114
106
  ) -> EmbeddedResource:
@@ -149,6 +141,34 @@ def create_image_content(data: str, mime_type: str) -> ImageContent:
149
141
  )
150
142
 
151
143
 
144
+ def create_blob_resource(
145
+ resource_path: str, content: str, mime_type: str
146
+ ) -> EmbeddedResource:
147
+ """Create an embedded resource for binary data"""
148
+ return EmbeddedResource(
149
+ type="resource",
150
+ resource=BlobResourceContents(
151
+ uri=resource_path,
152
+ mimeType=mime_type,
153
+ blob=content, # Content should already be base64 encoded
154
+ ),
155
+ )
156
+
157
+
158
+ def create_text_resource(
159
+ resource_path: str, content: str, mime_type: str
160
+ ) -> EmbeddedResource:
161
+ """Create an embedded resource for text data"""
162
+ return EmbeddedResource(
163
+ type="resource",
164
+ resource=TextResourceContents(
165
+ uri=resource_path,
166
+ mimeType=mime_type,
167
+ text=content,
168
+ ),
169
+ )
170
+
171
+
152
172
  def normalize_uri(uri_or_filename: str) -> str:
153
173
  """
154
174
  Normalize a URI or filename to ensure it's a valid URI.