pydantic-ai-slim 1.0.18__py3-none-any.whl → 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_agent_graph.py +13 -0
- pydantic_ai/agent/__init__.py +43 -12
- pydantic_ai/agent/abstract.py +12 -0
- pydantic_ai/durable_exec/prefect/__init__.py +15 -0
- pydantic_ai/durable_exec/prefect/_agent.py +833 -0
- pydantic_ai/durable_exec/prefect/_cache_policies.py +102 -0
- pydantic_ai/durable_exec/prefect/_function_toolset.py +58 -0
- pydantic_ai/durable_exec/prefect/_mcp_server.py +60 -0
- pydantic_ai/durable_exec/prefect/_model.py +152 -0
- pydantic_ai/durable_exec/prefect/_toolset.py +67 -0
- pydantic_ai/durable_exec/prefect/_types.py +44 -0
- pydantic_ai/models/__init__.py +3 -0
- pydantic_ai/models/google.py +6 -0
- pydantic_ai/models/openai.py +42 -45
- pydantic_ai/providers/gateway.py +16 -7
- pydantic_ai/toolsets/function.py +19 -12
- {pydantic_ai_slim-1.0.18.dist-info → pydantic_ai_slim-1.2.0.dist-info}/METADATA +7 -5
- {pydantic_ai_slim-1.0.18.dist-info → pydantic_ai_slim-1.2.0.dist-info}/RECORD +21 -13
- {pydantic_ai_slim-1.0.18.dist-info → pydantic_ai_slim-1.2.0.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-1.0.18.dist-info → pydantic_ai_slim-1.2.0.dist-info}/entry_points.txt +0 -0
- {pydantic_ai_slim-1.0.18.dist-info → pydantic_ai_slim-1.2.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+
from dataclasses import fields, is_dataclass
|
|
2
|
+
from typing import Any, TypeGuard
|
|
3
|
+
|
|
4
|
+
from prefect.cache_policies import INPUTS, RUN_ID, TASK_SOURCE, CachePolicy
|
|
5
|
+
from prefect.context import TaskRunContext
|
|
6
|
+
|
|
7
|
+
from pydantic_ai import ToolsetTool
|
|
8
|
+
from pydantic_ai.tools import RunContext
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def _is_dict(obj: Any) -> TypeGuard[dict[str, Any]]:
|
|
12
|
+
return isinstance(obj, dict)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def _is_list(obj: Any) -> TypeGuard[list[Any]]:
|
|
16
|
+
return isinstance(obj, list)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def _is_tuple(obj: Any) -> TypeGuard[tuple[Any, ...]]:
|
|
20
|
+
return isinstance(obj, tuple)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def _is_toolset_tool(obj: Any) -> TypeGuard[ToolsetTool]:
|
|
24
|
+
return isinstance(obj, ToolsetTool)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def _replace_run_context(
|
|
28
|
+
inputs: dict[str, Any],
|
|
29
|
+
) -> Any:
|
|
30
|
+
"""Replace RunContext objects with a dict containing only hashable fields."""
|
|
31
|
+
for key, value in inputs.items():
|
|
32
|
+
if isinstance(value, RunContext):
|
|
33
|
+
inputs[key] = {
|
|
34
|
+
'retries': value.retries,
|
|
35
|
+
'tool_call_id': value.tool_call_id,
|
|
36
|
+
'tool_name': value.tool_name,
|
|
37
|
+
'tool_call_approved': value.tool_call_approved,
|
|
38
|
+
'retry': value.retry,
|
|
39
|
+
'max_retries': value.max_retries,
|
|
40
|
+
'run_step': value.run_step,
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
return inputs
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def _strip_timestamps(
|
|
47
|
+
obj: Any | dict[str, Any] | list[Any] | tuple[Any, ...],
|
|
48
|
+
) -> Any:
|
|
49
|
+
"""Recursively convert dataclasses to dicts, excluding timestamp fields."""
|
|
50
|
+
if is_dataclass(obj) and not isinstance(obj, type):
|
|
51
|
+
result: dict[str, Any] = {}
|
|
52
|
+
for f in fields(obj):
|
|
53
|
+
if f.name != 'timestamp':
|
|
54
|
+
value = getattr(obj, f.name)
|
|
55
|
+
result[f.name] = _strip_timestamps(value)
|
|
56
|
+
return result
|
|
57
|
+
elif _is_dict(obj):
|
|
58
|
+
return {k: _strip_timestamps(v) for k, v in obj.items() if k != 'timestamp'}
|
|
59
|
+
elif _is_list(obj):
|
|
60
|
+
return [_strip_timestamps(item) for item in obj]
|
|
61
|
+
elif _is_tuple(obj):
|
|
62
|
+
return tuple(_strip_timestamps(item) for item in obj)
|
|
63
|
+
return obj
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def _replace_toolsets(
|
|
67
|
+
inputs: dict[str, Any],
|
|
68
|
+
) -> Any:
|
|
69
|
+
"""Replace Toolset objects with a dict containing only hashable fields."""
|
|
70
|
+
inputs = inputs.copy()
|
|
71
|
+
for key, value in inputs.items():
|
|
72
|
+
if _is_toolset_tool(value):
|
|
73
|
+
inputs[key] = {field.name: getattr(value, field.name) for field in fields(value) if field.name != 'toolset'}
|
|
74
|
+
return inputs
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
class PrefectAgentInputs(CachePolicy):
|
|
78
|
+
"""Cache policy designed to handle input hashing for PrefectAgent cache keys.
|
|
79
|
+
|
|
80
|
+
Computes a cache key based on inputs, ignoring nested 'timestamp' fields
|
|
81
|
+
and serializing RunContext objects to only include hashable fields.
|
|
82
|
+
"""
|
|
83
|
+
|
|
84
|
+
def compute_key(
|
|
85
|
+
self,
|
|
86
|
+
task_ctx: TaskRunContext,
|
|
87
|
+
inputs: dict[str, Any],
|
|
88
|
+
flow_parameters: dict[str, Any],
|
|
89
|
+
**kwargs: Any,
|
|
90
|
+
) -> str | None:
|
|
91
|
+
"""Compute cache key from inputs with timestamps removed and RunContext serialized."""
|
|
92
|
+
if not inputs:
|
|
93
|
+
return None
|
|
94
|
+
|
|
95
|
+
inputs_without_toolsets = _replace_toolsets(inputs)
|
|
96
|
+
inputs_with_hashable_context = _replace_run_context(inputs_without_toolsets)
|
|
97
|
+
filtered_inputs = _strip_timestamps(inputs_with_hashable_context)
|
|
98
|
+
|
|
99
|
+
return INPUTS.compute_key(task_ctx, filtered_inputs, flow_parameters, **kwargs)
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
DEFAULT_PYDANTIC_AI_CACHE_POLICY = PrefectAgentInputs() + TASK_SOURCE + RUN_ID
|
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from prefect import task
|
|
6
|
+
|
|
7
|
+
from pydantic_ai import FunctionToolset, ToolsetTool
|
|
8
|
+
from pydantic_ai.tools import AgentDepsT, RunContext
|
|
9
|
+
|
|
10
|
+
from ._toolset import PrefectWrapperToolset
|
|
11
|
+
from ._types import TaskConfig, default_task_config
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class PrefectFunctionToolset(PrefectWrapperToolset[AgentDepsT]):
|
|
15
|
+
"""A wrapper for FunctionToolset that integrates with Prefect, turning tool calls into Prefect tasks."""
|
|
16
|
+
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
wrapped: FunctionToolset[AgentDepsT],
|
|
20
|
+
*,
|
|
21
|
+
task_config: TaskConfig,
|
|
22
|
+
tool_task_config: dict[str, TaskConfig | None],
|
|
23
|
+
):
|
|
24
|
+
super().__init__(wrapped)
|
|
25
|
+
self._task_config = default_task_config | (task_config or {})
|
|
26
|
+
self._tool_task_config = tool_task_config or {}
|
|
27
|
+
|
|
28
|
+
@task
|
|
29
|
+
async def _call_tool_task(
|
|
30
|
+
tool_name: str,
|
|
31
|
+
tool_args: dict[str, Any],
|
|
32
|
+
ctx: RunContext[AgentDepsT],
|
|
33
|
+
tool: ToolsetTool[AgentDepsT],
|
|
34
|
+
) -> Any:
|
|
35
|
+
return await super(PrefectFunctionToolset, self).call_tool(tool_name, tool_args, ctx, tool)
|
|
36
|
+
|
|
37
|
+
self._call_tool_task = _call_tool_task
|
|
38
|
+
|
|
39
|
+
async def call_tool(
|
|
40
|
+
self,
|
|
41
|
+
name: str,
|
|
42
|
+
tool_args: dict[str, Any],
|
|
43
|
+
ctx: RunContext[AgentDepsT],
|
|
44
|
+
tool: ToolsetTool[AgentDepsT],
|
|
45
|
+
) -> Any:
|
|
46
|
+
"""Call a tool, wrapped as a Prefect task with a descriptive name."""
|
|
47
|
+
# Check if this specific tool has custom config or is disabled
|
|
48
|
+
tool_specific_config = self._tool_task_config.get(name, default_task_config)
|
|
49
|
+
if tool_specific_config is None:
|
|
50
|
+
# None means this tool should not be wrapped as a task
|
|
51
|
+
return await super().call_tool(name, tool_args, ctx, tool)
|
|
52
|
+
|
|
53
|
+
# Merge tool-specific config with default config
|
|
54
|
+
merged_config = self._task_config | tool_specific_config
|
|
55
|
+
|
|
56
|
+
return await self._call_tool_task.with_options(name=f'Call Tool: {name}', **merged_config)(
|
|
57
|
+
name, tool_args, ctx, tool
|
|
58
|
+
)
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from abc import ABC
|
|
4
|
+
from typing import TYPE_CHECKING, Any
|
|
5
|
+
|
|
6
|
+
from prefect import task
|
|
7
|
+
from typing_extensions import Self
|
|
8
|
+
|
|
9
|
+
from pydantic_ai import ToolsetTool
|
|
10
|
+
from pydantic_ai.tools import AgentDepsT, RunContext
|
|
11
|
+
|
|
12
|
+
from ._toolset import PrefectWrapperToolset
|
|
13
|
+
from ._types import TaskConfig, default_task_config
|
|
14
|
+
|
|
15
|
+
if TYPE_CHECKING:
|
|
16
|
+
from pydantic_ai.mcp import MCPServer, ToolResult
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class PrefectMCPServer(PrefectWrapperToolset[AgentDepsT], ABC):
|
|
20
|
+
"""A wrapper for MCPServer that integrates with Prefect, turning call_tool and get_tools into Prefect tasks."""
|
|
21
|
+
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
wrapped: MCPServer,
|
|
25
|
+
*,
|
|
26
|
+
task_config: TaskConfig,
|
|
27
|
+
):
|
|
28
|
+
super().__init__(wrapped)
|
|
29
|
+
self._task_config = default_task_config | (task_config or {})
|
|
30
|
+
self._mcp_id = wrapped.id
|
|
31
|
+
|
|
32
|
+
@task
|
|
33
|
+
async def _call_tool_task(
|
|
34
|
+
tool_name: str,
|
|
35
|
+
tool_args: dict[str, Any],
|
|
36
|
+
ctx: RunContext[AgentDepsT],
|
|
37
|
+
tool: ToolsetTool[AgentDepsT],
|
|
38
|
+
) -> ToolResult:
|
|
39
|
+
return await super(PrefectMCPServer, self).call_tool(tool_name, tool_args, ctx, tool)
|
|
40
|
+
|
|
41
|
+
self._call_tool_task = _call_tool_task
|
|
42
|
+
|
|
43
|
+
async def __aenter__(self) -> Self:
|
|
44
|
+
await self.wrapped.__aenter__()
|
|
45
|
+
return self
|
|
46
|
+
|
|
47
|
+
async def __aexit__(self, *args: Any) -> bool | None:
|
|
48
|
+
return await self.wrapped.__aexit__(*args)
|
|
49
|
+
|
|
50
|
+
async def call_tool(
|
|
51
|
+
self,
|
|
52
|
+
name: str,
|
|
53
|
+
tool_args: dict[str, Any],
|
|
54
|
+
ctx: RunContext[AgentDepsT],
|
|
55
|
+
tool: ToolsetTool[AgentDepsT],
|
|
56
|
+
) -> ToolResult:
|
|
57
|
+
"""Call an MCP tool, wrapped as a Prefect task with a descriptive name."""
|
|
58
|
+
return await self._call_tool_task.with_options(name=f'Call MCP Tool: {name}', **self._task_config)(
|
|
59
|
+
name, tool_args, ctx, tool
|
|
60
|
+
)
|
|
@@ -0,0 +1,152 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from collections.abc import AsyncIterator
|
|
4
|
+
from contextlib import asynccontextmanager
|
|
5
|
+
from datetime import datetime
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
from prefect import task
|
|
9
|
+
from prefect.context import FlowRunContext
|
|
10
|
+
|
|
11
|
+
from pydantic_ai import (
|
|
12
|
+
ModelMessage,
|
|
13
|
+
ModelResponse,
|
|
14
|
+
ModelResponseStreamEvent,
|
|
15
|
+
)
|
|
16
|
+
from pydantic_ai.agent import EventStreamHandler
|
|
17
|
+
from pydantic_ai.models import ModelRequestParameters, StreamedResponse
|
|
18
|
+
from pydantic_ai.models.wrapper import WrapperModel
|
|
19
|
+
from pydantic_ai.settings import ModelSettings
|
|
20
|
+
from pydantic_ai.tools import RunContext
|
|
21
|
+
from pydantic_ai.usage import RequestUsage
|
|
22
|
+
|
|
23
|
+
from ._types import TaskConfig, default_task_config
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class PrefectStreamedResponse(StreamedResponse):
|
|
27
|
+
"""A non-streaming response wrapper for Prefect tasks.
|
|
28
|
+
|
|
29
|
+
When a model request is executed inside a Prefect flow, the entire stream
|
|
30
|
+
is consumed within the task, and this wrapper is returned containing the
|
|
31
|
+
final response.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
def __init__(self, model_request_parameters: ModelRequestParameters, response: ModelResponse):
|
|
35
|
+
super().__init__(model_request_parameters)
|
|
36
|
+
self.response = response
|
|
37
|
+
|
|
38
|
+
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
|
|
39
|
+
"""Return an empty iterator since the stream has already been consumed."""
|
|
40
|
+
return
|
|
41
|
+
# noinspection PyUnreachableCode
|
|
42
|
+
yield
|
|
43
|
+
|
|
44
|
+
def get(self) -> ModelResponse:
|
|
45
|
+
return self.response
|
|
46
|
+
|
|
47
|
+
def usage(self) -> RequestUsage:
|
|
48
|
+
return self.response.usage # pragma: no cover
|
|
49
|
+
|
|
50
|
+
@property
|
|
51
|
+
def model_name(self) -> str:
|
|
52
|
+
return self.response.model_name or '' # pragma: no cover
|
|
53
|
+
|
|
54
|
+
@property
|
|
55
|
+
def provider_name(self) -> str:
|
|
56
|
+
return self.response.provider_name or '' # pragma: no cover
|
|
57
|
+
|
|
58
|
+
@property
|
|
59
|
+
def timestamp(self) -> datetime:
|
|
60
|
+
return self.response.timestamp # pragma: no cover
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class PrefectModel(WrapperModel):
|
|
64
|
+
"""A wrapper for Model that integrates with Prefect, turning request and request_stream into Prefect tasks."""
|
|
65
|
+
|
|
66
|
+
def __init__(
|
|
67
|
+
self,
|
|
68
|
+
model: Any,
|
|
69
|
+
*,
|
|
70
|
+
task_config: TaskConfig,
|
|
71
|
+
event_stream_handler: EventStreamHandler[Any] | None = None,
|
|
72
|
+
):
|
|
73
|
+
super().__init__(model)
|
|
74
|
+
self.task_config = default_task_config | (task_config or {})
|
|
75
|
+
self.event_stream_handler = event_stream_handler
|
|
76
|
+
|
|
77
|
+
@task
|
|
78
|
+
async def wrapped_request(
|
|
79
|
+
messages: list[ModelMessage],
|
|
80
|
+
model_settings: ModelSettings | None,
|
|
81
|
+
model_request_parameters: ModelRequestParameters,
|
|
82
|
+
) -> ModelResponse:
|
|
83
|
+
response = await super(PrefectModel, self).request(messages, model_settings, model_request_parameters)
|
|
84
|
+
return response
|
|
85
|
+
|
|
86
|
+
self._wrapped_request = wrapped_request
|
|
87
|
+
|
|
88
|
+
@task
|
|
89
|
+
async def request_stream_task(
|
|
90
|
+
messages: list[ModelMessage],
|
|
91
|
+
model_settings: ModelSettings | None,
|
|
92
|
+
model_request_parameters: ModelRequestParameters,
|
|
93
|
+
ctx: RunContext[Any] | None,
|
|
94
|
+
) -> ModelResponse:
|
|
95
|
+
async with super(PrefectModel, self).request_stream(
|
|
96
|
+
messages, model_settings, model_request_parameters, ctx
|
|
97
|
+
) as streamed_response:
|
|
98
|
+
if self.event_stream_handler is not None:
|
|
99
|
+
assert ctx is not None, (
|
|
100
|
+
'A Prefect model cannot be used with `pydantic_ai.direct.model_request_stream()` as it requires a `run_context`. '
|
|
101
|
+
'Set an `event_stream_handler` on the agent and use `agent.run()` instead.'
|
|
102
|
+
)
|
|
103
|
+
await self.event_stream_handler(ctx, streamed_response)
|
|
104
|
+
|
|
105
|
+
# Consume the entire stream
|
|
106
|
+
async for _ in streamed_response:
|
|
107
|
+
pass
|
|
108
|
+
response = streamed_response.get()
|
|
109
|
+
return response
|
|
110
|
+
|
|
111
|
+
self._wrapped_request_stream = request_stream_task
|
|
112
|
+
|
|
113
|
+
async def request(
|
|
114
|
+
self,
|
|
115
|
+
messages: list[ModelMessage],
|
|
116
|
+
model_settings: ModelSettings | None,
|
|
117
|
+
model_request_parameters: ModelRequestParameters,
|
|
118
|
+
) -> ModelResponse:
|
|
119
|
+
"""Make a model request, wrapped as a Prefect task when in a flow."""
|
|
120
|
+
return await self._wrapped_request.with_options(
|
|
121
|
+
name=f'Model Request: {self.wrapped.model_name}', **self.task_config
|
|
122
|
+
)(messages, model_settings, model_request_parameters)
|
|
123
|
+
|
|
124
|
+
@asynccontextmanager
|
|
125
|
+
async def request_stream(
|
|
126
|
+
self,
|
|
127
|
+
messages: list[ModelMessage],
|
|
128
|
+
model_settings: ModelSettings | None,
|
|
129
|
+
model_request_parameters: ModelRequestParameters,
|
|
130
|
+
run_context: RunContext[Any] | None = None,
|
|
131
|
+
) -> AsyncIterator[StreamedResponse]:
|
|
132
|
+
"""Make a streaming model request.
|
|
133
|
+
|
|
134
|
+
When inside a Prefect flow, the stream is consumed within a task and
|
|
135
|
+
a non-streaming response is returned. When not in a flow, behaves normally.
|
|
136
|
+
"""
|
|
137
|
+
# Check if we're in a flow context
|
|
138
|
+
flow_run_context = FlowRunContext.get()
|
|
139
|
+
|
|
140
|
+
# If not in a flow, just call the wrapped request_stream method
|
|
141
|
+
if flow_run_context is None:
|
|
142
|
+
async with super().request_stream(
|
|
143
|
+
messages, model_settings, model_request_parameters, run_context
|
|
144
|
+
) as streamed_response:
|
|
145
|
+
yield streamed_response
|
|
146
|
+
return
|
|
147
|
+
|
|
148
|
+
# If in a flow, consume the stream in a task and return the final response
|
|
149
|
+
response = await self._wrapped_request_stream.with_options(
|
|
150
|
+
name=f'Model Request (Streaming): {self.wrapped.model_name}', **self.task_config
|
|
151
|
+
)(messages, model_settings, model_request_parameters, run_context)
|
|
152
|
+
yield PrefectStreamedResponse(model_request_parameters, response)
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from abc import ABC
|
|
4
|
+
from collections.abc import Callable
|
|
5
|
+
from typing import TYPE_CHECKING
|
|
6
|
+
|
|
7
|
+
from pydantic_ai import AbstractToolset, FunctionToolset, WrapperToolset
|
|
8
|
+
from pydantic_ai.tools import AgentDepsT
|
|
9
|
+
|
|
10
|
+
from ._types import TaskConfig
|
|
11
|
+
|
|
12
|
+
if TYPE_CHECKING:
|
|
13
|
+
pass
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class PrefectWrapperToolset(WrapperToolset[AgentDepsT], ABC):
|
|
17
|
+
"""Base class for Prefect-wrapped toolsets."""
|
|
18
|
+
|
|
19
|
+
@property
|
|
20
|
+
def id(self) -> str | None:
|
|
21
|
+
# Prefect toolsets should have IDs for better task naming
|
|
22
|
+
return self.wrapped.id
|
|
23
|
+
|
|
24
|
+
def visit_and_replace(
|
|
25
|
+
self, visitor: Callable[[AbstractToolset[AgentDepsT]], AbstractToolset[AgentDepsT]]
|
|
26
|
+
) -> AbstractToolset[AgentDepsT]:
|
|
27
|
+
# Prefect-ified toolsets cannot be swapped out after the fact.
|
|
28
|
+
return self
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def prefectify_toolset(
|
|
32
|
+
toolset: AbstractToolset[AgentDepsT],
|
|
33
|
+
mcp_task_config: TaskConfig,
|
|
34
|
+
tool_task_config: TaskConfig,
|
|
35
|
+
tool_task_config_by_name: dict[str, TaskConfig | None],
|
|
36
|
+
) -> AbstractToolset[AgentDepsT]:
|
|
37
|
+
"""Wrap a toolset to integrate it with Prefect.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
toolset: The toolset to wrap.
|
|
41
|
+
mcp_task_config: The Prefect task config to use for MCP server tasks.
|
|
42
|
+
tool_task_config: The default Prefect task config to use for tool calls.
|
|
43
|
+
tool_task_config_by_name: Per-tool task configuration. Keys are tool names, values are TaskConfig or None.
|
|
44
|
+
"""
|
|
45
|
+
if isinstance(toolset, FunctionToolset):
|
|
46
|
+
from ._function_toolset import PrefectFunctionToolset
|
|
47
|
+
|
|
48
|
+
return PrefectFunctionToolset(
|
|
49
|
+
wrapped=toolset,
|
|
50
|
+
task_config=tool_task_config,
|
|
51
|
+
tool_task_config=tool_task_config_by_name,
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
try:
|
|
55
|
+
from pydantic_ai.mcp import MCPServer
|
|
56
|
+
|
|
57
|
+
from ._mcp_server import PrefectMCPServer
|
|
58
|
+
except ImportError:
|
|
59
|
+
pass
|
|
60
|
+
else:
|
|
61
|
+
if isinstance(toolset, MCPServer):
|
|
62
|
+
return PrefectMCPServer(
|
|
63
|
+
wrapped=toolset,
|
|
64
|
+
task_config=mcp_task_config,
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
return toolset
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from prefect.cache_policies import CachePolicy
|
|
4
|
+
from prefect.results import ResultStorage
|
|
5
|
+
from typing_extensions import TypedDict
|
|
6
|
+
|
|
7
|
+
from pydantic_ai.durable_exec.prefect._cache_policies import DEFAULT_PYDANTIC_AI_CACHE_POLICY
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class TaskConfig(TypedDict, total=False):
|
|
11
|
+
"""Configuration for a task in Prefect.
|
|
12
|
+
|
|
13
|
+
These options are passed to the `@task` decorator.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
retries: int
|
|
17
|
+
"""Maximum number of retries for the task."""
|
|
18
|
+
|
|
19
|
+
retry_delay_seconds: float | list[float]
|
|
20
|
+
"""Delay between retries in seconds. Can be a single value or a list for custom backoff."""
|
|
21
|
+
|
|
22
|
+
timeout_seconds: float
|
|
23
|
+
"""Maximum time in seconds for the task to complete."""
|
|
24
|
+
|
|
25
|
+
cache_policy: CachePolicy
|
|
26
|
+
"""Prefect cache policy for the task."""
|
|
27
|
+
|
|
28
|
+
persist_result: bool
|
|
29
|
+
"""Whether to persist the task result."""
|
|
30
|
+
|
|
31
|
+
result_storage: ResultStorage
|
|
32
|
+
"""Prefect result storage for the task. Should be a storage block or a block slug like `s3-bucket/my-storage`."""
|
|
33
|
+
|
|
34
|
+
log_prints: bool
|
|
35
|
+
"""Whether to log print statements from the task."""
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
default_task_config = TaskConfig(
|
|
39
|
+
retries=0,
|
|
40
|
+
retry_delay_seconds=1.0,
|
|
41
|
+
persist_result=True,
|
|
42
|
+
log_prints=False,
|
|
43
|
+
cache_policy=DEFAULT_PYDANTIC_AI_CACHE_POLICY,
|
|
44
|
+
)
|
pydantic_ai/models/__init__.py
CHANGED
|
@@ -55,6 +55,8 @@ KnownModelName = TypeAliasType(
|
|
|
55
55
|
'anthropic:claude-3-5-sonnet-20240620',
|
|
56
56
|
'anthropic:claude-3-5-sonnet-20241022',
|
|
57
57
|
'anthropic:claude-3-5-sonnet-latest',
|
|
58
|
+
'anthropic:claude-haiku-4-5',
|
|
59
|
+
'anthropic:claude-haiku-4-5-20251001',
|
|
58
60
|
'anthropic:claude-3-7-sonnet-20250219',
|
|
59
61
|
'anthropic:claude-3-7-sonnet-latest',
|
|
60
62
|
'anthropic:claude-3-haiku-20240307',
|
|
@@ -685,6 +687,7 @@ def infer_model(model: Model | KnownModelName | str) -> Model: # noqa: C901
|
|
|
685
687
|
'grok',
|
|
686
688
|
'heroku',
|
|
687
689
|
'moonshotai',
|
|
690
|
+
'ollama',
|
|
688
691
|
'openai',
|
|
689
692
|
'openai-chat',
|
|
690
693
|
'openrouter',
|
pydantic_ai/models/google.py
CHANGED
|
@@ -73,6 +73,7 @@ try:
|
|
|
73
73
|
GroundingMetadata,
|
|
74
74
|
HttpOptionsDict,
|
|
75
75
|
MediaResolution,
|
|
76
|
+
Modality,
|
|
76
77
|
Part,
|
|
77
78
|
PartDict,
|
|
78
79
|
SafetySettingDict,
|
|
@@ -415,6 +416,10 @@ class GoogleModel(Model):
|
|
|
415
416
|
tool_config = self._get_tool_config(model_request_parameters, tools)
|
|
416
417
|
system_instruction, contents = await self._map_messages(messages)
|
|
417
418
|
|
|
419
|
+
modalities = [Modality.TEXT.value]
|
|
420
|
+
if self.profile.supports_image_output:
|
|
421
|
+
modalities.append(Modality.IMAGE.value)
|
|
422
|
+
|
|
418
423
|
http_options: HttpOptionsDict = {
|
|
419
424
|
'headers': {'Content-Type': 'application/json', 'User-Agent': get_user_agent()}
|
|
420
425
|
}
|
|
@@ -443,6 +448,7 @@ class GoogleModel(Model):
|
|
|
443
448
|
tool_config=tool_config,
|
|
444
449
|
response_mime_type=response_mime_type,
|
|
445
450
|
response_schema=response_schema,
|
|
451
|
+
response_modalities=modalities,
|
|
446
452
|
)
|
|
447
453
|
return contents, config
|
|
448
454
|
|