pydantic-ai-slim 0.6.2__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

Files changed (57) hide show
  1. pydantic_ai/_a2a.py +6 -4
  2. pydantic_ai/_agent_graph.py +25 -32
  3. pydantic_ai/_cli.py +3 -3
  4. pydantic_ai/_output.py +8 -0
  5. pydantic_ai/_tool_manager.py +3 -0
  6. pydantic_ai/ag_ui.py +25 -14
  7. pydantic_ai/{agent.py → agent/__init__.py} +209 -1027
  8. pydantic_ai/agent/abstract.py +942 -0
  9. pydantic_ai/agent/wrapper.py +227 -0
  10. pydantic_ai/direct.py +9 -9
  11. pydantic_ai/durable_exec/__init__.py +0 -0
  12. pydantic_ai/durable_exec/temporal/__init__.py +83 -0
  13. pydantic_ai/durable_exec/temporal/_agent.py +699 -0
  14. pydantic_ai/durable_exec/temporal/_function_toolset.py +92 -0
  15. pydantic_ai/durable_exec/temporal/_logfire.py +48 -0
  16. pydantic_ai/durable_exec/temporal/_mcp_server.py +145 -0
  17. pydantic_ai/durable_exec/temporal/_model.py +168 -0
  18. pydantic_ai/durable_exec/temporal/_run_context.py +50 -0
  19. pydantic_ai/durable_exec/temporal/_toolset.py +77 -0
  20. pydantic_ai/ext/aci.py +10 -9
  21. pydantic_ai/ext/langchain.py +4 -2
  22. pydantic_ai/mcp.py +203 -75
  23. pydantic_ai/messages.py +2 -2
  24. pydantic_ai/models/__init__.py +65 -9
  25. pydantic_ai/models/anthropic.py +16 -7
  26. pydantic_ai/models/bedrock.py +8 -5
  27. pydantic_ai/models/cohere.py +1 -4
  28. pydantic_ai/models/fallback.py +4 -2
  29. pydantic_ai/models/function.py +9 -4
  30. pydantic_ai/models/gemini.py +15 -9
  31. pydantic_ai/models/google.py +18 -14
  32. pydantic_ai/models/groq.py +17 -14
  33. pydantic_ai/models/huggingface.py +18 -12
  34. pydantic_ai/models/instrumented.py +3 -1
  35. pydantic_ai/models/mcp_sampling.py +3 -1
  36. pydantic_ai/models/mistral.py +12 -18
  37. pydantic_ai/models/openai.py +29 -26
  38. pydantic_ai/models/test.py +3 -0
  39. pydantic_ai/models/wrapper.py +6 -2
  40. pydantic_ai/profiles/openai.py +1 -1
  41. pydantic_ai/providers/google.py +7 -7
  42. pydantic_ai/result.py +21 -55
  43. pydantic_ai/run.py +357 -0
  44. pydantic_ai/tools.py +0 -1
  45. pydantic_ai/toolsets/__init__.py +2 -0
  46. pydantic_ai/toolsets/_dynamic.py +87 -0
  47. pydantic_ai/toolsets/abstract.py +23 -3
  48. pydantic_ai/toolsets/combined.py +19 -4
  49. pydantic_ai/toolsets/deferred.py +10 -2
  50. pydantic_ai/toolsets/function.py +23 -8
  51. pydantic_ai/toolsets/prefixed.py +4 -0
  52. pydantic_ai/toolsets/wrapper.py +14 -1
  53. {pydantic_ai_slim-0.6.2.dist-info → pydantic_ai_slim-0.7.0.dist-info}/METADATA +6 -4
  54. {pydantic_ai_slim-0.6.2.dist-info → pydantic_ai_slim-0.7.0.dist-info}/RECORD +57 -44
  55. {pydantic_ai_slim-0.6.2.dist-info → pydantic_ai_slim-0.7.0.dist-info}/WHEEL +0 -0
  56. {pydantic_ai_slim-0.6.2.dist-info → pydantic_ai_slim-0.7.0.dist-info}/entry_points.txt +0 -0
  57. {pydantic_ai_slim-0.6.2.dist-info → pydantic_ai_slim-0.7.0.dist-info}/licenses/LICENSE +0 -0
pydantic_ai/mcp.py CHANGED
@@ -7,7 +7,7 @@ from abc import ABC, abstractmethod
7
7
  from asyncio import Lock
8
8
  from collections.abc import AsyncIterator, Awaitable, Sequence
9
9
  from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager
10
- from dataclasses import dataclass, field, replace
10
+ from dataclasses import field, replace
11
11
  from datetime import timedelta
