pydantic-ai-slim 1.0.18__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/agent/__init__.py +43 -12
- pydantic_ai/durable_exec/prefect/__init__.py +15 -0
- pydantic_ai/durable_exec/prefect/_agent.py +833 -0
- pydantic_ai/durable_exec/prefect/_cache_policies.py +102 -0
- pydantic_ai/durable_exec/prefect/_function_toolset.py +58 -0
- pydantic_ai/durable_exec/prefect/_mcp_server.py +60 -0
- pydantic_ai/durable_exec/prefect/_model.py +152 -0
- pydantic_ai/durable_exec/prefect/_toolset.py +67 -0
- pydantic_ai/durable_exec/prefect/_types.py +44 -0
- pydantic_ai/models/__init__.py +1 -0
- pydantic_ai/models/google.py +6 -0
- pydantic_ai/providers/gateway.py +16 -7
- pydantic_ai/toolsets/function.py +19 -12
- {pydantic_ai_slim-1.0.18.dist-info → pydantic_ai_slim-1.1.0.dist-info}/METADATA +5 -3
- {pydantic_ai_slim-1.0.18.dist-info → pydantic_ai_slim-1.1.0.dist-info}/RECORD +18 -10
- {pydantic_ai_slim-1.0.18.dist-info → pydantic_ai_slim-1.1.0.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-1.0.18.dist-info → pydantic_ai_slim-1.1.0.dist-info}/entry_points.txt +0 -0
- {pydantic_ai_slim-1.0.18.dist-info → pydantic_ai_slim-1.1.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+
from dataclasses import fields, is_dataclass
|
|
2
|
+
from typing import Any, TypeGuard
|
|
3
|
+
|
|
4
|
+
from prefect.cache_policies import INPUTS, RUN_ID, TASK_SOURCE, CachePolicy
|
|
5
|
+
from prefect.context import TaskRunContext
|
|
6
|
+
|
|
7
|
+
from pydantic_ai import ToolsetTool
|
|
8
|
+
from pydantic_ai.tools import RunContext
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def _is_dict(obj: Any) -> TypeGuard[dict[str, Any]]:
|
|
12
|
+
return isinstance(obj, dict)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def _is_list(obj: Any) -> TypeGuard[list[Any]]:
|
|
16
|
+
return isinstance(obj, list)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def _is_tuple(obj: Any) -> TypeGuard[tuple[Any, ...]]:
|
|
20
|
+
return isinstance(obj, tuple)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def _is_toolset_tool(obj: Any) -> TypeGuard[ToolsetTool]:
|
|
24
|
+
return isinstance(obj, ToolsetTool)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def _replace_run_context(
|
|
28
|
+
inputs: dict[str, Any],
|
|
29
|
+
) -> Any:
|
|
30
|
+
"""Replace RunContext objects with a dict containing only hashable fields."""
|
|
31
|
+
for key, value in inputs.items():
|
|
32
|
+
if isinstance(value, RunContext):
|
|
33
|
+
inputs[key] = {
|
|
34
|
+
'retries': value.retries,
|
|
35
|
+
'tool_call_id': value.tool_call_id,
|
|
36
|
+
'tool_name': value.tool_name,
|
|
37
|
+
'tool_call_approved': value.tool_call_approved,
|
|
38
|
+
'retry': value.retry,
|
|
39
|
+
'max_retries': value.max_retries,
|
|
40
|
+
'run_step': value.run_step,
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
return inputs
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def _strip_timestamps(
|
|
47
|
+
obj: Any | dict[str, Any] | list[Any] | tuple[Any, ...],
|
|
48
|
+
) -> Any:
|
|
49
|
+
"""Recursively convert dataclasses to dicts, excluding timestamp fields."""
|
|
50
|
+
if is_dataclass(obj) and not isinstance(obj, type):
|
|
51
|
+
result: dict[str, Any] = {}
|
|
52
|
+
for f in fields(obj):
|
|
53
|
+
if f.name != 'timestamp':
|
|
54
|
+
value = getattr(obj, f.name)
|
|
55
|
+
result[f.name] = _strip_timestamps(value)
|
|
56
|
+
return result
|
|
57
|
+
elif _is_dict(obj):
|
|
58
|
+
return {k: _strip_timestamps(v) for k, v in obj.items() if k != 'timestamp'}
|
|
59
|
+
elif _is_list(obj):
|
|
60
|
+
return [_strip_timestamps(item) for item in obj]
|
|
61
|
+
elif _is_tuple(obj):
|
|
62
|
+
return tuple(_strip_timestamps(item) for item in obj)
|
|
63
|
+
return obj
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def _replace_toolsets(
|
|
67
|
+
inputs: dict[str, Any],
|
|
68
|
+
) -> Any:
|
|
69
|
+
"""Replace Toolset objects with a dict containing only hashable fields."""
|
|
70
|
+
inputs = inputs.copy()
|
|
71
|
+
for key, value in inputs.items():
|
|
72
|
+
if _is_toolset_tool(value):
|
|
73
|
+
inputs[key] = {field.name: getattr(value, field.name) for field in fields(value) if field.name != 'toolset'}
|
|
74
|
+
return inputs
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
class PrefectAgentInputs(CachePolicy):
|
|
78
|
+
"""Cache policy designed to handle input hashing for PrefectAgent cache keys.
|
|
79
|
+
|
|
80
|
+
Computes a cache key based on inputs, ignoring nested 'timestamp' fields
|
|
81
|
+
and serializing RunContext objects to only include hashable fields.
|
|
82
|
+
"""
|
|
83
|
+
|
|
84
|
+
def compute_key(
|
|
85
|
+
self,
|
|
86
|
+
task_ctx: TaskRunContext,
|
|
87
|
+
inputs: dict[str, Any],
|
|
88
|
+
flow_parameters: dict[str, Any],
|
|
89
|
+
**kwargs: Any,
|
|
90
|
+
) -> str | None:
|
|
91
|
+
"""Compute cache key from inputs with timestamps removed and RunContext serialized."""
|
|
92
|
+
if not inputs:
|
|
93
|
+
return None
|
|
94
|
+
|
|
95
|
+
inputs_without_toolsets = _replace_toolsets(inputs)
|
|
96
|
+
inputs_with_hashable_context = _replace_run_context(inputs_without_toolsets)
|
|
97
|
+
filtered_inputs = _strip_timestamps(inputs_with_hashable_context)
|
|
98
|
+
|
|
99
|
+
return INPUTS.compute_key(task_ctx, filtered_inputs, flow_parameters, **kwargs)
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
DEFAULT_PYDANTIC_AI_CACHE_POLICY = PrefectAgentInputs() + TASK_SOURCE + RUN_ID
|
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from prefect import task
|
|
6
|
+
|
|
7
|
+
from pydantic_ai import FunctionToolset, ToolsetTool
|
|
8
|
+
from pydantic_ai.tools import AgentDepsT, RunContext
|
|
9
|
+
|
|
10
|
+
from ._toolset import PrefectWrapperToolset
|
|
11
|
+
from ._types import TaskConfig, default_task_config
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class PrefectFunctionToolset(PrefectWrapperToolset[AgentDepsT]):
|
|
15
|
+
"""A wrapper for FunctionToolset that integrates with Prefect, turning tool calls into Prefect tasks."""
|
|
16
|
+
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
wrapped: FunctionToolset[AgentDepsT],
|
|
20
|
+
*,
|
|
21
|
+
task_config: TaskConfig,
|
|
22
|
+
tool_task_config: dict[str, TaskConfig | None],
|
|
23
|
+
):
|
|
24
|
+
super().__init__(wrapped)
|
|
25
|
+
self._task_config = default_task_config | (task_config or {})
|
|
26
|
+
self._tool_task_config = tool_task_config or {}
|
|
27
|
+
|
|
28
|
+
@task
|
|
29
|
+
async def _call_tool_task(
|
|
30
|
+
tool_name: str,
|
|
31
|
+
tool_args: dict[str, Any],
|
|
32
|
+
ctx: RunContext[AgentDepsT],
|
|
33
|
+
tool: ToolsetTool[AgentDepsT],
|
|
34
|
+
) -> Any:
|
|
35
|
+
return await super(PrefectFunctionToolset, self).call_tool(tool_name, tool_args, ctx, tool)
|
|
36
|
+
|
|
37
|
+
self._call_tool_task = _call_tool_task
|
|
38
|
+
|
|
39
|
+
async def call_tool(
|
|
40
|
+
self,
|
|
41
|
+
name: str,
|
|
42
|
+
tool_args: dict[str, Any],
|
|
43
|
+
ctx: RunContext[AgentDepsT],
|
|
44
|
+
tool: ToolsetTool[AgentDepsT],
|
|
45
|
+
) -> Any:
|
|
46
|
+
"""Call a tool, wrapped as a Prefect task with a descriptive name."""
|
|
47
|
+
# Check if this specific tool has custom config or is disabled
|
|
48
|
+
tool_specific_config = self._tool_task_config.get(name, default_task_config)
|
|
49
|
+
if tool_specific_config is None:
|
|
50
|
+
# None means this tool should not be wrapped as a task
|
|
51
|
+
return await super().call_tool(name, tool_args, ctx, tool)
|
|
52
|
+
|
|
53
|
+
# Merge tool-specific config with default config
|
|
54
|
+
merged_config = self._task_config | tool_specific_config
|
|
55
|
+
|
|
56
|
+
return await self._call_tool_task.with_options(name=f'Call Tool: {name}', **merged_config)(
|
|
57
|
+
name, tool_args, ctx, tool
|
|
58
|
+
)
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from abc import ABC
|
|
4
|
+
from typing import TYPE_CHECKING, Any
|
|
5
|
+
|
|
6
|
+
from prefect import task
|
|
7
|
+
from typing_extensions import Self
|
|
8
|
+
|
|
9
|
+
from pydantic_ai import ToolsetTool
|
|
10
|
+
from pydantic_ai.tools import AgentDepsT, RunContext
|
|
11
|
+
|
|
12
|
+
from ._toolset import PrefectWrapperToolset
|
|
13
|
+
from ._types import TaskConfig, default_task_config
|
|
14
|
+
|
|
15
|
+
if TYPE_CHECKING:
|
|
16
|
+
from pydantic_ai.mcp import MCPServer, ToolResult
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class PrefectMCPServer(PrefectWrapperToolset[AgentDepsT], ABC):
|
|
20
|
+
"""A wrapper for MCPServer that integrates with Prefect, turning call_tool and get_tools into Prefect tasks."""
|
|
21
|
+
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
wrapped: MCPServer,
|
|
25
|
+
*,
|
|
26
|
+
task_config: TaskConfig,
|
|
27
|
+
):
|
|
28
|
+
super().__init__(wrapped)
|
|
29
|
+
self._task_config = default_task_config | (task_config or {})
|
|
30
|
+
self._mcp_id = wrapped.id
|
|
31
|
+
|
|
32
|
+
@task
|
|
33
|
+
async def _call_tool_task(
|
|
34
|
+
tool_name: str,
|
|
35
|
+
tool_args: dict[str, Any],
|
|
36
|
+
ctx: RunContext[AgentDepsT],
|
|
37
|
+
tool: ToolsetTool[AgentDepsT],
|
|
38
|
+
) -> ToolResult:
|
|
39
|
+
return await super(PrefectMCPServer, self).call_tool(tool_name, tool_args, ctx, tool)
|
|
40
|
+
|
|
41
|
+
self._call_tool_task = _call_tool_task
|
|
42
|
+
|
|
43
|
+
async def __aenter__(self) -> Self:
|
|
44
|
+
await self.wrapped.__aenter__()
|
|
45
|
+
return self
|
|
46
|
+
|
|
47
|
+
async def __aexit__(self, *args: Any) -> bool | None:
|
|
48
|
+
return await self.wrapped.__aexit__(*args)
|
|
49
|
+
|
|
50
|
+
async def call_tool(
|
|
51
|
+
self,
|
|
52
|
+
name: str,
|
|
53
|
+
tool_args: dict[str, Any],
|
|
54
|
+
ctx: RunContext[AgentDepsT],
|
|
55
|
+
tool: ToolsetTool[AgentDepsT],
|
|
56
|
+
) -> ToolResult:
|
|
57
|
+
"""Call an MCP tool, wrapped as a Prefect task with a descriptive name."""
|
|
58
|
+
return await self._call_tool_task.with_options(name=f'Call MCP Tool: {name}', **self._task_config)(
|
|
59
|
+
name, tool_args, ctx, tool
|
|
60
|
+
)
|
|
@@ -0,0 +1,152 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from collections.abc import AsyncIterator
|
|
4
|
+
from contextlib import asynccontextmanager
|
|
5
|
+
from datetime import datetime
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
from prefect import task
|
|
9
|
+
from prefect.context import FlowRunContext
|
|
10
|
+
|
|
11
|
+
from pydantic_ai import (
|
|
12
|
+
ModelMessage,
|
|
13
|
+
ModelResponse,
|
|
14
|
+
ModelResponseStreamEvent,
|
|
15
|
+
)
|
|
16
|
+
from pydantic_ai.agent import EventStreamHandler
|
|
17
|
+
from pydantic_ai.models import ModelRequestParameters, StreamedResponse
|
|
18
|
+
from pydantic_ai.models.wrapper import WrapperModel
|
|
19
|
+
from pydantic_ai.settings import ModelSettings
|
|
20
|
+
from pydantic_ai.tools import RunContext
|
|
21
|
+
from pydantic_ai.usage import RequestUsage
|
|
22
|
+
|
|
23
|
+
from ._types import TaskConfig, default_task_config
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class PrefectStreamedResponse(StreamedResponse):
|
|
27
|
+
"""A non-streaming response wrapper for Prefect tasks.
|
|
28
|
+
|
|
29
|
+
When a model request is executed inside a Prefect flow, the entire stream
|
|
30
|
+
is consumed within the task, and this wrapper is returned containing the
|
|
31
|
+
final response.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
def __init__(self, model_request_parameters: ModelRequestParameters, response: ModelResponse):
|
|
35
|
+
super().__init__(model_request_parameters)
|
|
36
|
+
self.response = response
|
|
37
|
+
|
|
38
|
+
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
|
|
39
|
+
"""Return an empty iterator since the stream has already been consumed."""
|
|
40
|
+
return
|
|
41
|
+
# noinspection PyUnreachableCode
|
|
42
|
+
yield
|
|
43
|
+
|
|
44
|
+
def get(self) -> ModelResponse:
|
|
45
|
+
return self.response
|
|
46
|
+
|
|
47
|
+
def usage(self) -> RequestUsage:
|
|
48
|
+
return self.response.usage # pragma: no cover
|
|
49
|
+
|
|
50
|
+
@property
|
|
51
|
+
def model_name(self) -> str:
|
|
52
|
+
return self.response.model_name or '' # pragma: no cover
|
|
53
|
+
|
|
54
|
+
@property
|
|
55
|
+
def provider_name(self) -> str:
|
|
56
|
+
return self.response.provider_name or '' # pragma: no cover
|
|
57
|
+
|
|
58
|
+
@property
|
|
59
|
+
def timestamp(self) -> datetime:
|
|
60
|
+
return self.response.timestamp # pragma: no cover
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class PrefectModel(WrapperModel):
|
|
64
|
+
"""A wrapper for Model that integrates with Prefect, turning request and request_stream into Prefect tasks."""
|
|
65
|
+
|
|
66
|
+
def __init__(
|
|
67
|
+
self,
|
|
68
|
+
model: Any,
|
|
69
|
+
*,
|
|
70
|
+
task_config: TaskConfig,
|
|
71
|
+
event_stream_handler: EventStreamHandler[Any] | None = None,
|
|
72
|
+
):
|
|
73
|
+
super().__init__(model)
|
|
74
|
+
self.task_config = default_task_config | (task_config or {})
|
|
75
|
+
self.event_stream_handler = event_stream_handler
|
|
76
|
+
|
|
77
|
+
@task
|
|
78
|
+
async def wrapped_request(
|
|
79
|
+
messages: list[ModelMessage],
|
|
80
|
+
model_settings: ModelSettings | None,
|
|
81
|
+
model_request_parameters: ModelRequestParameters,
|
|
82
|
+
) -> ModelResponse:
|
|
83
|
+
response = await super(PrefectModel, self).request(messages, model_settings, model_request_parameters)
|
|
84
|
+
return response
|
|
85
|
+
|
|
86
|
+
self._wrapped_request = wrapped_request
|
|
87
|
+
|
|
88
|
+
@task
|
|
89
|
+
async def request_stream_task(
|
|
90
|
+
messages: list[ModelMessage],
|
|
91
|
+
model_settings: ModelSettings | None,
|
|
92
|
+
model_request_parameters: ModelRequestParameters,
|
|
93
|
+
ctx: RunContext[Any] | None,
|
|
94
|
+
) -> ModelResponse:
|
|
95
|
+
async with super(PrefectModel, self).request_stream(
|
|
96
|
+
messages, model_settings, model_request_parameters, ctx
|
|
97
|
+
) as streamed_response:
|
|
98
|
+
if self.event_stream_handler is not None:
|
|
99
|
+
assert ctx is not None, (
|
|
100
|
+
'A Prefect model cannot be used with `pydantic_ai.direct.model_request_stream()` as it requires a `run_context`. '
|
|
101
|
+
'Set an `event_stream_handler` on the agent and use `agent.run()` instead.'
|
|
102
|
+
)
|
|
103
|
+
await self.event_stream_handler(ctx, streamed_response)
|
|
104
|
+
|
|
105
|
+
# Consume the entire stream
|
|
106
|
+
async for _ in streamed_response:
|
|
107
|
+
pass
|
|
108
|
+
response = streamed_response.get()
|
|
109
|
+
return response
|
|
110
|
+
|
|
111
|
+
self._wrapped_request_stream = request_stream_task
|
|
112
|
+
|
|
113
|
+
async def request(
|
|
114
|
+
self,
|
|
115
|
+
messages: list[ModelMessage],
|
|
116
|
+
model_settings: ModelSettings | None,
|
|
117
|
+
model_request_parameters: ModelRequestParameters,
|
|
118
|
+
) -> ModelResponse:
|
|
119
|
+
"""Make a model request, wrapped as a Prefect task when in a flow."""
|
|
120
|
+
return await self._wrapped_request.with_options(
|
|
121
|
+
name=f'Model Request: {self.wrapped.model_name}', **self.task_config
|
|
122
|
+
)(messages, model_settings, model_request_parameters)
|
|
123
|
+
|
|
124
|
+
@asynccontextmanager
|
|
125
|
+
async def request_stream(
|
|
126
|
+
self,
|
|
127
|
+
messages: list[ModelMessage],
|
|
128
|
+
model_settings: ModelSettings | None,
|
|
129
|
+
model_request_parameters: ModelRequestParameters,
|
|
130
|
+
run_context: RunContext[Any] | None = None,
|
|
131
|
+
) -> AsyncIterator[StreamedResponse]:
|
|
132
|
+
"""Make a streaming model request.
|
|
133
|
+
|
|
134
|
+
When inside a Prefect flow, the stream is consumed within a task and
|
|
135
|
+
a non-streaming response is returned. When not in a flow, behaves normally.
|
|
136
|
+
"""
|
|
137
|
+
# Check if we're in a flow context
|
|
138
|
+
flow_run_context = FlowRunContext.get()
|
|
139
|
+
|
|
140
|
+
# If not in a flow, just call the wrapped request_stream method
|
|
141
|
+
if flow_run_context is None:
|
|
142
|
+
async with super().request_stream(
|
|
143
|
+
messages, model_settings, model_request_parameters, run_context
|
|
144
|
+
) as streamed_response:
|
|
145
|
+
yield streamed_response
|
|
146
|
+
return
|
|
147
|
+
|
|
148
|
+
# If in a flow, consume the stream in a task and return the final response
|
|
149
|
+
response = await self._wrapped_request_stream.with_options(
|
|
150
|
+
name=f'Model Request (Streaming): {self.wrapped.model_name}', **self.task_config
|
|
151
|
+
)(messages, model_settings, model_request_parameters, run_context)
|
|
152
|
+
yield PrefectStreamedResponse(model_request_parameters, response)
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from abc import ABC
|
|
4
|
+
from collections.abc import Callable
|
|
5
|
+
from typing import TYPE_CHECKING
|
|
6
|
+
|
|
7
|
+
from pydantic_ai import AbstractToolset, FunctionToolset, WrapperToolset
|
|
8
|
+
from pydantic_ai.tools import AgentDepsT
|
|
9
|
+
|
|
10
|
+
from ._types import TaskConfig
|
|
11
|
+
|
|
12
|
+
if TYPE_CHECKING:
|
|
13
|
+
pass
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class PrefectWrapperToolset(WrapperToolset[AgentDepsT], ABC):
|
|
17
|
+
"""Base class for Prefect-wrapped toolsets."""
|
|
18
|
+
|
|
19
|
+
@property
|
|
20
|
+
def id(self) -> str | None:
|
|
21
|
+
# Prefect toolsets should have IDs for better task naming
|
|
22
|
+
return self.wrapped.id
|
|
23
|
+
|
|
24
|
+
def visit_and_replace(
|
|
25
|
+
self, visitor: Callable[[AbstractToolset[AgentDepsT]], AbstractToolset[AgentDepsT]]
|
|
26
|
+
) -> AbstractToolset[AgentDepsT]:
|
|
27
|
+
# Prefect-ified toolsets cannot be swapped out after the fact.
|
|
28
|
+
return self
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def prefectify_toolset(
|
|
32
|
+
toolset: AbstractToolset[AgentDepsT],
|
|
33
|
+
mcp_task_config: TaskConfig,
|
|
34
|
+
tool_task_config: TaskConfig,
|
|
35
|
+
tool_task_config_by_name: dict[str, TaskConfig | None],
|
|
36
|
+
) -> AbstractToolset[AgentDepsT]:
|
|
37
|
+
"""Wrap a toolset to integrate it with Prefect.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
toolset: The toolset to wrap.
|
|
41
|
+
mcp_task_config: The Prefect task config to use for MCP server tasks.
|
|
42
|
+
tool_task_config: The default Prefect task config to use for tool calls.
|
|
43
|
+
tool_task_config_by_name: Per-tool task configuration. Keys are tool names, values are TaskConfig or None.
|
|
44
|
+
"""
|
|
45
|
+
if isinstance(toolset, FunctionToolset):
|
|
46
|
+
from ._function_toolset import PrefectFunctionToolset
|
|
47
|
+
|
|
48
|
+
return PrefectFunctionToolset(
|
|
49
|
+
wrapped=toolset,
|
|
50
|
+
task_config=tool_task_config,
|
|
51
|
+
tool_task_config=tool_task_config_by_name,
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
try:
|
|
55
|
+
from pydantic_ai.mcp import MCPServer
|
|
56
|
+
|
|
57
|
+
from ._mcp_server import PrefectMCPServer
|
|
58
|
+
except ImportError:
|
|
59
|
+
pass
|
|
60
|
+
else:
|
|
61
|
+
if isinstance(toolset, MCPServer):
|
|
62
|
+
return PrefectMCPServer(
|
|
63
|
+
wrapped=toolset,
|
|
64
|
+
task_config=mcp_task_config,
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
return toolset
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from prefect.cache_policies import CachePolicy
|
|
4
|
+
from prefect.results import ResultStorage
|
|
5
|
+
from typing_extensions import TypedDict
|
|
6
|
+
|
|
7
|
+
from pydantic_ai.durable_exec.prefect._cache_policies import DEFAULT_PYDANTIC_AI_CACHE_POLICY
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class TaskConfig(TypedDict, total=False):
|
|
11
|
+
"""Configuration for a task in Prefect.
|
|
12
|
+
|
|
13
|
+
These options are passed to the `@task` decorator.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
retries: int
|
|
17
|
+
"""Maximum number of retries for the task."""
|
|
18
|
+
|
|
19
|
+
retry_delay_seconds: float | list[float]
|
|
20
|
+
"""Delay between retries in seconds. Can be a single value or a list for custom backoff."""
|
|
21
|
+
|
|
22
|
+
timeout_seconds: float
|
|
23
|
+
"""Maximum time in seconds for the task to complete."""
|
|
24
|
+
|
|
25
|
+
cache_policy: CachePolicy
|
|
26
|
+
"""Prefect cache policy for the task."""
|
|
27
|
+
|
|
28
|
+
persist_result: bool
|
|
29
|
+
"""Whether to persist the task result."""
|
|
30
|
+
|
|
31
|
+
result_storage: ResultStorage
|
|
32
|
+
"""Prefect result storage for the task. Should be a storage block or a block slug like `s3-bucket/my-storage`."""
|
|
33
|
+
|
|
34
|
+
log_prints: bool
|
|
35
|
+
"""Whether to log print statements from the task."""
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
default_task_config = TaskConfig(
|
|
39
|
+
retries=0,
|
|
40
|
+
retry_delay_seconds=1.0,
|
|
41
|
+
persist_result=True,
|
|
42
|
+
log_prints=False,
|
|
43
|
+
cache_policy=DEFAULT_PYDANTIC_AI_CACHE_POLICY,
|
|
44
|
+
)
|
pydantic_ai/models/__init__.py
CHANGED
pydantic_ai/models/google.py
CHANGED
|
@@ -73,6 +73,7 @@ try:
|
|
|
73
73
|
GroundingMetadata,
|
|
74
74
|
HttpOptionsDict,
|
|
75
75
|
MediaResolution,
|
|
76
|
+
Modality,
|
|
76
77
|
Part,
|
|
77
78
|
PartDict,
|
|
78
79
|
SafetySettingDict,
|
|
@@ -415,6 +416,10 @@ class GoogleModel(Model):
|
|
|
415
416
|
tool_config = self._get_tool_config(model_request_parameters, tools)
|
|
416
417
|
system_instruction, contents = await self._map_messages(messages)
|
|
417
418
|
|
|
419
|
+
modalities = [Modality.TEXT.value]
|
|
420
|
+
if self.profile.supports_image_output:
|
|
421
|
+
modalities.append(Modality.IMAGE.value)
|
|
422
|
+
|
|
418
423
|
http_options: HttpOptionsDict = {
|
|
419
424
|
'headers': {'Content-Type': 'application/json', 'User-Agent': get_user_agent()}
|
|
420
425
|
}
|
|
@@ -443,6 +448,7 @@ class GoogleModel(Model):
|
|
|
443
448
|
tool_config=tool_config,
|
|
444
449
|
response_mime_type=response_mime_type,
|
|
445
450
|
response_schema=response_schema,
|
|
451
|
+
response_modalities=modalities,
|
|
446
452
|
)
|
|
447
453
|
return contents, config
|
|
448
454
|
|
pydantic_ai/providers/gateway.py
CHANGED
|
@@ -4,7 +4,6 @@ from __future__ import annotations as _annotations
|
|
|
4
4
|
|
|
5
5
|
import os
|
|
6
6
|
from typing import TYPE_CHECKING, Any, Literal, overload
|
|
7
|
-
from urllib.parse import urljoin
|
|
8
7
|
|
|
9
8
|
import httpx
|
|
10
9
|
|
|
@@ -84,22 +83,22 @@ def gateway_provider(
|
|
|
84
83
|
' to use the Pydantic AI Gateway provider.'
|
|
85
84
|
)
|
|
86
85
|
|
|
87
|
-
base_url = base_url or os.getenv('PYDANTIC_AI_GATEWAY_BASE_URL', 'http://localhost:8787')
|
|
86
|
+
base_url = base_url or os.getenv('PYDANTIC_AI_GATEWAY_BASE_URL', 'http://localhost:8787/proxy')
|
|
88
87
|
http_client = http_client or cached_async_http_client(provider=f'gateway-{upstream_provider}')
|
|
89
88
|
http_client.event_hooks = {'request': [_request_hook]}
|
|
90
89
|
|
|
91
90
|
if upstream_provider in ('openai', 'openai-chat'):
|
|
92
91
|
from .openai import OpenAIProvider
|
|
93
92
|
|
|
94
|
-
return OpenAIProvider(api_key=api_key, base_url=
|
|
93
|
+
return OpenAIProvider(api_key=api_key, base_url=_merge_url_path(base_url, 'openai'), http_client=http_client)
|
|
95
94
|
elif upstream_provider == 'openai-responses':
|
|
96
95
|
from .openai import OpenAIProvider
|
|
97
96
|
|
|
98
|
-
return OpenAIProvider(api_key=api_key, base_url=
|
|
97
|
+
return OpenAIProvider(api_key=api_key, base_url=_merge_url_path(base_url, 'openai'), http_client=http_client)
|
|
99
98
|
elif upstream_provider == 'groq':
|
|
100
99
|
from .groq import GroqProvider
|
|
101
100
|
|
|
102
|
-
return GroqProvider(api_key=api_key, base_url=
|
|
101
|
+
return GroqProvider(api_key=api_key, base_url=_merge_url_path(base_url, 'groq'), http_client=http_client)
|
|
103
102
|
elif upstream_provider == 'anthropic':
|
|
104
103
|
from anthropic import AsyncAnthropic
|
|
105
104
|
|
|
@@ -108,7 +107,7 @@ def gateway_provider(
|
|
|
108
107
|
return AnthropicProvider(
|
|
109
108
|
anthropic_client=AsyncAnthropic(
|
|
110
109
|
auth_token=api_key,
|
|
111
|
-
base_url=
|
|
110
|
+
base_url=_merge_url_path(base_url, 'anthropic'),
|
|
112
111
|
http_client=http_client,
|
|
113
112
|
)
|
|
114
113
|
)
|
|
@@ -122,7 +121,7 @@ def gateway_provider(
|
|
|
122
121
|
vertexai=True,
|
|
123
122
|
api_key='unset',
|
|
124
123
|
http_options={
|
|
125
|
-
'base_url':
|
|
124
|
+
'base_url': _merge_url_path(base_url, 'google-vertex'),
|
|
126
125
|
'headers': {'User-Agent': get_user_agent(), 'Authorization': api_key},
|
|
127
126
|
# TODO(Marcelo): Until https://github.com/googleapis/python-genai/issues/1357 is solved.
|
|
128
127
|
'async_client_args': {
|
|
@@ -185,3 +184,13 @@ async def _request_hook(request: httpx.Request) -> httpx.Request:
|
|
|
185
184
|
request.headers.update(headers)
|
|
186
185
|
|
|
187
186
|
return request
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
def _merge_url_path(base_url: str, path: str) -> str:
|
|
190
|
+
"""Merge a base URL and a path.
|
|
191
|
+
|
|
192
|
+
Args:
|
|
193
|
+
base_url: The base URL to merge.
|
|
194
|
+
path: The path to merge.
|
|
195
|
+
"""
|
|
196
|
+
return base_url.rstrip('/') + '/' + path.lstrip('/')
|
pydantic_ai/toolsets/function.py
CHANGED
|
@@ -109,6 +109,7 @@ class FunctionToolset(AbstractToolset[AgentDepsT]):
|
|
|
109
109
|
/,
|
|
110
110
|
*,
|
|
111
111
|
name: str | None = None,
|
|
112
|
+
description: str | None = None,
|
|
112
113
|
retries: int | None = None,
|
|
113
114
|
prepare: ToolPrepareFunc[AgentDepsT] | None = None,
|
|
114
115
|
docstring_format: DocstringFormat | None = None,
|
|
@@ -126,6 +127,7 @@ class FunctionToolset(AbstractToolset[AgentDepsT]):
|
|
|
126
127
|
/,
|
|
127
128
|
*,
|
|
128
129
|
name: str | None = None,
|
|
130
|
+
description: str | None = None,
|
|
129
131
|
retries: int | None = None,
|
|
130
132
|
prepare: ToolPrepareFunc[AgentDepsT] | None = None,
|
|
131
133
|
docstring_format: DocstringFormat | None = None,
|
|
@@ -169,6 +171,7 @@ class FunctionToolset(AbstractToolset[AgentDepsT]):
|
|
|
169
171
|
Args:
|
|
170
172
|
func: The tool function to register.
|
|
171
173
|
name: The name of the tool, defaults to the function name.
|
|
174
|
+
description: The description of the tool,defaults to the function docstring.
|
|
172
175
|
retries: The number of retries to allow for this tool, defaults to the agent's default retries,
|
|
173
176
|
which defaults to 1.
|
|
174
177
|
prepare: custom method to prepare the tool definition for each step, return `None` to omit this
|
|
@@ -197,18 +200,19 @@ class FunctionToolset(AbstractToolset[AgentDepsT]):
|
|
|
197
200
|
) -> ToolFuncEither[AgentDepsT, ToolParams]:
|
|
198
201
|
# noinspection PyTypeChecker
|
|
199
202
|
self.add_function(
|
|
200
|
-
func_,
|
|
201
|
-
None,
|
|
202
|
-
name,
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
203
|
+
func=func_,
|
|
204
|
+
takes_ctx=None,
|
|
205
|
+
name=name,
|
|
206
|
+
description=description,
|
|
207
|
+
retries=retries,
|
|
208
|
+
prepare=prepare,
|
|
209
|
+
docstring_format=docstring_format,
|
|
210
|
+
require_parameter_descriptions=require_parameter_descriptions,
|
|
211
|
+
schema_generator=schema_generator,
|
|
212
|
+
strict=strict,
|
|
213
|
+
sequential=sequential,
|
|
214
|
+
requires_approval=requires_approval,
|
|
215
|
+
metadata=metadata,
|
|
212
216
|
)
|
|
213
217
|
return func_
|
|
214
218
|
|
|
@@ -219,6 +223,7 @@ class FunctionToolset(AbstractToolset[AgentDepsT]):
|
|
|
219
223
|
func: ToolFuncEither[AgentDepsT, ToolParams],
|
|
220
224
|
takes_ctx: bool | None = None,
|
|
221
225
|
name: str | None = None,
|
|
226
|
+
description: str | None = None,
|
|
222
227
|
retries: int | None = None,
|
|
223
228
|
prepare: ToolPrepareFunc[AgentDepsT] | None = None,
|
|
224
229
|
docstring_format: DocstringFormat | None = None,
|
|
@@ -240,6 +245,7 @@ class FunctionToolset(AbstractToolset[AgentDepsT]):
|
|
|
240
245
|
func: The tool function to register.
|
|
241
246
|
takes_ctx: Whether the function takes a [`RunContext`][pydantic_ai.tools.RunContext] as its first argument. If `None`, this is inferred from the function signature.
|
|
242
247
|
name: The name of the tool, defaults to the function name.
|
|
248
|
+
description: The description of the tool, defaults to the function docstring.
|
|
243
249
|
retries: The number of retries to allow for this tool, defaults to the agent's default retries,
|
|
244
250
|
which defaults to 1.
|
|
245
251
|
prepare: custom method to prepare the tool definition for each step, return `None` to omit this
|
|
@@ -279,6 +285,7 @@ class FunctionToolset(AbstractToolset[AgentDepsT]):
|
|
|
279
285
|
func,
|
|
280
286
|
takes_ctx=takes_ctx,
|
|
281
287
|
name=name,
|
|
288
|
+
description=description,
|
|
282
289
|
max_retries=retries,
|
|
283
290
|
prepare=prepare,
|
|
284
291
|
docstring_format=docstring_format,
|