pydantic-ai-slim 0.2.17__tar.gz → 0.2.19__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/PKG-INFO +5 -5
- {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/_agent_graph.py +44 -14
- {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/_function_schema.py +2 -3
- {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/_output.py +1 -1
- {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/_system_prompt.py +1 -1
- {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/_utils.py +28 -3
- {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/agent.py +13 -3
- {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/mcp.py +145 -53
- {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/messages.py +4 -5
- {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/models/__init__.py +2 -2
- {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/models/anthropic.py +10 -6
- {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/models/gemini.py +1 -3
- {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/models/google.py +5 -3
- {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/models/openai.py +7 -1
- {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/providers/__init__.py +23 -17
- {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/providers/google.py +1 -1
- {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/tools.py +1 -2
- {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pyproject.toml +1 -1
- {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/.gitignore +0 -0
- {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/LICENSE +0 -0
- {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/README.md +0 -0
- {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/__init__.py +0 -0
- {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/__main__.py +0 -0
- {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/_a2a.py +0 -0
- {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/_cli.py +0 -0
- {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/_griffe.py +0 -0
- {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/_parts_manager.py +0 -0
- {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/common_tools/__init__.py +0 -0
- {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/common_tools/duckduckgo.py +0 -0
- {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/common_tools/tavily.py +0 -0
- {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/direct.py +0 -0
- {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/exceptions.py +0 -0
- {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/ext/__init__.py +0 -0
- {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/ext/langchain.py +0 -0
- {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/format_as_xml.py +0 -0
- {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/format_prompt.py +0 -0
- {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/models/bedrock.py +0 -0
- {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/models/cohere.py +0 -0
- {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/models/fallback.py +0 -0
- {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/models/function.py +0 -0
- {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/models/groq.py +0 -0
- {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/models/instrumented.py +0 -0
- {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/models/mistral.py +0 -0
- {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/models/test.py +0 -0
- {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/models/wrapper.py +0 -0
- {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/profiles/__init__.py +0 -0
- {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/profiles/_json_schema.py +0 -0
- {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/profiles/amazon.py +0 -0
- {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/profiles/anthropic.py +0 -0
- {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/profiles/cohere.py +0 -0
- {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/profiles/deepseek.py +0 -0
- {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/profiles/google.py +0 -0
- {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/profiles/grok.py +0 -0
- {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/profiles/meta.py +0 -0
- {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/profiles/mistral.py +0 -0
- {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/profiles/openai.py +0 -0
- {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/profiles/qwen.py +0 -0
- {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/providers/anthropic.py +0 -0
- {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/providers/azure.py +0 -0
- {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/providers/bedrock.py +0 -0
- {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/providers/cohere.py +0 -0
- {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/providers/deepseek.py +0 -0
- {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/providers/fireworks.py +0 -0
- {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/providers/google_gla.py +0 -0
- {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/providers/google_vertex.py +0 -0
- {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/providers/grok.py +0 -0
- {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/providers/groq.py +0 -0
- {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/providers/heroku.py +0 -0
- {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/providers/mistral.py +0 -0
- {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/providers/openai.py +0 -0
- {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/providers/openrouter.py +0 -0
- {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/providers/together.py +0 -0
- {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/py.typed +0 -0
- {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/result.py +0 -0
- {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/settings.py +0 -0
- {pydantic_ai_slim-0.2.17 → pydantic_ai_slim-0.2.19}/pydantic_ai/usage.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: pydantic-ai-slim
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.19
|
|
4
4
|
Summary: Agent Framework / shim to use Pydantic with LLMs, slim package
|
|
5
5
|
Author-email: Samuel Colvin <samuel@pydantic.dev>, Marcelo Trylesinski <marcelotryle@gmail.com>, David Montague <david@pydantic.dev>, Alex Hall <alex@pydantic.dev>
|
|
6
6
|
License-Expression: MIT
|
|
@@ -30,11 +30,11 @@ Requires-Dist: exceptiongroup; python_version < '3.11'
|
|
|
30
30
|
Requires-Dist: griffe>=1.3.2
|
|
31
31
|
Requires-Dist: httpx>=0.27
|
|
32
32
|
Requires-Dist: opentelemetry-api>=1.28.0
|
|
33
|
-
Requires-Dist: pydantic-graph==0.2.
|
|
33
|
+
Requires-Dist: pydantic-graph==0.2.19
|
|
34
34
|
Requires-Dist: pydantic>=2.10
|
|
35
35
|
Requires-Dist: typing-inspection>=0.4.0
|
|
36
36
|
Provides-Extra: a2a
|
|
37
|
-
Requires-Dist: fasta2a==0.2.
|
|
37
|
+
Requires-Dist: fasta2a==0.2.19; extra == 'a2a'
|
|
38
38
|
Provides-Extra: anthropic
|
|
39
39
|
Requires-Dist: anthropic>=0.52.0; extra == 'anthropic'
|
|
40
40
|
Provides-Extra: bedrock
|
|
@@ -48,7 +48,7 @@ Requires-Dist: cohere>=5.13.11; (platform_system != 'Emscripten') and extra == '
|
|
|
48
48
|
Provides-Extra: duckduckgo
|
|
49
49
|
Requires-Dist: duckduckgo-search>=7.0.0; extra == 'duckduckgo'
|
|
50
50
|
Provides-Extra: evals
|
|
51
|
-
Requires-Dist: pydantic-evals==0.2.
|
|
51
|
+
Requires-Dist: pydantic-evals==0.2.19; extra == 'evals'
|
|
52
52
|
Provides-Extra: google
|
|
53
53
|
Requires-Dist: google-genai>=1.15.0; extra == 'google'
|
|
54
54
|
Provides-Extra: groq
|
|
@@ -56,7 +56,7 @@ Requires-Dist: groq>=0.15.0; extra == 'groq'
|
|
|
56
56
|
Provides-Extra: logfire
|
|
57
57
|
Requires-Dist: logfire>=3.11.0; extra == 'logfire'
|
|
58
58
|
Provides-Extra: mcp
|
|
59
|
-
Requires-Dist: mcp>=1.9.
|
|
59
|
+
Requires-Dist: mcp>=1.9.4; (python_version >= '3.10') and extra == 'mcp'
|
|
60
60
|
Provides-Extra: mistral
|
|
61
61
|
Requires-Dist: mistralai>=1.2.5; extra == 'mistral'
|
|
62
62
|
Provides-Extra: openai
|
|
@@ -12,18 +12,11 @@ from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Union, cast
|
|
|
12
12
|
from opentelemetry.trace import Tracer
|
|
13
13
|
from typing_extensions import TypeGuard, TypeVar, assert_never
|
|
14
14
|
|
|
15
|
+
from pydantic_ai._utils import is_async_callable, run_in_executor
|
|
15
16
|
from pydantic_graph import BaseNode, Graph, GraphRunContext
|
|
16
17
|
from pydantic_graph.nodes import End, NodeRunEndT
|
|
17
18
|
|
|
18
|
-
from . import
|
|
19
|
-
_output,
|
|
20
|
-
_system_prompt,
|
|
21
|
-
exceptions,
|
|
22
|
-
messages as _messages,
|
|
23
|
-
models,
|
|
24
|
-
result,
|
|
25
|
-
usage as _usage,
|
|
26
|
-
)
|
|
19
|
+
from . import _output, _system_prompt, exceptions, messages as _messages, models, result, usage as _usage
|
|
27
20
|
from .result import OutputDataT
|
|
28
21
|
from .settings import ModelSettings, merge_model_settings
|
|
29
22
|
from .tools import RunContext, Tool, ToolDefinition, ToolsPrepareFunc
|
|
@@ -39,6 +32,7 @@ __all__ = (
|
|
|
39
32
|
'CallToolsNode',
|
|
40
33
|
'build_run_context',
|
|
41
34
|
'capture_run_messages',
|
|
35
|
+
'HistoryProcessor',
|
|
42
36
|
)
|
|
43
37
|
|
|
44
38
|
|
|
@@ -54,6 +48,11 @@ EndStrategy = Literal['early', 'exhaustive']
|
|
|
54
48
|
DepsT = TypeVar('DepsT')
|
|
55
49
|
OutputT = TypeVar('OutputT')
|
|
56
50
|
|
|
51
|
+
_HistoryProcessorSync = Callable[[list[_messages.ModelMessage]], list[_messages.ModelMessage]]
|
|
52
|
+
_HistoryProcessorAsync = Callable[[list[_messages.ModelMessage]], Awaitable[list[_messages.ModelMessage]]]
|
|
53
|
+
HistoryProcessor = Union[_HistoryProcessorSync, _HistoryProcessorAsync]
|
|
54
|
+
"""A function that processes a list of model messages and returns a list of model messages."""
|
|
55
|
+
|
|
57
56
|
|
|
58
57
|
@dataclasses.dataclass
|
|
59
58
|
class GraphAgentState:
|
|
@@ -93,6 +92,8 @@ class GraphAgentDeps(Generic[DepsT, OutputDataT]):
|
|
|
93
92
|
output_schema: _output.OutputSchema[OutputDataT] | None
|
|
94
93
|
output_validators: list[_output.OutputValidator[DepsT, OutputDataT]]
|
|
95
94
|
|
|
95
|
+
history_processors: Sequence[HistoryProcessor]
|
|
96
|
+
|
|
96
97
|
function_tools: dict[str, Tool[DepsT]] = dataclasses.field(repr=False)
|
|
97
98
|
mcp_servers: Sequence[MCPServer] = dataclasses.field(repr=False)
|
|
98
99
|
default_retries: int
|
|
@@ -183,6 +184,16 @@ class UserPromptNode(AgentNode[DepsT, NodeRunEndT]):
|
|
|
183
184
|
|
|
184
185
|
if user_prompt is not None:
|
|
185
186
|
parts.append(_messages.UserPromptPart(user_prompt))
|
|
187
|
+
elif (
|
|
188
|
+
len(parts) == 0
|
|
189
|
+
and message_history
|
|
190
|
+
and (last_message := message_history[-1])
|
|
191
|
+
and isinstance(last_message, _messages.ModelRequest)
|
|
192
|
+
):
|
|
193
|
+
# Drop last message that came from history and reuse its parts
|
|
194
|
+
messages.pop()
|
|
195
|
+
parts.extend(last_message.parts)
|
|
196
|
+
|
|
186
197
|
return messages, _messages.ModelRequest(parts, instructions=instructions)
|
|
187
198
|
|
|
188
199
|
async def _reevaluate_dynamic_prompts(
|
|
@@ -317,8 +328,9 @@ class ModelRequestNode(AgentNode[DepsT, NodeRunEndT]):
|
|
|
317
328
|
|
|
318
329
|
model_settings, model_request_parameters = await self._prepare_request(ctx)
|
|
319
330
|
model_request_parameters = ctx.deps.model.customize_request_parameters(model_request_parameters)
|
|
331
|
+
message_history = await _process_message_history(ctx.state.message_history, ctx.deps.history_processors)
|
|
320
332
|
async with ctx.deps.model.request_stream(
|
|
321
|
-
|
|
333
|
+
message_history, model_settings, model_request_parameters
|
|
322
334
|
) as streamed_response:
|
|
323
335
|
self._did_stream = True
|
|
324
336
|
ctx.state.usage.requests += 1
|
|
@@ -340,9 +352,8 @@ class ModelRequestNode(AgentNode[DepsT, NodeRunEndT]):
|
|
|
340
352
|
|
|
341
353
|
model_settings, model_request_parameters = await self._prepare_request(ctx)
|
|
342
354
|
model_request_parameters = ctx.deps.model.customize_request_parameters(model_request_parameters)
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
)
|
|
355
|
+
message_history = await _process_message_history(ctx.state.message_history, ctx.deps.history_processors)
|
|
356
|
+
model_response = await ctx.deps.model.request(message_history, model_settings, model_request_parameters)
|
|
346
357
|
ctx.state.usage.incr(_usage.Usage())
|
|
347
358
|
|
|
348
359
|
return self._finish_handling(ctx, model_response)
|
|
@@ -637,6 +648,7 @@ async def process_function_tools( # noqa C901
|
|
|
637
648
|
# if tool_name is in output_schema, it means we found a output tool but an error occurred in
|
|
638
649
|
# validation, we don't add another part here
|
|
639
650
|
if output_tool_name is not None:
|
|
651
|
+
yield _messages.FunctionToolCallEvent(call)
|
|
640
652
|
if found_used_output_tool:
|
|
641
653
|
content = 'Output tool not used - a final result was already processed.'
|
|
642
654
|
else:
|
|
@@ -647,9 +659,14 @@ async def process_function_tools( # noqa C901
|
|
|
647
659
|
content=content,
|
|
648
660
|
tool_call_id=call.tool_call_id,
|
|
649
661
|
)
|
|
662
|
+
yield _messages.FunctionToolResultEvent(part, tool_call_id=call.tool_call_id)
|
|
650
663
|
output_parts.append(part)
|
|
651
664
|
else:
|
|
652
|
-
|
|
665
|
+
yield _messages.FunctionToolCallEvent(call)
|
|
666
|
+
|
|
667
|
+
part = _unknown_tool(call.tool_name, call.tool_call_id, ctx)
|
|
668
|
+
yield _messages.FunctionToolResultEvent(part, tool_call_id=call.tool_call_id)
|
|
669
|
+
output_parts.append(part)
|
|
653
670
|
|
|
654
671
|
if not calls_to_run:
|
|
655
672
|
return
|
|
@@ -855,3 +872,16 @@ def build_agent_graph(
|
|
|
855
872
|
auto_instrument=False,
|
|
856
873
|
)
|
|
857
874
|
return graph
|
|
875
|
+
|
|
876
|
+
|
|
877
|
+
async def _process_message_history(
|
|
878
|
+
messages: list[_messages.ModelMessage],
|
|
879
|
+
processors: Sequence[HistoryProcessor],
|
|
880
|
+
) -> list[_messages.ModelMessage]:
|
|
881
|
+
"""Process message history through a sequence of processors."""
|
|
882
|
+
for processor in processors:
|
|
883
|
+
if is_async_callable(processor):
|
|
884
|
+
messages = await processor(messages)
|
|
885
|
+
else:
|
|
886
|
+
messages = await run_in_executor(processor, messages)
|
|
887
|
+
return messages
|
|
@@ -5,7 +5,6 @@ This module has to use numerous internal Pydantic APIs and is therefore brittle
|
|
|
5
5
|
|
|
6
6
|
from __future__ import annotations as _annotations
|
|
7
7
|
|
|
8
|
-
import inspect
|
|
9
8
|
from collections.abc import Awaitable
|
|
10
9
|
from dataclasses import dataclass, field
|
|
11
10
|
from inspect import Parameter, signature
|
|
@@ -23,7 +22,7 @@ from typing_extensions import get_origin
|
|
|
23
22
|
from pydantic_ai.tools import RunContext
|
|
24
23
|
|
|
25
24
|
from ._griffe import doc_descriptions
|
|
26
|
-
from ._utils import check_object_json_schema, is_model_like, run_in_executor
|
|
25
|
+
from ._utils import check_object_json_schema, is_async_callable, is_model_like, run_in_executor
|
|
27
26
|
|
|
28
27
|
if TYPE_CHECKING:
|
|
29
28
|
from .tools import DocstringFormat, ObjectJsonSchema
|
|
@@ -214,7 +213,7 @@ def function_schema( # noqa: C901
|
|
|
214
213
|
positional_fields=positional_fields,
|
|
215
214
|
var_positional_field=var_positional_field,
|
|
216
215
|
takes_ctx=takes_ctx,
|
|
217
|
-
is_async=
|
|
216
|
+
is_async=is_async_callable(function),
|
|
218
217
|
function=function,
|
|
219
218
|
)
|
|
220
219
|
|
|
@@ -60,7 +60,7 @@ class OutputValidator(Generic[AgentDepsT, OutputDataT_inv]):
|
|
|
60
60
|
|
|
61
61
|
def __post_init__(self):
|
|
62
62
|
self._takes_ctx = len(inspect.signature(self.function).parameters) > 1
|
|
63
|
-
self._is_async =
|
|
63
|
+
self._is_async = _utils.is_async_callable(self.function)
|
|
64
64
|
|
|
65
65
|
async def validate(
|
|
66
66
|
self,
|
|
@@ -18,7 +18,7 @@ class SystemPromptRunner(Generic[AgentDepsT]):
|
|
|
18
18
|
|
|
19
19
|
def __post_init__(self):
|
|
20
20
|
self._takes_ctx = len(inspect.signature(self.function).parameters) > 0
|
|
21
|
-
self._is_async =
|
|
21
|
+
self._is_async = _utils.is_async_callable(self.function)
|
|
22
22
|
|
|
23
23
|
async def run(self, run_context: RunContext[AgentDepsT]) -> str:
|
|
24
24
|
if self._takes_ctx:
|
|
@@ -1,20 +1,22 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
3
|
import asyncio
|
|
4
|
+
import functools
|
|
5
|
+
import inspect
|
|
4
6
|
import time
|
|
5
7
|
import uuid
|
|
6
|
-
from collections.abc import AsyncIterable, AsyncIterator, Iterator
|
|
8
|
+
from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Iterator
|
|
7
9
|
from contextlib import asynccontextmanager, suppress
|
|
8
10
|
from dataclasses import dataclass, fields, is_dataclass
|
|
9
11
|
from datetime import datetime, timezone
|
|
10
12
|
from functools import partial
|
|
11
13
|
from types import GenericAlias
|
|
12
|
-
from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union
|
|
14
|
+
from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union, overload
|
|
13
15
|
|
|
14
16
|
from anyio.to_thread import run_sync
|
|
15
17
|
from pydantic import BaseModel, TypeAdapter
|
|
16
18
|
from pydantic.json_schema import JsonSchemaValue
|
|
17
|
-
from typing_extensions import ParamSpec, TypeAlias, TypeGuard, is_typeddict
|
|
19
|
+
from typing_extensions import ParamSpec, TypeAlias, TypeGuard, TypeIs, is_typeddict
|
|
18
20
|
|
|
19
21
|
from pydantic_graph._utils import AbstractSpan
|
|
20
22
|
|
|
@@ -302,3 +304,26 @@ def dataclasses_no_defaults_repr(self: Any) -> str:
|
|
|
302
304
|
|
|
303
305
|
def number_to_datetime(x: int | float) -> datetime:
|
|
304
306
|
return TypeAdapter(datetime).validate_python(x)
|
|
307
|
+
|
|
308
|
+
|
|
309
|
+
AwaitableCallable = Callable[..., Awaitable[T]]
|
|
310
|
+
|
|
311
|
+
|
|
312
|
+
@overload
|
|
313
|
+
def is_async_callable(obj: AwaitableCallable[T]) -> TypeIs[AwaitableCallable[T]]: ...
|
|
314
|
+
|
|
315
|
+
|
|
316
|
+
@overload
|
|
317
|
+
def is_async_callable(obj: Any) -> TypeIs[AwaitableCallable[Any]]: ...
|
|
318
|
+
|
|
319
|
+
|
|
320
|
+
def is_async_callable(obj: Any) -> Any:
|
|
321
|
+
"""Correctly check if a callable is async.
|
|
322
|
+
|
|
323
|
+
This function was copied from Starlette:
|
|
324
|
+
https://github.com/encode/starlette/blob/78da9b9e218ab289117df7d62aee200ed4c59617/starlette/_utils.py#L36-L40
|
|
325
|
+
"""
|
|
326
|
+
while isinstance(obj, functools.partial):
|
|
327
|
+
obj = obj.func
|
|
328
|
+
|
|
329
|
+
return inspect.iscoroutinefunction(obj) or (callable(obj) and inspect.iscoroutinefunction(obj.__call__)) # type: ignore
|
|
@@ -28,6 +28,7 @@ from . import (
|
|
|
28
28
|
result,
|
|
29
29
|
usage as _usage,
|
|
30
30
|
)
|
|
31
|
+
from ._agent_graph import HistoryProcessor
|
|
31
32
|
from .models.instrumented import InstrumentationSettings, InstrumentedModel, instrument_model
|
|
32
33
|
from .result import FinalResult, OutputDataT, StreamedRunResult
|
|
33
34
|
from .settings import ModelSettings, merge_model_settings
|
|
@@ -179,6 +180,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
179
180
|
defer_model_check: bool = False,
|
|
180
181
|
end_strategy: EndStrategy = 'early',
|
|
181
182
|
instrument: InstrumentationSettings | bool | None = None,
|
|
183
|
+
history_processors: Sequence[HistoryProcessor] | None = None,
|
|
182
184
|
) -> None: ...
|
|
183
185
|
|
|
184
186
|
@overload
|
|
@@ -208,6 +210,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
208
210
|
defer_model_check: bool = False,
|
|
209
211
|
end_strategy: EndStrategy = 'early',
|
|
210
212
|
instrument: InstrumentationSettings | bool | None = None,
|
|
213
|
+
history_processors: Sequence[HistoryProcessor] | None = None,
|
|
211
214
|
) -> None: ...
|
|
212
215
|
|
|
213
216
|
def __init__(
|
|
@@ -232,6 +235,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
232
235
|
defer_model_check: bool = False,
|
|
233
236
|
end_strategy: EndStrategy = 'early',
|
|
234
237
|
instrument: InstrumentationSettings | bool | None = None,
|
|
238
|
+
history_processors: Sequence[HistoryProcessor] | None = None,
|
|
235
239
|
**_deprecated_kwargs: Any,
|
|
236
240
|
):
|
|
237
241
|
"""Create an agent.
|
|
@@ -275,6 +279,9 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
275
279
|
[`Agent.instrument_all()`][pydantic_ai.Agent.instrument_all]
|
|
276
280
|
will be used, which defaults to False.
|
|
277
281
|
See the [Debugging and Monitoring guide](https://ai.pydantic.dev/logfire/) for more info.
|
|
282
|
+
history_processors: Optional list of callables to process the message history before sending it to the model.
|
|
283
|
+
Each processor takes a list of messages and returns a modified list of messages.
|
|
284
|
+
Processors can be sync or async and are applied in sequence.
|
|
278
285
|
"""
|
|
279
286
|
if model is None or defer_model_check:
|
|
280
287
|
self.model = model
|
|
@@ -343,6 +350,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
343
350
|
self._max_result_retries = output_retries if output_retries is not None else retries
|
|
344
351
|
self._mcp_servers = mcp_servers
|
|
345
352
|
self._prepare_tools = prepare_tools
|
|
353
|
+
self.history_processors = history_processors or []
|
|
346
354
|
for tool in tools:
|
|
347
355
|
if isinstance(tool, Tool):
|
|
348
356
|
self._register_tool(tool)
|
|
@@ -669,10 +677,11 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
669
677
|
if self._instructions is None and not self._instructions_functions:
|
|
670
678
|
return None
|
|
671
679
|
|
|
672
|
-
instructions = self._instructions
|
|
680
|
+
instructions = [self._instructions] if self._instructions else []
|
|
673
681
|
for instructions_runner in self._instructions_functions:
|
|
674
|
-
instructions
|
|
675
|
-
|
|
682
|
+
instructions.append(await instructions_runner.run(run_context))
|
|
683
|
+
concatenated_instructions = '\n'.join(instruction for instruction in instructions if instruction)
|
|
684
|
+
return concatenated_instructions.strip() if concatenated_instructions else None
|
|
676
685
|
|
|
677
686
|
# Copy the function tools so that retry state is agent-run-specific
|
|
678
687
|
# Note that the retry count is reset to 0 when this happens due to the `default=0` and `init=False`.
|
|
@@ -689,6 +698,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
689
698
|
end_strategy=self.end_strategy,
|
|
690
699
|
output_schema=output_schema,
|
|
691
700
|
output_validators=output_validators,
|
|
701
|
+
history_processors=self.history_processors,
|
|
692
702
|
function_tools=run_function_tools,
|
|
693
703
|
mcp_servers=self._mcp_servers,
|
|
694
704
|
default_retries=self._default_retries,
|
|
@@ -5,25 +5,28 @@ import functools
|
|
|
5
5
|
import json
|
|
6
6
|
from abc import ABC, abstractmethod
|
|
7
7
|
from collections.abc import AsyncIterator, Sequence
|
|
8
|
-
from contextlib import AsyncExitStack, asynccontextmanager
|
|
8
|
+
from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager
|
|
9
9
|
from dataclasses import dataclass
|
|
10
10
|
from pathlib import Path
|
|
11
11
|
from types import TracebackType
|
|
12
|
-
from typing import Any
|
|
12
|
+
from typing import Any, Callable
|
|
13
13
|
|
|
14
14
|
import anyio
|
|
15
15
|
import httpx
|
|
16
16
|
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
|
17
|
+
from mcp.client.streamable_http import GetSessionIdCallback, streamablehttp_client
|
|
17
18
|
from mcp.shared.message import SessionMessage
|
|
18
19
|
from mcp.types import (
|
|
20
|
+
AudioContent,
|
|
19
21
|
BlobResourceContents,
|
|
22
|
+
Content,
|
|
20
23
|
EmbeddedResource,
|
|
21
24
|
ImageContent,
|
|
22
25
|
LoggingLevel,
|
|
23
26
|
TextContent,
|
|
24
27
|
TextResourceContents,
|
|
25
28
|
)
|
|
26
|
-
from typing_extensions import Self, assert_never
|
|
29
|
+
from typing_extensions import Self, assert_never, deprecated
|
|
27
30
|
|
|
28
31
|
from pydantic_ai.exceptions import ModelRetry
|
|
29
32
|
from pydantic_ai.messages import BinaryContent
|
|
@@ -39,7 +42,7 @@ except ImportError as _import_error:
|
|
|
39
42
|
'you can use the `mcp` optional group — `pip install "pydantic-ai-slim[mcp]"`'
|
|
40
43
|
) from _import_error
|
|
41
44
|
|
|
42
|
-
__all__ = 'MCPServer', 'MCPServerStdio', 'MCPServerHTTP'
|
|
45
|
+
__all__ = 'MCPServer', 'MCPServerStdio', 'MCPServerHTTP', 'MCPServerSSE', 'MCPServerStreamableHTTP'
|
|
43
46
|
|
|
44
47
|
|
|
45
48
|
class MCPServer(ABC):
|
|
@@ -160,9 +163,7 @@ class MCPServer(ABC):
|
|
|
160
163
|
await self._exit_stack.aclose()
|
|
161
164
|
self.is_running = False
|
|
162
165
|
|
|
163
|
-
def _map_tool_result_part(
|
|
164
|
-
self, part: TextContent | ImageContent | EmbeddedResource
|
|
165
|
-
) -> str | BinaryContent | dict[str, Any] | list[Any]:
|
|
166
|
+
def _map_tool_result_part(self, part: Content) -> str | BinaryContent | dict[str, Any] | list[Any]:
|
|
166
167
|
# See https://github.com/jlowin/fastmcp/blob/main/docs/servers/tools.mdx#return-values
|
|
167
168
|
|
|
168
169
|
if isinstance(part, TextContent):
|
|
@@ -175,6 +176,10 @@ class MCPServer(ABC):
|
|
|
175
176
|
return text
|
|
176
177
|
elif isinstance(part, ImageContent):
|
|
177
178
|
return BinaryContent(data=base64.b64decode(part.data), media_type=part.mimeType)
|
|
179
|
+
elif isinstance(part, AudioContent):
|
|
180
|
+
# NOTE: The FastMCP server doesn't support audio content.
|
|
181
|
+
# See <https://github.com/modelcontextprotocol/python-sdk/issues/952> for more details.
|
|
182
|
+
return BinaryContent(data=base64.b64decode(part.data), media_type=part.mimeType) # pragma: no cover
|
|
178
183
|
elif isinstance(part, EmbeddedResource):
|
|
179
184
|
resource = part.resource
|
|
180
185
|
if isinstance(resource, TextResourceContents):
|
|
@@ -287,44 +292,12 @@ class MCPServerStdio(MCPServer):
|
|
|
287
292
|
|
|
288
293
|
|
|
289
294
|
@dataclass
|
|
290
|
-
class
|
|
291
|
-
"""An MCP server that connects over streamable HTTP connections.
|
|
292
|
-
|
|
293
|
-
This class implements the SSE transport from the MCP specification.
|
|
294
|
-
See <https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/transports/#http-with-sse> for more information.
|
|
295
|
-
|
|
296
|
-
The name "HTTP" is used since this implemented will be adapted in future to use the new
|
|
297
|
-
[Streamable HTTP](https://github.com/modelcontextprotocol/specification/pull/206) currently in development.
|
|
298
|
-
|
|
299
|
-
!!! note
|
|
300
|
-
Using this class as an async context manager will create a new pool of HTTP connections to connect
|
|
301
|
-
to a server which should already be running.
|
|
302
|
-
|
|
303
|
-
Example:
|
|
304
|
-
```python {py="3.10"}
|
|
305
|
-
from pydantic_ai import Agent
|
|
306
|
-
from pydantic_ai.mcp import MCPServerHTTP
|
|
307
|
-
|
|
308
|
-
server = MCPServerHTTP('http://localhost:3001/sse') # (1)!
|
|
309
|
-
agent = Agent('openai:gpt-4o', mcp_servers=[server])
|
|
310
|
-
|
|
311
|
-
async def main():
|
|
312
|
-
async with agent.run_mcp_servers(): # (2)!
|
|
313
|
-
...
|
|
314
|
-
```
|
|
315
|
-
|
|
316
|
-
1. E.g. you might be connecting to a server run with [`mcp-run-python`](../mcp/run-python.md).
|
|
317
|
-
2. This will connect to a server running on `localhost:3001`.
|
|
318
|
-
"""
|
|
319
|
-
|
|
295
|
+
class _MCPServerHTTP(MCPServer):
|
|
320
296
|
url: str
|
|
321
|
-
"""The URL of the
|
|
322
|
-
|
|
323
|
-
For example for a server running locally, this might be `http://localhost:3001/sse`.
|
|
324
|
-
"""
|
|
297
|
+
"""The URL of the endpoint on the MCP server."""
|
|
325
298
|
|
|
326
299
|
headers: dict[str, Any] | None = None
|
|
327
|
-
"""Optional HTTP headers to be sent with each request to the
|
|
300
|
+
"""Optional HTTP headers to be sent with each request to the endpoint.
|
|
328
301
|
|
|
329
302
|
These headers will be passed directly to the underlying `httpx.AsyncClient`.
|
|
330
303
|
Useful for authentication, custom headers, or other HTTP-specific configurations.
|
|
@@ -336,22 +309,22 @@ class MCPServerHTTP(MCPServer):
|
|
|
336
309
|
"""
|
|
337
310
|
|
|
338
311
|
http_client: httpx.AsyncClient | None = None
|
|
339
|
-
"""An `httpx.AsyncClient` to use with the
|
|
312
|
+
"""An `httpx.AsyncClient` to use with the endpoint.
|
|
340
313
|
|
|
341
314
|
This client may be configured to use customized connection parameters like self-signed certificates.
|
|
342
315
|
|
|
343
316
|
!!! note
|
|
344
317
|
You can either pass `headers` or `http_client`, but not both.
|
|
345
318
|
|
|
346
|
-
If you want to use both, you can pass the headers to the `http_client` instead
|
|
319
|
+
If you want to use both, you can pass the headers to the `http_client` instead.
|
|
347
320
|
|
|
348
|
-
```python {py="3.10"}
|
|
321
|
+
```python {py="3.10" test="skip"}
|
|
349
322
|
import httpx
|
|
350
323
|
|
|
351
|
-
from pydantic_ai.mcp import
|
|
324
|
+
from pydantic_ai.mcp import MCPServerSSE
|
|
352
325
|
|
|
353
326
|
http_client = httpx.AsyncClient(headers={'Authorization': 'Bearer ...'})
|
|
354
|
-
server =
|
|
327
|
+
server = MCPServerSSE('http://localhost:3001/sse', http_client=http_client)
|
|
355
328
|
```
|
|
356
329
|
"""
|
|
357
330
|
|
|
@@ -369,10 +342,11 @@ class MCPServerHTTP(MCPServer):
|
|
|
369
342
|
If no new messages are received within this time, the connection will be considered stale
|
|
370
343
|
and may be closed. Defaults to 5 minutes (300 seconds).
|
|
371
344
|
"""
|
|
345
|
+
|
|
372
346
|
log_level: LoggingLevel | None = None
|
|
373
347
|
"""The log level to set when connecting to the server, if any.
|
|
374
348
|
|
|
375
|
-
See <https://modelcontextprotocol.io/
|
|
349
|
+
See <https://modelcontextprotocol.io/introduction#logging> for more details.
|
|
376
350
|
|
|
377
351
|
If `None`, no log level will be set.
|
|
378
352
|
"""
|
|
@@ -385,6 +359,27 @@ class MCPServerHTTP(MCPServer):
|
|
|
385
359
|
For example, if `tool_prefix='foo'`, then a tool named `bar` will be registered as `foo_bar`
|
|
386
360
|
"""
|
|
387
361
|
|
|
362
|
+
@property
|
|
363
|
+
@abstractmethod
|
|
364
|
+
def _transport_client(
|
|
365
|
+
self,
|
|
366
|
+
) -> Callable[
|
|
367
|
+
...,
|
|
368
|
+
AbstractAsyncContextManager[
|
|
369
|
+
tuple[
|
|
370
|
+
MemoryObjectReceiveStream[SessionMessage | Exception],
|
|
371
|
+
MemoryObjectSendStream[SessionMessage],
|
|
372
|
+
GetSessionIdCallback,
|
|
373
|
+
],
|
|
374
|
+
]
|
|
375
|
+
| AbstractAsyncContextManager[
|
|
376
|
+
tuple[
|
|
377
|
+
MemoryObjectReceiveStream[SessionMessage | Exception],
|
|
378
|
+
MemoryObjectSendStream[SessionMessage],
|
|
379
|
+
]
|
|
380
|
+
],
|
|
381
|
+
]: ...
|
|
382
|
+
|
|
388
383
|
@asynccontextmanager
|
|
389
384
|
async def client_streams(
|
|
390
385
|
self,
|
|
@@ -394,8 +389,8 @@ class MCPServerHTTP(MCPServer):
|
|
|
394
389
|
if self.http_client and self.headers:
|
|
395
390
|
raise ValueError('`http_client` is mutually exclusive with `headers`.')
|
|
396
391
|
|
|
397
|
-
|
|
398
|
-
|
|
392
|
+
transport_client_partial = functools.partial(
|
|
393
|
+
self._transport_client,
|
|
399
394
|
url=self.url,
|
|
400
395
|
timeout=self.timeout,
|
|
401
396
|
sse_read_timeout=self.sse_read_timeout,
|
|
@@ -411,17 +406,114 @@ class MCPServerHTTP(MCPServer):
|
|
|
411
406
|
assert self.http_client is not None
|
|
412
407
|
return self.http_client
|
|
413
408
|
|
|
414
|
-
async with
|
|
409
|
+
async with transport_client_partial(httpx_client_factory=httpx_client_factory) as (
|
|
410
|
+
read_stream,
|
|
411
|
+
write_stream,
|
|
412
|
+
*_,
|
|
413
|
+
):
|
|
415
414
|
yield read_stream, write_stream
|
|
416
415
|
else:
|
|
417
|
-
async with
|
|
416
|
+
async with transport_client_partial(headers=self.headers) as (read_stream, write_stream, *_):
|
|
418
417
|
yield read_stream, write_stream
|
|
419
418
|
|
|
420
419
|
def _get_log_level(self) -> LoggingLevel | None:
|
|
421
420
|
return self.log_level
|
|
422
421
|
|
|
423
422
|
def __repr__(self) -> str: # pragma: no cover
|
|
424
|
-
return f'
|
|
423
|
+
return f'{self.__class__.__name__}(url={self.url!r}, tool_prefix={self.tool_prefix!r})'
|
|
425
424
|
|
|
426
425
|
def _get_client_initialize_timeout(self) -> float: # pragma: no cover
|
|
427
426
|
return self.timeout
|
|
427
|
+
|
|
428
|
+
|
|
429
|
+
@dataclass
|
|
430
|
+
class MCPServerSSE(_MCPServerHTTP):
|
|
431
|
+
"""An MCP server that connects over streamable HTTP connections.
|
|
432
|
+
|
|
433
|
+
This class implements the SSE transport from the MCP specification.
|
|
434
|
+
See <https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/transports/#http-with-sse> for more information.
|
|
435
|
+
|
|
436
|
+
!!! note
|
|
437
|
+
Using this class as an async context manager will create a new pool of HTTP connections to connect
|
|
438
|
+
to a server which should already be running.
|
|
439
|
+
|
|
440
|
+
Example:
|
|
441
|
+
```python {py="3.10"}
|
|
442
|
+
from pydantic_ai import Agent
|
|
443
|
+
from pydantic_ai.mcp import MCPServerSSE
|
|
444
|
+
|
|
445
|
+
server = MCPServerSSE('http://localhost:3001/sse') # (1)!
|
|
446
|
+
agent = Agent('openai:gpt-4o', mcp_servers=[server])
|
|
447
|
+
|
|
448
|
+
async def main():
|
|
449
|
+
async with agent.run_mcp_servers(): # (2)!
|
|
450
|
+
...
|
|
451
|
+
```
|
|
452
|
+
|
|
453
|
+
1. E.g. you might be connecting to a server run with [`mcp-run-python`](../mcp/run-python.md).
|
|
454
|
+
2. This will connect to a server running on `localhost:3001`.
|
|
455
|
+
"""
|
|
456
|
+
|
|
457
|
+
@property
|
|
458
|
+
def _transport_client(self):
|
|
459
|
+
return sse_client # pragma: no cover
|
|
460
|
+
|
|
461
|
+
|
|
462
|
+
@deprecated('The `MCPServerHTTP` class is deprecated, use `MCPServerSSE` instead.')
|
|
463
|
+
@dataclass
|
|
464
|
+
class MCPServerHTTP(MCPServerSSE):
|
|
465
|
+
"""An MCP server that connects over HTTP using the old SSE transport.
|
|
466
|
+
|
|
467
|
+
This class implements the SSE transport from the MCP specification.
|
|
468
|
+
See <https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/transports/#http-with-sse> for more information.
|
|
469
|
+
|
|
470
|
+
!!! note
|
|
471
|
+
Using this class as an async context manager will create a new pool of HTTP connections to connect
|
|
472
|
+
to a server which should already be running.
|
|
473
|
+
|
|
474
|
+
Example:
|
|
475
|
+
```python {py="3.10" test="skip"}
|
|
476
|
+
from pydantic_ai import Agent
|
|
477
|
+
from pydantic_ai.mcp import MCPServerHTTP
|
|
478
|
+
|
|
479
|
+
server = MCPServerHTTP('http://localhost:3001/sse') # (1)!
|
|
480
|
+
agent = Agent('openai:gpt-4o', mcp_servers=[server])
|
|
481
|
+
|
|
482
|
+
async def main():
|
|
483
|
+
async with agent.run_mcp_servers(): # (2)!
|
|
484
|
+
...
|
|
485
|
+
```
|
|
486
|
+
|
|
487
|
+
1. E.g. you might be connecting to a server run with [`mcp-run-python`](../mcp/run-python.md).
|
|
488
|
+
2. This will connect to a server running on `localhost:3001`.
|
|
489
|
+
"""
|
|
490
|
+
|
|
491
|
+
|
|
492
|
+
@dataclass
|
|
493
|
+
class MCPServerStreamableHTTP(_MCPServerHTTP):
|
|
494
|
+
"""An MCP server that connects over HTTP using the Streamable HTTP transport.
|
|
495
|
+
|
|
496
|
+
This class implements the Streamable HTTP transport from the MCP specification.
|
|
497
|
+
See <https://modelcontextprotocol.io/introduction#streamable-http> for more information.
|
|
498
|
+
|
|
499
|
+
!!! note
|
|
500
|
+
Using this class as an async context manager will create a new pool of HTTP connections to connect
|
|
501
|
+
to a server which should already be running.
|
|
502
|
+
|
|
503
|
+
Example:
|
|
504
|
+
```python {py="3.10"}
|
|
505
|
+
from pydantic_ai import Agent
|
|
506
|
+
from pydantic_ai.mcp import MCPServerStreamableHTTP
|
|
507
|
+
|
|
508
|
+
server = MCPServerStreamableHTTP('http://localhost:8000/mcp') # (1)!
|
|
509
|
+
agent = Agent('openai:gpt-4o', mcp_servers=[server])
|
|
510
|
+
|
|
511
|
+
async def main():
|
|
512
|
+
async with agent.run_mcp_servers(): # (2)!
|
|
513
|
+
...
|
|
514
|
+
```
|
|
515
|
+
"""
|
|
516
|
+
|
|
517
|
+
@property
|
|
518
|
+
def _transport_client(self):
|
|
519
|
+
return streamablehttp_client # pragma: no cover
|
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
3
|
import base64
|
|
4
|
-
import uuid
|
|
5
4
|
from abc import ABC, abstractmethod
|
|
6
5
|
from collections.abc import Sequence
|
|
7
6
|
from dataclasses import dataclass, field, replace
|
|
@@ -888,13 +887,13 @@ class FunctionToolCallEvent:
|
|
|
888
887
|
|
|
889
888
|
part: ToolCallPart
|
|
890
889
|
"""The (function) tool call to make."""
|
|
891
|
-
call_id: str = field(init=False)
|
|
892
|
-
"""An ID used for matching details about the call to its result. If present, defaults to the part's tool_call_id."""
|
|
893
890
|
event_kind: Literal['function_tool_call'] = 'function_tool_call'
|
|
894
891
|
"""Event type identifier, used as a discriminator."""
|
|
895
892
|
|
|
896
|
-
|
|
897
|
-
|
|
893
|
+
@property
|
|
894
|
+
def call_id(self) -> str:
|
|
895
|
+
"""An ID used for matching details about the call to its result. If present, defaults to the part's tool_call_id."""
|
|
896
|
+
return self.part.tool_call_id
|
|
898
897
|
|
|
899
898
|
__repr__ = _utils.dataclasses_no_defaults_repr
|
|
900
899
|
|
|
@@ -555,9 +555,9 @@ def infer_model(model: Model | KnownModelName | str) -> Model:
|
|
|
555
555
|
|
|
556
556
|
return OpenAIModel(model_name, provider=provider)
|
|
557
557
|
elif provider in ('google-gla', 'google-vertex'):
|
|
558
|
-
from .
|
|
558
|
+
from .google import GoogleModel
|
|
559
559
|
|
|
560
|
-
return
|
|
560
|
+
return GoogleModel(model_name, provider=provider)
|
|
561
561
|
elif provider == 'groq':
|
|
562
562
|
from .groq import GroqModel
|
|
563
563
|
|
|
@@ -220,7 +220,7 @@ class AnthropicModel(Model):
|
|
|
220
220
|
extra_headers = model_settings.get('extra_headers', {})
|
|
221
221
|
extra_headers.setdefault('User-Agent', get_user_agent())
|
|
222
222
|
return await self.client.beta.messages.create(
|
|
223
|
-
max_tokens=model_settings.get('max_tokens',
|
|
223
|
+
max_tokens=model_settings.get('max_tokens', 4096),
|
|
224
224
|
system=system_prompt or NOT_GIVEN,
|
|
225
225
|
messages=anthropic_messages,
|
|
226
226
|
model=self._model_name,
|
|
@@ -276,7 +276,7 @@ class AnthropicModel(Model):
|
|
|
276
276
|
tools += [self._map_tool_definition(r) for r in model_request_parameters.output_tools]
|
|
277
277
|
return tools
|
|
278
278
|
|
|
279
|
-
async def _map_message(self, messages: list[ModelMessage]) -> tuple[str, list[BetaMessageParam]]:
|
|
279
|
+
async def _map_message(self, messages: list[ModelMessage]) -> tuple[str, list[BetaMessageParam]]: # noqa: C901
|
|
280
280
|
"""Just maps a `pydantic_ai.Message` to a `anthropic.types.MessageParam`."""
|
|
281
281
|
system_prompt_parts: list[str] = []
|
|
282
282
|
anthropic_messages: list[BetaMessageParam] = []
|
|
@@ -315,7 +315,8 @@ class AnthropicModel(Model):
|
|
|
315
315
|
assistant_content_params: list[BetaTextBlockParam | BetaToolUseBlockParam] = []
|
|
316
316
|
for response_part in m.parts:
|
|
317
317
|
if isinstance(response_part, TextPart):
|
|
318
|
-
|
|
318
|
+
if response_part.content: # Only add non-empty text
|
|
319
|
+
assistant_content_params.append(BetaTextBlockParam(text=response_part.content, type='text'))
|
|
319
320
|
else:
|
|
320
321
|
tool_use_block_param = BetaToolUseBlockParam(
|
|
321
322
|
id=_guard_tool_call_id(t=response_part),
|
|
@@ -324,7 +325,8 @@ class AnthropicModel(Model):
|
|
|
324
325
|
input=response_part.args_as_dict(),
|
|
325
326
|
)
|
|
326
327
|
assistant_content_params.append(tool_use_block_param)
|
|
327
|
-
|
|
328
|
+
if len(assistant_content_params) > 0:
|
|
329
|
+
anthropic_messages.append(BetaMessageParam(role='assistant', content=assistant_content_params))
|
|
328
330
|
else:
|
|
329
331
|
assert_never(m)
|
|
330
332
|
system_prompt = '\n\n'.join(system_prompt_parts)
|
|
@@ -337,11 +339,13 @@ class AnthropicModel(Model):
|
|
|
337
339
|
part: UserPromptPart,
|
|
338
340
|
) -> AsyncGenerator[BetaContentBlockParam]:
|
|
339
341
|
if isinstance(part.content, str):
|
|
340
|
-
|
|
342
|
+
if part.content: # Only yield non-empty text
|
|
343
|
+
yield BetaTextBlockParam(text=part.content, type='text')
|
|
341
344
|
else:
|
|
342
345
|
for item in part.content:
|
|
343
346
|
if isinstance(item, str):
|
|
344
|
-
|
|
347
|
+
if item: # Only yield non-empty text
|
|
348
|
+
yield BetaTextBlockParam(text=item, type='text')
|
|
345
349
|
elif isinstance(item, BinaryContent):
|
|
346
350
|
if item.is_image:
|
|
347
351
|
yield BetaImageBlockParam(
|
|
@@ -723,9 +723,7 @@ class _GeminiFunction(TypedDict):
|
|
|
723
723
|
|
|
724
724
|
def _function_from_abstract_tool(tool: ToolDefinition) -> _GeminiFunction:
|
|
725
725
|
json_schema = tool.parameters_json_schema
|
|
726
|
-
f = _GeminiFunction(name=tool.name, description=tool.description)
|
|
727
|
-
if json_schema.get('properties'):
|
|
728
|
-
f['parameters'] = json_schema
|
|
726
|
+
f = _GeminiFunction(name=tool.name, description=tool.description, parameters=json_schema)
|
|
729
727
|
return f
|
|
730
728
|
|
|
731
729
|
|
|
@@ -469,9 +469,11 @@ def _process_response_from_parts(
|
|
|
469
469
|
|
|
470
470
|
def _function_declaration_from_tool(tool: ToolDefinition) -> FunctionDeclarationDict:
|
|
471
471
|
json_schema = tool.parameters_json_schema
|
|
472
|
-
f = FunctionDeclarationDict(
|
|
473
|
-
|
|
474
|
-
|
|
472
|
+
f = FunctionDeclarationDict(
|
|
473
|
+
name=tool.name,
|
|
474
|
+
description=tool.description,
|
|
475
|
+
parameters=json_schema, # type: ignore
|
|
476
|
+
)
|
|
475
477
|
return f
|
|
476
478
|
|
|
477
479
|
|
|
@@ -613,7 +613,13 @@ class OpenAIResponsesModel(Model):
|
|
|
613
613
|
for item in response.output:
|
|
614
614
|
if item.type == 'function_call':
|
|
615
615
|
items.append(ToolCallPart(item.name, item.arguments, tool_call_id=item.call_id))
|
|
616
|
-
return ModelResponse(
|
|
616
|
+
return ModelResponse(
|
|
617
|
+
items,
|
|
618
|
+
usage=_map_usage(response),
|
|
619
|
+
model_name=response.model,
|
|
620
|
+
vendor_id=response.id,
|
|
621
|
+
timestamp=timestamp,
|
|
622
|
+
)
|
|
617
623
|
|
|
618
624
|
async def _process_streamed_response(
|
|
619
625
|
self, response: AsyncStream[responses.ResponseStreamEvent]
|
|
@@ -48,68 +48,74 @@ class Provider(ABC, Generic[InterfaceClient]):
|
|
|
48
48
|
return None # pragma: no cover
|
|
49
49
|
|
|
50
50
|
|
|
51
|
-
def
|
|
52
|
-
"""
|
|
51
|
+
def infer_provider_class(provider: str) -> type[Provider[Any]]: # noqa: C901
|
|
52
|
+
"""Infers the provider class from the provider name."""
|
|
53
53
|
if provider == 'openai':
|
|
54
54
|
from .openai import OpenAIProvider
|
|
55
55
|
|
|
56
|
-
return OpenAIProvider
|
|
56
|
+
return OpenAIProvider
|
|
57
57
|
elif provider == 'deepseek':
|
|
58
58
|
from .deepseek import DeepSeekProvider
|
|
59
59
|
|
|
60
|
-
return DeepSeekProvider
|
|
60
|
+
return DeepSeekProvider
|
|
61
61
|
elif provider == 'openrouter':
|
|
62
62
|
from .openrouter import OpenRouterProvider
|
|
63
63
|
|
|
64
|
-
return OpenRouterProvider
|
|
64
|
+
return OpenRouterProvider
|
|
65
65
|
elif provider == 'azure':
|
|
66
66
|
from .azure import AzureProvider
|
|
67
67
|
|
|
68
|
-
return AzureProvider
|
|
68
|
+
return AzureProvider
|
|
69
69
|
elif provider == 'google-vertex':
|
|
70
70
|
from .google_vertex import GoogleVertexProvider
|
|
71
71
|
|
|
72
|
-
return GoogleVertexProvider
|
|
72
|
+
return GoogleVertexProvider
|
|
73
73
|
elif provider == 'google-gla':
|
|
74
74
|
from .google_gla import GoogleGLAProvider
|
|
75
75
|
|
|
76
|
-
return GoogleGLAProvider
|
|
76
|
+
return GoogleGLAProvider
|
|
77
77
|
# NOTE: We don't test because there are many ways the `boto3.client` can retrieve the credentials.
|
|
78
78
|
elif provider == 'bedrock':
|
|
79
79
|
from .bedrock import BedrockProvider
|
|
80
80
|
|
|
81
|
-
return BedrockProvider
|
|
81
|
+
return BedrockProvider
|
|
82
82
|
elif provider == 'groq':
|
|
83
83
|
from .groq import GroqProvider
|
|
84
84
|
|
|
85
|
-
return GroqProvider
|
|
85
|
+
return GroqProvider
|
|
86
86
|
elif provider == 'anthropic':
|
|
87
87
|
from .anthropic import AnthropicProvider
|
|
88
88
|
|
|
89
|
-
return AnthropicProvider
|
|
89
|
+
return AnthropicProvider
|
|
90
90
|
elif provider == 'mistral':
|
|
91
91
|
from .mistral import MistralProvider
|
|
92
92
|
|
|
93
|
-
return MistralProvider
|
|
93
|
+
return MistralProvider
|
|
94
94
|
elif provider == 'cohere':
|
|
95
95
|
from .cohere import CohereProvider
|
|
96
96
|
|
|
97
|
-
return CohereProvider
|
|
97
|
+
return CohereProvider
|
|
98
98
|
elif provider == 'grok':
|
|
99
99
|
from .grok import GrokProvider
|
|
100
100
|
|
|
101
|
-
return GrokProvider
|
|
101
|
+
return GrokProvider
|
|
102
102
|
elif provider == 'fireworks':
|
|
103
103
|
from .fireworks import FireworksProvider
|
|
104
104
|
|
|
105
|
-
return FireworksProvider
|
|
105
|
+
return FireworksProvider
|
|
106
106
|
elif provider == 'together':
|
|
107
107
|
from .together import TogetherProvider
|
|
108
108
|
|
|
109
|
-
return TogetherProvider
|
|
109
|
+
return TogetherProvider
|
|
110
110
|
elif provider == 'heroku':
|
|
111
111
|
from .heroku import HerokuProvider
|
|
112
112
|
|
|
113
|
-
return HerokuProvider
|
|
113
|
+
return HerokuProvider
|
|
114
114
|
else: # pragma: no cover
|
|
115
115
|
raise ValueError(f'Unknown provider: {provider}')
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def infer_provider(provider: str) -> Provider[Any]:
|
|
119
|
+
"""Infer the provider from the provider name."""
|
|
120
|
+
provider_class = infer_provider_class(provider)
|
|
121
|
+
return provider_class()
|
|
@@ -84,7 +84,7 @@ class GoogleProvider(Provider[genai.Client]):
|
|
|
84
84
|
"""
|
|
85
85
|
if client is None:
|
|
86
86
|
# NOTE: We are keeping GEMINI_API_KEY for backwards compatibility.
|
|
87
|
-
api_key = api_key or os.
|
|
87
|
+
api_key = api_key or os.getenv('GOOGLE_API_KEY') or os.getenv('GEMINI_API_KEY')
|
|
88
88
|
|
|
89
89
|
if vertexai is None: # pragma: lax no cover
|
|
90
90
|
vertexai = bool(location or project or credentials)
|
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
|
-
import asyncio
|
|
4
3
|
import dataclasses
|
|
5
4
|
import json
|
|
6
5
|
from collections.abc import Awaitable, Sequence
|
|
@@ -337,7 +336,7 @@ class Tool(Generic[AgentDepsT]):
|
|
|
337
336
|
validator=SchemaValidator(schema=core_schema.any_schema()),
|
|
338
337
|
json_schema=json_schema,
|
|
339
338
|
takes_ctx=False,
|
|
340
|
-
is_async=
|
|
339
|
+
is_async=_utils.is_async_callable(function),
|
|
341
340
|
)
|
|
342
341
|
|
|
343
342
|
return cls(
|
|
@@ -75,7 +75,7 @@ tavily = ["tavily-python>=0.5.0"]
|
|
|
75
75
|
# CLI
|
|
76
76
|
cli = ["rich>=13", "prompt-toolkit>=3", "argcomplete>=3.5.0"]
|
|
77
77
|
# MCP
|
|
78
|
-
mcp = ["mcp>=1.9.
|
|
78
|
+
mcp = ["mcp>=1.9.4; python_version >= '3.10'"]
|
|
79
79
|
# Evals
|
|
80
80
|
evals = ["pydantic-evals=={{ version }}"]
|
|
81
81
|
# A2A
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|