pydantic-ai-slim 1.0.1__py3-none-any.whl → 1.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pydantic_ai/_agent_graph.py +50 -31
- pydantic_ai/_output.py +19 -7
- pydantic_ai/_parts_manager.py +8 -10
- pydantic_ai/_tool_manager.py +21 -0
- pydantic_ai/ag_ui.py +32 -17
- pydantic_ai/agent/__init__.py +3 -0
- pydantic_ai/agent/abstract.py +8 -0
- pydantic_ai/durable_exec/dbos/__init__.py +6 -0
- pydantic_ai/durable_exec/dbos/_agent.py +721 -0
- pydantic_ai/durable_exec/dbos/_mcp_server.py +89 -0
- pydantic_ai/durable_exec/dbos/_model.py +137 -0
- pydantic_ai/durable_exec/dbos/_utils.py +10 -0
- pydantic_ai/durable_exec/temporal/_agent.py +1 -1
- pydantic_ai/mcp.py +1 -1
- pydantic_ai/messages.py +42 -6
- pydantic_ai/models/__init__.py +8 -0
- pydantic_ai/models/anthropic.py +79 -25
- pydantic_ai/models/bedrock.py +82 -31
- pydantic_ai/models/cohere.py +39 -13
- pydantic_ai/models/function.py +8 -1
- pydantic_ai/models/google.py +105 -37
- pydantic_ai/models/groq.py +35 -7
- pydantic_ai/models/huggingface.py +27 -5
- pydantic_ai/models/instrumented.py +27 -14
- pydantic_ai/models/mistral.py +54 -20
- pydantic_ai/models/openai.py +151 -57
- pydantic_ai/profiles/openai.py +7 -0
- pydantic_ai/providers/bedrock.py +20 -4
- pydantic_ai/settings.py +1 -0
- pydantic_ai/tools.py +11 -0
- pydantic_ai/toolsets/function.py +7 -0
- {pydantic_ai_slim-1.0.1.dist-info → pydantic_ai_slim-1.0.3.dist-info}/METADATA +8 -6
- {pydantic_ai_slim-1.0.1.dist-info → pydantic_ai_slim-1.0.3.dist-info}/RECORD +36 -31
- {pydantic_ai_slim-1.0.1.dist-info → pydantic_ai_slim-1.0.3.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-1.0.1.dist-info → pydantic_ai_slim-1.0.3.dist-info}/entry_points.txt +0 -0
- {pydantic_ai_slim-1.0.1.dist-info → pydantic_ai_slim-1.0.3.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from abc import ABC
|
|
4
|
+
from collections.abc import Callable
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
from dbos import DBOS
|
|
8
|
+
from typing_extensions import Self
|
|
9
|
+
|
|
10
|
+
from pydantic_ai.mcp import MCPServer, ToolResult
|
|
11
|
+
from pydantic_ai.tools import AgentDepsT, RunContext
|
|
12
|
+
from pydantic_ai.toolsets.abstract import AbstractToolset, ToolsetTool
|
|
13
|
+
from pydantic_ai.toolsets.wrapper import WrapperToolset
|
|
14
|
+
|
|
15
|
+
from ._utils import StepConfig
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class DBOSMCPServer(WrapperToolset[AgentDepsT], ABC):
|
|
19
|
+
"""A wrapper for MCPServer that integrates with DBOS, turning call_tool and get_tools to DBOS steps."""
|
|
20
|
+
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
wrapped: MCPServer,
|
|
24
|
+
*,
|
|
25
|
+
step_name_prefix: str,
|
|
26
|
+
step_config: StepConfig,
|
|
27
|
+
):
|
|
28
|
+
super().__init__(wrapped)
|
|
29
|
+
self._step_config = step_config or {}
|
|
30
|
+
self._step_name_prefix = step_name_prefix
|
|
31
|
+
id_suffix = f'__{wrapped.id}' if wrapped.id else ''
|
|
32
|
+
self._name = f'{step_name_prefix}__mcp_server{id_suffix}'
|
|
33
|
+
|
|
34
|
+
# Wrap get_tools in a DBOS step.
|
|
35
|
+
@DBOS.step(
|
|
36
|
+
name=f'{self._name}.get_tools',
|
|
37
|
+
**self._step_config,
|
|
38
|
+
)
|
|
39
|
+
async def wrapped_get_tools_step(
|
|
40
|
+
ctx: RunContext[AgentDepsT],
|
|
41
|
+
) -> dict[str, ToolsetTool[AgentDepsT]]:
|
|
42
|
+
return await super(DBOSMCPServer, self).get_tools(ctx)
|
|
43
|
+
|
|
44
|
+
self._dbos_wrapped_get_tools_step = wrapped_get_tools_step
|
|
45
|
+
|
|
46
|
+
# Wrap call_tool in a DBOS step.
|
|
47
|
+
@DBOS.step(
|
|
48
|
+
name=f'{self._name}.call_tool',
|
|
49
|
+
**self._step_config,
|
|
50
|
+
)
|
|
51
|
+
async def wrapped_call_tool_step(
|
|
52
|
+
name: str,
|
|
53
|
+
tool_args: dict[str, Any],
|
|
54
|
+
ctx: RunContext[AgentDepsT],
|
|
55
|
+
tool: ToolsetTool[AgentDepsT],
|
|
56
|
+
) -> ToolResult:
|
|
57
|
+
return await super(DBOSMCPServer, self).call_tool(name, tool_args, ctx, tool)
|
|
58
|
+
|
|
59
|
+
self._dbos_wrapped_call_tool_step = wrapped_call_tool_step
|
|
60
|
+
|
|
61
|
+
@property
|
|
62
|
+
def id(self) -> str | None:
|
|
63
|
+
return self.wrapped.id
|
|
64
|
+
|
|
65
|
+
async def __aenter__(self) -> Self:
|
|
66
|
+
# The wrapped MCPServer enters itself around listing and calling tools
|
|
67
|
+
# so we don't need to enter it here (nor could we because we're not inside a DBOS step).
|
|
68
|
+
return self
|
|
69
|
+
|
|
70
|
+
async def __aexit__(self, *args: Any) -> bool | None:
|
|
71
|
+
return None
|
|
72
|
+
|
|
73
|
+
def visit_and_replace(
|
|
74
|
+
self, visitor: Callable[[AbstractToolset[AgentDepsT]], AbstractToolset[AgentDepsT]]
|
|
75
|
+
) -> AbstractToolset[AgentDepsT]:
|
|
76
|
+
# DBOS-ified toolsets cannot be swapped out after the fact.
|
|
77
|
+
return self
|
|
78
|
+
|
|
79
|
+
async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]:
|
|
80
|
+
return await self._dbos_wrapped_get_tools_step(ctx)
|
|
81
|
+
|
|
82
|
+
async def call_tool(
|
|
83
|
+
self,
|
|
84
|
+
name: str,
|
|
85
|
+
tool_args: dict[str, Any],
|
|
86
|
+
ctx: RunContext[AgentDepsT],
|
|
87
|
+
tool: ToolsetTool[AgentDepsT],
|
|
88
|
+
) -> ToolResult:
|
|
89
|
+
return await self._dbos_wrapped_call_tool_step(name, tool_args, ctx, tool)
|
|
@@ -0,0 +1,137 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from collections.abc import AsyncIterator
|
|
4
|
+
from contextlib import asynccontextmanager
|
|
5
|
+
from datetime import datetime
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
from dbos import DBOS
|
|
9
|
+
|
|
10
|
+
from pydantic_ai.agent import EventStreamHandler
|
|
11
|
+
from pydantic_ai.messages import (
|
|
12
|
+
ModelMessage,
|
|
13
|
+
ModelResponse,
|
|
14
|
+
ModelResponseStreamEvent,
|
|
15
|
+
)
|
|
16
|
+
from pydantic_ai.models import Model, ModelRequestParameters, StreamedResponse
|
|
17
|
+
from pydantic_ai.models.wrapper import WrapperModel
|
|
18
|
+
from pydantic_ai.settings import ModelSettings
|
|
19
|
+
from pydantic_ai.tools import RunContext
|
|
20
|
+
from pydantic_ai.usage import RequestUsage
|
|
21
|
+
|
|
22
|
+
from ._utils import StepConfig
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class DBOSStreamedResponse(StreamedResponse):
|
|
26
|
+
def __init__(self, model_request_parameters: ModelRequestParameters, response: ModelResponse):
|
|
27
|
+
super().__init__(model_request_parameters)
|
|
28
|
+
self.response = response
|
|
29
|
+
|
|
30
|
+
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
|
|
31
|
+
return
|
|
32
|
+
# noinspection PyUnreachableCode
|
|
33
|
+
yield
|
|
34
|
+
|
|
35
|
+
def get(self) -> ModelResponse:
|
|
36
|
+
return self.response
|
|
37
|
+
|
|
38
|
+
def usage(self) -> RequestUsage:
|
|
39
|
+
return self.response.usage # pragma: no cover
|
|
40
|
+
|
|
41
|
+
@property
|
|
42
|
+
def model_name(self) -> str:
|
|
43
|
+
return self.response.model_name or '' # pragma: no cover
|
|
44
|
+
|
|
45
|
+
@property
|
|
46
|
+
def provider_name(self) -> str:
|
|
47
|
+
return self.response.provider_name or '' # pragma: no cover
|
|
48
|
+
|
|
49
|
+
@property
|
|
50
|
+
def timestamp(self) -> datetime:
|
|
51
|
+
return self.response.timestamp # pragma: no cover
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class DBOSModel(WrapperModel):
|
|
55
|
+
"""A wrapper for Model that integrates with DBOS, turning request and request_stream to DBOS steps."""
|
|
56
|
+
|
|
57
|
+
def __init__(
|
|
58
|
+
self,
|
|
59
|
+
model: Model,
|
|
60
|
+
*,
|
|
61
|
+
step_name_prefix: str,
|
|
62
|
+
step_config: StepConfig,
|
|
63
|
+
event_stream_handler: EventStreamHandler[Any] | None = None,
|
|
64
|
+
):
|
|
65
|
+
super().__init__(model)
|
|
66
|
+
self.step_config = step_config
|
|
67
|
+
self.event_stream_handler = event_stream_handler
|
|
68
|
+
self._step_name_prefix = step_name_prefix
|
|
69
|
+
|
|
70
|
+
# Wrap the request in a DBOS step.
|
|
71
|
+
@DBOS.step(
|
|
72
|
+
name=f'{self._step_name_prefix}__model.request',
|
|
73
|
+
**self.step_config,
|
|
74
|
+
)
|
|
75
|
+
async def wrapped_request_step(
|
|
76
|
+
messages: list[ModelMessage],
|
|
77
|
+
model_settings: ModelSettings | None,
|
|
78
|
+
model_request_parameters: ModelRequestParameters,
|
|
79
|
+
) -> ModelResponse:
|
|
80
|
+
return await super(DBOSModel, self).request(messages, model_settings, model_request_parameters)
|
|
81
|
+
|
|
82
|
+
self._dbos_wrapped_request_step = wrapped_request_step
|
|
83
|
+
|
|
84
|
+
# Wrap the request_stream in a DBOS step.
|
|
85
|
+
@DBOS.step(
|
|
86
|
+
name=f'{self._step_name_prefix}__model.request_stream',
|
|
87
|
+
**self.step_config,
|
|
88
|
+
)
|
|
89
|
+
async def wrapped_request_stream_step(
|
|
90
|
+
messages: list[ModelMessage],
|
|
91
|
+
model_settings: ModelSettings | None,
|
|
92
|
+
model_request_parameters: ModelRequestParameters,
|
|
93
|
+
run_context: RunContext[Any] | None = None,
|
|
94
|
+
) -> ModelResponse:
|
|
95
|
+
async with super(DBOSModel, self).request_stream(
|
|
96
|
+
messages, model_settings, model_request_parameters, run_context
|
|
97
|
+
) as streamed_response:
|
|
98
|
+
if self.event_stream_handler is not None:
|
|
99
|
+
assert run_context is not None, (
|
|
100
|
+
'A DBOS model cannot be used with `pydantic_ai.direct.model_request_stream()` as it requires a `run_context`. Set an `event_stream_handler` on the agent and use `agent.run()` instead.'
|
|
101
|
+
)
|
|
102
|
+
await self.event_stream_handler(run_context, streamed_response)
|
|
103
|
+
|
|
104
|
+
async for _ in streamed_response:
|
|
105
|
+
pass
|
|
106
|
+
return streamed_response.get()
|
|
107
|
+
|
|
108
|
+
self._dbos_wrapped_request_stream_step = wrapped_request_stream_step
|
|
109
|
+
|
|
110
|
+
async def request(
|
|
111
|
+
self,
|
|
112
|
+
messages: list[ModelMessage],
|
|
113
|
+
model_settings: ModelSettings | None,
|
|
114
|
+
model_request_parameters: ModelRequestParameters,
|
|
115
|
+
) -> ModelResponse:
|
|
116
|
+
return await self._dbos_wrapped_request_step(messages, model_settings, model_request_parameters)
|
|
117
|
+
|
|
118
|
+
@asynccontextmanager
|
|
119
|
+
async def request_stream(
|
|
120
|
+
self,
|
|
121
|
+
messages: list[ModelMessage],
|
|
122
|
+
model_settings: ModelSettings | None,
|
|
123
|
+
model_request_parameters: ModelRequestParameters,
|
|
124
|
+
run_context: RunContext[Any] | None = None,
|
|
125
|
+
) -> AsyncIterator[StreamedResponse]:
|
|
126
|
+
# If not in a workflow (could be in a step), just call the wrapped request_stream method.
|
|
127
|
+
if DBOS.workflow_id is None or DBOS.step_id is not None:
|
|
128
|
+
async with super().request_stream(
|
|
129
|
+
messages, model_settings, model_request_parameters, run_context
|
|
130
|
+
) as streamed_response:
|
|
131
|
+
yield streamed_response
|
|
132
|
+
return
|
|
133
|
+
|
|
134
|
+
response = await self._dbos_wrapped_request_stream_step(
|
|
135
|
+
messages, model_settings, model_request_parameters, run_context
|
|
136
|
+
)
|
|
137
|
+
yield DBOSStreamedResponse(model_request_parameters, response)
|
|
@@ -21,7 +21,6 @@ from pydantic_ai import (
|
|
|
21
21
|
models,
|
|
22
22
|
usage as _usage,
|
|
23
23
|
)
|
|
24
|
-
from pydantic_ai._run_context import AgentDepsT
|
|
25
24
|
from pydantic_ai.agent import AbstractAgent, AgentRun, AgentRunResult, EventStreamHandler, RunOutputDataT, WrapperAgent
|
|
26
25
|
from pydantic_ai.exceptions import UserError
|
|
27
26
|
from pydantic_ai.models import Model
|
|
@@ -29,6 +28,7 @@ from pydantic_ai.output import OutputDataT, OutputSpec
|
|
|
29
28
|
from pydantic_ai.result import StreamedRunResult
|
|
30
29
|
from pydantic_ai.settings import ModelSettings
|
|
31
30
|
from pydantic_ai.tools import (
|
|
31
|
+
AgentDepsT,
|
|
32
32
|
DeferredToolResults,
|
|
33
33
|
RunContext,
|
|
34
34
|
Tool,
|
pydantic_ai/mcp.py
CHANGED
|
@@ -517,7 +517,7 @@ class MCPServerStdio(MCPServer):
|
|
|
517
517
|
f'args={self.args!r}',
|
|
518
518
|
]
|
|
519
519
|
if self.id:
|
|
520
|
-
repr_args.append(f'id={self.id!r}')
|
|
520
|
+
repr_args.append(f'id={self.id!r}')
|
|
521
521
|
return f'{self.__class__.__name__}({", ".join(repr_args)})'
|
|
522
522
|
|
|
523
523
|
|
pydantic_ai/messages.py
CHANGED
|
@@ -52,6 +52,15 @@ ImageFormat: TypeAlias = Literal['jpeg', 'png', 'gif', 'webp']
|
|
|
52
52
|
DocumentFormat: TypeAlias = Literal['csv', 'doc', 'docx', 'html', 'md', 'pdf', 'txt', 'xls', 'xlsx']
|
|
53
53
|
VideoFormat: TypeAlias = Literal['mkv', 'mov', 'mp4', 'webm', 'flv', 'mpeg', 'mpg', 'wmv', 'three_gp']
|
|
54
54
|
|
|
55
|
+
FinishReason: TypeAlias = Literal[
|
|
56
|
+
'stop',
|
|
57
|
+
'length',
|
|
58
|
+
'content_filter',
|
|
59
|
+
'tool_call',
|
|
60
|
+
'error',
|
|
61
|
+
]
|
|
62
|
+
"""Reason the model finished generating the response, normalized to OpenTelemetry values."""
|
|
63
|
+
|
|
55
64
|
|
|
56
65
|
@dataclass(repr=False)
|
|
57
66
|
class SystemPromptPart:
|
|
@@ -886,7 +895,18 @@ class ThinkingPart:
|
|
|
886
895
|
signature: str | None = None
|
|
887
896
|
"""The signature of the thinking.
|
|
888
897
|
|
|
889
|
-
|
|
898
|
+
Supported by:
|
|
899
|
+
|
|
900
|
+
* Anthropic (corresponds to the `signature` field)
|
|
901
|
+
* Bedrock (corresponds to the `signature` field)
|
|
902
|
+
* Google (corresponds to the `thought_signature` field)
|
|
903
|
+
* OpenAI (corresponds to the `encrypted_content` field)
|
|
904
|
+
"""
|
|
905
|
+
|
|
906
|
+
provider_name: str | None = None
|
|
907
|
+
"""The name of the provider that generated the response.
|
|
908
|
+
|
|
909
|
+
Signatures are only sent back to the same provider.
|
|
890
910
|
"""
|
|
891
911
|
|
|
892
912
|
part_kind: Literal['thinking'] = 'thinking'
|
|
@@ -971,7 +991,10 @@ class BuiltinToolCallPart(BaseToolCallPart):
|
|
|
971
991
|
_: KW_ONLY
|
|
972
992
|
|
|
973
993
|
provider_name: str | None = None
|
|
974
|
-
"""The name of the provider that generated the response.
|
|
994
|
+
"""The name of the provider that generated the response.
|
|
995
|
+
|
|
996
|
+
Built-in tool calls are only sent back to the same provider.
|
|
997
|
+
"""
|
|
975
998
|
|
|
976
999
|
part_kind: Literal['builtin-tool-call'] = 'builtin-tool-call'
|
|
977
1000
|
"""Part type identifier, this is available on all parts as a discriminator."""
|
|
@@ -1032,6 +1055,9 @@ class ModelResponse:
|
|
|
1032
1055
|
] = None
|
|
1033
1056
|
"""request ID as specified by the model provider. This can be used to track the specific request to the model."""
|
|
1034
1057
|
|
|
1058
|
+
finish_reason: FinishReason | None = None
|
|
1059
|
+
"""Reason the model finished generating the response, normalized to OpenTelemetry values."""
|
|
1060
|
+
|
|
1035
1061
|
@deprecated('`price` is deprecated, use `cost` instead')
|
|
1036
1062
|
def price(self) -> genai_types.PriceCalculation: # pragma: no cover
|
|
1037
1063
|
return self.cost()
|
|
@@ -1186,6 +1212,12 @@ class ThinkingPartDelta:
|
|
|
1186
1212
|
Note this is never treated as a delta — it can replace None.
|
|
1187
1213
|
"""
|
|
1188
1214
|
|
|
1215
|
+
provider_name: str | None = None
|
|
1216
|
+
"""Optional provider name for the thinking part.
|
|
1217
|
+
|
|
1218
|
+
Signatures are only sent back to the same provider.
|
|
1219
|
+
"""
|
|
1220
|
+
|
|
1189
1221
|
part_delta_kind: Literal['thinking'] = 'thinking'
|
|
1190
1222
|
"""Part delta type identifier, used as a discriminator."""
|
|
1191
1223
|
|
|
@@ -1210,14 +1242,18 @@ class ThinkingPartDelta:
|
|
|
1210
1242
|
if isinstance(part, ThinkingPart):
|
|
1211
1243
|
new_content = part.content + self.content_delta if self.content_delta else part.content
|
|
1212
1244
|
new_signature = self.signature_delta if self.signature_delta is not None else part.signature
|
|
1213
|
-
|
|
1245
|
+
new_provider_name = self.provider_name if self.provider_name is not None else part.provider_name
|
|
1246
|
+
return replace(part, content=new_content, signature=new_signature, provider_name=new_provider_name)
|
|
1214
1247
|
elif isinstance(part, ThinkingPartDelta):
|
|
1215
1248
|
if self.content_delta is None and self.signature_delta is None:
|
|
1216
1249
|
raise ValueError('Cannot apply ThinkingPartDelta with no content or signature')
|
|
1217
|
-
if self.signature_delta is not None:
|
|
1218
|
-
return replace(part, signature_delta=self.signature_delta)
|
|
1219
1250
|
if self.content_delta is not None:
|
|
1220
|
-
|
|
1251
|
+
part = replace(part, content_delta=(part.content_delta or '') + self.content_delta)
|
|
1252
|
+
if self.signature_delta is not None:
|
|
1253
|
+
part = replace(part, signature_delta=self.signature_delta)
|
|
1254
|
+
if self.provider_name is not None:
|
|
1255
|
+
part = replace(part, provider_name=self.provider_name)
|
|
1256
|
+
return part
|
|
1221
1257
|
raise ValueError( # pragma: no cover
|
|
1222
1258
|
f'Cannot apply ThinkingPartDeltas to non-ThinkingParts or non-ThinkingPartDeltas ({part=}, {self=})'
|
|
1223
1259
|
)
|
pydantic_ai/models/__init__.py
CHANGED
|
@@ -28,6 +28,7 @@ from ..exceptions import UserError
|
|
|
28
28
|
from ..messages import (
|
|
29
29
|
FileUrl,
|
|
30
30
|
FinalResultEvent,
|
|
31
|
+
FinishReason,
|
|
31
32
|
ModelMessage,
|
|
32
33
|
ModelRequest,
|
|
33
34
|
ModelResponse,
|
|
@@ -555,6 +556,10 @@ class StreamedResponse(ABC):
|
|
|
555
556
|
|
|
556
557
|
final_result_event: FinalResultEvent | None = field(default=None, init=False)
|
|
557
558
|
|
|
559
|
+
provider_response_id: str | None = field(default=None, init=False)
|
|
560
|
+
provider_details: dict[str, Any] | None = field(default=None, init=False)
|
|
561
|
+
finish_reason: FinishReason | None = field(default=None, init=False)
|
|
562
|
+
|
|
558
563
|
_parts_manager: ModelResponsePartsManager = field(default_factory=ModelResponsePartsManager, init=False)
|
|
559
564
|
_event_iterator: AsyncIterator[ModelResponseStreamEvent] | None = field(default=None, init=False)
|
|
560
565
|
_usage: RequestUsage = field(default_factory=RequestUsage, init=False)
|
|
@@ -609,6 +614,9 @@ class StreamedResponse(ABC):
|
|
|
609
614
|
timestamp=self.timestamp,
|
|
610
615
|
usage=self.usage(),
|
|
611
616
|
provider_name=self.provider_name,
|
|
617
|
+
provider_response_id=self.provider_response_id,
|
|
618
|
+
provider_details=self.provider_details,
|
|
619
|
+
finish_reason=self.finish_reason,
|
|
612
620
|
)
|
|
613
621
|
|
|
614
622
|
def usage(self) -> RequestUsage:
|
pydantic_ai/models/anthropic.py
CHANGED
|
@@ -1,11 +1,10 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
3
|
import io
|
|
4
|
-
import warnings
|
|
5
4
|
from collections.abc import AsyncGenerator, AsyncIterable, AsyncIterator
|
|
6
5
|
from contextlib import asynccontextmanager
|
|
7
6
|
from dataclasses import dataclass, field
|
|
8
|
-
from datetime import datetime
|
|
7
|
+
from datetime import datetime
|
|
9
8
|
from typing import Any, Literal, cast, overload
|
|
10
9
|
|
|
11
10
|
from typing_extensions import assert_never
|
|
@@ -21,6 +20,7 @@ from ..messages import (
|
|
|
21
20
|
BuiltinToolCallPart,
|
|
22
21
|
BuiltinToolReturnPart,
|
|
23
22
|
DocumentUrl,
|
|
23
|
+
FinishReason,
|
|
24
24
|
ImageUrl,
|
|
25
25
|
ModelMessage,
|
|
26
26
|
ModelRequest,
|
|
@@ -42,6 +42,16 @@ from ..settings import ModelSettings
|
|
|
42
42
|
from ..tools import ToolDefinition
|
|
43
43
|
from . import Model, ModelRequestParameters, StreamedResponse, check_allow_model_requests, download_item, get_user_agent
|
|
44
44
|
|
|
45
|
+
_FINISH_REASON_MAP: dict[BetaStopReason, FinishReason] = {
|
|
46
|
+
'end_turn': 'stop',
|
|
47
|
+
'max_tokens': 'length',
|
|
48
|
+
'stop_sequence': 'stop',
|
|
49
|
+
'tool_use': 'tool_call',
|
|
50
|
+
'pause_turn': 'stop',
|
|
51
|
+
'refusal': 'content_filter',
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
|
|
45
55
|
try:
|
|
46
56
|
from anthropic import NOT_GIVEN, APIStatusError, AsyncStream
|
|
47
57
|
from anthropic.types.beta import (
|
|
@@ -67,9 +77,11 @@ try:
|
|
|
67
77
|
BetaRawMessageStopEvent,
|
|
68
78
|
BetaRawMessageStreamEvent,
|
|
69
79
|
BetaRedactedThinkingBlock,
|
|
80
|
+
BetaRedactedThinkingBlockParam,
|
|
70
81
|
BetaServerToolUseBlock,
|
|
71
82
|
BetaServerToolUseBlockParam,
|
|
72
83
|
BetaSignatureDelta,
|
|
84
|
+
BetaStopReason,
|
|
73
85
|
BetaTextBlock,
|
|
74
86
|
BetaTextBlockParam,
|
|
75
87
|
BetaTextDelta,
|
|
@@ -293,7 +305,7 @@ class AnthropicModel(Model):
|
|
|
293
305
|
elif isinstance(item, BetaWebSearchToolResultBlock | BetaCodeExecutionToolResultBlock):
|
|
294
306
|
items.append(
|
|
295
307
|
BuiltinToolReturnPart(
|
|
296
|
-
provider_name=
|
|
308
|
+
provider_name=self.system,
|
|
297
309
|
tool_name=item.type,
|
|
298
310
|
content=item.content,
|
|
299
311
|
tool_call_id=item.tool_use_id,
|
|
@@ -302,20 +314,18 @@ class AnthropicModel(Model):
|
|
|
302
314
|
elif isinstance(item, BetaServerToolUseBlock):
|
|
303
315
|
items.append(
|
|
304
316
|
BuiltinToolCallPart(
|
|
305
|
-
provider_name=
|
|
317
|
+
provider_name=self.system,
|
|
306
318
|
tool_name=item.name,
|
|
307
319
|
args=cast(dict[str, Any], item.input),
|
|
308
320
|
tool_call_id=item.id,
|
|
309
321
|
)
|
|
310
322
|
)
|
|
311
|
-
elif isinstance(item, BetaRedactedThinkingBlock):
|
|
312
|
-
|
|
313
|
-
'
|
|
314
|
-
'If you have a suggestion on how we should handle them, please open an issue.',
|
|
315
|
-
UserWarning,
|
|
323
|
+
elif isinstance(item, BetaRedactedThinkingBlock):
|
|
324
|
+
items.append(
|
|
325
|
+
ThinkingPart(id='redacted_thinking', content='', signature=item.data, provider_name=self.system)
|
|
316
326
|
)
|
|
317
327
|
elif isinstance(item, BetaThinkingBlock):
|
|
318
|
-
items.append(ThinkingPart(content=item.thinking, signature=item.signature))
|
|
328
|
+
items.append(ThinkingPart(content=item.thinking, signature=item.signature, provider_name=self.system))
|
|
319
329
|
else:
|
|
320
330
|
assert isinstance(item, BetaToolUseBlock), f'unexpected item type {type(item)}'
|
|
321
331
|
items.append(
|
|
@@ -326,12 +336,20 @@ class AnthropicModel(Model):
|
|
|
326
336
|
)
|
|
327
337
|
)
|
|
328
338
|
|
|
339
|
+
finish_reason: FinishReason | None = None
|
|
340
|
+
provider_details: dict[str, Any] | None = None
|
|
341
|
+
if raw_finish_reason := response.stop_reason: # pragma: no branch
|
|
342
|
+
provider_details = {'finish_reason': raw_finish_reason}
|
|
343
|
+
finish_reason = _FINISH_REASON_MAP.get(raw_finish_reason)
|
|
344
|
+
|
|
329
345
|
return ModelResponse(
|
|
330
346
|
parts=items,
|
|
331
347
|
usage=_map_usage(response),
|
|
332
348
|
model_name=response.model,
|
|
333
349
|
provider_response_id=response.id,
|
|
334
350
|
provider_name=self._provider.name,
|
|
351
|
+
finish_reason=finish_reason,
|
|
352
|
+
provider_details=provider_details,
|
|
335
353
|
)
|
|
336
354
|
|
|
337
355
|
async def _process_streamed_response(
|
|
@@ -342,13 +360,13 @@ class AnthropicModel(Model):
|
|
|
342
360
|
if isinstance(first_chunk, _utils.Unset):
|
|
343
361
|
raise UnexpectedModelBehavior('Streamed response ended without content or tool calls') # pragma: no cover
|
|
344
362
|
|
|
345
|
-
|
|
346
|
-
|
|
363
|
+
assert isinstance(first_chunk, BetaRawMessageStartEvent)
|
|
364
|
+
|
|
347
365
|
return AnthropicStreamedResponse(
|
|
348
366
|
model_request_parameters=model_request_parameters,
|
|
349
|
-
_model_name=
|
|
367
|
+
_model_name=first_chunk.message.model,
|
|
350
368
|
_response=peekable_response,
|
|
351
|
-
_timestamp=
|
|
369
|
+
_timestamp=_utils.now_utc(),
|
|
352
370
|
_provider_name=self._provider.name,
|
|
353
371
|
)
|
|
354
372
|
|
|
@@ -425,6 +443,7 @@ class AnthropicModel(Model):
|
|
|
425
443
|
| BetaWebSearchToolResultBlockParam
|
|
426
444
|
| BetaCodeExecutionToolResultBlockParam
|
|
427
445
|
| BetaThinkingBlockParam
|
|
446
|
+
| BetaRedactedThinkingBlockParam
|
|
428
447
|
] = []
|
|
429
448
|
for response_part in m.parts:
|
|
430
449
|
if isinstance(response_part, TextPart):
|
|
@@ -439,15 +458,33 @@ class AnthropicModel(Model):
|
|
|
439
458
|
)
|
|
440
459
|
assistant_content_params.append(tool_use_block_param)
|
|
441
460
|
elif isinstance(response_part, ThinkingPart):
|
|
442
|
-
|
|
443
|
-
|
|
461
|
+
if (
|
|
462
|
+
response_part.provider_name == self.system and response_part.signature is not None
|
|
463
|
+
): # pragma: no branch
|
|
464
|
+
if response_part.id == 'redacted_thinking':
|
|
465
|
+
assistant_content_params.append(
|
|
466
|
+
BetaRedactedThinkingBlockParam(
|
|
467
|
+
data=response_part.signature,
|
|
468
|
+
type='redacted_thinking',
|
|
469
|
+
)
|
|
470
|
+
)
|
|
471
|
+
else:
|
|
472
|
+
assistant_content_params.append(
|
|
473
|
+
BetaThinkingBlockParam(
|
|
474
|
+
thinking=response_part.content,
|
|
475
|
+
signature=response_part.signature,
|
|
476
|
+
type='thinking',
|
|
477
|
+
)
|
|
478
|
+
)
|
|
479
|
+
elif response_part.content: # pragma: no branch
|
|
480
|
+
start_tag, end_tag = self.profile.thinking_tags
|
|
444
481
|
assistant_content_params.append(
|
|
445
|
-
|
|
446
|
-
|
|
482
|
+
BetaTextBlockParam(
|
|
483
|
+
text='\n'.join([start_tag, response_part.content, end_tag]), type='text'
|
|
447
484
|
)
|
|
448
485
|
)
|
|
449
486
|
elif isinstance(response_part, BuiltinToolCallPart):
|
|
450
|
-
if response_part.provider_name ==
|
|
487
|
+
if response_part.provider_name == self.system:
|
|
451
488
|
server_tool_use_block_param = BetaServerToolUseBlockParam(
|
|
452
489
|
id=_guard_tool_call_id(t=response_part),
|
|
453
490
|
type='server_tool_use',
|
|
@@ -456,7 +493,7 @@ class AnthropicModel(Model):
|
|
|
456
493
|
)
|
|
457
494
|
assistant_content_params.append(server_tool_use_block_param)
|
|
458
495
|
elif isinstance(response_part, BuiltinToolReturnPart):
|
|
459
|
-
if response_part.provider_name ==
|
|
496
|
+
if response_part.provider_name == self.system:
|
|
460
497
|
tool_use_id = _guard_tool_call_id(t=response_part)
|
|
461
498
|
if response_part.tool_name == 'web_search_tool_result':
|
|
462
499
|
server_tool_result_block_param = BetaWebSearchToolResultBlockParam(
|
|
@@ -583,20 +620,30 @@ class AnthropicStreamedResponse(StreamedResponse):
|
|
|
583
620
|
async for event in self._response:
|
|
584
621
|
if isinstance(event, BetaRawMessageStartEvent):
|
|
585
622
|
self._usage = _map_usage(event)
|
|
623
|
+
self.provider_response_id = event.message.id
|
|
586
624
|
|
|
587
625
|
elif isinstance(event, BetaRawContentBlockStartEvent):
|
|
588
626
|
current_block = event.content_block
|
|
589
627
|
if isinstance(current_block, BetaTextBlock) and current_block.text:
|
|
590
628
|
maybe_event = self._parts_manager.handle_text_delta(
|
|
591
|
-
vendor_part_id=
|
|
629
|
+
vendor_part_id=event.index, content=current_block.text
|
|
592
630
|
)
|
|
593
631
|
if maybe_event is not None: # pragma: no branch
|
|
594
632
|
yield maybe_event
|
|
595
633
|
elif isinstance(current_block, BetaThinkingBlock):
|
|
596
634
|
yield self._parts_manager.handle_thinking_delta(
|
|
597
|
-
vendor_part_id=
|
|
635
|
+
vendor_part_id=event.index,
|
|
598
636
|
content=current_block.thinking,
|
|
599
637
|
signature=current_block.signature,
|
|
638
|
+
provider_name=self.provider_name,
|
|
639
|
+
)
|
|
640
|
+
elif isinstance(current_block, BetaRedactedThinkingBlock):
|
|
641
|
+
yield self._parts_manager.handle_thinking_delta(
|
|
642
|
+
vendor_part_id=event.index,
|
|
643
|
+
id='redacted_thinking',
|
|
644
|
+
content='',
|
|
645
|
+
signature=current_block.data,
|
|
646
|
+
provider_name=self.provider_name,
|
|
600
647
|
)
|
|
601
648
|
elif isinstance(current_block, BetaToolUseBlock):
|
|
602
649
|
maybe_event = self._parts_manager.handle_tool_call_delta(
|
|
@@ -613,17 +660,21 @@ class AnthropicStreamedResponse(StreamedResponse):
|
|
|
613
660
|
elif isinstance(event, BetaRawContentBlockDeltaEvent):
|
|
614
661
|
if isinstance(event.delta, BetaTextDelta):
|
|
615
662
|
maybe_event = self._parts_manager.handle_text_delta(
|
|
616
|
-
vendor_part_id=
|
|
663
|
+
vendor_part_id=event.index, content=event.delta.text
|
|
617
664
|
)
|
|
618
665
|
if maybe_event is not None: # pragma: no branch
|
|
619
666
|
yield maybe_event
|
|
620
667
|
elif isinstance(event.delta, BetaThinkingDelta):
|
|
621
668
|
yield self._parts_manager.handle_thinking_delta(
|
|
622
|
-
vendor_part_id=
|
|
669
|
+
vendor_part_id=event.index,
|
|
670
|
+
content=event.delta.thinking,
|
|
671
|
+
provider_name=self.provider_name,
|
|
623
672
|
)
|
|
624
673
|
elif isinstance(event.delta, BetaSignatureDelta):
|
|
625
674
|
yield self._parts_manager.handle_thinking_delta(
|
|
626
|
-
vendor_part_id=
|
|
675
|
+
vendor_part_id=event.index,
|
|
676
|
+
signature=event.delta.signature,
|
|
677
|
+
provider_name=self.provider_name,
|
|
627
678
|
)
|
|
628
679
|
elif (
|
|
629
680
|
current_block
|
|
@@ -646,6 +697,9 @@ class AnthropicStreamedResponse(StreamedResponse):
|
|
|
646
697
|
|
|
647
698
|
elif isinstance(event, BetaRawMessageDeltaEvent):
|
|
648
699
|
self._usage = _map_usage(event)
|
|
700
|
+
if raw_finish_reason := event.delta.stop_reason: # pragma: no branch
|
|
701
|
+
self.provider_details = {'finish_reason': raw_finish_reason}
|
|
702
|
+
self.finish_reason = _FINISH_REASON_MAP.get(raw_finish_reason)
|
|
649
703
|
|
|
650
704
|
elif isinstance(event, BetaRawContentBlockStopEvent | BetaRawMessageStopEvent): # pragma: no branch
|
|
651
705
|
current_block = None
|