pydantic-ai-slim 0.2.17__tar.gz → 0.2.19__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

Files changed (76) hide show
  1. {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/PKG-INFO +5 -5
  2. {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/_agent_graph.py +44 -14
  3. {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/_function_schema.py +2 -3
  4. {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/_output.py +1 -1
  5. {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/_system_prompt.py +1 -1
  6. {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/_utils.py +28 -3
  7. {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/agent.py +13 -3
  8. {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/mcp.py +145 -53
  9. {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/messages.py +4 -5
  10. {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/models/__init__.py +2 -2
  11. {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/models/anthropic.py +10 -6
  12. {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/models/gemini.py +1 -3
  13. {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/models/google.py +5 -3
  14. {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/models/openai.py +7 -1
  15. {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/providers/__init__.py +23 -17
  16. {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/providers/google.py +1 -1
  17. {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/tools.py +1 -2
  18. {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pyproject.toml +1 -1
  19. {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/.gitignore +0 -0
  20. {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/LICENSE +0 -0
  21. {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/README.md +0 -0
  22. {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/__init__.py +0 -0
  23. {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/__main__.py +0 -0
  24. {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/_a2a.py +0 -0
  25. {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/_cli.py +0 -0
  26. {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/_griffe.py +0 -0
  27. {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/_parts_manager.py +0 -0
  28. {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/common_tools/__init__.py +0 -0
  29. {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/common_tools/duckduckgo.py +0 -0
  30. {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/common_tools/tavily.py +0 -0
  31. {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/direct.py +0 -0
  32. {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/exceptions.py +0 -0
  33. {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/ext/__init__.py +0 -0
  34. {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/ext/langchain.py +0 -0
  35. {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/format_as_xml.py +0 -0
  36. {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/format_prompt.py +0 -0
  37. {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/models/bedrock.py +0 -0
  38. {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/models/cohere.py +0 -0
  39. {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/models/fallback.py +0 -0
  40. {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/models/function.py +0 -0
  41. {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/models/groq.py +0 -0
  42. {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/models/instrumented.py +0 -0
  43. {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/models/mistral.py +0 -0
  44. {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/models/test.py +0 -0
  45. {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/models/wrapper.py +0 -0
  46. {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/profiles/__init__.py +0 -0
  47. {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/profiles/_json_schema.py +0 -0
  48. {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/profiles/amazon.py +0 -0
  49. {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/profiles/anthropic.py +0 -0
  50. {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/profiles/cohere.py +0 -0
  51. {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/profiles/deepseek.py +0 -0
  52. {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/profiles/google.py +0 -0
  53. {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/profiles/grok.py +0 -0
  54. {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/profiles/meta.py +0 -0
  55. {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/profiles/mistral.py +0 -0
  56. {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/profiles/openai.py +0 -0
  57. {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/profiles/qwen.py +0 -0
  58. {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/providers/anthropic.py +0 -0
  59. {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/providers/azure.py +0 -0
  60. {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/providers/bedrock.py +0 -0
  61. {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/providers/cohere.py +0 -0
  62. {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/providers/deepseek.py +0 -0
  63. {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/providers/fireworks.py +0 -0
  64. {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/providers/google_gla.py +0 -0
  65. {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/providers/google_vertex.py +0 -0
  66. {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/providers/grok.py +0 -0
  67. {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/providers/groq.py +0 -0
  68. {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/providers/heroku.py +0 -0
  69. {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/providers/mistral.py +0 -0
  70. {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/providers/openai.py +0 -0
  71. {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/providers/openrouter.py +0 -0
  72. {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/providers/together.py +0 -0
  73. {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/py.typed +0 -0
  74. {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/result.py +0 -0
  75. {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/settings.py +0 -0
  76. {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/usage.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pydantic-ai-slim
3
- Version: 0.2.17
3
+ Version: 0.2.19
4
4
  Summary: Agent Framework / shim to use Pydantic with LLMs, slim package
5
5
  Author-email: Samuel Colvin <samuel@pydantic.dev>, Marcelo Trylesinski <marcelotryle@gmail.com>, David Montague <david@pydantic.dev>, Alex Hall <alex@pydantic.dev>
6
6
  License-Expression: MIT
@@ -30,11 +30,11 @@ Requires-Dist: exceptiongroup; python_version < '3.11'
30
30
  Requires-Dist: griffe>=1.3.2
31
31
  Requires-Dist: httpx>=0.27
32
32
  Requires-Dist: opentelemetry-api>=1.28.0
33
- Requires-Dist: pydantic-graph==0.2.17
33
+ Requires-Dist: pydantic-graph==0.2.19
34
34
  Requires-Dist: pydantic>=2.10
35
35
  Requires-Dist: typing-inspection>=0.4.0
36
36
  Provides-Extra: a2a
37
- Requires-Dist: fasta2a==0.2.17; extra == 'a2a'
37
+ Requires-Dist: fasta2a==0.2.19; extra == 'a2a'
38
38
  Provides-Extra: anthropic
39
39
  Requires-Dist: anthropic>=0.52.0; extra == 'anthropic'
40
40
  Provides-Extra: bedrock
@@ -48,7 +48,7 @@ Requires-Dist: cohere>=5.13.11; (platform_system != 'Emscripten') and extra == '
48
48
  Provides-Extra: duckduckgo
49
49
  Requires-Dist: duckduckgo-search>=7.0.0; extra == 'duckduckgo'
50
50
  Provides-Extra: evals
51
- Requires-Dist: pydantic-evals==0.2.17; extra == 'evals'
51
+ Requires-Dist: pydantic-evals==0.2.19; extra == 'evals'
52
52
  Provides-Extra: google
53
53
  Requires-Dist: google-genai>=1.15.0; extra == 'google'
54
54
  Provides-Extra: groq
@@ -56,7 +56,7 @@ Requires-Dist: groq>=0.15.0; extra == 'groq'
56
56
  Provides-Extra: logfire
57
57
  Requires-Dist: logfire>=3.11.0; extra == 'logfire'
58
58
  Provides-Extra: mcp
59
- Requires-Dist: mcp>=1.9.2; (python_version >= '3.10') and extra == 'mcp'
59
+ Requires-Dist: mcp>=1.9.4; (python_version >= '3.10') and extra == 'mcp'
60
60
  Provides-Extra: mistral
61
61
  Requires-Dist: mistralai>=1.2.5; extra == 'mistral'
62
62
  Provides-Extra: openai
@@ -12,18 +12,11 @@ from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Union, cast
12
12
  from opentelemetry.trace import Tracer
13
13
  from typing_extensions import TypeGuard, TypeVar, assert_never
14
14
 
15
+ from pydantic_ai._utils import is_async_callable, run_in_executor
15
16
  from pydantic_graph import BaseNode, Graph, GraphRunContext
16
17
  from pydantic_graph.nodes import End, NodeRunEndT
17
18
 
18
- from . import (
19
- _output,
20
- _system_prompt,
21
- exceptions,
22
- messages as _messages,
23
- models,
24
- result,
25
- usage as _usage,
26
- )
19
+ from . import _output, _system_prompt, exceptions, messages as _messages, models, result, usage as _usage
27
20
  from .result import OutputDataT
28
21
  from .settings import ModelSettings, merge_model_settings
29
22
  from .tools import RunContext, Tool, ToolDefinition, ToolsPrepareFunc
@@ -39,6 +32,7 @@ __all__ = (
39
32
  'CallToolsNode',
40
33
  'build_run_context',
41
34
  'capture_run_messages',
35
+ 'HistoryProcessor',
42
36
  )
43
37
 
44
38
 
@@ -54,6 +48,11 @@ EndStrategy = Literal['early', 'exhaustive']
54
48
  DepsT = TypeVar('DepsT')
55
49
  OutputT = TypeVar('OutputT')
56
50
 
51
+ _HistoryProcessorSync = Callable[[list[_messages.ModelMessage]], list[_messages.ModelMessage]]
52
+ _HistoryProcessorAsync = Callable[[list[_messages.ModelMessage]], Awaitable[list[_messages.ModelMessage]]]
53
+ HistoryProcessor = Union[_HistoryProcessorSync, _HistoryProcessorAsync]
54
+ """A function that processes a list of model messages and returns a list of model messages."""
55
+
57
56
 
58
57
  @dataclasses.dataclass
59
58
  class GraphAgentState:
@@ -93,6 +92,8 @@ class GraphAgentDeps(Generic[DepsT, OutputDataT]):
93
92
  output_schema: _output.OutputSchema[OutputDataT] | None
94
93
  output_validators: list[_output.OutputValidator[DepsT, OutputDataT]]
95
94
 
95
+ history_processors: Sequence[HistoryProcessor]
96
+
96
97
  function_tools: dict[str, Tool[DepsT]] = dataclasses.field(repr=False)
97
98
  mcp_servers: Sequence[MCPServer] = dataclasses.field(repr=False)
98
99
  default_retries: int
@@ -183,6 +184,16 @@ class UserPromptNode(AgentNode[DepsT, NodeRunEndT]):
183
184
 
184
185
  if user_prompt is not None:
185
186
  parts.append(_messages.UserPromptPart(user_prompt))
187
+ elif (
188
+ len(parts) == 0
189
+ and message_history
190
+ and (last_message := message_history[-1])
191
+ and isinstance(last_message, _messages.ModelRequest)
192
+ ):
193
+ # Drop last message that came from history and reuse its parts
194
+ messages.pop()
195
+ parts.extend(last_message.parts)
196
+
186
197
  return messages, _messages.ModelRequest(parts, instructions=instructions)
187
198
 
188
199
  async def _reevaluate_dynamic_prompts(
@@ -317,8 +328,9 @@ class ModelRequestNode(AgentNode[DepsT, NodeRunEndT]):
317
328
 
318
329
  model_settings, model_request_parameters = await self._prepare_request(ctx)
319
330
  model_request_parameters = ctx.deps.model.customize_request_parameters(model_request_parameters)
331
+ message_history = await _process_message_history(ctx.state.message_history, ctx.deps.history_processors)
320
332
  async with ctx.deps.model.request_stream(
321
- ctx.state.message_history, model_settings, model_request_parameters
333
+ message_history, model_settings, model_request_parameters
322
334
  ) as streamed_response:
323
335
  self._did_stream = True
324
336
  ctx.state.usage.requests += 1
@@ -340,9 +352,8 @@ class ModelRequestNode(AgentNode[DepsT, NodeRunEndT]):
340
352
 
341
353
  model_settings, model_request_parameters = await self._prepare_request(ctx)
342
354
  model_request_parameters = ctx.deps.model.customize_request_parameters(model_request_parameters)
343
- model_response = await ctx.deps.model.request(
344
- ctx.state.message_history, model_settings, model_request_parameters
345
- )
355
+ message_history = await _process_message_history(ctx.state.message_history, ctx.deps.history_processors)
356
+ model_response = await ctx.deps.model.request(message_history, model_settings, model_request_parameters)
346
357
  ctx.state.usage.incr(_usage.Usage())
347
358
 
348
359
  return self._finish_handling(ctx, model_response)
@@ -637,6 +648,7 @@ async def process_function_tools( # noqa C901
637
648
  # if tool_name is in output_schema, it means we found a output tool but an error occurred in
638
649
  # validation, we don't add another part here
639
650
  if output_tool_name is not None:
651
+ yield _messages.FunctionToolCallEvent(call)
640
652
  if found_used_output_tool:
641
653
  content = 'Output tool not used - a final result was already processed.'
642
654
  else:
@@ -647,9 +659,14 @@ async def process_function_tools( # noqa C901
647
659
  content=content,
648
660
  tool_call_id=call.tool_call_id,
649
661
  )
662
+ yield _messages.FunctionToolResultEvent(part, tool_call_id=call.tool_call_id)
650
663
  output_parts.append(part)
651
664
  else:
652
- output_parts.append(_unknown_tool(call.tool_name, call.tool_call_id, ctx))
665
+ yield _messages.FunctionToolCallEvent(call)
666
+
667
+ part = _unknown_tool(call.tool_name, call.tool_call_id, ctx)
668
+ yield _messages.FunctionToolResultEvent(part, tool_call_id=call.tool_call_id)
669
+ output_parts.append(part)
653
670
 
654
671
  if not calls_to_run:
655
672
  return
@@ -855,3 +872,16 @@ def build_agent_graph(
855
872
  auto_instrument=False,
856
873
  )
857
874
  return graph
875
+
876
+
877
+ async def _process_message_history(
878
+ messages: list[_messages.ModelMessage],
879
+ processors: Sequence[HistoryProcessor],
880
+ ) -> list[_messages.ModelMessage]:
881
+ """Process message history through a sequence of processors."""
882
+ for processor in processors:
883
+ if is_async_callable(processor):
884
+ messages = await processor(messages)
885
+ else:
886
+ messages = await run_in_executor(processor, messages)
887
+ return messages
@@ -5,7 +5,6 @@ This module has to use numerous internal Pydantic APIs and is therefore brittle
5
5
 
6
6
  from __future__ import annotations as _annotations
7
7
 
8
- import inspect
9
8
  from collections.abc import Awaitable
10
9
  from dataclasses import dataclass, field
11
10
  from inspect import Parameter, signature
@@ -23,7 +22,7 @@ from typing_extensions import get_origin
23
22
  from pydantic_ai.tools import RunContext
24
23
 
25
24
  from ._griffe import doc_descriptions
26
- from ._utils import check_object_json_schema, is_model_like, run_in_executor
25
+ from ._utils import check_object_json_schema, is_async_callable, is_model_like, run_in_executor
27
26
 
28
27
  if TYPE_CHECKING:
29
28
  from .tools import DocstringFormat, ObjectJsonSchema
@@ -214,7 +213,7 @@ def function_schema( # noqa: C901
214
213
  positional_fields=positional_fields,
215
214
  var_positional_field=var_positional_field,
216
215
  takes_ctx=takes_ctx,
217
- is_async=inspect.iscoroutinefunction(function),
216
+ is_async=is_async_callable(function),
218
217
  function=function,
219
218
  )
220
219
 
@@ -60,7 +60,7 @@ class OutputValidator(Generic[AgentDepsT, OutputDataT_inv]):
60
60
 
61
61
  def __post_init__(self):
62
62
  self._takes_ctx = len(inspect.signature(self.function).parameters) > 1
63
- self._is_async = inspect.iscoroutinefunction(self.function)
63
+ self._is_async = _utils.is_async_callable(self.function)
64
64
 
65
65
  async def validate(
66
66
  self,
@@ -18,7 +18,7 @@ class SystemPromptRunner(Generic[AgentDepsT]):
18
18
 
19
19
  def __post_init__(self):
20
20
  self._takes_ctx = len(inspect.signature(self.function).parameters) > 0
21
- self._is_async = inspect.iscoroutinefunction(self.function)
21
+ self._is_async = _utils.is_async_callable(self.function)
22
22
 
23
23
  async def run(self, run_context: RunContext[AgentDepsT]) -> str:
24
24
  if self._takes_ctx:
@@ -1,20 +1,22 @@
1
1
  from __future__ import annotations as _annotations
2
2
 
3
3
  import asyncio
4
+ import functools
5
+ import inspect
4
6
  import time
5
7
  import uuid
6
- from collections.abc import AsyncIterable, AsyncIterator, Iterator
8
+ from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Iterator
7
9
  from contextlib import asynccontextmanager, suppress
8
10
  from dataclasses import dataclass, fields, is_dataclass
9
11
  from datetime import datetime, timezone
10
12
  from functools import partial
11
13
  from types import GenericAlias
12
- from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union
14
+ from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union, overload
13
15
 
14
16
  from anyio.to_thread import run_sync
15
17
  from pydantic import BaseModel, TypeAdapter
16
18
  from pydantic.json_schema import JsonSchemaValue
17
- from typing_extensions import ParamSpec, TypeAlias, TypeGuard, is_typeddict
19
+ from typing_extensions import ParamSpec, TypeAlias, TypeGuard, TypeIs, is_typeddict
18
20
 
19
21
  from pydantic_graph._utils import AbstractSpan
20
22
 
@@ -302,3 +304,26 @@ def dataclasses_no_defaults_repr(self: Any) -> str:
302
304
 
303
305
  def number_to_datetime(x: int | float) -> datetime:
304
306
  return TypeAdapter(datetime).validate_python(x)
307
+
308
+
309
+ AwaitableCallable = Callable[..., Awaitable[T]]
310
+
311
+
312
+ @overload
313
+ def is_async_callable(obj: AwaitableCallable[T]) -> TypeIs[AwaitableCallable[T]]: ...
314
+
315
+
316
+ @overload
317
+ def is_async_callable(obj: Any) -> TypeIs[AwaitableCallable[Any]]: ...
318
+
319
+
320
+ def is_async_callable(obj: Any) -> Any:
321
+ """Correctly check if a callable is async.
322
+
323
+ This function was copied from Starlette:
324
+ https://github.com/encode/starlette/blob/78da9b9e218ab289117df7d62aee200ed4c59617/starlette/_utils.py#L36-L40
325
+ """
326
+ while isinstance(obj, functools.partial):
327
+ obj = obj.func
328
+
329
+ return inspect.iscoroutinefunction(obj) or (callable(obj) and inspect.iscoroutinefunction(obj.__call__)) # type: ignore
@@ -28,6 +28,7 @@ from . import (
28
28
  result,
29
29
  usage as _usage,
30
30
  )
31
+ from ._agent_graph import HistoryProcessor
31
32
  from .models.instrumented import InstrumentationSettings, InstrumentedModel, instrument_model
32
33
  from .result import FinalResult, OutputDataT, StreamedRunResult
33
34
  from .settings import ModelSettings, merge_model_settings
@@ -179,6 +180,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
179
180
  defer_model_check: bool = False,
180
181
  end_strategy: EndStrategy = 'early',
181
182
  instrument: InstrumentationSettings | bool | None = None,
183
+ history_processors: Sequence[HistoryProcessor] | None = None,
182
184
  ) -> None: ...
183
185
 
184
186
  @overload
@@ -208,6 +210,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
208
210
  defer_model_check: bool = False,
209
211
  end_strategy: EndStrategy = 'early',
210
212
  instrument: InstrumentationSettings | bool | None = None,
213
+ history_processors: Sequence[HistoryProcessor] | None = None,
211
214
  ) -> None: ...
212
215
 
213
216
  def __init__(
@@ -232,6 +235,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
232
235
  defer_model_check: bool = False,
233
236
  end_strategy: EndStrategy = 'early',
234
237
  instrument: InstrumentationSettings | bool | None = None,
238
+ history_processors: Sequence[HistoryProcessor] | None = None,
235
239
  **_deprecated_kwargs: Any,
236
240
  ):
237
241
  """Create an agent.
@@ -275,6 +279,9 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
275
279
  [`Agent.instrument_all()`][pydantic_ai.Agent.instrument_all]
276
280
  will be used, which defaults to False.
277
281
  See the [Debugging and Monitoring guide](https://ai.pydantic.dev/logfire/) for more info.
282
+ history_processors: Optional list of callables to process the message history before sending it to the model.
283
+ Each processor takes a list of messages and returns a modified list of messages.
284
+ Processors can be sync or async and are applied in sequence.
278
285
  """
279
286
  if model is None or defer_model_check:
280
287
  self.model = model
@@ -343,6 +350,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
343
350
  self._max_result_retries = output_retries if output_retries is not None else retries
344
351
  self._mcp_servers = mcp_servers
345
352
  self._prepare_tools = prepare_tools
353
+ self.history_processors = history_processors or []
346
354
  for tool in tools:
347
355
  if isinstance(tool, Tool):
348
356
  self._register_tool(tool)
@@ -669,10 +677,11 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
669
677
  if self._instructions is None and not self._instructions_functions:
670
678
  return None
671
679
 
672
- instructions = self._instructions or ''
680
+ instructions = [self._instructions] if self._instructions else []
673
681
  for instructions_runner in self._instructions_functions:
674
- instructions += '\n' + await instructions_runner.run(run_context)
675
- return instructions.strip()
682
+ instructions.append(await instructions_runner.run(run_context))
683
+ concatenated_instructions = '\n'.join(instruction for instruction in instructions if instruction)
684
+ return concatenated_instructions.strip() if concatenated_instructions else None
676
685
 
677
686
  # Copy the function tools so that retry state is agent-run-specific
678
687
  # Note that the retry count is reset to 0 when this happens due to the `default=0` and `init=False`.
@@ -689,6 +698,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
689
698
  end_strategy=self.end_strategy,
690
699
  output_schema=output_schema,
691
700
  output_validators=output_validators,
701
+ history_processors=self.history_processors,
692
702
  function_tools=run_function_tools,
693
703
  mcp_servers=self._mcp_servers,
694
704
  default_retries=self._default_retries,
@@ -5,25 +5,28 @@ import functools
5
5
  import json
6
6
  from abc import ABC, abstractmethod
7
7
  from collections.abc import AsyncIterator, Sequence
8
- from contextlib import AsyncExitStack, asynccontextmanager
8
+ from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager
9
9
  from dataclasses import dataclass
10
10
  from pathlib import Path
11
11
  from types import TracebackType
12
- from typing import Any
12
+ from typing import Any, Callable
13
13
 
14
14
  import anyio
15
15
  import httpx
16
16
  from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
17
+ from mcp.client.streamable_http import GetSessionIdCallback, streamablehttp_client
17
18
  from mcp.shared.message import SessionMessage
18
19
  from mcp.types import (
20
+ AudioContent,
19
21
  BlobResourceContents,
22
+ Content,
20
23
  EmbeddedResource,
21
24
  ImageContent,
22
25
  LoggingLevel,
23
26
  TextContent,
24
27
  TextResourceContents,
25
28
  )
26
- from typing_extensions import Self, assert_never
29
+ from typing_extensions import Self, assert_never, deprecated
27
30
 
28
31
  from pydantic_ai.exceptions import ModelRetry
29
32
  from pydantic_ai.messages import BinaryContent
@@ -39,7 +42,7 @@ except ImportError as _import_error:
39
42
  'you can use the `mcp` optional group — `pip install "pydantic-ai-slim[mcp]"`'
40
43
  ) from _import_error
41
44
 
42
- __all__ = 'MCPServer', 'MCPServerStdio', 'MCPServerHTTP'
45
+ __all__ = 'MCPServer', 'MCPServerStdio', 'MCPServerHTTP', 'MCPServerSSE', 'MCPServerStreamableHTTP'
43
46
 
44
47
 
45
48
  class MCPServer(ABC):
@@ -160,9 +163,7 @@ class MCPServer(ABC):
160
163
  await self._exit_stack.aclose()
161
164
  self.is_running = False
162
165
 
163
- def _map_tool_result_part(
164
- self, part: TextContent | ImageContent | EmbeddedResource
165
- ) -> str | BinaryContent | dict[str, Any] | list[Any]:
166
+ def _map_tool_result_part(self, part: Content) -> str | BinaryContent | dict[str, Any] | list[Any]:
166
167
  # See https://github.com/jlowin/fastmcp/blob/main/docs/servers/tools.mdx#return-values
167
168
 
168
169
  if isinstance(part, TextContent):
@@ -175,6 +176,10 @@ class MCPServer(ABC):
175
176
  return text
176
177
  elif isinstance(part, ImageContent):
177
178
  return BinaryContent(data=base64.b64decode(part.data), media_type=part.mimeType)
179
+ elif isinstance(part, AudioContent):
180
+ # NOTE: The FastMCP server doesn't support audio content.
181
+ # See <https://github.com/modelcontextprotocol/python-sdk/issues/952> for more details.
182
+ return BinaryContent(data=base64.b64decode(part.data), media_type=part.mimeType) # pragma: no cover
178
183
  elif isinstance(part, EmbeddedResource):
179
184
  resource = part.resource
180
185
  if isinstance(resource, TextResourceContents):
@@ -287,44 +292,12 @@ class MCPServerStdio(MCPServer):
287
292
 
288
293
 
289
294
  @dataclass
290
- class MCPServerHTTP(MCPServer):
291
- """An MCP server that connects over streamable HTTP connections.
292
-
293
- This class implements the SSE transport from the MCP specification.
294
- See <https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/transports/#http-with-sse> for more information.
295
-
296
- The name "HTTP" is used since this implemented will be adapted in future to use the new
297
- [Streamable HTTP](https://github.com/modelcontextprotocol/specification/pull/206) currently in development.
298
-
299
- !!! note
300
- Using this class as an async context manager will create a new pool of HTTP connections to connect
301
- to a server which should already be running.
302
-
303
- Example:
304
- ```python {py="3.10"}
305
- from pydantic_ai import Agent
306
- from pydantic_ai.mcp import MCPServerHTTP
307
-
308
- server = MCPServerHTTP('http://localhost:3001/sse') # (1)!
309
- agent = Agent('openai:gpt-4o', mcp_servers=[server])
310
-
311
- async def main():
312
- async with agent.run_mcp_servers(): # (2)!
313
- ...
314
- ```
315
-
316
- 1. E.g. you might be connecting to a server run with [`mcp-run-python`](../mcp/run-python.md).
317
- 2. This will connect to a server running on `localhost:3001`.
318
- """
319
-
295
+ class _MCPServerHTTP(MCPServer):
320
296
  url: str
321
- """The URL of the SSE endpoint on the MCP server.
322
-
323
- For example for a server running locally, this might be `http://localhost:3001/sse`.
324
- """
297
+ """The URL of the endpoint on the MCP server."""
325
298
 
326
299
  headers: dict[str, Any] | None = None
327
- """Optional HTTP headers to be sent with each request to the SSE endpoint.
300
+ """Optional HTTP headers to be sent with each request to the endpoint.
328
301
 
329
302
  These headers will be passed directly to the underlying `httpx.AsyncClient`.
330
303
  Useful for authentication, custom headers, or other HTTP-specific configurations.
@@ -336,22 +309,22 @@ class MCPServerHTTP(MCPServer):
336
309
  """
337
310
 
338
311
  http_client: httpx.AsyncClient | None = None
339
- """An `httpx.AsyncClient` to use with the SSE endpoint.
312
+ """An `httpx.AsyncClient` to use with the endpoint.
340
313
 
341
314
  This client may be configured to use customized connection parameters like self-signed certificates.
342
315
 
343
316
  !!! note
344
317
  You can either pass `headers` or `http_client`, but not both.
345
318
 
346
- If you want to use both, you can pass the headers to the `http_client` instead:
319
+ If you want to use both, you can pass the headers to the `http_client` instead.
347
320
 
348
- ```python {py="3.10"}
321
+ ```python {py="3.10" test="skip"}
349
322
  import httpx
350
323
 
351
- from pydantic_ai.mcp import MCPServerHTTP
324
+ from pydantic_ai.mcp import MCPServerSSE
352
325
 
353
326
  http_client = httpx.AsyncClient(headers={'Authorization': 'Bearer ...'})
354
- server = MCPServerHTTP('http://localhost:3001/sse', http_client=http_client)
327
+ server = MCPServerSSE('http://localhost:3001/sse', http_client=http_client)
355
328
  ```
356
329
  """
357
330
 
@@ -369,10 +342,11 @@ class MCPServerHTTP(MCPServer):
369
342
  If no new messages are received within this time, the connection will be considered stale
370
343
  and may be closed. Defaults to 5 minutes (300 seconds).
371
344
  """
345
+
372
346
  log_level: LoggingLevel | None = None
373
347
  """The log level to set when connecting to the server, if any.
374
348
 
375
- See <https://modelcontextprotocol.io/specification/2025-03-26/server/utilities/logging#logging> for more details.
349
+ See <https://modelcontextprotocol.io/introduction#logging> for more details.
376
350
 
377
351
  If `None`, no log level will be set.
378
352
  """
@@ -385,6 +359,27 @@ class MCPServerHTTP(MCPServer):
385
359
  For example, if `tool_prefix='foo'`, then a tool named `bar` will be registered as `foo_bar`
386
360
  """
387
361
 
362
+ @property
363
+ @abstractmethod
364
+ def _transport_client(
365
+ self,
366
+ ) -> Callable[
367
+ ...,
368
+ AbstractAsyncContextManager[
369
+ tuple[
370
+ MemoryObjectReceiveStream[SessionMessage | Exception],
371
+ MemoryObjectSendStream[SessionMessage],
372
+ GetSessionIdCallback,
373
+ ],
374
+ ]
375
+ | AbstractAsyncContextManager[
376
+ tuple[
377
+ MemoryObjectReceiveStream[SessionMessage | Exception],
378
+ MemoryObjectSendStream[SessionMessage],
379
+ ]
380
+ ],
381
+ ]: ...
382
+
388
383
  @asynccontextmanager
389
384
  async def client_streams(
390
385
  self,
@@ -394,8 +389,8 @@ class MCPServerHTTP(MCPServer):
394
389
  if self.http_client and self.headers:
395
390
  raise ValueError('`http_client` is mutually exclusive with `headers`.')
396
391
 
397
- sse_client_partial = functools.partial(
398
- sse_client,
392
+ transport_client_partial = functools.partial(
393
+ self._transport_client,
399
394
  url=self.url,
400
395
  timeout=self.timeout,
401
396
  sse_read_timeout=self.sse_read_timeout,
@@ -411,17 +406,114 @@ class MCPServerHTTP(MCPServer):
411
406
  assert self.http_client is not None
412
407
  return self.http_client
413
408
 
414
- async with sse_client_partial(httpx_client_factory=httpx_client_factory) as (read_stream, write_stream):
409
+ async with transport_client_partial(httpx_client_factory=httpx_client_factory) as (
410
+ read_stream,
411
+ write_stream,
412
+ *_,
413
+ ):
415
414
  yield read_stream, write_stream
416
415
  else:
417
- async with sse_client_partial(headers=self.headers) as (read_stream, write_stream):
416
+ async with transport_client_partial(headers=self.headers) as (read_stream, write_stream, *_):
418
417
  yield read_stream, write_stream
419
418
 
420
419
  def _get_log_level(self) -> LoggingLevel | None:
421
420
  return self.log_level
422
421
 
423
422
  def __repr__(self) -> str: # pragma: no cover
424
- return f'MCPServerHTTP(url={self.url!r}, tool_prefix={self.tool_prefix!r})'
423
+ return f'{self.__class__.__name__}(url={self.url!r}, tool_prefix={self.tool_prefix!r})'
425
424
 
426
425
  def _get_client_initialize_timeout(self) -> float: # pragma: no cover
427
426
  return self.timeout
427
+
428
+
429
+ @dataclass
430
+ class MCPServerSSE(_MCPServerHTTP):
431
+ """An MCP server that connects over streamable HTTP connections.
432
+
433
+ This class implements the SSE transport from the MCP specification.
434
+ See <https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/transports/#http-with-sse> for more information.
435
+
436
+ !!! note
437
+ Using this class as an async context manager will create a new pool of HTTP connections to connect
438
+ to a server which should already be running.
439
+
440
+ Example:
441
+ ```python {py="3.10"}
442
+ from pydantic_ai import Agent
443
+ from pydantic_ai.mcp import MCPServerSSE
444
+
445
+ server = MCPServerSSE('http://localhost:3001/sse') # (1)!
446
+ agent = Agent('openai:gpt-4o', mcp_servers=[server])
447
+
448
+ async def main():
449
+ async with agent.run_mcp_servers(): # (2)!
450
+ ...
451
+ ```
452
+
453
+ 1. E.g. you might be connecting to a server run with [`mcp-run-python`](../mcp/run-python.md).
454
+ 2. This will connect to a server running on `localhost:3001`.
455
+ """
456
+
457
+ @property
458
+ def _transport_client(self):
459
+ return sse_client # pragma: no cover
460
+
461
+
462
+ @deprecated('The `MCPServerHTTP` class is deprecated, use `MCPServerSSE` instead.')
463
+ @dataclass
464
+ class MCPServerHTTP(MCPServerSSE):
465
+ """An MCP server that connects over HTTP using the old SSE transport.
466
+
467
+ This class implements the SSE transport from the MCP specification.
468
+ See <https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/transports/#http-with-sse> for more information.
469
+
470
+ !!! note
471
+ Using this class as an async context manager will create a new pool of HTTP connections to connect
472
+ to a server which should already be running.
473
+
474
+ Example:
475
+ ```python {py="3.10" test="skip"}
476
+ from pydantic_ai import Agent
477
+ from pydantic_ai.mcp import MCPServerHTTP
478
+
479
+ server = MCPServerHTTP('http://localhost:3001/sse') # (1)!
480
+ agent = Agent('openai:gpt-4o', mcp_servers=[server])
481
+
482
+ async def main():
483
+ async with agent.run_mcp_servers(): # (2)!
484
+ ...
485
+ ```
486
+
487
+ 1. E.g. you might be connecting to a server run with [`mcp-run-python`](../mcp/run-python.md).
488
+ 2. This will connect to a server running on `localhost:3001`.
489
+ """
490
+
491
+
492
+ @dataclass
493
+ class MCPServerStreamableHTTP(_MCPServerHTTP):
494
+ """An MCP server that connects over HTTP using the Streamable HTTP transport.
495
+
496
+ This class implements the Streamable HTTP transport from the MCP specification.
497
+ See <https://modelcontextprotocol.io/introduction#streamable-http> for more information.
498
+
499
+ !!! note
500
+ Using this class as an async context manager will create a new pool of HTTP connections to connect
501
+ to a server which should already be running.
502
+
503
+ Example:
504
+ ```python {py="3.10"}
505
+ from pydantic_ai import Agent
506
+ from pydantic_ai.mcp import MCPServerStreamableHTTP
507
+
508
+ server = MCPServerStreamableHTTP('http://localhost:8000/mcp') # (1)!
509
+ agent = Agent('openai:gpt-4o', mcp_servers=[server])
510
+
511
+ async def main():
512
+ async with agent.run_mcp_servers(): # (2)!
513
+ ...
514
+ ```
515
+ """
516
+
517
+ @property
518
+ def _transport_client(self):
519
+ return streamablehttp_client # pragma: no cover
@@ -1,7 +1,6 @@
1
1
  from __future__ import annotations as _annotations
2
2
 
3
3
  import base64
4
- import uuid
5
4
  from abc import ABC, abstractmethod
6
5
  from collections.abc import Sequence
7
6
  from dataclasses import dataclass, field, replace
@@ -888,13 +887,13 @@ class FunctionToolCallEvent:
888
887
 
889
888
  part: ToolCallPart
890
889
  """The (function) tool call to make."""
891
- call_id: str = field(init=False)
892
- """An ID used for matching details about the call to its result. If present, defaults to the part's tool_call_id."""
893
890
  event_kind: Literal['function_tool_call'] = 'function_tool_call'
894
891
  """Event type identifier, used as a discriminator."""
895
892
 
896
- def __post_init__(self):
897
- self.call_id = self.part.tool_call_id or str(uuid.uuid4())
893
+ @property
894
+ def call_id(self) -> str:
895
+ """An ID used for matching details about the call to its result. If present, defaults to the part's tool_call_id."""
896
+ return self.part.tool_call_id
898
897
 
899
898
  __repr__ = _utils.dataclasses_no_defaults_repr
900
899
 
@@ -555,9 +555,9 @@ def infer_model(model: Model | KnownModelName | str) -> Model:
555
555
 
556
556
  return OpenAIModel(model_name, provider=provider)
557
557
  elif provider in ('google-gla', 'google-vertex'):
558
- from .gemini import GeminiModel
558
+ from .google import GoogleModel
559
559
 
560
- return GeminiModel(model_name, provider=provider)
560
+ return GoogleModel(model_name, provider=provider)
561
561
  elif provider == 'groq':
562
562
  from .groq import GroqModel
563
563
 
@@ -220,7 +220,7 @@ class AnthropicModel(Model):
220
220
  extra_headers = model_settings.get('extra_headers', {})
221
221
  extra_headers.setdefault('User-Agent', get_user_agent())
222
222
  return await self.client.beta.messages.create(
223
- max_tokens=model_settings.get('max_tokens', 1024),
223
+ max_tokens=model_settings.get('max_tokens', 4096),
224
224
  system=system_prompt or NOT_GIVEN,
225
225
  messages=anthropic_messages,
226
226
  model=self._model_name,
@@ -276,7 +276,7 @@ class AnthropicModel(Model):
276
276
  tools += [self._map_tool_definition(r) for r in model_request_parameters.output_tools]
277
277
  return tools
278
278
 
279
- async def _map_message(self, messages: list[ModelMessage]) -> tuple[str, list[BetaMessageParam]]:
279
+ async def _map_message(self, messages: list[ModelMessage]) -> tuple[str, list[BetaMessageParam]]: # noqa: C901
280
280
  """Just maps a `pydantic_ai.Message` to a `anthropic.types.MessageParam`."""
281
281
  system_prompt_parts: list[str] = []
282
282
  anthropic_messages: list[BetaMessageParam] = []
@@ -315,7 +315,8 @@ class AnthropicModel(Model):
315
315
  assistant_content_params: list[BetaTextBlockParam | BetaToolUseBlockParam] = []
316
316
  for response_part in m.parts:
317
317
  if isinstance(response_part, TextPart):
318
- assistant_content_params.append(BetaTextBlockParam(text=response_part.content, type='text'))
318
+ if response_part.content: # Only add non-empty text
319
+ assistant_content_params.append(BetaTextBlockParam(text=response_part.content, type='text'))
319
320
  else:
320
321
  tool_use_block_param = BetaToolUseBlockParam(
321
322
  id=_guard_tool_call_id(t=response_part),
@@ -324,7 +325,8 @@ class AnthropicModel(Model):
324
325
  input=response_part.args_as_dict(),
325
326
  )
326
327
  assistant_content_params.append(tool_use_block_param)
327
- anthropic_messages.append(BetaMessageParam(role='assistant', content=assistant_content_params))
328
+ if len(assistant_content_params) > 0:
329
+ anthropic_messages.append(BetaMessageParam(role='assistant', content=assistant_content_params))
328
330
  else:
329
331
  assert_never(m)
330
332
  system_prompt = '\n\n'.join(system_prompt_parts)
@@ -337,11 +339,13 @@ class AnthropicModel(Model):
337
339
  part: UserPromptPart,
338
340
  ) -> AsyncGenerator[BetaContentBlockParam]:
339
341
  if isinstance(part.content, str):
340
- yield BetaTextBlockParam(text=part.content, type='text')
342
+ if part.content: # Only yield non-empty text
343
+ yield BetaTextBlockParam(text=part.content, type='text')
341
344
  else:
342
345
  for item in part.content:
343
346
  if isinstance(item, str):
344
- yield BetaTextBlockParam(text=item, type='text')
347
+ if item: # Only yield non-empty text
348
+ yield BetaTextBlockParam(text=item, type='text')
345
349
  elif isinstance(item, BinaryContent):
346
350
  if item.is_image:
347
351
  yield BetaImageBlockParam(
@@ -723,9 +723,7 @@ class _GeminiFunction(TypedDict):
723
723
 
724
724
  def _function_from_abstract_tool(tool: ToolDefinition) -> _GeminiFunction:
725
725
  json_schema = tool.parameters_json_schema
726
- f = _GeminiFunction(name=tool.name, description=tool.description)
727
- if json_schema.get('properties'):
728
- f['parameters'] = json_schema
726
+ f = _GeminiFunction(name=tool.name, description=tool.description, parameters=json_schema)
729
727
  return f
730
728
 
731
729
 
@@ -469,9 +469,11 @@ def _process_response_from_parts(
469
469
 
470
470
  def _function_declaration_from_tool(tool: ToolDefinition) -> FunctionDeclarationDict:
471
471
  json_schema = tool.parameters_json_schema
472
- f = FunctionDeclarationDict(name=tool.name, description=tool.description)
473
- if json_schema.get('properties'): # pragma: no branch
474
- f['parameters'] = json_schema # type: ignore
472
+ f = FunctionDeclarationDict(
473
+ name=tool.name,
474
+ description=tool.description,
475
+ parameters=json_schema, # type: ignore
476
+ )
475
477
  return f
476
478
 
477
479
 
@@ -613,7 +613,13 @@ class OpenAIResponsesModel(Model):
613
613
  for item in response.output:
614
614
  if item.type == 'function_call':
615
615
  items.append(ToolCallPart(item.name, item.arguments, tool_call_id=item.call_id))
616
- return ModelResponse(items, usage=_map_usage(response), model_name=response.model, timestamp=timestamp)
616
+ return ModelResponse(
617
+ items,
618
+ usage=_map_usage(response),
619
+ model_name=response.model,
620
+ vendor_id=response.id,
621
+ timestamp=timestamp,
622
+ )
617
623
 
618
624
  async def _process_streamed_response(
619
625
  self, response: AsyncStream[responses.ResponseStreamEvent]
@@ -48,68 +48,74 @@ class Provider(ABC, Generic[InterfaceClient]):
48
48
  return None # pragma: no cover
49
49
 
50
50
 
51
- def infer_provider(provider: str) -> Provider[Any]: # noqa: C901
52
- """Infer the provider from the provider name."""
51
+ def infer_provider_class(provider: str) -> type[Provider[Any]]: # noqa: C901
52
+ """Infers the provider class from the provider name."""
53
53
  if provider == 'openai':
54
54
  from .openai import OpenAIProvider
55
55
 
56
- return OpenAIProvider()
56
+ return OpenAIProvider
57
57
  elif provider == 'deepseek':
58
58
  from .deepseek import DeepSeekProvider
59
59
 
60
- return DeepSeekProvider()
60
+ return DeepSeekProvider
61
61
  elif provider == 'openrouter':
62
62
  from .openrouter import OpenRouterProvider
63
63
 
64
- return OpenRouterProvider()
64
+ return OpenRouterProvider
65
65
  elif provider == 'azure':
66
66
  from .azure import AzureProvider
67
67
 
68
- return AzureProvider()
68
+ return AzureProvider
69
69
  elif provider == 'google-vertex':
70
70
  from .google_vertex import GoogleVertexProvider
71
71
 
72
- return GoogleVertexProvider()
72
+ return GoogleVertexProvider
73
73
  elif provider == 'google-gla':
74
74
  from .google_gla import GoogleGLAProvider
75
75
 
76
- return GoogleGLAProvider()
76
+ return GoogleGLAProvider
77
77
  # NOTE: We don't test because there are many ways the `boto3.client` can retrieve the credentials.
78
78
  elif provider == 'bedrock':
79
79
  from .bedrock import BedrockProvider
80
80
 
81
- return BedrockProvider()
81
+ return BedrockProvider
82
82
  elif provider == 'groq':
83
83
  from .groq import GroqProvider
84
84
 
85
- return GroqProvider()
85
+ return GroqProvider
86
86
  elif provider == 'anthropic':
87
87
  from .anthropic import AnthropicProvider
88
88
 
89
- return AnthropicProvider()
89
+ return AnthropicProvider
90
90
  elif provider == 'mistral':
91
91
  from .mistral import MistralProvider
92
92
 
93
- return MistralProvider()
93
+ return MistralProvider
94
94
  elif provider == 'cohere':
95
95
  from .cohere import CohereProvider
96
96
 
97
- return CohereProvider()
97
+ return CohereProvider
98
98
  elif provider == 'grok':
99
99
  from .grok import GrokProvider
100
100
 
101
- return GrokProvider()
101
+ return GrokProvider
102
102
  elif provider == 'fireworks':
103
103
  from .fireworks import FireworksProvider
104
104
 
105
- return FireworksProvider()
105
+ return FireworksProvider
106
106
  elif provider == 'together':
107
107
  from .together import TogetherProvider
108
108
 
109
- return TogetherProvider()
109
+ return TogetherProvider
110
110
  elif provider == 'heroku':
111
111
  from .heroku import HerokuProvider
112
112
 
113
- return HerokuProvider()
113
+ return HerokuProvider
114
114
  else: # pragma: no cover
115
115
  raise ValueError(f'Unknown provider: {provider}')
116
+
117
+
118
+ def infer_provider(provider: str) -> Provider[Any]:
119
+ """Infer the provider from the provider name."""
120
+ provider_class = infer_provider_class(provider)
121
+ return provider_class()
@@ -84,7 +84,7 @@ class GoogleProvider(Provider[genai.Client]):
84
84
  """
85
85
  if client is None:
86
86
  # NOTE: We are keeping GEMINI_API_KEY for backwards compatibility.
87
- api_key = api_key or os.environ.get('GOOGLE_API_KEY')
87
+ api_key = api_key or os.getenv('GOOGLE_API_KEY') or os.getenv('GEMINI_API_KEY')
88
88
 
89
89
  if vertexai is None: # pragma: lax no cover
90
90
  vertexai = bool(location or project or credentials)
@@ -1,6 +1,5 @@
1
1
  from __future__ import annotations as _annotations
2
2
 
3
- import asyncio
4
3
  import dataclasses
5
4
  import json
6
5
  from collections.abc import Awaitable, Sequence
@@ -337,7 +336,7 @@ class Tool(Generic[AgentDepsT]):
337
336
  validator=SchemaValidator(schema=core_schema.any_schema()),
338
337
  json_schema=json_schema,
339
338
  takes_ctx=False,
340
- is_async=asyncio.iscoroutinefunction(function),
339
+ is_async=_utils.is_async_callable(function),
341
340
  )
342
341
 
343
342
  return cls(
@@ -75,7 +75,7 @@ tavily = ["tavily-python>=0.5.0"]
75
75
  # CLI
76
76
  cli = ["rich>=13", "prompt-toolkit>=3", "argcomplete>=3.5.0"]
77
77
  # MCP
78
- mcp = ["mcp>=1.9.2; python_version >= '3.10'"]
78
+ mcp = ["mcp>=1.9.4; python_version >= '3.10'"]
79
79
  # Evals
80
80
  evals = ["pydantic-evals=={{ version }}"]
81
81
  # A2A