pydantic-ai-slim 1.0.17__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_parts_manager.py +3 -0
- pydantic_ai/agent/__init__.py +43 -12
- pydantic_ai/durable_exec/prefect/__init__.py +15 -0
- pydantic_ai/durable_exec/prefect/_agent.py +833 -0
- pydantic_ai/durable_exec/prefect/_cache_policies.py +102 -0
- pydantic_ai/durable_exec/prefect/_function_toolset.py +58 -0
- pydantic_ai/durable_exec/prefect/_mcp_server.py +60 -0
- pydantic_ai/durable_exec/prefect/_model.py +152 -0
- pydantic_ai/durable_exec/prefect/_toolset.py +67 -0
- pydantic_ai/durable_exec/prefect/_types.py +44 -0
- pydantic_ai/durable_exec/temporal/__init__.py +2 -0
- pydantic_ai/messages.py +7 -0
- pydantic_ai/models/__init__.py +2 -0
- pydantic_ai/models/google.py +6 -0
- pydantic_ai/models/openai.py +17 -10
- pydantic_ai/providers/__init__.py +4 -0
- pydantic_ai/providers/gateway.py +16 -7
- pydantic_ai/providers/nebius.py +102 -0
- pydantic_ai/toolsets/function.py +19 -12
- {pydantic_ai_slim-1.0.17.dist-info → pydantic_ai_slim-1.1.0.dist-info}/METADATA +5 -3
- {pydantic_ai_slim-1.0.17.dist-info → pydantic_ai_slim-1.1.0.dist-info}/RECORD +24 -15
- {pydantic_ai_slim-1.0.17.dist-info → pydantic_ai_slim-1.1.0.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-1.0.17.dist-info → pydantic_ai_slim-1.1.0.dist-info}/entry_points.txt +0 -0
- {pydantic_ai_slim-1.0.17.dist-info → pydantic_ai_slim-1.1.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+
from dataclasses import fields, is_dataclass
|
|
2
|
+
from typing import Any, TypeGuard
|
|
3
|
+
|
|
4
|
+
from prefect.cache_policies import INPUTS, RUN_ID, TASK_SOURCE, CachePolicy
|
|
5
|
+
from prefect.context import TaskRunContext
|
|
6
|
+
|
|
7
|
+
from pydantic_ai import ToolsetTool
|
|
8
|
+
from pydantic_ai.tools import RunContext
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def _is_dict(obj: Any) -> TypeGuard[dict[str, Any]]:
|
|
12
|
+
return isinstance(obj, dict)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def _is_list(obj: Any) -> TypeGuard[list[Any]]:
|
|
16
|
+
return isinstance(obj, list)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def _is_tuple(obj: Any) -> TypeGuard[tuple[Any, ...]]:
|
|
20
|
+
return isinstance(obj, tuple)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def _is_toolset_tool(obj: Any) -> TypeGuard[ToolsetTool]:
|
|
24
|
+
return isinstance(obj, ToolsetTool)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def _replace_run_context(
|
|
28
|
+
inputs: dict[str, Any],
|
|
29
|
+
) -> Any:
|
|
30
|
+
"""Replace RunContext objects with a dict containing only hashable fields."""
|
|
31
|
+
for key, value in inputs.items():
|
|
32
|
+
if isinstance(value, RunContext):
|
|
33
|
+
inputs[key] = {
|
|
34
|
+
'retries': value.retries,
|
|
35
|
+
'tool_call_id': value.tool_call_id,
|
|
36
|
+
'tool_name': value.tool_name,
|
|
37
|
+
'tool_call_approved': value.tool_call_approved,
|
|
38
|
+
'retry': value.retry,
|
|
39
|
+
'max_retries': value.max_retries,
|
|
40
|
+
'run_step': value.run_step,
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
return inputs
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def _strip_timestamps(
|
|
47
|
+
obj: Any | dict[str, Any] | list[Any] | tuple[Any, ...],
|
|
48
|
+
) -> Any:
|
|
49
|
+
"""Recursively convert dataclasses to dicts, excluding timestamp fields."""
|
|
50
|
+
if is_dataclass(obj) and not isinstance(obj, type):
|
|
51
|
+
result: dict[str, Any] = {}
|
|
52
|
+
for f in fields(obj):
|
|
53
|
+
if f.name != 'timestamp':
|
|
54
|
+
value = getattr(obj, f.name)
|
|
55
|
+
result[f.name] = _strip_timestamps(value)
|
|
56
|
+
return result
|
|
57
|
+
elif _is_dict(obj):
|
|
58
|
+
return {k: _strip_timestamps(v) for k, v in obj.items() if k != 'timestamp'}
|
|
59
|
+
elif _is_list(obj):
|
|
60
|
+
return [_strip_timestamps(item) for item in obj]
|
|
61
|
+
elif _is_tuple(obj):
|
|
62
|
+
return tuple(_strip_timestamps(item) for item in obj)
|
|
63
|
+
return obj
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def _replace_toolsets(
|
|
67
|
+
inputs: dict[str, Any],
|
|
68
|
+
) -> Any:
|
|
69
|
+
"""Replace Toolset objects with a dict containing only hashable fields."""
|
|
70
|
+
inputs = inputs.copy()
|
|
71
|
+
for key, value in inputs.items():
|
|
72
|
+
if _is_toolset_tool(value):
|
|
73
|
+
inputs[key] = {field.name: getattr(value, field.name) for field in fields(value) if field.name != 'toolset'}
|
|
74
|
+
return inputs
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
class PrefectAgentInputs(CachePolicy):
|
|
78
|
+
"""Cache policy designed to handle input hashing for PrefectAgent cache keys.
|
|
79
|
+
|
|
80
|
+
Computes a cache key based on inputs, ignoring nested 'timestamp' fields
|
|
81
|
+
and serializing RunContext objects to only include hashable fields.
|
|
82
|
+
"""
|
|
83
|
+
|
|
84
|
+
def compute_key(
|
|
85
|
+
self,
|
|
86
|
+
task_ctx: TaskRunContext,
|
|
87
|
+
inputs: dict[str, Any],
|
|
88
|
+
flow_parameters: dict[str, Any],
|
|
89
|
+
**kwargs: Any,
|
|
90
|
+
) -> str | None:
|
|
91
|
+
"""Compute cache key from inputs with timestamps removed and RunContext serialized."""
|
|
92
|
+
if not inputs:
|
|
93
|
+
return None
|
|
94
|
+
|
|
95
|
+
inputs_without_toolsets = _replace_toolsets(inputs)
|
|
96
|
+
inputs_with_hashable_context = _replace_run_context(inputs_without_toolsets)
|
|
97
|
+
filtered_inputs = _strip_timestamps(inputs_with_hashable_context)
|
|
98
|
+
|
|
99
|
+
return INPUTS.compute_key(task_ctx, filtered_inputs, flow_parameters, **kwargs)
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
DEFAULT_PYDANTIC_AI_CACHE_POLICY = PrefectAgentInputs() + TASK_SOURCE + RUN_ID
|
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from prefect import task
|
|
6
|
+
|
|
7
|
+
from pydantic_ai import FunctionToolset, ToolsetTool
|
|
8
|
+
from pydantic_ai.tools import AgentDepsT, RunContext
|
|
9
|
+
|
|
10
|
+
from ._toolset import PrefectWrapperToolset
|
|
11
|
+
from ._types import TaskConfig, default_task_config
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class PrefectFunctionToolset(PrefectWrapperToolset[AgentDepsT]):
|
|
15
|
+
"""A wrapper for FunctionToolset that integrates with Prefect, turning tool calls into Prefect tasks."""
|
|
16
|
+
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
wrapped: FunctionToolset[AgentDepsT],
|
|
20
|
+
*,
|
|
21
|
+
task_config: TaskConfig,
|
|
22
|
+
tool_task_config: dict[str, TaskConfig | None],
|
|
23
|
+
):
|
|
24
|
+
super().__init__(wrapped)
|
|
25
|
+
self._task_config = default_task_config | (task_config or {})
|
|
26
|
+
self._tool_task_config = tool_task_config or {}
|
|
27
|
+
|
|
28
|
+
@task
|
|
29
|
+
async def _call_tool_task(
|
|
30
|
+
tool_name: str,
|
|
31
|
+
tool_args: dict[str, Any],
|
|
32
|
+
ctx: RunContext[AgentDepsT],
|
|
33
|
+
tool: ToolsetTool[AgentDepsT],
|
|
34
|
+
) -> Any:
|
|
35
|
+
return await super(PrefectFunctionToolset, self).call_tool(tool_name, tool_args, ctx, tool)
|
|
36
|
+
|
|
37
|
+
self._call_tool_task = _call_tool_task
|
|
38
|
+
|
|
39
|
+
async def call_tool(
|
|
40
|
+
self,
|
|
41
|
+
name: str,
|
|
42
|
+
tool_args: dict[str, Any],
|
|
43
|
+
ctx: RunContext[AgentDepsT],
|
|
44
|
+
tool: ToolsetTool[AgentDepsT],
|
|
45
|
+
) -> Any:
|
|
46
|
+
"""Call a tool, wrapped as a Prefect task with a descriptive name."""
|
|
47
|
+
# Check if this specific tool has custom config or is disabled
|
|
48
|
+
tool_specific_config = self._tool_task_config.get(name, default_task_config)
|
|
49
|
+
if tool_specific_config is None:
|
|
50
|
+
# None means this tool should not be wrapped as a task
|
|
51
|
+
return await super().call_tool(name, tool_args, ctx, tool)
|
|
52
|
+
|
|
53
|
+
# Merge tool-specific config with default config
|
|
54
|
+
merged_config = self._task_config | tool_specific_config
|
|
55
|
+
|
|
56
|
+
return await self._call_tool_task.with_options(name=f'Call Tool: {name}', **merged_config)(
|
|
57
|
+
name, tool_args, ctx, tool
|
|
58
|
+
)
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from abc import ABC
|
|
4
|
+
from typing import TYPE_CHECKING, Any
|
|
5
|
+
|
|
6
|
+
from prefect import task
|
|
7
|
+
from typing_extensions import Self
|
|
8
|
+
|
|
9
|
+
from pydantic_ai import ToolsetTool
|
|
10
|
+
from pydantic_ai.tools import AgentDepsT, RunContext
|
|
11
|
+
|
|
12
|
+
from ._toolset import PrefectWrapperToolset
|
|
13
|
+
from ._types import TaskConfig, default_task_config
|
|
14
|
+
|
|
15
|
+
if TYPE_CHECKING:
|
|
16
|
+
from pydantic_ai.mcp import MCPServer, ToolResult
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class PrefectMCPServer(PrefectWrapperToolset[AgentDepsT], ABC):
|
|
20
|
+
"""A wrapper for MCPServer that integrates with Prefect, turning call_tool and get_tools into Prefect tasks."""
|
|
21
|
+
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
wrapped: MCPServer,
|
|
25
|
+
*,
|
|
26
|
+
task_config: TaskConfig,
|
|
27
|
+
):
|
|
28
|
+
super().__init__(wrapped)
|
|
29
|
+
self._task_config = default_task_config | (task_config or {})
|
|
30
|
+
self._mcp_id = wrapped.id
|
|
31
|
+
|
|
32
|
+
@task
|
|
33
|
+
async def _call_tool_task(
|
|
34
|
+
tool_name: str,
|
|
35
|
+
tool_args: dict[str, Any],
|
|
36
|
+
ctx: RunContext[AgentDepsT],
|
|
37
|
+
tool: ToolsetTool[AgentDepsT],
|
|
38
|
+
) -> ToolResult:
|
|
39
|
+
return await super(PrefectMCPServer, self).call_tool(tool_name, tool_args, ctx, tool)
|
|
40
|
+
|
|
41
|
+
self._call_tool_task = _call_tool_task
|
|
42
|
+
|
|
43
|
+
async def __aenter__(self) -> Self:
|
|
44
|
+
await self.wrapped.__aenter__()
|
|
45
|
+
return self
|
|
46
|
+
|
|
47
|
+
async def __aexit__(self, *args: Any) -> bool | None:
|
|
48
|
+
return await self.wrapped.__aexit__(*args)
|
|
49
|
+
|
|
50
|
+
async def call_tool(
|
|
51
|
+
self,
|
|
52
|
+
name: str,
|
|
53
|
+
tool_args: dict[str, Any],
|
|
54
|
+
ctx: RunContext[AgentDepsT],
|
|
55
|
+
tool: ToolsetTool[AgentDepsT],
|
|
56
|
+
) -> ToolResult:
|
|
57
|
+
"""Call an MCP tool, wrapped as a Prefect task with a descriptive name."""
|
|
58
|
+
return await self._call_tool_task.with_options(name=f'Call MCP Tool: {name}', **self._task_config)(
|
|
59
|
+
name, tool_args, ctx, tool
|
|
60
|
+
)
|
|
@@ -0,0 +1,152 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from collections.abc import AsyncIterator
|
|
4
|
+
from contextlib import asynccontextmanager
|
|
5
|
+
from datetime import datetime
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
from prefect import task
|
|
9
|
+
from prefect.context import FlowRunContext
|
|
10
|
+
|
|
11
|
+
from pydantic_ai import (
|
|
12
|
+
ModelMessage,
|
|
13
|
+
ModelResponse,
|
|
14
|
+
ModelResponseStreamEvent,
|
|
15
|
+
)
|
|
16
|
+
from pydantic_ai.agent import EventStreamHandler
|
|
17
|
+
from pydantic_ai.models import ModelRequestParameters, StreamedResponse
|
|
18
|
+
from pydantic_ai.models.wrapper import WrapperModel
|
|
19
|
+
from pydantic_ai.settings import ModelSettings
|
|
20
|
+
from pydantic_ai.tools import RunContext
|
|
21
|
+
from pydantic_ai.usage import RequestUsage
|
|
22
|
+
|
|
23
|
+
from ._types import TaskConfig, default_task_config
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class PrefectStreamedResponse(StreamedResponse):
|
|
27
|
+
"""A non-streaming response wrapper for Prefect tasks.
|
|
28
|
+
|
|
29
|
+
When a model request is executed inside a Prefect flow, the entire stream
|
|
30
|
+
is consumed within the task, and this wrapper is returned containing the
|
|
31
|
+
final response.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
def __init__(self, model_request_parameters: ModelRequestParameters, response: ModelResponse):
|
|
35
|
+
super().__init__(model_request_parameters)
|
|
36
|
+
self.response = response
|
|
37
|
+
|
|
38
|
+
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
|
|
39
|
+
"""Return an empty iterator since the stream has already been consumed."""
|
|
40
|
+
return
|
|
41
|
+
# noinspection PyUnreachableCode
|
|
42
|
+
yield
|
|
43
|
+
|
|
44
|
+
def get(self) -> ModelResponse:
|
|
45
|
+
return self.response
|
|
46
|
+
|
|
47
|
+
def usage(self) -> RequestUsage:
|
|
48
|
+
return self.response.usage # pragma: no cover
|
|
49
|
+
|
|
50
|
+
@property
|
|
51
|
+
def model_name(self) -> str:
|
|
52
|
+
return self.response.model_name or '' # pragma: no cover
|
|
53
|
+
|
|
54
|
+
@property
|
|
55
|
+
def provider_name(self) -> str:
|
|
56
|
+
return self.response.provider_name or '' # pragma: no cover
|
|
57
|
+
|
|
58
|
+
@property
|
|
59
|
+
def timestamp(self) -> datetime:
|
|
60
|
+
return self.response.timestamp # pragma: no cover
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class PrefectModel(WrapperModel):
|
|
64
|
+
"""A wrapper for Model that integrates with Prefect, turning request and request_stream into Prefect tasks."""
|
|
65
|
+
|
|
66
|
+
def __init__(
|
|
67
|
+
self,
|
|
68
|
+
model: Any,
|
|
69
|
+
*,
|
|
70
|
+
task_config: TaskConfig,
|
|
71
|
+
event_stream_handler: EventStreamHandler[Any] | None = None,
|
|
72
|
+
):
|
|
73
|
+
super().__init__(model)
|
|
74
|
+
self.task_config = default_task_config | (task_config or {})
|
|
75
|
+
self.event_stream_handler = event_stream_handler
|
|
76
|
+
|
|
77
|
+
@task
|
|
78
|
+
async def wrapped_request(
|
|
79
|
+
messages: list[ModelMessage],
|
|
80
|
+
model_settings: ModelSettings | None,
|
|
81
|
+
model_request_parameters: ModelRequestParameters,
|
|
82
|
+
) -> ModelResponse:
|
|
83
|
+
response = await super(PrefectModel, self).request(messages, model_settings, model_request_parameters)
|
|
84
|
+
return response
|
|
85
|
+
|
|
86
|
+
self._wrapped_request = wrapped_request
|
|
87
|
+
|
|
88
|
+
@task
|
|
89
|
+
async def request_stream_task(
|
|
90
|
+
messages: list[ModelMessage],
|
|
91
|
+
model_settings: ModelSettings | None,
|
|
92
|
+
model_request_parameters: ModelRequestParameters,
|
|
93
|
+
ctx: RunContext[Any] | None,
|
|
94
|
+
) -> ModelResponse:
|
|
95
|
+
async with super(PrefectModel, self).request_stream(
|
|
96
|
+
messages, model_settings, model_request_parameters, ctx
|
|
97
|
+
) as streamed_response:
|
|
98
|
+
if self.event_stream_handler is not None:
|
|
99
|
+
assert ctx is not None, (
|
|
100
|
+
'A Prefect model cannot be used with `pydantic_ai.direct.model_request_stream()` as it requires a `run_context`. '
|
|
101
|
+
'Set an `event_stream_handler` on the agent and use `agent.run()` instead.'
|
|
102
|
+
)
|
|
103
|
+
await self.event_stream_handler(ctx, streamed_response)
|
|
104
|
+
|
|
105
|
+
# Consume the entire stream
|
|
106
|
+
async for _ in streamed_response:
|
|
107
|
+
pass
|
|
108
|
+
response = streamed_response.get()
|
|
109
|
+
return response
|
|
110
|
+
|
|
111
|
+
self._wrapped_request_stream = request_stream_task
|
|
112
|
+
|
|
113
|
+
async def request(
|
|
114
|
+
self,
|
|
115
|
+
messages: list[ModelMessage],
|
|
116
|
+
model_settings: ModelSettings | None,
|
|
117
|
+
model_request_parameters: ModelRequestParameters,
|
|
118
|
+
) -> ModelResponse:
|
|
119
|
+
"""Make a model request, wrapped as a Prefect task when in a flow."""
|
|
120
|
+
return await self._wrapped_request.with_options(
|
|
121
|
+
name=f'Model Request: {self.wrapped.model_name}', **self.task_config
|
|
122
|
+
)(messages, model_settings, model_request_parameters)
|
|
123
|
+
|
|
124
|
+
@asynccontextmanager
|
|
125
|
+
async def request_stream(
|
|
126
|
+
self,
|
|
127
|
+
messages: list[ModelMessage],
|
|
128
|
+
model_settings: ModelSettings | None,
|
|
129
|
+
model_request_parameters: ModelRequestParameters,
|
|
130
|
+
run_context: RunContext[Any] | None = None,
|
|
131
|
+
) -> AsyncIterator[StreamedResponse]:
|
|
132
|
+
"""Make a streaming model request.
|
|
133
|
+
|
|
134
|
+
When inside a Prefect flow, the stream is consumed within a task and
|
|
135
|
+
a non-streaming response is returned. When not in a flow, behaves normally.
|
|
136
|
+
"""
|
|
137
|
+
# Check if we're in a flow context
|
|
138
|
+
flow_run_context = FlowRunContext.get()
|
|
139
|
+
|
|
140
|
+
# If not in a flow, just call the wrapped request_stream method
|
|
141
|
+
if flow_run_context is None:
|
|
142
|
+
async with super().request_stream(
|
|
143
|
+
messages, model_settings, model_request_parameters, run_context
|
|
144
|
+
) as streamed_response:
|
|
145
|
+
yield streamed_response
|
|
146
|
+
return
|
|
147
|
+
|
|
148
|
+
# If in a flow, consume the stream in a task and return the final response
|
|
149
|
+
response = await self._wrapped_request_stream.with_options(
|
|
150
|
+
name=f'Model Request (Streaming): {self.wrapped.model_name}', **self.task_config
|
|
151
|
+
)(messages, model_settings, model_request_parameters, run_context)
|
|
152
|
+
yield PrefectStreamedResponse(model_request_parameters, response)
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from abc import ABC
|
|
4
|
+
from collections.abc import Callable
|
|
5
|
+
from typing import TYPE_CHECKING
|
|
6
|
+
|
|
7
|
+
from pydantic_ai import AbstractToolset, FunctionToolset, WrapperToolset
|
|
8
|
+
from pydantic_ai.tools import AgentDepsT
|
|
9
|
+
|
|
10
|
+
from ._types import TaskConfig
|
|
11
|
+
|
|
12
|
+
if TYPE_CHECKING:
|
|
13
|
+
pass
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class PrefectWrapperToolset(WrapperToolset[AgentDepsT], ABC):
|
|
17
|
+
"""Base class for Prefect-wrapped toolsets."""
|
|
18
|
+
|
|
19
|
+
@property
|
|
20
|
+
def id(self) -> str | None:
|
|
21
|
+
# Prefect toolsets should have IDs for better task naming
|
|
22
|
+
return self.wrapped.id
|
|
23
|
+
|
|
24
|
+
def visit_and_replace(
|
|
25
|
+
self, visitor: Callable[[AbstractToolset[AgentDepsT]], AbstractToolset[AgentDepsT]]
|
|
26
|
+
) -> AbstractToolset[AgentDepsT]:
|
|
27
|
+
# Prefect-ified toolsets cannot be swapped out after the fact.
|
|
28
|
+
return self
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def prefectify_toolset(
|
|
32
|
+
toolset: AbstractToolset[AgentDepsT],
|
|
33
|
+
mcp_task_config: TaskConfig,
|
|
34
|
+
tool_task_config: TaskConfig,
|
|
35
|
+
tool_task_config_by_name: dict[str, TaskConfig | None],
|
|
36
|
+
) -> AbstractToolset[AgentDepsT]:
|
|
37
|
+
"""Wrap a toolset to integrate it with Prefect.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
toolset: The toolset to wrap.
|
|
41
|
+
mcp_task_config: The Prefect task config to use for MCP server tasks.
|
|
42
|
+
tool_task_config: The default Prefect task config to use for tool calls.
|
|
43
|
+
tool_task_config_by_name: Per-tool task configuration. Keys are tool names, values are TaskConfig or None.
|
|
44
|
+
"""
|
|
45
|
+
if isinstance(toolset, FunctionToolset):
|
|
46
|
+
from ._function_toolset import PrefectFunctionToolset
|
|
47
|
+
|
|
48
|
+
return PrefectFunctionToolset(
|
|
49
|
+
wrapped=toolset,
|
|
50
|
+
task_config=tool_task_config,
|
|
51
|
+
tool_task_config=tool_task_config_by_name,
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
try:
|
|
55
|
+
from pydantic_ai.mcp import MCPServer
|
|
56
|
+
|
|
57
|
+
from ._mcp_server import PrefectMCPServer
|
|
58
|
+
except ImportError:
|
|
59
|
+
pass
|
|
60
|
+
else:
|
|
61
|
+
if isinstance(toolset, MCPServer):
|
|
62
|
+
return PrefectMCPServer(
|
|
63
|
+
wrapped=toolset,
|
|
64
|
+
task_config=mcp_task_config,
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
return toolset
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from prefect.cache_policies import CachePolicy
|
|
4
|
+
from prefect.results import ResultStorage
|
|
5
|
+
from typing_extensions import TypedDict
|
|
6
|
+
|
|
7
|
+
from pydantic_ai.durable_exec.prefect._cache_policies import DEFAULT_PYDANTIC_AI_CACHE_POLICY
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class TaskConfig(TypedDict, total=False):
|
|
11
|
+
"""Configuration for a task in Prefect.
|
|
12
|
+
|
|
13
|
+
These options are passed to the `@task` decorator.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
retries: int
|
|
17
|
+
"""Maximum number of retries for the task."""
|
|
18
|
+
|
|
19
|
+
retry_delay_seconds: float | list[float]
|
|
20
|
+
"""Delay between retries in seconds. Can be a single value or a list for custom backoff."""
|
|
21
|
+
|
|
22
|
+
timeout_seconds: float
|
|
23
|
+
"""Maximum time in seconds for the task to complete."""
|
|
24
|
+
|
|
25
|
+
cache_policy: CachePolicy
|
|
26
|
+
"""Prefect cache policy for the task."""
|
|
27
|
+
|
|
28
|
+
persist_result: bool
|
|
29
|
+
"""Whether to persist the task result."""
|
|
30
|
+
|
|
31
|
+
result_storage: ResultStorage
|
|
32
|
+
"""Prefect result storage for the task. Should be a storage block or a block slug like `s3-bucket/my-storage`."""
|
|
33
|
+
|
|
34
|
+
log_prints: bool
|
|
35
|
+
"""Whether to log print statements from the task."""
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
default_task_config = TaskConfig(
|
|
39
|
+
retries=0,
|
|
40
|
+
retry_delay_seconds=1.0,
|
|
41
|
+
persist_result=True,
|
|
42
|
+
log_prints=False,
|
|
43
|
+
cache_policy=DEFAULT_PYDANTIC_AI_CACHE_POLICY,
|
|
44
|
+
)
|
|
@@ -62,6 +62,8 @@ class PydanticAIPlugin(ClientPlugin, WorkerPlugin):
|
|
|
62
62
|
'logfire',
|
|
63
63
|
'rich',
|
|
64
64
|
'httpx',
|
|
65
|
+
'anyio',
|
|
66
|
+
'httpcore',
|
|
65
67
|
# Imported inside `logfire._internal.json_encoder` when running `logfire.info` inside an activity with attributes to serialize
|
|
66
68
|
'attrs',
|
|
67
69
|
# Imported inside `logfire._internal.json_schema` when running `logfire.info` inside an activity with attributes to serialize
|
pydantic_ai/messages.py
CHANGED
|
@@ -1052,6 +1052,13 @@ class BaseToolCallPart:
|
|
|
1052
1052
|
In case the tool call id is not provided by the model, Pydantic AI will generate a random one.
|
|
1053
1053
|
"""
|
|
1054
1054
|
|
|
1055
|
+
_: KW_ONLY
|
|
1056
|
+
|
|
1057
|
+
id: str | None = None
|
|
1058
|
+
"""An optional identifier of the tool call part, separate from the tool call ID.
|
|
1059
|
+
|
|
1060
|
+
This is used by some APIs like OpenAI Responses."""
|
|
1061
|
+
|
|
1055
1062
|
def args_as_dict(self) -> dict[str, Any]:
|
|
1056
1063
|
"""Return the arguments as a Python dictionary.
|
|
1057
1064
|
|
pydantic_ai/models/__init__.py
CHANGED
|
@@ -685,12 +685,14 @@ def infer_model(model: Model | KnownModelName | str) -> Model: # noqa: C901
|
|
|
685
685
|
'grok',
|
|
686
686
|
'heroku',
|
|
687
687
|
'moonshotai',
|
|
688
|
+
'ollama',
|
|
688
689
|
'openai',
|
|
689
690
|
'openai-chat',
|
|
690
691
|
'openrouter',
|
|
691
692
|
'together',
|
|
692
693
|
'vercel',
|
|
693
694
|
'litellm',
|
|
695
|
+
'nebius',
|
|
694
696
|
):
|
|
695
697
|
from .openai import OpenAIChatModel
|
|
696
698
|
|
pydantic_ai/models/google.py
CHANGED
|
@@ -73,6 +73,7 @@ try:
|
|
|
73
73
|
GroundingMetadata,
|
|
74
74
|
HttpOptionsDict,
|
|
75
75
|
MediaResolution,
|
|
76
|
+
Modality,
|
|
76
77
|
Part,
|
|
77
78
|
PartDict,
|
|
78
79
|
SafetySettingDict,
|
|
@@ -415,6 +416,10 @@ class GoogleModel(Model):
|
|
|
415
416
|
tool_config = self._get_tool_config(model_request_parameters, tools)
|
|
416
417
|
system_instruction, contents = await self._map_messages(messages)
|
|
417
418
|
|
|
419
|
+
modalities = [Modality.TEXT.value]
|
|
420
|
+
if self.profile.supports_image_output:
|
|
421
|
+
modalities.append(Modality.IMAGE.value)
|
|
422
|
+
|
|
418
423
|
http_options: HttpOptionsDict = {
|
|
419
424
|
'headers': {'Content-Type': 'application/json', 'User-Agent': get_user_agent()}
|
|
420
425
|
}
|
|
@@ -443,6 +448,7 @@ class GoogleModel(Model):
|
|
|
443
448
|
tool_config=tool_config,
|
|
444
449
|
response_mime_type=response_mime_type,
|
|
445
450
|
response_schema=response_schema,
|
|
451
|
+
response_modalities=modalities,
|
|
446
452
|
)
|
|
447
453
|
return contents, config
|
|
448
454
|
|
pydantic_ai/models/openai.py
CHANGED
|
@@ -284,6 +284,7 @@ class OpenAIChatModel(Model):
|
|
|
284
284
|
'together',
|
|
285
285
|
'vercel',
|
|
286
286
|
'litellm',
|
|
287
|
+
'nebius',
|
|
287
288
|
]
|
|
288
289
|
| Provider[AsyncOpenAI] = 'openai',
|
|
289
290
|
profile: ModelProfileSpec | None = None,
|
|
@@ -312,6 +313,7 @@ class OpenAIChatModel(Model):
|
|
|
312
313
|
'together',
|
|
313
314
|
'vercel',
|
|
314
315
|
'litellm',
|
|
316
|
+
'nebius',
|
|
315
317
|
]
|
|
316
318
|
| Provider[AsyncOpenAI] = 'openai',
|
|
317
319
|
profile: ModelProfileSpec | None = None,
|
|
@@ -339,6 +341,7 @@ class OpenAIChatModel(Model):
|
|
|
339
341
|
'together',
|
|
340
342
|
'vercel',
|
|
341
343
|
'litellm',
|
|
344
|
+
'nebius',
|
|
342
345
|
]
|
|
343
346
|
| Provider[AsyncOpenAI] = 'openai',
|
|
344
347
|
profile: ModelProfileSpec | None = None,
|
|
@@ -899,7 +902,7 @@ class OpenAIResponsesModel(Model):
|
|
|
899
902
|
self,
|
|
900
903
|
model_name: OpenAIModelName,
|
|
901
904
|
*,
|
|
902
|
-
provider: Literal['openai', 'deepseek', 'azure', 'openrouter', 'grok', 'fireworks', 'together']
|
|
905
|
+
provider: Literal['openai', 'deepseek', 'azure', 'openrouter', 'grok', 'fireworks', 'together', 'nebius']
|
|
903
906
|
| Provider[AsyncOpenAI] = 'openai',
|
|
904
907
|
profile: ModelProfileSpec | None = None,
|
|
905
908
|
settings: ModelSettings | None = None,
|
|
@@ -1005,7 +1008,12 @@ class OpenAIResponsesModel(Model):
|
|
|
1005
1008
|
items.append(TextPart(content.text, id=item.id))
|
|
1006
1009
|
elif isinstance(item, responses.ResponseFunctionToolCall):
|
|
1007
1010
|
items.append(
|
|
1008
|
-
ToolCallPart(
|
|
1011
|
+
ToolCallPart(
|
|
1012
|
+
item.name,
|
|
1013
|
+
item.arguments,
|
|
1014
|
+
tool_call_id=item.call_id,
|
|
1015
|
+
id=item.id,
|
|
1016
|
+
)
|
|
1009
1017
|
)
|
|
1010
1018
|
elif isinstance(item, responses.ResponseCodeInterpreterToolCall):
|
|
1011
1019
|
call_part, return_part, file_parts = _map_code_interpreter_tool_call(item, self.system)
|
|
@@ -1178,7 +1186,7 @@ class OpenAIResponsesModel(Model):
|
|
|
1178
1186
|
truncation=model_settings.get('openai_truncation', NOT_GIVEN),
|
|
1179
1187
|
timeout=model_settings.get('timeout', NOT_GIVEN),
|
|
1180
1188
|
service_tier=model_settings.get('openai_service_tier', NOT_GIVEN),
|
|
1181
|
-
previous_response_id=previous_response_id,
|
|
1189
|
+
previous_response_id=previous_response_id or NOT_GIVEN,
|
|
1182
1190
|
reasoning=reasoning,
|
|
1183
1191
|
user=model_settings.get('openai_user', NOT_GIVEN),
|
|
1184
1192
|
text=text or NOT_GIVEN,
|
|
@@ -1361,6 +1369,7 @@ class OpenAIResponsesModel(Model):
|
|
|
1361
1369
|
elif isinstance(item, ToolCallPart):
|
|
1362
1370
|
call_id = _guard_tool_call_id(t=item)
|
|
1363
1371
|
call_id, id = _split_combined_tool_call_id(call_id)
|
|
1372
|
+
id = id or item.id
|
|
1364
1373
|
|
|
1365
1374
|
param = responses.ResponseFunctionToolCallParam(
|
|
1366
1375
|
name=item.tool_name,
|
|
@@ -1724,7 +1733,8 @@ class OpenAIResponsesStreamedResponse(StreamedResponse):
|
|
|
1724
1733
|
vendor_part_id=chunk.item.id,
|
|
1725
1734
|
tool_name=chunk.item.name,
|
|
1726
1735
|
args=chunk.item.arguments,
|
|
1727
|
-
tool_call_id=
|
|
1736
|
+
tool_call_id=chunk.item.call_id,
|
|
1737
|
+
id=chunk.item.id,
|
|
1728
1738
|
)
|
|
1729
1739
|
elif isinstance(chunk.item, responses.ResponseReasoningItem):
|
|
1730
1740
|
pass
|
|
@@ -1963,18 +1973,15 @@ def _map_usage(response: chat.ChatCompletion | ChatCompletionChunk | responses.R
|
|
|
1963
1973
|
return u
|
|
1964
1974
|
|
|
1965
1975
|
|
|
1966
|
-
def
|
|
1976
|
+
def _split_combined_tool_call_id(combined_id: str) -> tuple[str, str | None]:
|
|
1967
1977
|
# When reasoning, the Responses API requires the `ResponseFunctionToolCall` to be returned with both the `call_id` and `id` fields.
|
|
1968
|
-
#
|
|
1969
|
-
return f'{call_id}|{id}' if id else call_id
|
|
1978
|
+
# Before our `ToolCallPart` gained the `id` field alongside `tool_call_id` field, we combined the two fields into a single string stored on `tool_call_id`.
|
|
1970
1979
|
|
|
1971
|
-
|
|
1972
|
-
def _split_combined_tool_call_id(combined_id: str) -> tuple[str, str | None]:
|
|
1973
1980
|
if '|' in combined_id:
|
|
1974
1981
|
call_id, id = combined_id.split('|', 1)
|
|
1975
1982
|
return call_id, id
|
|
1976
1983
|
else:
|
|
1977
|
-
return combined_id, None
|
|
1984
|
+
return combined_id, None
|
|
1978
1985
|
|
|
1979
1986
|
|
|
1980
1987
|
def _map_code_interpreter_tool_call(
|
|
@@ -142,6 +142,10 @@ def infer_provider_class(provider: str) -> type[Provider[Any]]: # noqa: C901
|
|
|
142
142
|
from .litellm import LiteLLMProvider
|
|
143
143
|
|
|
144
144
|
return LiteLLMProvider
|
|
145
|
+
elif provider == 'nebius':
|
|
146
|
+
from .nebius import NebiusProvider
|
|
147
|
+
|
|
148
|
+
return NebiusProvider
|
|
145
149
|
else: # pragma: no cover
|
|
146
150
|
raise ValueError(f'Unknown provider: {provider}')
|
|
147
151
|
|