pydantic-ai-slim 0.0.25__tar.gz → 0.0.27__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- {pydantic_ai_slim-0.0.25 → pydantic_ai_slim-0.0.27}/PKG-INFO +3 -2
- {pydantic_ai_slim-0.0.25 → pydantic_ai_slim-0.0.27}/pydantic_ai/__init__.py +22 -4
- {pydantic_ai_slim-0.0.25 → pydantic_ai_slim-0.0.27}/pydantic_ai/_agent_graph.py +15 -12
- {pydantic_ai_slim-0.0.25 → pydantic_ai_slim-0.0.27}/pydantic_ai/agent.py +13 -13
- {pydantic_ai_slim-0.0.25 → pydantic_ai_slim-0.0.27}/pydantic_ai/exceptions.py +42 -1
- {pydantic_ai_slim-0.0.25 → pydantic_ai_slim-0.0.27}/pydantic_ai/messages.py +90 -1
- {pydantic_ai_slim-0.0.25 → pydantic_ai_slim-0.0.27}/pydantic_ai/models/anthropic.py +58 -28
- {pydantic_ai_slim-0.0.25 → pydantic_ai_slim-0.0.27}/pydantic_ai/models/cohere.py +22 -13
- pydantic_ai_slim-0.0.27/pydantic_ai/models/fallback.py +116 -0
- {pydantic_ai_slim-0.0.25 → pydantic_ai_slim-0.0.27}/pydantic_ai/models/function.py +28 -10
- {pydantic_ai_slim-0.0.25 → pydantic_ai_slim-0.0.27}/pydantic_ai/models/gemini.py +78 -10
- {pydantic_ai_slim-0.0.25 → pydantic_ai_slim-0.0.27}/pydantic_ai/models/groq.py +59 -27
- {pydantic_ai_slim-0.0.25 → pydantic_ai_slim-0.0.27}/pydantic_ai/models/mistral.py +50 -15
- {pydantic_ai_slim-0.0.25 → pydantic_ai_slim-0.0.27}/pydantic_ai/models/openai.py +84 -30
- {pydantic_ai_slim-0.0.25 → pydantic_ai_slim-0.0.27}/pydantic_ai/tools.py +2 -2
- {pydantic_ai_slim-0.0.25 → pydantic_ai_slim-0.0.27}/pyproject.toml +4 -5
- {pydantic_ai_slim-0.0.25 → pydantic_ai_slim-0.0.27}/.gitignore +0 -0
- {pydantic_ai_slim-0.0.25 → pydantic_ai_slim-0.0.27}/README.md +0 -0
- {pydantic_ai_slim-0.0.25 → pydantic_ai_slim-0.0.27}/pydantic_ai/_griffe.py +0 -0
- {pydantic_ai_slim-0.0.25 → pydantic_ai_slim-0.0.27}/pydantic_ai/_parts_manager.py +0 -0
- {pydantic_ai_slim-0.0.25 → pydantic_ai_slim-0.0.27}/pydantic_ai/_pydantic.py +0 -0
- {pydantic_ai_slim-0.0.25 → pydantic_ai_slim-0.0.27}/pydantic_ai/_result.py +0 -0
- {pydantic_ai_slim-0.0.25 → pydantic_ai_slim-0.0.27}/pydantic_ai/_system_prompt.py +0 -0
- {pydantic_ai_slim-0.0.25 → pydantic_ai_slim-0.0.27}/pydantic_ai/_utils.py +0 -0
- {pydantic_ai_slim-0.0.25 → pydantic_ai_slim-0.0.27}/pydantic_ai/format_as_xml.py +0 -0
- {pydantic_ai_slim-0.0.25 → pydantic_ai_slim-0.0.27}/pydantic_ai/models/__init__.py +0 -0
- {pydantic_ai_slim-0.0.25 → pydantic_ai_slim-0.0.27}/pydantic_ai/models/instrumented.py +0 -0
- {pydantic_ai_slim-0.0.25 → pydantic_ai_slim-0.0.27}/pydantic_ai/models/test.py +0 -0
- {pydantic_ai_slim-0.0.25 → pydantic_ai_slim-0.0.27}/pydantic_ai/models/vertexai.py +0 -0
- {pydantic_ai_slim-0.0.25 → pydantic_ai_slim-0.0.27}/pydantic_ai/models/wrapper.py +0 -0
- {pydantic_ai_slim-0.0.25 → pydantic_ai_slim-0.0.27}/pydantic_ai/py.typed +0 -0
- {pydantic_ai_slim-0.0.25 → pydantic_ai_slim-0.0.27}/pydantic_ai/result.py +0 -0
- {pydantic_ai_slim-0.0.25 → pydantic_ai_slim-0.0.27}/pydantic_ai/settings.py +0 -0
- {pydantic_ai_slim-0.0.25 → pydantic_ai_slim-0.0.27}/pydantic_ai/usage.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: pydantic-ai-slim
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.27
|
|
4
4
|
Summary: Agent Framework / shim to use Pydantic with LLMs, slim package
|
|
5
5
|
Author-email: Samuel Colvin <samuel@pydantic.dev>
|
|
6
6
|
License-Expression: MIT
|
|
@@ -25,10 +25,11 @@ Classifier: Topic :: Internet
|
|
|
25
25
|
Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
|
26
26
|
Requires-Python: >=3.9
|
|
27
27
|
Requires-Dist: eval-type-backport>=0.2.0
|
|
28
|
+
Requires-Dist: exceptiongroup; python_version < '3.11'
|
|
28
29
|
Requires-Dist: griffe>=1.3.2
|
|
29
30
|
Requires-Dist: httpx>=0.27
|
|
30
31
|
Requires-Dist: logfire-api>=1.2.0
|
|
31
|
-
Requires-Dist: pydantic-graph==0.0.
|
|
32
|
+
Requires-Dist: pydantic-graph==0.0.27
|
|
32
33
|
Requires-Dist: pydantic>=2.10
|
|
33
34
|
Provides-Extra: anthropic
|
|
34
35
|
Requires-Dist: anthropic>=0.40.0; extra == 'anthropic'
|
|
@@ -1,23 +1,41 @@
|
|
|
1
1
|
from importlib.metadata import version
|
|
2
2
|
|
|
3
3
|
from .agent import Agent, EndStrategy, HandleResponseNode, ModelRequestNode, UserPromptNode, capture_run_messages
|
|
4
|
-
from .exceptions import
|
|
4
|
+
from .exceptions import (
|
|
5
|
+
AgentRunError,
|
|
6
|
+
FallbackExceptionGroup,
|
|
7
|
+
ModelHTTPError,
|
|
8
|
+
ModelRetry,
|
|
9
|
+
UnexpectedModelBehavior,
|
|
10
|
+
UsageLimitExceeded,
|
|
11
|
+
UserError,
|
|
12
|
+
)
|
|
13
|
+
from .messages import AudioUrl, BinaryContent, ImageUrl
|
|
5
14
|
from .tools import RunContext, Tool
|
|
6
15
|
|
|
7
16
|
__all__ = (
|
|
17
|
+
'__version__',
|
|
18
|
+
# agent
|
|
8
19
|
'Agent',
|
|
9
20
|
'EndStrategy',
|
|
10
21
|
'HandleResponseNode',
|
|
11
22
|
'ModelRequestNode',
|
|
12
23
|
'UserPromptNode',
|
|
13
24
|
'capture_run_messages',
|
|
14
|
-
|
|
15
|
-
'Tool',
|
|
25
|
+
# exceptions
|
|
16
26
|
'AgentRunError',
|
|
17
27
|
'ModelRetry',
|
|
28
|
+
'ModelHTTPError',
|
|
29
|
+
'FallbackExceptionGroup',
|
|
18
30
|
'UnexpectedModelBehavior',
|
|
19
31
|
'UsageLimitExceeded',
|
|
20
32
|
'UserError',
|
|
21
|
-
|
|
33
|
+
# messages
|
|
34
|
+
'ImageUrl',
|
|
35
|
+
'AudioUrl',
|
|
36
|
+
'BinaryContent',
|
|
37
|
+
# tools
|
|
38
|
+
'Tool',
|
|
39
|
+
'RunContext',
|
|
22
40
|
)
|
|
23
41
|
__version__ = version('pydantic_ai_slim')
|
|
@@ -3,7 +3,7 @@ from __future__ import annotations as _annotations
|
|
|
3
3
|
import asyncio
|
|
4
4
|
import dataclasses
|
|
5
5
|
from abc import ABC
|
|
6
|
-
from collections.abc import AsyncIterator, Iterator
|
|
6
|
+
from collections.abc import AsyncIterator, Iterator, Sequence
|
|
7
7
|
from contextlib import asynccontextmanager, contextmanager
|
|
8
8
|
from contextvars import ContextVar
|
|
9
9
|
from dataclasses import field
|
|
@@ -89,7 +89,7 @@ class GraphAgentDeps(Generic[DepsT, ResultDataT]):
|
|
|
89
89
|
|
|
90
90
|
user_deps: DepsT
|
|
91
91
|
|
|
92
|
-
prompt: str
|
|
92
|
+
prompt: str | Sequence[_messages.UserContent]
|
|
93
93
|
new_message_index: int
|
|
94
94
|
|
|
95
95
|
model: models.Model
|
|
@@ -108,20 +108,20 @@ class GraphAgentDeps(Generic[DepsT, ResultDataT]):
|
|
|
108
108
|
|
|
109
109
|
|
|
110
110
|
@dataclasses.dataclass
|
|
111
|
-
class UserPromptNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], NodeRunEndT], ABC):
|
|
112
|
-
user_prompt: str
|
|
111
|
+
class UserPromptNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[NodeRunEndT]], ABC):
|
|
112
|
+
user_prompt: str | Sequence[_messages.UserContent]
|
|
113
113
|
|
|
114
114
|
system_prompts: tuple[str, ...]
|
|
115
115
|
system_prompt_functions: list[_system_prompt.SystemPromptRunner[DepsT]]
|
|
116
116
|
system_prompt_dynamic_functions: dict[str, _system_prompt.SystemPromptRunner[DepsT]]
|
|
117
117
|
|
|
118
118
|
async def run(
|
|
119
|
-
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT,
|
|
119
|
+
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
|
|
120
120
|
) -> ModelRequestNode[DepsT, NodeRunEndT]:
|
|
121
121
|
return ModelRequestNode[DepsT, NodeRunEndT](request=await self._get_first_message(ctx))
|
|
122
122
|
|
|
123
123
|
async def _get_first_message(
|
|
124
|
-
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT,
|
|
124
|
+
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
|
|
125
125
|
) -> _messages.ModelRequest:
|
|
126
126
|
run_context = build_run_context(ctx)
|
|
127
127
|
history, next_message = await self._prepare_messages(self.user_prompt, ctx.state.message_history, run_context)
|
|
@@ -135,7 +135,10 @@ class UserPromptNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], NodeR
|
|
|
135
135
|
return next_message
|
|
136
136
|
|
|
137
137
|
async def _prepare_messages(
|
|
138
|
-
self,
|
|
138
|
+
self,
|
|
139
|
+
user_prompt: str | Sequence[_messages.UserContent],
|
|
140
|
+
message_history: list[_messages.ModelMessage] | None,
|
|
141
|
+
run_context: RunContext[DepsT],
|
|
139
142
|
) -> tuple[list[_messages.ModelMessage], _messages.ModelRequest]:
|
|
140
143
|
try:
|
|
141
144
|
ctx_messages = get_captured_run_messages()
|
|
@@ -212,7 +215,7 @@ async def _prepare_request_parameters(
|
|
|
212
215
|
|
|
213
216
|
|
|
214
217
|
@dataclasses.dataclass
|
|
215
|
-
class ModelRequestNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], NodeRunEndT]):
|
|
218
|
+
class ModelRequestNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[NodeRunEndT]]):
|
|
216
219
|
"""Make a request to the model using the last message in state.message_history."""
|
|
217
220
|
|
|
218
221
|
request: _messages.ModelRequest
|
|
@@ -316,7 +319,7 @@ class ModelRequestNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], Nod
|
|
|
316
319
|
|
|
317
320
|
|
|
318
321
|
@dataclasses.dataclass
|
|
319
|
-
class HandleResponseNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], NodeRunEndT]):
|
|
322
|
+
class HandleResponseNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[NodeRunEndT]]):
|
|
320
323
|
"""Process a model response, and decide whether to end the run or make a new request."""
|
|
321
324
|
|
|
322
325
|
model_response: _messages.ModelResponse
|
|
@@ -338,7 +341,7 @@ class HandleResponseNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], N
|
|
|
338
341
|
|
|
339
342
|
@asynccontextmanager
|
|
340
343
|
async def stream(
|
|
341
|
-
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT,
|
|
344
|
+
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
|
|
342
345
|
) -> AsyncIterator[AsyncIterator[_messages.HandleResponseEvent]]:
|
|
343
346
|
"""Process the model response and yield events for the start and end of each function tool call."""
|
|
344
347
|
with _logfire.span('handle model response', run_step=ctx.state.run_step) as handle_span:
|
|
@@ -363,7 +366,7 @@ class HandleResponseNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], N
|
|
|
363
366
|
handle_span.message = f'handle model response -> {tool_responses_str}'
|
|
364
367
|
|
|
365
368
|
async def _run_stream(
|
|
366
|
-
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT,
|
|
369
|
+
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
|
|
367
370
|
) -> AsyncIterator[_messages.HandleResponseEvent]:
|
|
368
371
|
if self._events_iterator is None:
|
|
369
372
|
# Ensure that the stream is only run once
|
|
@@ -667,7 +670,7 @@ def get_captured_run_messages() -> _RunMessages:
|
|
|
667
670
|
|
|
668
671
|
def build_agent_graph(
|
|
669
672
|
name: str | None, deps_type: type[DepsT], result_type: type[ResultT]
|
|
670
|
-
) -> Graph[GraphAgentState, GraphAgentDeps[DepsT,
|
|
673
|
+
) -> Graph[GraphAgentState, GraphAgentDeps[DepsT, result.FinalResult[ResultT]], result.FinalResult[ResultT]]:
|
|
671
674
|
"""Build the execution [Graph][pydantic_graph.Graph] for a given agent."""
|
|
672
675
|
nodes = (
|
|
673
676
|
UserPromptNode[DepsT],
|
|
@@ -220,7 +220,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
220
220
|
@overload
|
|
221
221
|
async def run(
|
|
222
222
|
self,
|
|
223
|
-
user_prompt: str,
|
|
223
|
+
user_prompt: str | Sequence[_messages.UserContent],
|
|
224
224
|
*,
|
|
225
225
|
result_type: None = None,
|
|
226
226
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
@@ -235,7 +235,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
235
235
|
@overload
|
|
236
236
|
async def run(
|
|
237
237
|
self,
|
|
238
|
-
user_prompt: str,
|
|
238
|
+
user_prompt: str | Sequence[_messages.UserContent],
|
|
239
239
|
*,
|
|
240
240
|
result_type: type[RunResultDataT],
|
|
241
241
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
@@ -249,7 +249,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
249
249
|
|
|
250
250
|
async def run(
|
|
251
251
|
self,
|
|
252
|
-
user_prompt: str,
|
|
252
|
+
user_prompt: str | Sequence[_messages.UserContent],
|
|
253
253
|
*,
|
|
254
254
|
result_type: type[RunResultDataT] | None = None,
|
|
255
255
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
@@ -313,7 +313,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
313
313
|
@contextmanager
|
|
314
314
|
def iter(
|
|
315
315
|
self,
|
|
316
|
-
user_prompt: str,
|
|
316
|
+
user_prompt: str | Sequence[_messages.UserContent],
|
|
317
317
|
*,
|
|
318
318
|
result_type: type[RunResultDataT] | None = None,
|
|
319
319
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
@@ -365,7 +365,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
365
365
|
HandleResponseNode(
|
|
366
366
|
model_response=ModelResponse(
|
|
367
367
|
parts=[TextPart(content='Paris', part_kind='text')],
|
|
368
|
-
model_name='
|
|
368
|
+
model_name='gpt-4o',
|
|
369
369
|
timestamp=datetime.datetime(...),
|
|
370
370
|
kind='response',
|
|
371
371
|
)
|
|
@@ -466,7 +466,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
466
466
|
@overload
|
|
467
467
|
def run_sync(
|
|
468
468
|
self,
|
|
469
|
-
user_prompt: str,
|
|
469
|
+
user_prompt: str | Sequence[_messages.UserContent],
|
|
470
470
|
*,
|
|
471
471
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
472
472
|
model: models.Model | models.KnownModelName | None = None,
|
|
@@ -480,7 +480,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
480
480
|
@overload
|
|
481
481
|
def run_sync(
|
|
482
482
|
self,
|
|
483
|
-
user_prompt: str,
|
|
483
|
+
user_prompt: str | Sequence[_messages.UserContent],
|
|
484
484
|
*,
|
|
485
485
|
result_type: type[RunResultDataT] | None,
|
|
486
486
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
@@ -494,7 +494,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
494
494
|
|
|
495
495
|
def run_sync(
|
|
496
496
|
self,
|
|
497
|
-
user_prompt: str,
|
|
497
|
+
user_prompt: str | Sequence[_messages.UserContent],
|
|
498
498
|
*,
|
|
499
499
|
result_type: type[RunResultDataT] | None = None,
|
|
500
500
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
@@ -555,7 +555,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
555
555
|
@overload
|
|
556
556
|
def run_stream(
|
|
557
557
|
self,
|
|
558
|
-
user_prompt: str,
|
|
558
|
+
user_prompt: str | Sequence[_messages.UserContent],
|
|
559
559
|
*,
|
|
560
560
|
result_type: None = None,
|
|
561
561
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
@@ -570,7 +570,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
570
570
|
@overload
|
|
571
571
|
def run_stream(
|
|
572
572
|
self,
|
|
573
|
-
user_prompt: str,
|
|
573
|
+
user_prompt: str | Sequence[_messages.UserContent],
|
|
574
574
|
*,
|
|
575
575
|
result_type: type[RunResultDataT],
|
|
576
576
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
@@ -585,7 +585,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
585
585
|
@asynccontextmanager
|
|
586
586
|
async def run_stream( # noqa C901
|
|
587
587
|
self,
|
|
588
|
-
user_prompt: str,
|
|
588
|
+
user_prompt: str | Sequence[_messages.UserContent],
|
|
589
589
|
*,
|
|
590
590
|
result_type: type[RunResultDataT] | None = None,
|
|
591
591
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
@@ -1214,7 +1214,7 @@ class AgentRun(Generic[AgentDepsT, ResultDataT]):
|
|
|
1214
1214
|
HandleResponseNode(
|
|
1215
1215
|
model_response=ModelResponse(
|
|
1216
1216
|
parts=[TextPart(content='Paris', part_kind='text')],
|
|
1217
|
-
model_name='
|
|
1217
|
+
model_name='gpt-4o',
|
|
1218
1218
|
timestamp=datetime.datetime(...),
|
|
1219
1219
|
kind='response',
|
|
1220
1220
|
)
|
|
@@ -1357,7 +1357,7 @@ class AgentRun(Generic[AgentDepsT, ResultDataT]):
|
|
|
1357
1357
|
HandleResponseNode(
|
|
1358
1358
|
model_response=ModelResponse(
|
|
1359
1359
|
parts=[TextPart(content='Paris', part_kind='text')],
|
|
1360
|
-
model_name='
|
|
1360
|
+
model_name='gpt-4o',
|
|
1361
1361
|
timestamp=datetime.datetime(...),
|
|
1362
1362
|
kind='response',
|
|
1363
1363
|
)
|
|
@@ -1,8 +1,22 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
3
|
import json
|
|
4
|
+
import sys
|
|
4
5
|
|
|
5
|
-
|
|
6
|
+
if sys.version_info < (3, 11):
|
|
7
|
+
from exceptiongroup import ExceptionGroup
|
|
8
|
+
else:
|
|
9
|
+
ExceptionGroup = ExceptionGroup
|
|
10
|
+
|
|
11
|
+
__all__ = (
|
|
12
|
+
'ModelRetry',
|
|
13
|
+
'UserError',
|
|
14
|
+
'AgentRunError',
|
|
15
|
+
'UnexpectedModelBehavior',
|
|
16
|
+
'UsageLimitExceeded',
|
|
17
|
+
'ModelHTTPError',
|
|
18
|
+
'FallbackExceptionGroup',
|
|
19
|
+
)
|
|
6
20
|
|
|
7
21
|
|
|
8
22
|
class ModelRetry(Exception):
|
|
@@ -72,3 +86,30 @@ class UnexpectedModelBehavior(AgentRunError):
|
|
|
72
86
|
return f'{self.message}, body:\n{self.body}'
|
|
73
87
|
else:
|
|
74
88
|
return self.message
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
class ModelHTTPError(AgentRunError):
|
|
92
|
+
"""Raised when an model provider response has a status code of 4xx or 5xx."""
|
|
93
|
+
|
|
94
|
+
status_code: int
|
|
95
|
+
"""The HTTP status code returned by the API."""
|
|
96
|
+
|
|
97
|
+
model_name: str
|
|
98
|
+
"""The name of the model associated with the error."""
|
|
99
|
+
|
|
100
|
+
body: object | None
|
|
101
|
+
"""The body of the response, if available."""
|
|
102
|
+
|
|
103
|
+
message: str
|
|
104
|
+
"""The error message with the status code and response body, if available."""
|
|
105
|
+
|
|
106
|
+
def __init__(self, status_code: int, model_name: str, body: object | None = None):
|
|
107
|
+
self.status_code = status_code
|
|
108
|
+
self.model_name = model_name
|
|
109
|
+
self.body = body
|
|
110
|
+
message = f'status_code: {status_code}, model_name: {model_name}, body: {body}'
|
|
111
|
+
super().__init__(message)
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
class FallbackExceptionGroup(ExceptionGroup):
|
|
115
|
+
"""A group of exceptions that can be raised when all fallback models fail."""
|
|
@@ -1,12 +1,14 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
3
|
import uuid
|
|
4
|
+
from collections.abc import Sequence
|
|
4
5
|
from dataclasses import dataclass, field, replace
|
|
5
6
|
from datetime import datetime
|
|
6
7
|
from typing import Annotated, Any, Literal, Union, cast, overload
|
|
7
8
|
|
|
8
9
|
import pydantic
|
|
9
10
|
import pydantic_core
|
|
11
|
+
from typing_extensions import TypeAlias
|
|
10
12
|
|
|
11
13
|
from ._utils import now_utc as _now_utc
|
|
12
14
|
from .exceptions import UnexpectedModelBehavior
|
|
@@ -32,6 +34,93 @@ class SystemPromptPart:
|
|
|
32
34
|
"""Part type identifier, this is available on all parts as a discriminator."""
|
|
33
35
|
|
|
34
36
|
|
|
37
|
+
@dataclass
|
|
38
|
+
class AudioUrl:
|
|
39
|
+
"""A URL to an audio file."""
|
|
40
|
+
|
|
41
|
+
url: str
|
|
42
|
+
"""The URL of the audio file."""
|
|
43
|
+
|
|
44
|
+
kind: Literal['audio-url'] = 'audio-url'
|
|
45
|
+
"""Type identifier, this is available on all parts as a discriminator."""
|
|
46
|
+
|
|
47
|
+
@property
|
|
48
|
+
def media_type(self) -> AudioMediaType:
|
|
49
|
+
"""Return the media type of the audio file, based on the url."""
|
|
50
|
+
if self.url.endswith('.mp3'):
|
|
51
|
+
return 'audio/mpeg'
|
|
52
|
+
elif self.url.endswith('.wav'):
|
|
53
|
+
return 'audio/wav'
|
|
54
|
+
else:
|
|
55
|
+
raise ValueError(f'Unknown audio file extension: {self.url}')
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
@dataclass
|
|
59
|
+
class ImageUrl:
|
|
60
|
+
"""A URL to an image."""
|
|
61
|
+
|
|
62
|
+
url: str
|
|
63
|
+
"""The URL of the image."""
|
|
64
|
+
|
|
65
|
+
kind: Literal['image-url'] = 'image-url'
|
|
66
|
+
"""Type identifier, this is available on all parts as a discriminator."""
|
|
67
|
+
|
|
68
|
+
@property
|
|
69
|
+
def media_type(self) -> ImageMediaType:
|
|
70
|
+
"""Return the media type of the image, based on the url."""
|
|
71
|
+
if self.url.endswith(('.jpg', '.jpeg')):
|
|
72
|
+
return 'image/jpeg'
|
|
73
|
+
elif self.url.endswith('.png'):
|
|
74
|
+
return 'image/png'
|
|
75
|
+
elif self.url.endswith('.gif'):
|
|
76
|
+
return 'image/gif'
|
|
77
|
+
elif self.url.endswith('.webp'):
|
|
78
|
+
return 'image/webp'
|
|
79
|
+
else:
|
|
80
|
+
raise ValueError(f'Unknown image file extension: {self.url}')
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
AudioMediaType: TypeAlias = Literal['audio/wav', 'audio/mpeg']
|
|
84
|
+
ImageMediaType: TypeAlias = Literal['image/jpeg', 'image/png', 'image/gif', 'image/webp']
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
@dataclass
|
|
88
|
+
class BinaryContent:
|
|
89
|
+
"""Binary content, e.g. an audio or image file."""
|
|
90
|
+
|
|
91
|
+
data: bytes
|
|
92
|
+
"""The binary data."""
|
|
93
|
+
|
|
94
|
+
media_type: AudioMediaType | ImageMediaType | str
|
|
95
|
+
"""The media type of the binary data."""
|
|
96
|
+
|
|
97
|
+
kind: Literal['binary'] = 'binary'
|
|
98
|
+
"""Type identifier, this is available on all parts as a discriminator."""
|
|
99
|
+
|
|
100
|
+
@property
|
|
101
|
+
def is_audio(self) -> bool:
|
|
102
|
+
"""Return `True` if the media type is an audio type."""
|
|
103
|
+
return self.media_type.startswith('audio/')
|
|
104
|
+
|
|
105
|
+
@property
|
|
106
|
+
def is_image(self) -> bool:
|
|
107
|
+
"""Return `True` if the media type is an image type."""
|
|
108
|
+
return self.media_type.startswith('image/')
|
|
109
|
+
|
|
110
|
+
@property
|
|
111
|
+
def audio_format(self) -> Literal['mp3', 'wav']:
|
|
112
|
+
"""Return the audio format given the media type."""
|
|
113
|
+
if self.media_type == 'audio/mpeg':
|
|
114
|
+
return 'mp3'
|
|
115
|
+
elif self.media_type == 'audio/wav':
|
|
116
|
+
return 'wav'
|
|
117
|
+
else:
|
|
118
|
+
raise ValueError(f'Unknown audio media type: {self.media_type}')
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
UserContent: TypeAlias = 'str | ImageUrl | AudioUrl | BinaryContent'
|
|
122
|
+
|
|
123
|
+
|
|
35
124
|
@dataclass
|
|
36
125
|
class UserPromptPart:
|
|
37
126
|
"""A user prompt, generally written by the end user.
|
|
@@ -40,7 +129,7 @@ class UserPromptPart:
|
|
|
40
129
|
[`Agent.run_sync`][pydantic_ai.Agent.run_sync], and [`Agent.run_stream`][pydantic_ai.Agent.run_stream].
|
|
41
130
|
"""
|
|
42
131
|
|
|
43
|
-
content: str
|
|
132
|
+
content: str | Sequence[UserContent]
|
|
44
133
|
"""The content of the prompt."""
|
|
45
134
|
|
|
46
135
|
timestamp: datetime = field(default_factory=_now_utc)
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
|
-
|
|
3
|
+
import io
|
|
4
|
+
from collections.abc import AsyncGenerator, AsyncIterable, AsyncIterator
|
|
4
5
|
from contextlib import asynccontextmanager
|
|
5
6
|
from dataclasses import dataclass, field
|
|
6
7
|
from datetime import datetime, timezone
|
|
@@ -10,9 +11,11 @@ from typing import Any, Literal, Union, cast, overload
|
|
|
10
11
|
from httpx import AsyncClient as AsyncHTTPClient
|
|
11
12
|
from typing_extensions import assert_never
|
|
12
13
|
|
|
13
|
-
from .. import UnexpectedModelBehavior, _utils, usage
|
|
14
|
+
from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage
|
|
14
15
|
from .._utils import guard_tool_call_id as _guard_tool_call_id
|
|
15
16
|
from ..messages import (
|
|
17
|
+
BinaryContent,
|
|
18
|
+
ImageUrl,
|
|
16
19
|
ModelMessage,
|
|
17
20
|
ModelRequest,
|
|
18
21
|
ModelResponse,
|
|
@@ -36,8 +39,9 @@ from . import (
|
|
|
36
39
|
)
|
|
37
40
|
|
|
38
41
|
try:
|
|
39
|
-
from anthropic import NOT_GIVEN, AsyncAnthropic, AsyncStream
|
|
42
|
+
from anthropic import NOT_GIVEN, APIStatusError, AsyncAnthropic, AsyncStream
|
|
40
43
|
from anthropic.types import (
|
|
44
|
+
ImageBlockParam,
|
|
41
45
|
Message as AnthropicMessage,
|
|
42
46
|
MessageParam,
|
|
43
47
|
MetadataParam,
|
|
@@ -214,21 +218,26 @@ class AnthropicModel(Model):
|
|
|
214
218
|
if (allow_parallel_tool_calls := model_settings.get('parallel_tool_calls')) is not None:
|
|
215
219
|
tool_choice['disable_parallel_tool_use'] = not allow_parallel_tool_calls
|
|
216
220
|
|
|
217
|
-
system_prompt, anthropic_messages = self._map_message(messages)
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
221
|
+
system_prompt, anthropic_messages = await self._map_message(messages)
|
|
222
|
+
|
|
223
|
+
try:
|
|
224
|
+
return await self.client.messages.create(
|
|
225
|
+
max_tokens=model_settings.get('max_tokens', 1024),
|
|
226
|
+
system=system_prompt or NOT_GIVEN,
|
|
227
|
+
messages=anthropic_messages,
|
|
228
|
+
model=self._model_name,
|
|
229
|
+
tools=tools or NOT_GIVEN,
|
|
230
|
+
tool_choice=tool_choice or NOT_GIVEN,
|
|
231
|
+
stream=stream,
|
|
232
|
+
temperature=model_settings.get('temperature', NOT_GIVEN),
|
|
233
|
+
top_p=model_settings.get('top_p', NOT_GIVEN),
|
|
234
|
+
timeout=model_settings.get('timeout', NOT_GIVEN),
|
|
235
|
+
metadata=model_settings.get('anthropic_metadata', NOT_GIVEN),
|
|
236
|
+
)
|
|
237
|
+
except APIStatusError as e:
|
|
238
|
+
if (status_code := e.status_code) >= 400:
|
|
239
|
+
raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e
|
|
240
|
+
raise
|
|
232
241
|
|
|
233
242
|
def _process_response(self, response: AnthropicMessage) -> ModelResponse:
|
|
234
243
|
"""Process a non-streamed response, and prepare a message to return."""
|
|
@@ -266,19 +275,19 @@ class AnthropicModel(Model):
|
|
|
266
275
|
tools += [self._map_tool_definition(r) for r in model_request_parameters.result_tools]
|
|
267
276
|
return tools
|
|
268
277
|
|
|
269
|
-
def _map_message(self, messages: list[ModelMessage]) -> tuple[str, list[MessageParam]]:
|
|
278
|
+
async def _map_message(self, messages: list[ModelMessage]) -> tuple[str, list[MessageParam]]:
|
|
270
279
|
"""Just maps a `pydantic_ai.Message` to a `anthropic.types.MessageParam`."""
|
|
271
280
|
system_prompt: str = ''
|
|
272
281
|
anthropic_messages: list[MessageParam] = []
|
|
273
282
|
for m in messages:
|
|
274
283
|
if isinstance(m, ModelRequest):
|
|
275
|
-
user_content_params: list[ToolResultBlockParam | TextBlockParam] = []
|
|
284
|
+
user_content_params: list[ToolResultBlockParam | TextBlockParam | ImageBlockParam] = []
|
|
276
285
|
for request_part in m.parts:
|
|
277
286
|
if isinstance(request_part, SystemPromptPart):
|
|
278
287
|
system_prompt += request_part.content
|
|
279
288
|
elif isinstance(request_part, UserPromptPart):
|
|
280
|
-
|
|
281
|
-
|
|
289
|
+
async for content in self._map_user_prompt(request_part):
|
|
290
|
+
user_content_params.append(content)
|
|
282
291
|
elif isinstance(request_part, ToolReturnPart):
|
|
283
292
|
tool_result_block_param = ToolResultBlockParam(
|
|
284
293
|
tool_use_id=_guard_tool_call_id(t=request_part, model_source='Anthropic'),
|
|
@@ -298,12 +307,7 @@ class AnthropicModel(Model):
|
|
|
298
307
|
is_error=True,
|
|
299
308
|
)
|
|
300
309
|
user_content_params.append(retry_param)
|
|
301
|
-
anthropic_messages.append(
|
|
302
|
-
MessageParam(
|
|
303
|
-
role='user',
|
|
304
|
-
content=user_content_params,
|
|
305
|
-
)
|
|
306
|
-
)
|
|
310
|
+
anthropic_messages.append(MessageParam(role='user', content=user_content_params))
|
|
307
311
|
elif isinstance(m, ModelResponse):
|
|
308
312
|
assistant_content_params: list[TextBlockParam | ToolUseBlockParam] = []
|
|
309
313
|
for response_part in m.parts:
|
|
@@ -322,6 +326,32 @@ class AnthropicModel(Model):
|
|
|
322
326
|
assert_never(m)
|
|
323
327
|
return system_prompt, anthropic_messages
|
|
324
328
|
|
|
329
|
+
@staticmethod
|
|
330
|
+
async def _map_user_prompt(part: UserPromptPart) -> AsyncGenerator[ImageBlockParam | TextBlockParam]:
|
|
331
|
+
if isinstance(part.content, str):
|
|
332
|
+
yield TextBlockParam(text=part.content, type='text')
|
|
333
|
+
else:
|
|
334
|
+
for item in part.content:
|
|
335
|
+
if isinstance(item, str):
|
|
336
|
+
yield TextBlockParam(text=item, type='text')
|
|
337
|
+
elif isinstance(item, BinaryContent):
|
|
338
|
+
if item.is_image:
|
|
339
|
+
yield ImageBlockParam(
|
|
340
|
+
source={'data': io.BytesIO(item.data), 'media_type': item.media_type, 'type': 'base64'}, # type: ignore
|
|
341
|
+
type='image',
|
|
342
|
+
)
|
|
343
|
+
else:
|
|
344
|
+
raise RuntimeError('Only images are supported for binary content')
|
|
345
|
+
elif isinstance(item, ImageUrl):
|
|
346
|
+
response = await cached_async_http_client().get(item.url)
|
|
347
|
+
response.raise_for_status()
|
|
348
|
+
yield ImageBlockParam(
|
|
349
|
+
source={'data': io.BytesIO(response.content), 'media_type': 'image/jpeg', 'type': 'base64'},
|
|
350
|
+
type='image',
|
|
351
|
+
)
|
|
352
|
+
else:
|
|
353
|
+
raise RuntimeError(f'Unsupported content type: {type(item)}')
|
|
354
|
+
|
|
325
355
|
@staticmethod
|
|
326
356
|
def _map_tool_definition(f: ToolDefinition) -> ToolParam:
|
|
327
357
|
return {
|
|
@@ -9,7 +9,7 @@ from cohere import TextAssistantMessageContentItem
|
|
|
9
9
|
from httpx import AsyncClient as AsyncHTTPClient
|
|
10
10
|
from typing_extensions import assert_never
|
|
11
11
|
|
|
12
|
-
from .. import result
|
|
12
|
+
from .. import ModelHTTPError, result
|
|
13
13
|
from .._utils import guard_tool_call_id as _guard_tool_call_id
|
|
14
14
|
from ..messages import (
|
|
15
15
|
ModelMessage,
|
|
@@ -45,6 +45,7 @@ try:
|
|
|
45
45
|
ToolV2Function,
|
|
46
46
|
UserChatMessageV2,
|
|
47
47
|
)
|
|
48
|
+
from cohere.core.api_error import ApiError
|
|
48
49
|
from cohere.v2.client import OMIT
|
|
49
50
|
except ImportError as _import_error:
|
|
50
51
|
raise ImportError(
|
|
@@ -154,17 +155,22 @@ class CohereModel(Model):
|
|
|
154
155
|
) -> ChatResponse:
|
|
155
156
|
tools = self._get_tools(model_request_parameters)
|
|
156
157
|
cohere_messages = list(chain(*(self._map_message(m) for m in messages)))
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
158
|
+
try:
|
|
159
|
+
return await self.client.chat(
|
|
160
|
+
model=self._model_name,
|
|
161
|
+
messages=cohere_messages,
|
|
162
|
+
tools=tools or OMIT,
|
|
163
|
+
max_tokens=model_settings.get('max_tokens', OMIT),
|
|
164
|
+
temperature=model_settings.get('temperature', OMIT),
|
|
165
|
+
p=model_settings.get('top_p', OMIT),
|
|
166
|
+
seed=model_settings.get('seed', OMIT),
|
|
167
|
+
presence_penalty=model_settings.get('presence_penalty', OMIT),
|
|
168
|
+
frequency_penalty=model_settings.get('frequency_penalty', OMIT),
|
|
169
|
+
)
|
|
170
|
+
except ApiError as e:
|
|
171
|
+
if (status_code := e.status_code) and status_code >= 400:
|
|
172
|
+
raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e
|
|
173
|
+
raise
|
|
168
174
|
|
|
169
175
|
def _process_response(self, response: ChatResponse) -> ModelResponse:
|
|
170
176
|
"""Process a non-streamed response, and prepare a message to return."""
|
|
@@ -242,7 +248,10 @@ class CohereModel(Model):
|
|
|
242
248
|
if isinstance(part, SystemPromptPart):
|
|
243
249
|
yield SystemChatMessageV2(role='system', content=part.content)
|
|
244
250
|
elif isinstance(part, UserPromptPart):
|
|
245
|
-
|
|
251
|
+
if isinstance(part.content, str):
|
|
252
|
+
yield UserChatMessageV2(role='user', content=part.content)
|
|
253
|
+
else:
|
|
254
|
+
raise RuntimeError('Cohere does not yet support multi-modal inputs.')
|
|
246
255
|
elif isinstance(part, ToolReturnPart):
|
|
247
256
|
yield ToolChatMessageV2(
|
|
248
257
|
role='tool',
|