12
12
  from pathlib import Path
13
13
  from typing import Any, Callable
@@ -56,17 +56,17 @@ class MCPServer(AbstractToolset[Any], ABC):
56
56
  See <https://modelcontextprotocol.io> for more information.
57
57
  """
58
58
 
59
- # these fields should be re-defined by dataclass subclasses so they appear as fields {
60
- tool_prefix: str | None = None
61
- log_level: mcp_types.LoggingLevel | None = None
62
- log_handler: LoggingFnT | None = None
63
- timeout: float = 5
64
- read_timeout: float = 5 * 60
65
- process_tool_call: ProcessToolCallback | None = None
66
- allow_sampling: bool = True
67
- max_retries: int = 1
68
- sampling_model: models.Model | None = None
69
- # } end of "abstract fields"
59
+ tool_prefix: str | None
60
+ log_level: mcp_types.LoggingLevel | None
61
+ log_handler: LoggingFnT | None
62
+ timeout: float
63
+ read_timeout: float
64
+ process_tool_call: ProcessToolCallback | None
65
+ allow_sampling: bool
66
+ sampling_model: models.Model | None
67
+ max_retries: int
68
+
69
+ _id: str | None
70
70
 
71
71
  _enter_lock: Lock = field(compare=False)
72
72
  _running_count: int
@@ -76,6 +76,34 @@ class MCPServer(AbstractToolset[Any], ABC):
76
76
  _read_stream: MemoryObjectReceiveStream[SessionMessage | Exception]
77
77
  _write_stream: MemoryObjectSendStream[SessionMessage]
78
78
 
79
+ def __init__(
80
+ self,
81
+ tool_prefix: str | None = None,
82
+ log_level: mcp_types.LoggingLevel | None = None,
83
+ log_handler: LoggingFnT | None = None,
84
+ timeout: float = 5,
85
+ read_timeout: float = 5 * 60,
86
+ process_tool_call: ProcessToolCallback | None = None,
87
+ allow_sampling: bool = True,
88
+ sampling_model: models.Model | None = None,
89
+ max_retries: int = 1,
90
+ *,
91
+ id: str | None = None,
92
+ ):
93
+ self.tool_prefix = tool_prefix
94
+ self.log_level = log_level
95
+ self.log_handler = log_handler
96
+ self.timeout = timeout
97
+ self.read_timeout = read_timeout
98
+ self.process_tool_call = process_tool_call
99
+ self.allow_sampling = allow_sampling
100
+ self.sampling_model = sampling_model
101
+ self.max_retries = max_retries
102
+
103
+ self._id = id or tool_prefix
104
+
105
+ self.__post_init__()
106
+
79
107
  def __post_init__(self):
80
108
  self._enter_lock = Lock()
81
109
  self._running_count = 0
@@ -96,12 +124,19 @@ class MCPServer(AbstractToolset[Any], ABC):
96
124
  yield
97
125
 
98
126
  @property
99
- def name(self) -> str:
100
- return repr(self)
127
+ def id(self) -> str | None:
128
+ return self._id
129
+
130
+ @property
131
+ def label(self) -> str:
132
+ if self.id:
133
+ return super().label # pragma: no cover
134
+ else:
135
+ return repr(self)
101
136
 
102
137
  @property
103
138
  def tool_name_conflict_hint(self) -> str:
104
- return 'Consider setting `tool_prefix` to avoid name conflicts.'
139
+ return 'Set the `tool_prefix` attribute to avoid name conflicts.'
105
140
 
106
141
  async def list_tools(self) -> list[mcp_types.Tool]:
107
142
  """Retrieve tools that are currently active on the server.
@@ -177,20 +212,25 @@ class MCPServer(AbstractToolset[Any], ABC):
177
212
 
178
213
  async def get_tools(self, ctx: RunContext[Any]) -> dict[str, ToolsetTool[Any]]:
