fastmcp 1.0__py3-none-any.whl → 2.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. fastmcp/__init__.py +15 -4
  2. fastmcp/cli/__init__.py +0 -1
  3. fastmcp/cli/claude.py +13 -11
  4. fastmcp/cli/cli.py +59 -39
  5. fastmcp/client/__init__.py +25 -0
  6. fastmcp/client/base.py +1 -0
  7. fastmcp/client/client.py +226 -0
  8. fastmcp/client/roots.py +75 -0
  9. fastmcp/client/sampling.py +50 -0
  10. fastmcp/client/transports.py +411 -0
  11. fastmcp/prompts/__init__.py +2 -2
  12. fastmcp/prompts/{base.py → prompt.py} +47 -26
  13. fastmcp/prompts/prompt_manager.py +69 -15
  14. fastmcp/resources/__init__.py +6 -6
  15. fastmcp/resources/{base.py → resource.py} +21 -2
  16. fastmcp/resources/resource_manager.py +116 -17
  17. fastmcp/resources/{templates.py → template.py} +36 -11
  18. fastmcp/resources/types.py +18 -13
  19. fastmcp/server/__init__.py +5 -0
  20. fastmcp/server/context.py +222 -0
  21. fastmcp/server/openapi.py +637 -0
  22. fastmcp/server/proxy.py +223 -0
  23. fastmcp/{server.py → server/server.py} +323 -267
  24. fastmcp/settings.py +81 -0
  25. fastmcp/tools/__init__.py +1 -1
  26. fastmcp/tools/{base.py → tool.py} +47 -18
  27. fastmcp/tools/tool_manager.py +57 -16
  28. fastmcp/utilities/func_metadata.py +33 -19
  29. fastmcp/utilities/openapi.py +797 -0
  30. fastmcp/utilities/types.py +15 -4
  31. fastmcp-2.1.0.dist-info/METADATA +770 -0
  32. fastmcp-2.1.0.dist-info/RECORD +39 -0
  33. fastmcp-2.1.0.dist-info/licenses/LICENSE +201 -0
  34. fastmcp/prompts/manager.py +0 -50
  35. fastmcp-1.0.dist-info/METADATA +0 -604
  36. fastmcp-1.0.dist-info/RECORD +0 -28
  37. fastmcp-1.0.dist-info/licenses/LICENSE +0 -21
  38. {fastmcp-1.0.dist-info → fastmcp-2.1.0.dist-info}/WHEEL +0 -0
  39. {fastmcp-1.0.dist-info → fastmcp-2.1.0.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,50 @@
1
+ import inspect
2
+ from collections.abc import Awaitable, Callable
3
+ from typing import TypeAlias
4
+
5
+ import mcp.types
6
+ from mcp import ClientSession, CreateMessageResult
7
+ from mcp.client.session import SamplingFnT
8
+ from mcp.shared.context import LifespanContextT, RequestContext
9
+ from mcp.types import CreateMessageRequestParams as SamplingParams
10
+ from mcp.types import SamplingMessage
11
+
12
+
13
+ class MessageResult(CreateMessageResult):
14
+ role: mcp.types.Role = "assistant"
15
+ content: mcp.types.TextContent | mcp.types.ImageContent
16
+ model: str = "client-model"
17
+
18
+
19
+ SamplingHandler: TypeAlias = Callable[
20
+ [
21
+ list[SamplingMessage],
22
+ SamplingParams,
23
+ RequestContext[ClientSession, LifespanContextT],
24
+ ],
25
+ str | CreateMessageResult | Awaitable[str | CreateMessageResult],
26
+ ]
27
+
28
+
29
+ def create_sampling_callback(sampling_handler: SamplingHandler) -> SamplingFnT:
30
+ async def _sampling_handler(
31
+ context: RequestContext[ClientSession, LifespanContextT],
32
+ params: SamplingParams,
33
+ ) -> CreateMessageResult | mcp.types.ErrorData:
34
+ try:
35
+ result = sampling_handler(params.messages, params, context)
36
+ if inspect.isawaitable(result):
37
+ result = await result
38
+
39
+ if isinstance(result, str):
40
+ result = MessageResult(
41
+ content=mcp.types.TextContent(type="text", text=result)
42
+ )
43
+ return result
44
+ except Exception as e:
45
+ return mcp.types.ErrorData(
46
+ code=mcp.types.INTERNAL_ERROR,
47
+ message=str(e),
48
+ )
49
+
50
+ return _sampling_handler
@@ -0,0 +1,411 @@
1
+ import abc
2
+ import contextlib
3
+ import datetime
4
+ import os
5
+ from collections.abc import AsyncIterator
6
+ from pathlib import Path
7
+ from typing import (
8
+ TypedDict,
9
+ )
10
+
11
+ from mcp import ClientSession, StdioServerParameters
12
+ from mcp.client.session import (
13
+ ListRootsFnT,
14
+ LoggingFnT,
15
+ MessageHandlerFnT,
16
+ SamplingFnT,
17
+ )
18
+ from mcp.client.sse import sse_client
19
+ from mcp.client.stdio import stdio_client
20
+ from mcp.client.websocket import websocket_client
21
+ from mcp.shared.memory import create_connected_server_and_client_session
22
+ from pydantic import AnyUrl
23
+ from typing_extensions import Unpack
24
+
25
+ from fastmcp.server import FastMCP as FastMCPServer
26
+
27
+
28
+ class SessionKwargs(TypedDict, total=False):
29
+ """Keyword arguments for the MCP ClientSession constructor."""
30
+
31
+ sampling_callback: SamplingFnT | None
32
+ list_roots_callback: ListRootsFnT | None
33
+ logging_callback: LoggingFnT | None
34
+ message_handler: MessageHandlerFnT | None
35
+ read_timeout_seconds: datetime.timedelta | None
36
+
37
+
38
+ class ClientTransport(abc.ABC):
39
+ """
40
+ Abstract base class for different MCP client transport mechanisms.
41
+
42
+ A Transport is responsible for establishing and managing connections
43
+ to an MCP server, and providing a ClientSession within an async context.
44
+ """
45
+
46
+ @abc.abstractmethod
47
+ @contextlib.asynccontextmanager
48
+ async def connect_session(
49
+ self, **session_kwargs: Unpack[SessionKwargs]
50
+ ) -> AsyncIterator[ClientSession]:
51
+ """
52
+ Establishes a connection and yields an active, initialized ClientSession.
53
+
54
+ The session is guaranteed to be valid only within the scope of the
55
+ async context manager. Connection setup and teardown are handled
56
+ within this context.
57
+
58
+ Args:
59
+ **session_kwargs: Keyword arguments to pass to the ClientSession
60
+ constructor (e.g., callbacks, timeouts).
61
+
62
+ Yields:
63
+ An initialized mcp.ClientSession instance.
64
+ """
65
+ raise NotImplementedError
66
+ yield None # type: ignore
67
+
68
+ def __repr__(self) -> str:
69
+ # Basic representation for subclasses
70
+ return f"<{self.__class__.__name__}>"
71
+
72
+
73
+ class WSTransport(ClientTransport):
74
+ """Transport implementation that connects to an MCP server via WebSockets."""
75
+
76
+ def __init__(self, url: str | AnyUrl):
77
+ if isinstance(url, AnyUrl):
78
+ url = str(url)
79
+ if not isinstance(url, str) or not url.startswith("ws"):
80
+ raise ValueError("Invalid WebSocket URL provided.")
81
+ self.url = url
82
+
83
+ @contextlib.asynccontextmanager
84
+ async def connect_session(
85
+ self, **session_kwargs: Unpack[SessionKwargs]
86
+ ) -> AsyncIterator[ClientSession]:
87
+ async with websocket_client(self.url) as transport:
88
+ read_stream, write_stream = transport
89
+ async with ClientSession(
90
+ read_stream, write_stream, **session_kwargs
91
+ ) as session:
92
+ await session.initialize() # Initialize after session creation
93
+ yield session
94
+
95
+ def __repr__(self) -> str:
96
+ return f"<WebSocket(url='{self.url}')>"
97
+
98
+
99
+ class SSETransport(ClientTransport):
100
+ """Transport implementation that connects to an MCP server via Server-Sent Events."""
101
+
102
+ def __init__(self, url: str | AnyUrl, headers: dict[str, str] | None = None):
103
+ if isinstance(url, AnyUrl):
104
+ url = str(url)
105
+ if not isinstance(url, str) or not url.startswith("http"):
106
+ raise ValueError("Invalid HTTP/S URL provided for SSE.")
107
+ self.url = url
108
+ self.headers = headers or {}
109
+
110
+ @contextlib.asynccontextmanager
111
+ async def connect_session(
112
+ self, **session_kwargs: Unpack[SessionKwargs]
113
+ ) -> AsyncIterator[ClientSession]:
114
+ async with sse_client(self.url, headers=self.headers) as transport:
115
+ read_stream, write_stream = transport
116
+ async with ClientSession(
117
+ read_stream, write_stream, **session_kwargs
118
+ ) as session:
119
+ await session.initialize()
120
+ yield session
121
+
122
+ def __repr__(self) -> str:
123
+ return f"<SSE(url='{self.url}')>"
124
+
125
+
126
+ class StdioTransport(ClientTransport):
127
+ """
128
+ Base transport for connecting to an MCP server via subprocess with stdio.
129
+
130
+ This is a base class that can be subclassed for specific command-based
131
+ transports like Python, Node, Uvx, etc.
132
+ """
133
+
134
+ def __init__(
135
+ self,
136
+ command: str,
137
+ args: list[str],
138
+ env: dict[str, str] | None = None,
139
+ cwd: str | None = None,
140
+ ):
141
+ """
142
+ Initialize a Stdio transport.
143
+
144
+ Args:
145
+ command: The command to run (e.g., "python", "node", "uvx")
146
+ args: The arguments to pass to the command
147
+ env: Environment variables to set for the subprocess
148
+ cwd: Current working directory for the subprocess
149
+ """
150
+ self.command = command
151
+ self.args = args
152
+ self.env = env
153
+ self.cwd = cwd
154
+
155
+ @contextlib.asynccontextmanager
156
+ async def connect_session(
157
+ self, **session_kwargs: Unpack[SessionKwargs]
158
+ ) -> AsyncIterator[ClientSession]:
159
+ server_params = StdioServerParameters(
160
+ command=self.command, args=self.args, env=self.env, cwd=self.cwd
161
+ )
162
+ async with stdio_client(server_params) as transport:
163
+ read_stream, write_stream = transport
164
+ async with ClientSession(
165
+ read_stream, write_stream, **session_kwargs
166
+ ) as session:
167
+ await session.initialize()
168
+ yield session
169
+
170
+ def __repr__(self) -> str:
171
+ return (
172
+ f"<{self.__class__.__name__}(command='{self.command}', args={self.args})>"
173
+ )
174
+
175
+
176
+ class PythonStdioTransport(StdioTransport):
177
+ """Transport for running Python scripts."""
178
+
179
+ def __init__(
180
+ self,
181
+ script_path: str | Path,
182
+ args: list[str] | None = None,
183
+ env: dict[str, str] | None = None,
184
+ cwd: str | None = None,
185
+ python_cmd: str = "python",
186
+ ):
187
+ """
188
+ Initialize a Python transport.
189
+
190
+ Args:
191
+ script_path: Path to the Python script to run
192
+ args: Additional arguments to pass to the script
193
+ env: Environment variables to set for the subprocess
194
+ cwd: Current working directory for the subprocess
195
+ python_cmd: Python command to use (default: "python")
196
+ """
197
+ script_path = Path(script_path).resolve()
198
+ if not script_path.is_file():
199
+ raise FileNotFoundError(f"Script not found: {script_path}")
200
+ if not str(script_path).endswith(".py"):
201
+ raise ValueError(f"Not a Python script: {script_path}")
202
+
203
+ full_args = [str(script_path)]
204
+ if args:
205
+ full_args.extend(args)
206
+
207
+ super().__init__(command=python_cmd, args=full_args, env=env, cwd=cwd)
208
+ self.script_path = script_path
209
+
210
+
211
+ class NodeStdioTransport(StdioTransport):
212
+ """Transport for running Node.js scripts."""
213
+
214
+ def __init__(
215
+ self,
216
+ script_path: str | Path,
217
+ args: list[str] | None = None,
218
+ env: dict[str, str] | None = None,
219
+ cwd: str | None = None,
220
+ node_cmd: str = "node",
221
+ ):
222
+ """
223
+ Initialize a Node transport.
224
+
225
+ Args:
226
+ script_path: Path to the Node.js script to run
227
+ args: Additional arguments to pass to the script
228
+ env: Environment variables to set for the subprocess
229
+ cwd: Current working directory for the subprocess
230
+ node_cmd: Node.js command to use (default: "node")
231
+ """
232
+ script_path = Path(script_path).resolve()
233
+ if not script_path.is_file():
234
+ raise FileNotFoundError(f"Script not found: {script_path}")
235
+ if not str(script_path).endswith(".js"):
236
+ raise ValueError(f"Not a JavaScript script: {script_path}")
237
+
238
+ full_args = [str(script_path)]
239
+ if args:
240
+ full_args.extend(args)
241
+
242
+ super().__init__(command=node_cmd, args=full_args, env=env, cwd=cwd)
243
+ self.script_path = script_path
244
+
245
+
246
+ class UvxStdioTransport(StdioTransport):
247
+ """Transport for running commands via the uvx tool."""
248
+
249
+ def __init__(
250
+ self,
251
+ tool_name: str,
252
+ tool_args: list[str] | None = None,
253
+ project_directory: str | None = None,
254
+ python_version: str | None = None,
255
+ with_packages: list[str] | None = None,
256
+ from_package: str | None = None,
257
+ env_vars: dict[str, str] | None = None,
258
+ ):
259
+ """
260
+ Initialize a Uvx transport.
261
+
262
+ Args:
263
+ tool_name: Name of the tool to run via uvx
264
+ tool_args: Arguments to pass to the tool
265
+ project_directory: Project directory (for package resolution)
266
+ python_version: Python version to use
267
+ with_packages: Additional packages to include
268
+ from_package: Package to install the tool from
269
+ env_vars: Additional environment variables
270
+ """
271
+ # Basic validation
272
+ if project_directory and not Path(project_directory).exists():
273
+ raise NotADirectoryError(
274
+ f"Project directory not found: {project_directory}"
275
+ )
276
+
277
+ # Build uvx arguments
278
+ uvx_args = []
279
+ if python_version:
280
+ uvx_args.extend(["--python", python_version])
281
+ if from_package:
282
+ uvx_args.extend(["--from", from_package])
283
+ for pkg in with_packages or []:
284
+ uvx_args.extend(["--with", pkg])
285
+
286
+ # Add the tool name and tool args
287
+ uvx_args.append(tool_name)
288
+ if tool_args:
289
+ uvx_args.extend(tool_args)
290
+
291
+ # Get environment with any additional variables
292
+ env = None
293
+ if env_vars:
294
+ env = os.environ.copy()
295
+ env.update(env_vars)
296
+
297
+ super().__init__(command="uvx", args=uvx_args, env=env, cwd=project_directory)
298
+ self.tool_name = tool_name
299
+
300
+
301
+ class NpxStdioTransport(StdioTransport):
302
+ """Transport for running commands via the npx tool."""
303
+
304
+ def __init__(
305
+ self,
306
+ package: str,
307
+ args: list[str] | None = None,
308
+ project_directory: str | None = None,
309
+ env_vars: dict[str, str] | None = None,
310
+ use_package_lock: bool = True,
311
+ ):
312
+ """
313
+ Initialize an Npx transport.
314
+
315
+ Args:
316
+ package: Name of the npm package to run
317
+ args: Arguments to pass to the package command
318
+ project_directory: Project directory with package.json
319
+ env_vars: Additional environment variables
320
+ use_package_lock: Whether to use package-lock.json (--prefer-offline)
321
+ """
322
+ # Basic validation
323
+ if project_directory and not Path(project_directory).exists():
324
+ raise NotADirectoryError(
325
+ f"Project directory not found: {project_directory}"
326
+ )
327
+
328
+ # Build npx arguments
329
+ npx_args = []
330
+ if use_package_lock:
331
+ npx_args.append("--prefer-offline")
332
+
333
+ # Add the package name and args
334
+ npx_args.append(package)
335
+ if args:
336
+ npx_args.extend(args)
337
+
338
+ # Get environment with any additional variables
339
+ env = None
340
+ if env_vars:
341
+ env = os.environ.copy()
342
+ env.update(env_vars)
343
+
344
+ super().__init__(command="npx", args=npx_args, env=env, cwd=project_directory)
345
+ self.package = package
346
+
347
+
348
+ class FastMCPTransport(ClientTransport):
349
+ """
350
+ Special transport for in-memory connections to an MCP server.
351
+
352
+ This is particularly useful for testing or when client and server
353
+ are in the same process.
354
+ """
355
+
356
+ def __init__(self, mcp: FastMCPServer):
357
+ self._fastmcp = mcp # Can be FastMCP or MCPServer
358
+
359
+ @contextlib.asynccontextmanager
360
+ async def connect_session(
361
+ self, **session_kwargs: Unpack[SessionKwargs]
362
+ ) -> AsyncIterator[ClientSession]:
363
+ # create_connected_server_and_client_session manages the session lifecycle itself
364
+ async with create_connected_server_and_client_session(
365
+ server=self._fastmcp._mcp_server,
366
+ **session_kwargs,
367
+ ) as session:
368
+ yield session
369
+
370
+ def __repr__(self) -> str:
371
+ return f"<FastMCP(server='{self._fastmcp.name}')>"
372
+
373
+
374
+ def infer_transport(
375
+ transport: ClientTransport | FastMCPServer | AnyUrl | Path | str,
376
+ ) -> ClientTransport:
377
+ """
378
+ Infer the appropriate transport type from the given transport argument.
379
+
380
+ This function attempts to infer the correct transport type from the provided
381
+ argument, handling various input types and converting them to the appropriate
382
+ ClientTransport subclass.
383
+ """
384
+ # the transport is already a ClientTransport
385
+ if isinstance(transport, ClientTransport):
386
+ return transport
387
+
388
+ # the transport is a FastMCP server
389
+ elif isinstance(transport, FastMCPServer):
390
+ return FastMCPTransport(mcp=transport)
391
+
392
+ # the transport is a path to a script
393
+ elif isinstance(transport, Path | str) and Path(transport).exists():
394
+ if str(transport).endswith(".py"):
395
+ return PythonStdioTransport(script_path=transport)
396
+ elif str(transport).endswith(".js"):
397
+ return NodeStdioTransport(script_path=transport)
398
+ else:
399
+ raise ValueError(f"Unsupported script type: {transport}")
400
+
401
+ # the transport is an http(s) URL
402
+ elif isinstance(transport, AnyUrl | str) and str(transport).startswith("http"):
403
+ return SSETransport(url=transport)
404
+
405
+ # the transport is a websocket URL
406
+ elif isinstance(transport, AnyUrl | str) and str(transport).startswith("ws"):
407
+ return WSTransport(url=transport)
408
+
409
+ # the transport is an unknown type
410
+ else:
411
+ raise ValueError(f"Could not infer a valid transport from: {transport}")
@@ -1,4 +1,4 @@
1
- from .base import Prompt
2
- from .manager import PromptManager
1
+ from .prompt import Prompt
2
+ from .prompt_manager import PromptManager
3
3
 
4
4
  __all__ = ["Prompt", "PromptManager"]
@@ -1,12 +1,16 @@
1
1
  """Base classes for FastMCP prompts."""
2
2
 
3
- import json
4
- from typing import Any, Callable, Dict, Literal, Optional, Sequence, Awaitable
5
3
  import inspect
4
+ import json
5
+ from collections.abc import Awaitable, Callable, Sequence
6
+ from typing import Annotated, Any, Literal
6
7
 
7
- from pydantic import BaseModel, Field, TypeAdapter, validate_call
8
- from mcp.types import TextContent, ImageContent, EmbeddedResource
9
8
  import pydantic_core
9
+ from mcp.types import EmbeddedResource, ImageContent, TextContent
10
+ from pydantic import BaseModel, BeforeValidator, Field, TypeAdapter, validate_call
11
+ from typing_extensions import Self
12
+
13
+ from fastmcp.utilities.types import _convert_set_defaults
10
14
 
11
15
  CONTENT_TYPES = TextContent | ImageContent | EmbeddedResource
12
16
 
@@ -17,7 +21,7 @@ class Message(BaseModel):
17
21
  role: Literal["user", "assistant"]
18
22
  content: CONTENT_TYPES
19
23
 
20
- def __init__(self, content: str | CONTENT_TYPES, **kwargs):
24
+ def __init__(self, content: str | CONTENT_TYPES, **kwargs: Any):
21
25
  if isinstance(content, str):
22
26
  content = TextContent(type="text", text=content)
23
27
  super().__init__(content=content, **kwargs)
@@ -26,22 +30,24 @@ class Message(BaseModel):
26
30
  class UserMessage(Message):
27
31
  """A message from the user."""
28
32
 
29
- role: Literal["user"] = "user"
33
+ role: Literal["user", "assistant"] = "user"
30
34
 
31
- def __init__(self, content: str | CONTENT_TYPES, **kwargs):
35
+ def __init__(self, content: str | CONTENT_TYPES, **kwargs: Any):
32
36
  super().__init__(content=content, **kwargs)
33
37
 
34
38
 
35
39
  class AssistantMessage(Message):
36
40
  """A message from the assistant."""
37
41
 
38
- role: Literal["assistant"] = "assistant"
42
+ role: Literal["user", "assistant"] = "assistant"
39
43
 
40
- def __init__(self, content: str | CONTENT_TYPES, **kwargs):
44
+ def __init__(self, content: str | CONTENT_TYPES, **kwargs: Any):
41
45
  super().__init__(content=content, **kwargs)
42
46
 
43
47
 
44
- message_validator = TypeAdapter(UserMessage | AssistantMessage)
48
+ message_validator = TypeAdapter[UserMessage | AssistantMessage](
49
+ UserMessage | AssistantMessage
50
+ )
45
51
 
46
52
  SyncPromptResult = (
47
53
  str | Message | dict[str, Any] | Sequence[str | Message | dict[str, Any]]
@@ -68,17 +74,21 @@ class Prompt(BaseModel):
68
74
  description: str | None = Field(
69
75
  None, description="Description of what the prompt does"
70
76
  )
77
+ tags: Annotated[set[str], BeforeValidator(_convert_set_defaults)] = Field(
78
+ default_factory=set, description="Tags for the prompt"
79
+ )
71
80
  arguments: list[PromptArgument] | None = Field(
72
81
  None, description="Arguments that can be passed to the prompt"
73
82
  )
74
- fn: Callable = Field(exclude=True)
83
+ fn: Callable[..., PromptResult | Awaitable[PromptResult]]
75
84
 
76
85
  @classmethod
77
86
  def from_function(
78
87
  cls,
79
- fn: Callable[..., PromptResult],
80
- name: Optional[str] = None,
81
- description: Optional[str] = None,
88
+ fn: Callable[..., PromptResult | Awaitable[PromptResult]],
89
+ name: str | None = None,
90
+ description: str | None = None,
91
+ tags: set[str] | None = None,
82
92
  ) -> "Prompt":
83
93
  """Create a Prompt from a function.
84
94
 
@@ -97,7 +107,7 @@ class Prompt(BaseModel):
97
107
  parameters = TypeAdapter(fn).json_schema()
98
108
 
99
109
  # Convert parameters to PromptArguments
100
- arguments = []
110
+ arguments: list[PromptArgument] = []
101
111
  if "properties" in parameters:
102
112
  for param_name, param in parameters["properties"].items():
103
113
  required = param_name in parameters.get("required", [])
@@ -117,9 +127,10 @@ class Prompt(BaseModel):
117
127
  description=description or fn.__doc__ or "",
118
128
  arguments=arguments,
119
129
  fn=fn,
130
+ tags=tags or set(),
120
131
  )
121
132
 
122
- async def render(self, arguments: Optional[Dict[str, Any]] = None) -> list[Message]:
133
+ async def render(self, arguments: dict[str, Any] | None = None) -> list[Message]:
123
134
  """Render the prompt with arguments."""
124
135
  # Validate required arguments
125
136
  if self.arguments:
@@ -136,25 +147,23 @@ class Prompt(BaseModel):
136
147
  result = await result
137
148
 
138
149
  # Validate messages
139
- if not isinstance(result, (list, tuple)):
150
+ if not isinstance(result, list | tuple):
140
151
  result = [result]
141
152
 
142
153
  # Convert result to messages
143
- messages = []
144
- for msg in result:
154
+ messages: list[Message] = []
155
+ for msg in result: # type: ignore[reportUnknownVariableType]
145
156
  try:
146
157
  if isinstance(msg, Message):
147
158
  messages.append(msg)
148
159
  elif isinstance(msg, dict):
149
- msg = message_validator.validate_python(msg)
150
- messages.append(msg)
160
+ messages.append(message_validator.validate_python(msg))
151
161
  elif isinstance(msg, str):
152
- messages.append(
153
- UserMessage(content=TextContent(type="text", text=msg))
154
- )
162
+ content = TextContent(type="text", text=msg)
163
+ messages.append(UserMessage(content=content))
155
164
  else:
156
- msg = json.dumps(pydantic_core.to_jsonable_python(msg))
157
- messages.append(Message(role="user", content=msg))
165
+ content = json.dumps(pydantic_core.to_jsonable_python(msg))
166
+ messages.append(Message(role="user", content=content))
158
167
  except Exception:
159
168
  raise ValueError(
160
169
  f"Could not convert prompt result to message: {msg}"
@@ -163,3 +172,15 @@ class Prompt(BaseModel):
163
172
  return messages
164
173
  except Exception as e:
165
174
  raise ValueError(f"Error rendering prompt {self.name}: {e}")
175
+
176
+ def copy(self, updates: dict[str, Any] | None = None) -> Self:
177
+ """Copy the prompt with optional updates."""
178
+ data = self.model_dump()
179
+ if updates:
180
+ data.update(updates)
181
+ return type(self)(**data)
182
+
183
+ def __eq__(self, other: object) -> bool:
184
+ if not isinstance(other, Prompt):
185
+ return False
186
+ return self.model_dump() == other.model_dump()