179
214
  return {
180
- name: ToolsetTool(
181
- toolset=self,
182
- tool_def=ToolDefinition(
215
+ name: self.tool_for_tool_def(
216
+ ToolDefinition(
183
217
  name=name,
184
218
  description=mcp_tool.description,
185
219
  parameters_json_schema=mcp_tool.inputSchema,
186
220
  ),
187
- max_retries=self.max_retries,
188
- args_validator=TOOL_SCHEMA_VALIDATOR,
189
221
  )
190
222
  for mcp_tool in await self.list_tools()
191
223
  if (name := f'{self.tool_prefix}_{mcp_tool.name}' if self.tool_prefix else mcp_tool.name)
192
224
  }
193
225
 
226
+ def tool_for_tool_def(self, tool_def: ToolDefinition) -> ToolsetTool[Any]:
227
+ return ToolsetTool(
228
+ toolset=self,
229
+ tool_def=tool_def,
230
+ max_retries=self.max_retries,
231
+ args_validator=TOOL_SCHEMA_VALIDATOR,
232
+ )
233
+
194
234
  async def __aenter__(self) -> Self:
195
235
  """Enter the MCP server context.
196
236
 
@@ -308,7 +348,6 @@ class MCPServer(AbstractToolset[Any], ABC):
308
348
  assert_never(resource)
309
349
 
310
350
 
311
- @dataclass
312
351
  class MCPServerStdio(MCPServer):
313
352
  """Runs an MCP server in a subprocess and communicates with it over stdin/stdout.
314
353
 
@@ -353,18 +392,18 @@ class MCPServerStdio(MCPServer):
353
392
  args: Sequence[str]
354
393
  """The arguments to pass to the command."""
355
394
 
356
- env: dict[str, str] | None = None
395
+ env: dict[str, str] | None
357
396
  """The environment variables the CLI server will have access to.
358
397
 
359
398
  By default the subprocess will not inherit any environment variables from the parent process.
360
399
  If you want to inherit the environment variables from the parent process, use `env=os.environ`.
361
400
  """
362
401
 
363
- cwd: str | Path | None = None
402
+ cwd: str | Path | None
364
403
  """The working directory to use when spawning the process."""
365
404
 
366
405
  # last fields are re-defined from the parent class so they appear as fields
367
- tool_prefix: str | None = None
406
+ tool_prefix: str | None
368
407
  """A prefix to add to all tools that are registered with the server.
369
408
 
370
409
  If not empty, will include a trailing underscore(`_`).
@@ -372,7 +411,7 @@ class MCPServerStdio(MCPServer):
372
411
  e.g. if `tool_prefix='foo'`, then a tool named `bar` will be registered as `foo_bar`
373
412
  """
374
413
 
375
- log_level: mcp_types.LoggingLevel | None = None
414
+ log_level: mcp_types.LoggingLevel | None
376
415
  """The log level to set when connecting to the server, if any.
377
416
 
378
417
  See <https://modelcontextprotocol.io/specification/2025-03-26/server/utilities/logging#logging> for more details.
@@ -380,23 +419,85 @@ class MCPServerStdio(MCPServer):
380
419
  If `None`, no log level will be set.
381
420
  """
382
421
 
383
- log_handler: LoggingFnT | None = None
422
+ log_handler: LoggingFnT | None
384
423
  """A handler for logging messages from the server."""
385
424
 
386
- timeout: float = 5
425
+ timeout: float
387
426
  """The timeout in seconds to wait for the client to initialize."""
388
427
 
389
- process_tool_call: ProcessToolCallback | None = None
428
+ read_timeout: float
429
+ """Maximum time in seconds to wait for new messages before timing out.
430
+
431
+ This timeout applies to the long-lived connection after it's established.
432
+ If no new messages are received within this time, the connection will be considered stale
433
+ and may be closed. Defaults to 5 minutes (300 seconds).
434
+ """
435
+
436
+ process_tool_call: ProcessToolCallback | None
390
437
  """Hook to customize tool calling and optionally pass extra metadata."""
391
438
 
392
- allow_sampling: bool = True
439
+ allow_sampling: bool
393
440
  """Whether to allow MCP sampling through this client."""
394
441
 
395
- max_retries: int = 1
442
+ sampling_model: models.Model | None
443
+ """The model to use for sampling."""
444
+
445
+ max_retries: int
396
446
  """The maximum number of times to retry a tool call."""
397
447
 
398
- sampling_model: models.Model | None = None
399
- """The model to use for sampling."""
448
+ def __init__(
449
+ self,
450
+ command: str,
451
+ args: Sequence[str],
452
+ env: dict[str, str] | None = None,
453
+ cwd: str | Path | None = None,
454
+ tool_prefix: str | None = None,
455
+ log_level: mcp_types.LoggingLevel | None = None,
456
+ log_handler: LoggingFnT | None = None,
457
+ timeout: float = 5,
458
+ read_timeout: float = 5 * 60,
459
+ process_tool_call: ProcessToolCallback | None = None,
460
+ allow_sampling: bool = True,
461
+ sampling_model: models.Model | None = None,
462
+ max_retries: int = 1,
463
+ *,
464
+ id: str | None = None,
465
+ ):
466
+ """Build a new MCP server.
467
+
468
+ Args:
469
+ command: The command to run.
470
+ args: The arguments to pass to the command.
471
+ env: The environment variables to set in the subprocess.
472
+ cwd: The working directory to use when spawning the process.
473
+ tool_prefix: A prefix to add to all tools that are registered with the server.
474
+ log_level: The log level to set when connecting to the server, if any.
475
+ log_handler: A handler for logging messages from the server.
476
+ timeout: The timeout in seconds to wait for the client to initialize.
477
+ read_timeout: Maximum time in seconds to wait for new messages before timing out.
478
+ process_tool_call: Hook to customize tool calling and optionally pass extra metadata.
479
+ allow_sampling: Whether to allow MCP sampling through this client.
480
+ sampling_model: The model to use for sampling.
481
+ max_retries: The maximum number of times to retry a tool call.
482
+ id: An optional unique ID for the MCP server. An MCP server needs to have an ID in order to be used in a durable execution environment like Temporal, in which case the ID will be used to identify the server's activities within the workflow.
483
+ """
484
+ self.command = command
485
+ self.args = args
486
+ self.env = env
487
+ self.cwd = cwd
488
+
489
+ super().__init__(
490
+ tool_prefix,
491
+ log_level,
492
+ log_handler,
493
+ timeout,
494
+ read_timeout,
495
+ process_tool_call,
496
+ allow_sampling,
497
+ sampling_model,
498
+ max_retries,
499
+ id=id,
500
+ )
400
501
 
401
502
  @asynccontextmanager
402
503
  async def client_streams(
@@ -412,15 +513,20 @@ class MCPServerStdio(MCPServer):
412
513
  yield read_stream, write_stream
413
514
 
414
515
  def __repr__(self) -> str:
415
- return f'MCPServerStdio(command={self.command!r}, args={self.args!r}, tool_prefix={self.tool_prefix!r})'
516
+ repr_args = [
517
+ f'command={self.command!r}',
518
+ f'args={self.args!r}',
519
+ ]
520
+ if self.id:
521
+ repr_args.append(f'id={self.id!r}') # pragma: no cover
522
+ return f'{self.__class__.__name__}({", ".join(repr_args)})'
416
523
 
417
524
 
418
- @dataclass(init=False)
419
525
  class _MCPServerHTTP(MCPServer):
420
526
  url: str
421
527
  """The URL of the endpoint on the MCP server."""
422
528
 
423
- headers: dict[str, Any] | None = None
529
+ headers: dict[str, Any] | None
424
530
  """Optional HTTP headers to be sent with each request to the endpoint.
425
531
 
426
532
  These headers will be passed directly to the underlying `httpx.AsyncClient`.
@@ -432,7 +538,7 @@ class _MCPServerHTTP(MCPServer):
432
538
  See [`MCPServerHTTP.http_client`][pydantic_ai.mcp.MCPServerHTTP.http_client] for more information.
433
539
  """
434
540
 
435
- http_client: httpx.AsyncClient | None = None
541
+ http_client: httpx.AsyncClient | None
436
542
  """An `httpx.AsyncClient` to use with the endpoint.
437
543
 
438
544
  This client may be configured to use customized connection parameters like self-signed certificates.
@@ -452,16 +558,8 @@ class _MCPServerHTTP(MCPServer):
452
558
  ```
453
559
  """
454
560
 
455
- read_timeout: float = 5 * 60
456
- """Maximum time in seconds to wait for new messages before timing out.
457
-
458
- This timeout applies to the long-lived connection after it's established.
459
- If no new messages are received within this time, the connection will be considered stale
460
- and may be closed. Defaults to 5 minutes (300 seconds).
461
- """
462
-
463
561
  # last fields are re-defined from the parent class so they appear as fields
464
- tool_prefix: str | None = None
562
+ tool_prefix: str | None
465
563
  """A prefix to add to all tools that are registered with the server.
466
564
 
467
565
  If not empty, will include a trailing underscore (`_`).
@@ -469,7 +567,7 @@ class _MCPServerHTTP(MCPServer):
469
567
  For example, if `tool_prefix='foo'`, then a tool named `bar` will be registered as `foo_bar`
470
568
  """
471
569
 
472
- log_level: mcp_types.LoggingLevel | None = None
570
+ log_level: mcp_types.LoggingLevel | None
473
571
  """The log level to set when connecting to the server, if any.
474
572
 
475
573
  See <https://modelcontextprotocol.io/introduction#logging> for more details.
@@ -477,56 +575,81 @@ class _MCPServerHTTP(MCPServer):
477
575
  If `None`, no log level will be set.
478
576
  """
479
577
 
480
- log_handler: LoggingFnT | None = None
578
+ log_handler: LoggingFnT | None
481
579
  """A handler for logging messages from the server."""
482
580
 
483
- timeout: float = 5
581
+ timeout: float
484
582
  """Initial connection timeout in seconds for establishing the connection.
485
583
 
486
584
  This timeout applies to the initial connection setup and handshake.
487
585
  If the connection cannot be established within this time, the operation will fail.
488
586
  """
489
587
 
490
- process_tool_call: ProcessToolCallback | None = None
588
+ read_timeout: float
589
+ """Maximum time in seconds to wait for new messages before timing out.
590
+
591
+ This timeout applies to the long-lived connection after it's established.
592
+ If no new messages are received within this time, the connection will be considered stale
593
+ and may be closed. Defaults to 5 minutes (300 seconds).
594
+ """
595
+
596
+ process_tool_call: ProcessToolCallback | None
491
597
  """Hook to customize tool calling and optionally pass extra metadata."""
492
598
 
493
- allow_sampling: bool = True
599
+ allow_sampling: bool
494
600
  """Whether to allow MCP sampling through this client."""
495
601
 
496
- max_retries: int = 1
497
- """The maximum number of times to retry a tool call."""
498
-
499
- sampling_model: models.Model | None = None
602
+ sampling_model: models.Model | None
500
603
  """The model to use for sampling."""
501
604
 
605
+ max_retries: int
606
+ """The maximum number of times to retry a tool call."""
607
+
502
608
  def __init__(
503
609
  self,
504
610
  *,
505
611
  url: str,
506
612
  headers: dict[str, str] | None = None,
507
613
  http_client: httpx.AsyncClient | None = None,
508
- read_timeout: float | None = None,
614
+ id: str | None = None,
509
615
  tool_prefix: str | None = None,
510
616
  log_level: mcp_types.LoggingLevel | None = None,
511
617
  log_handler: LoggingFnT | None = None,
512
618
  timeout: float = 5,
619
+ read_timeout: float | None = None,
513
620
  process_tool_call: ProcessToolCallback | None = None,
514
621
  allow_sampling: bool = True,
515
- max_retries: int = 1,
516
622
  sampling_model: models.Model | None = None,
517
- **kwargs: Any,
623
+ max_retries: int = 1,
624
+ **_deprecated_kwargs: Any,
518
625
  ):
519
- # Handle deprecated sse_read_timeout parameter
520
- if 'sse_read_timeout' in kwargs:
626
+ """Build a new MCP server.
627
+
628
+ Args:
629
+ url: The URL of the endpoint on the MCP server.
630
+ headers: Optional HTTP headers to be sent with each request to the endpoint.
631
+ http_client: An `httpx.AsyncClient` to use with the endpoint.
632
+ id: An optional unique ID for the MCP server. An MCP server needs to have an ID in order to be used in a durable execution environment like Temporal, in which case the ID will be used to identify the server's activities within the workflow.
633
+ tool_prefix: A prefix to add to all tools that are registered with the server.
634
+ log_level: The log level to set when connecting to the server, if any.
635
+ log_handler: A handler for logging messages from the server.
636
+ timeout: The timeout in seconds to wait for the client to initialize.
637
+ read_timeout: Maximum time in seconds to wait for new messages before timing out.
638
+ process_tool_call: Hook to customize tool calling and optionally pass extra metadata.
639
+ allow_sampling: Whether to allow MCP sampling through this client.
640
+ sampling_model: The model to use for sampling.
641
+ max_retries: The maximum number of times to retry a tool call.
642
+ """
643
+ if 'sse_read_timeout' in _deprecated_kwargs:
521
644
  if read_timeout is not None:
522
645
  raise TypeError("'read_timeout' and 'sse_read_timeout' cannot be set at the same time.")
523
646
 
524
647
  warnings.warn(
525
648
  "'sse_read_timeout' is deprecated, use 'read_timeout' instead.", DeprecationWarning, stacklevel=2
526
649
  )
527
- read_timeout = kwargs.pop('sse_read_timeout')
650
+ read_timeout = _deprecated_kwargs.pop('sse_read_timeout')
528
651
 
529
- _utils.validate_empty_kwargs(kwargs)
652
+ _utils.validate_empty_kwargs(_deprecated_kwargs)
530
653
 
531
654
  if read_timeout is None:
532
655
  read_timeout = 5 * 60
@@ -534,16 +657,19 @@ class _MCPServerHTTP(MCPServer):
534
657
  self.url = url
535
658
  self.headers = headers
536
659
  self.http_client = http_client
537
- self.tool_prefix = tool_prefix
538
- self.log_level = log_level
539
- self.log_handler = log_handler
540
- self.timeout = timeout
541
- self.process_tool_call = process_tool_call
542
- self.allow_sampling = allow_sampling
543
- self.max_retries = max_retries
544
- self.sampling_model = sampling_model
545
- self.read_timeout = read_timeout
546
- self.__post_init__()
660
+
661
+ super().__init__(
662
+ tool_prefix,
663
+ log_level,
664
+ log_handler,
665
+ timeout,
666
+ read_timeout,
667
+ process_tool_call,
668
+ allow_sampling,
669
+ sampling_model,
670
+ max_retries,
671
+ id=id,
672
+ )
547
673
 
548
674
  @property
549
675
  @abstractmethod
@@ -606,10 +732,14 @@ class _MCPServerHTTP(MCPServer):
606
732
  yield read_stream, write_stream
607
733
 
608
734
  def __repr__(self) -> str: # pragma: no cover
609
- return f'{self.__class__.__name__}(url={self.url!r}, tool_prefix={self.tool_prefix!r})'
735
+ repr_args = [
736
+ f'url={self.url!r}',
737
+ ]
738
+ if self.id:
739
+ repr_args.append(f'id={self.id!r}')
740
+ return f'{self.__class__.__name__}({", ".join(repr_args)})'
610
741
 
611
742
 
612
- @dataclass(init=False)
613
743
  class MCPServerSSE(_MCPServerHTTP):
614
744
  """An MCP server that connects over streamable HTTP connections.
615
745
 
@@ -643,7 +773,6 @@ class MCPServerSSE(_MCPServerHTTP):
643
773
 
644
774
 
645
775
  @deprecated('The `MCPServerHTTP` class is deprecated, use `MCPServerSSE` instead.')
646
- @dataclass
647
776
  class MCPServerHTTP(MCPServerSSE):
648
777
  """An MCP server that connects over HTTP using the old SSE transport.
649
778
 
@@ -672,7 +801,6 @@ class MCPServerHTTP(MCPServerSSE):
672
801
  """
673
802
 
674
803
 
675
- @dataclass
676
804
  class MCPServerStreamableHTTP(_MCPServerHTTP):
677
805
  """An MCP server that connects over HTTP using the Streamable HTTP transport.
678
806
 
pydantic_ai/messages.py CHANGED
@@ -490,8 +490,8 @@ _video_format_lookup: dict[str, VideoFormat] = {
490
490
  class UserPromptPart:
491
491
  """A user prompt, generally written by the end user.
492
492
 
493
- Content comes from the `user_prompt` parameter of [`Agent.run`][pydantic_ai.Agent.run],
494
- [`Agent.run_sync`][pydantic_ai.Agent.run_sync], and [`Agent.run_stream`][pydantic_ai.Agent.run_stream].
493
+ Content comes from the `user_prompt` parameter of [`Agent.run`][pydantic_ai.agent.AbstractAgent.run],
494
+ [`Agent.run_sync`][pydantic_ai.agent.AbstractAgent.run_sync], and [`Agent.run_stream`][pydantic_ai.agent.AbstractAgent.run_stream].
495
495
  """
496
496
 
497
497
  content: str | Sequence[UserContent]
@@ -13,20 +13,32 @@ from contextlib import asynccontextmanager, contextmanager
13
13
  from dataclasses import dataclass, field, replace
14
14
  from datetime import datetime
15
15
  from functools import cache, cached_property
16
- from typing import Generic, TypeVar, overload
16
+ from typing import Any, Generic, TypeVar, overload
17
17
 
18
18
  import httpx
19
19
  from typing_extensions import Literal, TypeAliasType, TypedDict
20
20
 
21
- from pydantic_ai.builtin_tools import AbstractBuiltinTool
22
- from pydantic_ai.profiles import DEFAULT_PROFILE, ModelProfile, ModelProfileSpec
23
-
24
21
  from .. import _utils
25
22
  from .._output import OutputObjectDefinition
26
23
  from .._parts_manager import ModelResponsePartsManager
24
+ from .._run_context import RunContext
25
+ from ..builtin_tools import AbstractBuiltinTool
27
26
  from ..exceptions import UserError
28
- from ..messages import FileUrl, ModelMessage, ModelRequest, ModelResponse, ModelResponseStreamEvent, VideoUrl
27
+ from ..messages import (
28
+ AgentStreamEvent,
29
+ FileUrl,
30
+ FinalResultEvent,
31
+ ModelMessage,
32
+ ModelRequest,
33
+ ModelResponse,
34
+ ModelResponseStreamEvent,
35
+ PartStartEvent,
36
+ TextPart,
37
+ ToolCallPart,
38
+ VideoUrl,
39
+ )
29
40
  from ..output import OutputMode
41
+ from ..profiles import DEFAULT_PROFILE, ModelProfile, ModelProfileSpec
30
42
  from ..profiles._json_schema import JsonSchemaTransformer
31
43
  from ..settings import ModelSettings
32
44
  from ..tools import ToolDefinition
@@ -344,6 +356,10 @@ class ModelRequestParameters:
344
356
  output_tools: list[ToolDefinition] = field(default_factory=list)
345
357
  allow_text_output: bool = True
346
358
 
359
+ @cached_property
360
+ def tool_defs(self) -> dict[str, ToolDefinition]:
361
+ return {tool_def.name: tool_def for tool_def in [*self.function_tools, *self.output_tools]}
362
+
347
363
  __repr__ = _utils.dataclasses_no_defaults_repr
348
364
 
349
365
 
@@ -389,6 +405,7 @@ class Model(ABC):
389
405
  messages: list[ModelMessage],
390
406
  model_settings: ModelSettings | None,
391
407
  model_request_parameters: ModelRequestParameters,
408
+ run_context: RunContext[Any] | None = None,
392
409
  ) -> AsyncIterator[StreamedResponse]:
393
410
  """Make a request to the model and return a streaming response."""
394
411
  # This method is not required, but you need to implement it if you want to support streamed responses
@@ -501,14 +518,40 @@ class Model(ABC):
501
518
  class StreamedResponse(ABC):
502
519
  """Streamed response from an LLM when calling a tool."""
503
520
 
521
+ model_request_parameters: ModelRequestParameters
522
+ final_result_event: FinalResultEvent | None = field(default=None, init=False)
523
+
504
524
  _parts_manager: ModelResponsePartsManager = field(default_factory=ModelResponsePartsManager, init=False)
505
- _event_iterator: AsyncIterator[ModelResponseStreamEvent] | None = field(default=None, init=False)
525
+ _event_iterator: AsyncIterator[AgentStreamEvent] | None = field(default=None, init=False)
506
526
  _usage: Usage = field(default_factory=Usage, init=False)
507
527
 
508
- def __aiter__(self) -> AsyncIterator[ModelResponseStreamEvent]:
509
- """Stream the response as an async iterable of [`ModelResponseStreamEvent`][pydantic_ai.messages.ModelResponseStreamEvent]s."""
528
+ def __aiter__(self) -> AsyncIterator[AgentStreamEvent]:
529
+ """Stream the response as an async iterable of [`AgentStreamEvent`][pydantic_ai.messages.AgentStreamEvent]s.
530
+
531
+ This proxies the `_event_iterator()` and emits all events, while also checking for matches
532
+ on the result schema and emitting a [`FinalResultEvent`][pydantic_ai.messages.FinalResultEvent] if/when the
533
+ first match is found.
534
+ """
510
535
  if self._event_iterator is None:
511
- self._event_iterator = self._get_event_iterator()
536
+
537
+ async def iterator_with_final_event(
538
+ iterator: AsyncIterator[ModelResponseStreamEvent],
539
+ ) -> AsyncIterator[AgentStreamEvent]:
540
+ async for event in iterator:
541
+ yield event
542
+ if (
543
+ final_result_event := _get_final_result_event(event, self.model_request_parameters)
544
+ ) is not None:
545
+ self.final_result_event = final_result_event
546
+ yield final_result_event
547
+ break
548
+
549
+ # If we broke out of the above loop, we need to yield the rest of the events
550
+ # If we didn't, this will just be a no-op
551
+ async for event in iterator:
552
+ yield event
553
+
554
+ self._event_iterator = iterator_with_final_event(self._get_event_iterator())
512
555
  return self._event_iterator
513
556
 
514
557
  @abstractmethod
@@ -810,3 +853,16 @@ def _customize_output_object(transformer: type[JsonSchemaTransformer], o: Output
810
853
  json_schema=json_schema,
811
854
  strict=schema_transformer.is_strict_compatible if o.strict is None else o.strict,
812
855
  )
856
+
857
+
858
+ def _get_final_result_event(e: ModelResponseStreamEvent, params: ModelRequestParameters) -> FinalResultEvent | None:
859
+ """Return an appropriate FinalResultEvent if `e` corresponds to a part that will produce a final result."""
860
+ if isinstance(e, PartStartEvent):
861
+ new_part = e.part
862
+ if isinstance(new_part, TextPart) and params.allow_text_output: # pragma: no branch
863
+ return FinalResultEvent(tool_name=None, tool_call_id=None)
864
+ elif isinstance(new_part, ToolCallPart) and (tool_def := params.tool_defs.get(new_part.tool_name)):
865
+ if tool_def.kind == 'output':
866
+ return FinalResultEvent(tool_name=new_part.tool_name, tool_call_id=new_part.tool_call_id)
867
+ elif tool_def.kind == 'deferred':
868
+ return FinalResultEvent(tool_name=None, tool_call_id=None)
@@ -21,7 +21,9 @@ from typing_extensions import assert_never
21
21
  from pydantic_ai.builtin_tools import CodeExecutionTool, WebSearchTool
22
22
 
23
23
  from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage
24
+ from .._run_context import RunContext
24
25
  from .._utils import guard_tool_call_id as _guard_tool_call_id
26
+ from ..exceptions import UserError
25
27
  from ..messages import (
26
28
  BinaryContent,
27
29
  BuiltinToolCallPart,
@@ -196,13 +198,14 @@ class AnthropicModel(Model):
196
198
  messages: list[ModelMessage],
197
199
  model_settings: ModelSettings | None,
198
200
  model_request_parameters: ModelRequestParameters,
201
+ run_context: RunContext[Any] | None = None,
199
202
  ) -> AsyncIterator[StreamedResponse]:
200
203
  check_allow_model_requests()
201
204
  response = await self._messages_create(
202
205
  messages, True, cast(AnthropicModelSettings, model_settings or {}), model_request_parameters
203
206
  )
204
207
  async with response:
205
- yield await self._process_streamed_response(response)
208
+ yield await self._process_streamed_response(response, model_request_parameters)
206
209
 
207
210
  @property
208
211
  def model_name(self) -> AnthropicModelName:
@@ -329,7 +332,9 @@ class AnthropicModel(Model):
329
332
 
330
333
  return ModelResponse(items, usage=_map_usage(response), model_name=response.model, vendor_id=response.id)
331
334
 
332
- async def _process_streamed_response(self, response: AsyncStream[BetaRawMessageStreamEvent]) -> StreamedResponse:
335
+ async def _process_streamed_response(
336
+ self, response: AsyncStream[BetaRawMessageStreamEvent], model_request_parameters: ModelRequestParameters
337
+ ) -> StreamedResponse:
333
338
  peekable_response = _utils.PeekableAsyncStream(response)
334
339
  first_chunk = await peekable_response.peek()
335
340
  if isinstance(first_chunk, _utils.Unset):
@@ -338,14 +343,14 @@ class AnthropicModel(Model):
338
343
  # Since Anthropic doesn't provide a timestamp in the message, we'll use the current time
339
344
  timestamp = datetime.now(tz=timezone.utc)
340
345
  return AnthropicStreamedResponse(
341
- _model_name=self._model_name, _response=peekable_response, _timestamp=timestamp
346
+ model_request_parameters=model_request_parameters,
347
+ _model_name=self._model_name,
348
+ _response=peekable_response,
349
+ _timestamp=timestamp,
342
350
  )
343
351
 
344
352
  def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[BetaToolParam]:
345
- tools = [self._map_tool_definition(r) for r in model_request_parameters.function_tools]
346
- if model_request_parameters.output_tools:
347
- tools += [self._map_tool_definition(r) for r in model_request_parameters.output_tools]
348
- return tools
353
+ return [self._map_tool_definition(r) for r in model_request_parameters.tool_defs.values()]
349
354
 
350
355
  def _get_builtin_tools(self, model_request_parameters: ModelRequestParameters) -> list[BetaToolUnionParam]:
351
356
  tools: list[BetaToolUnionParam] = []
@@ -363,6 +368,10 @@ class AnthropicModel(Model):
363
368
  )
364
369
  elif isinstance(tool, CodeExecutionTool): # pragma: no branch
365
370
  tools.append(BetaCodeExecutionTool20250522Param(name='code_execution', type='code_execution_20250522'))
371
+ else: # pragma: no cover
372
+ raise UserError(
373
+ f'`{tool.__class__.__name__}` is not supported by `AnthropicModel`. If it should be, please file an issue.'
374
+ )
366
375
  return tools
367
376
 
368
377
  async def _map_message(self, messages: list[ModelMessage]) -> tuple[str, list[BetaMessageParam]]: # noqa: C901