pydantic-ai-slim 0.0.25__tar.gz → 0.0.27__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

Files changed (34) hide show
  1. {pydantic_ai_slim-0.0.25 → pydantic_ai_slim-0.0.27}/PKG-INFO +3 -2
  2. {pydantic_ai_slim-0.0.25 → pydantic_ai_slim-0.0.27}/pydantic_ai/__init__.py +22 -4
  3. {pydantic_ai_slim-0.0.25 → pydantic_ai_slim-0.0.27}/pydantic_ai/_agent_graph.py +15 -12
  4. {pydantic_ai_slim-0.0.25 → pydantic_ai_slim-0.0.27}/pydantic_ai/agent.py +13 -13
  5. {pydantic_ai_slim-0.0.25 → pydantic_ai_slim-0.0.27}/pydantic_ai/exceptions.py +42 -1
  6. {pydantic_ai_slim-0.0.25 → pydantic_ai_slim-0.0.27}/pydantic_ai/messages.py +90 -1
  7. {pydantic_ai_slim-0.0.25 → pydantic_ai_slim-0.0.27}/pydantic_ai/models/anthropic.py +58 -28
  8. {pydantic_ai_slim-0.0.25 → pydantic_ai_slim-0.0.27}/pydantic_ai/models/cohere.py +22 -13
  9. pydantic_ai_slim-0.0.27/pydantic_ai/models/fallback.py +116 -0
  10. {pydantic_ai_slim-0.0.25 → pydantic_ai_slim-0.0.27}/pydantic_ai/models/function.py +28 -10
  11. {pydantic_ai_slim-0.0.25 → pydantic_ai_slim-0.0.27}/pydantic_ai/models/gemini.py +78 -10
  12. {pydantic_ai_slim-0.0.25 → pydantic_ai_slim-0.0.27}/pydantic_ai/models/groq.py +59 -27
  13. {pydantic_ai_slim-0.0.25 → pydantic_ai_slim-0.0.27}/pydantic_ai/models/mistral.py +50 -15
  14. {pydantic_ai_slim-0.0.25 → pydantic_ai_slim-0.0.27}/pydantic_ai/models/openai.py +84 -30
  15. {pydantic_ai_slim-0.0.25 → pydantic_ai_slim-0.0.27}/pydantic_ai/tools.py +2 -2
  16. {pydantic_ai_slim-0.0.25 → pydantic_ai_slim-0.0.27}/pyproject.toml +4 -5
  17. {pydantic_ai_slim-0.0.25 → pydantic_ai_slim-0.0.27}/.gitignore +0 -0
  18. {pydantic_ai_slim-0.0.25 → pydantic_ai_slim-0.0.27}/README.md +0 -0
  19. {pydantic_ai_slim-0.0.25 → pydantic_ai_slim-0.0.27}/pydantic_ai/_griffe.py +0 -0
  20. {pydantic_ai_slim-0.0.25 → pydantic_ai_slim-0.0.27}/pydantic_ai/_parts_manager.py +0 -0
  21. {pydantic_ai_slim-0.0.25 → pydantic_ai_slim-0.0.27}/pydantic_ai/_pydantic.py +0 -0
  22. {pydantic_ai_slim-0.0.25 → pydantic_ai_slim-0.0.27}/pydantic_ai/_result.py +0 -0
  23. {pydantic_ai_slim-0.0.25 → pydantic_ai_slim-0.0.27}/pydantic_ai/_system_prompt.py +0 -0
  24. {pydantic_ai_slim-0.0.25 → pydantic_ai_slim-0.0.27}/pydantic_ai/_utils.py +0 -0
  25. {pydantic_ai_slim-0.0.25 → pydantic_ai_slim-0.0.27}/pydantic_ai/format_as_xml.py +0 -0
  26. {pydantic_ai_slim-0.0.25 → pydantic_ai_slim-0.0.27}/pydantic_ai/models/__init__.py +0 -0
  27. {pydantic_ai_slim-0.0.25 → pydantic_ai_slim-0.0.27}/pydantic_ai/models/instrumented.py +0 -0
  28. {pydantic_ai_slim-0.0.25 → pydantic_ai_slim-0.0.27}/pydantic_ai/models/test.py +0 -0
  29. {pydantic_ai_slim-0.0.25 → pydantic_ai_slim-0.0.27}/pydantic_ai/models/vertexai.py +0 -0
  30. {pydantic_ai_slim-0.0.25 → pydantic_ai_slim-0.0.27}/pydantic_ai/models/wrapper.py +0 -0
  31. {pydantic_ai_slim-0.0.25 → pydantic_ai_slim-0.0.27}/pydantic_ai/py.typed +0 -0
  32. {pydantic_ai_slim-0.0.25 → pydantic_ai_slim-0.0.27}/pydantic_ai/result.py +0 -0
  33. {pydantic_ai_slim-0.0.25 → pydantic_ai_slim-0.0.27}/pydantic_ai/settings.py +0 -0
  34. {pydantic_ai_slim-0.0.25 → pydantic_ai_slim-0.0.27}/pydantic_ai/usage.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pydantic-ai-slim
3
- Version: 0.0.25
3
+ Version: 0.0.27
4
4
  Summary: Agent Framework / shim to use Pydantic with LLMs, slim package
5
5
  Author-email: Samuel Colvin <samuel@pydantic.dev>
6
6
  License-Expression: MIT
@@ -25,10 +25,11 @@ Classifier: Topic :: Internet
25
25
  Classifier: Topic :: Software Development :: Libraries :: Python Modules
26
26
  Requires-Python: >=3.9
27
27
  Requires-Dist: eval-type-backport>=0.2.0
28
+ Requires-Dist: exceptiongroup; python_version < '3.11'
28
29
  Requires-Dist: griffe>=1.3.2
29
30
  Requires-Dist: httpx>=0.27
30
31
  Requires-Dist: logfire-api>=1.2.0
31
- Requires-Dist: pydantic-graph==0.0.25
32
+ Requires-Dist: pydantic-graph==0.0.27
32
33
  Requires-Dist: pydantic>=2.10
33
34
  Provides-Extra: anthropic
34
35
  Requires-Dist: anthropic>=0.40.0; extra == 'anthropic'
@@ -1,23 +1,41 @@
1
1
  from importlib.metadata import version
2
2
 
3
3
  from .agent import Agent, EndStrategy, HandleResponseNode, ModelRequestNode, UserPromptNode, capture_run_messages
4
- from .exceptions import AgentRunError, ModelRetry, UnexpectedModelBehavior, UsageLimitExceeded, UserError
4
+ from .exceptions import (
5
+ AgentRunError,
6
+ FallbackExceptionGroup,
7
+ ModelHTTPError,
8
+ ModelRetry,
9
+ UnexpectedModelBehavior,
10
+ UsageLimitExceeded,
11
+ UserError,
12
+ )
13
+ from .messages import AudioUrl, BinaryContent, ImageUrl
5
14
  from .tools import RunContext, Tool
6
15
 
7
16
  __all__ = (
17
+ '__version__',
18
+ # agent
8
19
  'Agent',
9
20
  'EndStrategy',
10
21
  'HandleResponseNode',
11
22
  'ModelRequestNode',
12
23
  'UserPromptNode',
13
24
  'capture_run_messages',
14
- 'RunContext',
15
- 'Tool',
25
+ # exceptions
16
26
  'AgentRunError',
17
27
  'ModelRetry',
28
+ 'ModelHTTPError',
29
+ 'FallbackExceptionGroup',
18
30
  'UnexpectedModelBehavior',
19
31
  'UsageLimitExceeded',
20
32
  'UserError',
21
- '__version__',
33
+ # messages
34
+ 'ImageUrl',
35
+ 'AudioUrl',
36
+ 'BinaryContent',
37
+ # tools
38
+ 'Tool',
39
+ 'RunContext',
22
40
  )
23
41
  __version__ = version('pydantic_ai_slim')
@@ -3,7 +3,7 @@ from __future__ import annotations as _annotations
3
3
  import asyncio
4
4
  import dataclasses
5
5
  from abc import ABC
6
- from collections.abc import AsyncIterator, Iterator
6
+ from collections.abc import AsyncIterator, Iterator, Sequence
7
7
  from contextlib import asynccontextmanager, contextmanager
8
8
  from contextvars import ContextVar
9
9
  from dataclasses import field
@@ -89,7 +89,7 @@ class GraphAgentDeps(Generic[DepsT, ResultDataT]):
89
89
 
90
90
  user_deps: DepsT
91
91
 
92
- prompt: str
92
+ prompt: str | Sequence[_messages.UserContent]
93
93
  new_message_index: int
94
94
 
95
95
  model: models.Model
@@ -108,20 +108,20 @@ class GraphAgentDeps(Generic[DepsT, ResultDataT]):
108
108
 
109
109
 
110
110
  @dataclasses.dataclass
111
- class UserPromptNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], NodeRunEndT], ABC):
112
- user_prompt: str
111
+ class UserPromptNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[NodeRunEndT]], ABC):
112
+ user_prompt: str | Sequence[_messages.UserContent]
113
113
 
114
114
  system_prompts: tuple[str, ...]
115
115
  system_prompt_functions: list[_system_prompt.SystemPromptRunner[DepsT]]
116
116
  system_prompt_dynamic_functions: dict[str, _system_prompt.SystemPromptRunner[DepsT]]
117
117
 
118
118
  async def run(
119
- self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]
119
+ self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
120
120
  ) -> ModelRequestNode[DepsT, NodeRunEndT]:
121
121
  return ModelRequestNode[DepsT, NodeRunEndT](request=await self._get_first_message(ctx))
122
122
 
123
123
  async def _get_first_message(
124
- self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]
124
+ self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
125
125
  ) -> _messages.ModelRequest:
126
126
  run_context = build_run_context(ctx)
127
127
  history, next_message = await self._prepare_messages(self.user_prompt, ctx.state.message_history, run_context)
@@ -135,7 +135,10 @@ class UserPromptNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], NodeR
135
135
  return next_message
136
136
 
137
137
  async def _prepare_messages(
138
- self, user_prompt: str, message_history: list[_messages.ModelMessage] | None, run_context: RunContext[DepsT]
138
+ self,
139
+ user_prompt: str | Sequence[_messages.UserContent],
140
+ message_history: list[_messages.ModelMessage] | None,
141
+ run_context: RunContext[DepsT],
139
142
  ) -> tuple[list[_messages.ModelMessage], _messages.ModelRequest]:
140
143
  try:
141
144
  ctx_messages = get_captured_run_messages()
@@ -212,7 +215,7 @@ async def _prepare_request_parameters(
212
215
 
213
216
 
214
217
  @dataclasses.dataclass
215
- class ModelRequestNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], NodeRunEndT]):
218
+ class ModelRequestNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[NodeRunEndT]]):
216
219
  """Make a request to the model using the last message in state.message_history."""
217
220
 
218
221
  request: _messages.ModelRequest
@@ -316,7 +319,7 @@ class ModelRequestNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], Nod
316
319
 
317
320
 
318
321
  @dataclasses.dataclass
319
- class HandleResponseNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], NodeRunEndT]):
322
+ class HandleResponseNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[NodeRunEndT]]):
320
323
  """Process a model response, and decide whether to end the run or make a new request."""
321
324
 
322
325
  model_response: _messages.ModelResponse
@@ -338,7 +341,7 @@ class HandleResponseNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], N
338
341
 
339
342
  @asynccontextmanager
340
343
  async def stream(
341
- self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]
344
+ self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
342
345
  ) -> AsyncIterator[AsyncIterator[_messages.HandleResponseEvent]]:
343
346
  """Process the model response and yield events for the start and end of each function tool call."""
344
347
  with _logfire.span('handle model response', run_step=ctx.state.run_step) as handle_span:
@@ -363,7 +366,7 @@ class HandleResponseNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], N
363
366
  handle_span.message = f'handle model response -> {tool_responses_str}'
364
367
 
365
368
  async def _run_stream(
366
- self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]
369
+ self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
367
370
  ) -> AsyncIterator[_messages.HandleResponseEvent]:
368
371
  if self._events_iterator is None:
369
372
  # Ensure that the stream is only run once
@@ -667,7 +670,7 @@ def get_captured_run_messages() -> _RunMessages:
667
670
 
668
671
  def build_agent_graph(
669
672
  name: str | None, deps_type: type[DepsT], result_type: type[ResultT]
670
- ) -> Graph[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[ResultT]]:
673
+ ) -> Graph[GraphAgentState, GraphAgentDeps[DepsT, result.FinalResult[ResultT]], result.FinalResult[ResultT]]:
671
674
  """Build the execution [Graph][pydantic_graph.Graph] for a given agent."""
672
675
  nodes = (
673
676
  UserPromptNode[DepsT],
@@ -220,7 +220,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
220
220
  @overload
221
221
  async def run(
222
222
  self,
223
- user_prompt: str,
223
+ user_prompt: str | Sequence[_messages.UserContent],
224
224
  *,
225
225
  result_type: None = None,
226
226
  message_history: list[_messages.ModelMessage] | None = None,
@@ -235,7 +235,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
235
235
  @overload
236
236
  async def run(
237
237
  self,
238
- user_prompt: str,
238
+ user_prompt: str | Sequence[_messages.UserContent],
239
239
  *,
240
240
  result_type: type[RunResultDataT],
241
241
  message_history: list[_messages.ModelMessage] | None = None,
@@ -249,7 +249,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
249
249
 
250
250
  async def run(
251
251
  self,
252
- user_prompt: str,
252
+ user_prompt: str | Sequence[_messages.UserContent],
253
253
  *,
254
254
  result_type: type[RunResultDataT] | None = None,
255
255
  message_history: list[_messages.ModelMessage] | None = None,
@@ -313,7 +313,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
313
313
  @contextmanager
314
314
  def iter(
315
315
  self,
316
- user_prompt: str,
316
+ user_prompt: str | Sequence[_messages.UserContent],
317
317
  *,
318
318
  result_type: type[RunResultDataT] | None = None,
319
319
  message_history: list[_messages.ModelMessage] | None = None,
@@ -365,7 +365,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
365
365
  HandleResponseNode(
366
366
  model_response=ModelResponse(
367
367
  parts=[TextPart(content='Paris', part_kind='text')],
368
- model_name='function:model_logic',
368
+ model_name='gpt-4o',
369
369
  timestamp=datetime.datetime(...),
370
370
  kind='response',
371
371
  )
@@ -466,7 +466,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
466
466
  @overload
467
467
  def run_sync(
468
468
  self,
469
- user_prompt: str,
469
+ user_prompt: str | Sequence[_messages.UserContent],
470
470
  *,
471
471
  message_history: list[_messages.ModelMessage] | None = None,
472
472
  model: models.Model | models.KnownModelName | None = None,
@@ -480,7 +480,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
480
480
  @overload
481
481
  def run_sync(
482
482
  self,
483
- user_prompt: str,
483
+ user_prompt: str | Sequence[_messages.UserContent],
484
484
  *,
485
485
  result_type: type[RunResultDataT] | None,
486
486
  message_history: list[_messages.ModelMessage] | None = None,
@@ -494,7 +494,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
494
494
 
495
495
  def run_sync(
496
496
  self,
497
- user_prompt: str,
497
+ user_prompt: str | Sequence[_messages.UserContent],
498
498
  *,
499
499
  result_type: type[RunResultDataT] | None = None,
500
500
  message_history: list[_messages.ModelMessage] | None = None,
@@ -555,7 +555,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
555
555
  @overload
556
556
  def run_stream(
557
557
  self,
558
- user_prompt: str,
558
+ user_prompt: str | Sequence[_messages.UserContent],
559
559
  *,
560
560
  result_type: None = None,
561
561
  message_history: list[_messages.ModelMessage] | None = None,
@@ -570,7 +570,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
570
570
  @overload
571
571
  def run_stream(
572
572
  self,
573
- user_prompt: str,
573
+ user_prompt: str | Sequence[_messages.UserContent],
574
574
  *,
575
575
  result_type: type[RunResultDataT],
576
576
  message_history: list[_messages.ModelMessage] | None = None,
@@ -585,7 +585,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
585
585
  @asynccontextmanager
586
586
  async def run_stream( # noqa C901
587
587
  self,
588
- user_prompt: str,
588
+ user_prompt: str | Sequence[_messages.UserContent],
589
589
  *,
590
590
  result_type: type[RunResultDataT] | None = None,
591
591
  message_history: list[_messages.ModelMessage] | None = None,
@@ -1214,7 +1214,7 @@ class AgentRun(Generic[AgentDepsT, ResultDataT]):
1214
1214
  HandleResponseNode(
1215
1215
  model_response=ModelResponse(
1216
1216
  parts=[TextPart(content='Paris', part_kind='text')],
1217
- model_name='function:model_logic',
1217
+ model_name='gpt-4o',
1218
1218
  timestamp=datetime.datetime(...),
1219
1219
  kind='response',
1220
1220
  )
@@ -1357,7 +1357,7 @@ class AgentRun(Generic[AgentDepsT, ResultDataT]):
1357
1357
  HandleResponseNode(
1358
1358
  model_response=ModelResponse(
1359
1359
  parts=[TextPart(content='Paris', part_kind='text')],
1360
- model_name='function:model_logic',
1360
+ model_name='gpt-4o',
1361
1361
  timestamp=datetime.datetime(...),
1362
1362
  kind='response',
1363
1363
  )
@@ -1,8 +1,22 @@
1
1
  from __future__ import annotations as _annotations
2
2
 
3
3
  import json
4
+ import sys
4
5
 
5
- __all__ = 'ModelRetry', 'UserError', 'AgentRunError', 'UnexpectedModelBehavior', 'UsageLimitExceeded'
6
+ if sys.version_info < (3, 11):
7
+ from exceptiongroup import ExceptionGroup
8
+ else:
9
+ ExceptionGroup = ExceptionGroup
10
+
11
+ __all__ = (
12
+ 'ModelRetry',
13
+ 'UserError',
14
+ 'AgentRunError',
15
+ 'UnexpectedModelBehavior',
16
+ 'UsageLimitExceeded',
17
+ 'ModelHTTPError',
18
+ 'FallbackExceptionGroup',
19
+ )
6
20
 
7
21
 
8
22
  class ModelRetry(Exception):
@@ -72,3 +86,30 @@ class UnexpectedModelBehavior(AgentRunError):
72
86
  return f'{self.message}, body:\n{self.body}'
73
87
  else:
74
88
  return self.message
89
+
90
+
91
+ class ModelHTTPError(AgentRunError):
92
+ """Raised when an model provider response has a status code of 4xx or 5xx."""
93
+
94
+ status_code: int
95
+ """The HTTP status code returned by the API."""
96
+
97
+ model_name: str
98
+ """The name of the model associated with the error."""
99
+
100
+ body: object | None
101
+ """The body of the response, if available."""
102
+
103
+ message: str
104
+ """The error message with the status code and response body, if available."""
105
+
106
+ def __init__(self, status_code: int, model_name: str, body: object | None = None):
107
+ self.status_code = status_code
108
+ self.model_name = model_name
109
+ self.body = body
110
+ message = f'status_code: {status_code}, model_name: {model_name}, body: {body}'
111
+ super().__init__(message)
112
+
113
+
114
+ class FallbackExceptionGroup(ExceptionGroup):
115
+ """A group of exceptions that can be raised when all fallback models fail."""
@@ -1,12 +1,14 @@
1
1
  from __future__ import annotations as _annotations
2
2
 
3
3
  import uuid
4
+ from collections.abc import Sequence
4
5
  from dataclasses import dataclass, field, replace
5
6
  from datetime import datetime
6
7
  from typing import Annotated, Any, Literal, Union, cast, overload
7
8
 
8
9
  import pydantic
9
10
  import pydantic_core
11
+ from typing_extensions import TypeAlias
10
12
 
11
13
  from ._utils import now_utc as _now_utc
12
14
  from .exceptions import UnexpectedModelBehavior
@@ -32,6 +34,93 @@ class SystemPromptPart:
32
34
  """Part type identifier, this is available on all parts as a discriminator."""
33
35
 
34
36
 
37
+ @dataclass
38
+ class AudioUrl:
39
+ """A URL to an audio file."""
40
+
41
+ url: str
42
+ """The URL of the audio file."""
43
+
44
+ kind: Literal['audio-url'] = 'audio-url'
45
+ """Type identifier, this is available on all parts as a discriminator."""
46
+
47
+ @property
48
+ def media_type(self) -> AudioMediaType:
49
+ """Return the media type of the audio file, based on the url."""
50
+ if self.url.endswith('.mp3'):
51
+ return 'audio/mpeg'
52
+ elif self.url.endswith('.wav'):
53
+ return 'audio/wav'
54
+ else:
55
+ raise ValueError(f'Unknown audio file extension: {self.url}')
56
+
57
+
58
+ @dataclass
59
+ class ImageUrl:
60
+ """A URL to an image."""
61
+
62
+ url: str
63
+ """The URL of the image."""
64
+
65
+ kind: Literal['image-url'] = 'image-url'
66
+ """Type identifier, this is available on all parts as a discriminator."""
67
+
68
+ @property
69
+ def media_type(self) -> ImageMediaType:
70
+ """Return the media type of the image, based on the url."""
71
+ if self.url.endswith(('.jpg', '.jpeg')):
72
+ return 'image/jpeg'
73
+ elif self.url.endswith('.png'):
74
+ return 'image/png'
75
+ elif self.url.endswith('.gif'):
76
+ return 'image/gif'
77
+ elif self.url.endswith('.webp'):
78
+ return 'image/webp'
79
+ else:
80
+ raise ValueError(f'Unknown image file extension: {self.url}')
81
+
82
+
83
+ AudioMediaType: TypeAlias = Literal['audio/wav', 'audio/mpeg']
84
+ ImageMediaType: TypeAlias = Literal['image/jpeg', 'image/png', 'image/gif', 'image/webp']
85
+
86
+
87
+ @dataclass
88
+ class BinaryContent:
89
+ """Binary content, e.g. an audio or image file."""
90
+
91
+ data: bytes
92
+ """The binary data."""
93
+
94
+ media_type: AudioMediaType | ImageMediaType | str
95
+ """The media type of the binary data."""
96
+
97
+ kind: Literal['binary'] = 'binary'
98
+ """Type identifier, this is available on all parts as a discriminator."""
99
+
100
+ @property
101
+ def is_audio(self) -> bool:
102
+ """Return `True` if the media type is an audio type."""
103
+ return self.media_type.startswith('audio/')
104
+
105
+ @property
106
+ def is_image(self) -> bool:
107
+ """Return `True` if the media type is an image type."""
108
+ return self.media_type.startswith('image/')
109
+
110
+ @property
111
+ def audio_format(self) -> Literal['mp3', 'wav']:
112
+ """Return the audio format given the media type."""
113
+ if self.media_type == 'audio/mpeg':
114
+ return 'mp3'
115
+ elif self.media_type == 'audio/wav':
116
+ return 'wav'
117
+ else:
118
+ raise ValueError(f'Unknown audio media type: {self.media_type}')
119
+
120
+
121
+ UserContent: TypeAlias = 'str | ImageUrl | AudioUrl | BinaryContent'
122
+
123
+
35
124
  @dataclass
36
125
  class UserPromptPart:
37
126
  """A user prompt, generally written by the end user.
@@ -40,7 +129,7 @@ class UserPromptPart:
40
129
  [`Agent.run_sync`][pydantic_ai.Agent.run_sync], and [`Agent.run_stream`][pydantic_ai.Agent.run_stream].
41
130
  """
42
131
 
43
- content: str
132
+ content: str | Sequence[UserContent]
44
133
  """The content of the prompt."""
45
134
 
46
135
  timestamp: datetime = field(default_factory=_now_utc)
@@ -1,6 +1,7 @@
1
1
  from __future__ import annotations as _annotations
2
2
 
3
- from collections.abc import AsyncIterable, AsyncIterator
3
+ import io
4
+ from collections.abc import AsyncGenerator, AsyncIterable, AsyncIterator
4
5
  from contextlib import asynccontextmanager
5
6
  from dataclasses import dataclass, field
6
7
  from datetime import datetime, timezone
@@ -10,9 +11,11 @@ from typing import Any, Literal, Union, cast, overload
10
11
  from httpx import AsyncClient as AsyncHTTPClient
11
12
  from typing_extensions import assert_never
12
13
 
13
- from .. import UnexpectedModelBehavior, _utils, usage
14
+ from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage
14
15
  from .._utils import guard_tool_call_id as _guard_tool_call_id
15
16
  from ..messages import (
17
+ BinaryContent,
18
+ ImageUrl,
16
19
  ModelMessage,
17
20
  ModelRequest,
18
21
  ModelResponse,
@@ -36,8 +39,9 @@ from . import (
36
39
  )
37
40
 
38
41
  try:
39
- from anthropic import NOT_GIVEN, AsyncAnthropic, AsyncStream
42
+ from anthropic import NOT_GIVEN, APIStatusError, AsyncAnthropic, AsyncStream
40
43
  from anthropic.types import (
44
+ ImageBlockParam,
41
45
  Message as AnthropicMessage,
42
46
  MessageParam,
43
47
  MetadataParam,
@@ -214,21 +218,26 @@ class AnthropicModel(Model):
214
218
  if (allow_parallel_tool_calls := model_settings.get('parallel_tool_calls')) is not None:
215
219
  tool_choice['disable_parallel_tool_use'] = not allow_parallel_tool_calls
216
220
 
217
- system_prompt, anthropic_messages = self._map_message(messages)
218
-
219
- return await self.client.messages.create(
220
- max_tokens=model_settings.get('max_tokens', 1024),
221
- system=system_prompt or NOT_GIVEN,
222
- messages=anthropic_messages,
223
- model=self._model_name,
224
- tools=tools or NOT_GIVEN,
225
- tool_choice=tool_choice or NOT_GIVEN,
226
- stream=stream,
227
- temperature=model_settings.get('temperature', NOT_GIVEN),
228
- top_p=model_settings.get('top_p', NOT_GIVEN),
229
- timeout=model_settings.get('timeout', NOT_GIVEN),
230
- metadata=model_settings.get('anthropic_metadata', NOT_GIVEN),
231
- )
221
+ system_prompt, anthropic_messages = await self._map_message(messages)
222
+
223
+ try:
224
+ return await self.client.messages.create(
225
+ max_tokens=model_settings.get('max_tokens', 1024),
226
+ system=system_prompt or NOT_GIVEN,
227
+ messages=anthropic_messages,
228
+ model=self._model_name,
229
+ tools=tools or NOT_GIVEN,
230
+ tool_choice=tool_choice or NOT_GIVEN,
231
+ stream=stream,
232
+ temperature=model_settings.get('temperature', NOT_GIVEN),
233
+ top_p=model_settings.get('top_p', NOT_GIVEN),
234
+ timeout=model_settings.get('timeout', NOT_GIVEN),
235
+ metadata=model_settings.get('anthropic_metadata', NOT_GIVEN),
236
+ )
237
+ except APIStatusError as e:
238
+ if (status_code := e.status_code) >= 400:
239
+ raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e
240
+ raise
232
241
 
233
242
  def _process_response(self, response: AnthropicMessage) -> ModelResponse:
234
243
  """Process a non-streamed response, and prepare a message to return."""
@@ -266,19 +275,19 @@ class AnthropicModel(Model):
266
275
  tools += [self._map_tool_definition(r) for r in model_request_parameters.result_tools]
267
276
  return tools
268
277
 
269
- def _map_message(self, messages: list[ModelMessage]) -> tuple[str, list[MessageParam]]:
278
+ async def _map_message(self, messages: list[ModelMessage]) -> tuple[str, list[MessageParam]]:
270
279
  """Just maps a `pydantic_ai.Message` to a `anthropic.types.MessageParam`."""
271
280
  system_prompt: str = ''
272
281
  anthropic_messages: list[MessageParam] = []
273
282
  for m in messages:
274
283
  if isinstance(m, ModelRequest):
275
- user_content_params: list[ToolResultBlockParam | TextBlockParam] = []
284
+ user_content_params: list[ToolResultBlockParam | TextBlockParam | ImageBlockParam] = []
276
285
  for request_part in m.parts:
277
286
  if isinstance(request_part, SystemPromptPart):
278
287
  system_prompt += request_part.content
279
288
  elif isinstance(request_part, UserPromptPart):
280
- text_block_param = TextBlockParam(type='text', text=request_part.content)
281
- user_content_params.append(text_block_param)
289
+ async for content in self._map_user_prompt(request_part):
290
+ user_content_params.append(content)
282
291
  elif isinstance(request_part, ToolReturnPart):
283
292
  tool_result_block_param = ToolResultBlockParam(
284
293
  tool_use_id=_guard_tool_call_id(t=request_part, model_source='Anthropic'),
@@ -298,12 +307,7 @@ class AnthropicModel(Model):
298
307
  is_error=True,
299
308
  )
300
309
  user_content_params.append(retry_param)
301
- anthropic_messages.append(
302
- MessageParam(
303
- role='user',
304
- content=user_content_params,
305
- )
306
- )
310
+ anthropic_messages.append(MessageParam(role='user', content=user_content_params))
307
311
  elif isinstance(m, ModelResponse):
308
312
  assistant_content_params: list[TextBlockParam | ToolUseBlockParam] = []
309
313
  for response_part in m.parts:
@@ -322,6 +326,32 @@ class AnthropicModel(Model):
322
326
  assert_never(m)
323
327
  return system_prompt, anthropic_messages
324
328
 
329
+ @staticmethod
330
+ async def _map_user_prompt(part: UserPromptPart) -> AsyncGenerator[ImageBlockParam | TextBlockParam]:
331
+ if isinstance(part.content, str):
332
+ yield TextBlockParam(text=part.content, type='text')
333
+ else:
334
+ for item in part.content:
335
+ if isinstance(item, str):
336
+ yield TextBlockParam(text=item, type='text')
337
+ elif isinstance(item, BinaryContent):
338
+ if item.is_image:
339
+ yield ImageBlockParam(
340
+ source={'data': io.BytesIO(item.data), 'media_type': item.media_type, 'type': 'base64'}, # type: ignore
341
+ type='image',
342
+ )
343
+ else:
344
+ raise RuntimeError('Only images are supported for binary content')
345
+ elif isinstance(item, ImageUrl):
346
+ response = await cached_async_http_client().get(item.url)
347
+ response.raise_for_status()
348
+ yield ImageBlockParam(
349
+ source={'data': io.BytesIO(response.content), 'media_type': 'image/jpeg', 'type': 'base64'},
350
+ type='image',
351
+ )
352
+ else:
353
+ raise RuntimeError(f'Unsupported content type: {type(item)}')
354
+
325
355
  @staticmethod
326
356
  def _map_tool_definition(f: ToolDefinition) -> ToolParam:
327
357
  return {
@@ -9,7 +9,7 @@ from cohere import TextAssistantMessageContentItem
9
9
  from httpx import AsyncClient as AsyncHTTPClient
10
10
  from typing_extensions import assert_never
11
11
 
12
- from .. import result
12
+ from .. import ModelHTTPError, result
13
13
  from .._utils import guard_tool_call_id as _guard_tool_call_id
14
14
  from ..messages import (
15
15
  ModelMessage,
@@ -45,6 +45,7 @@ try:
45
45
  ToolV2Function,
46
46
  UserChatMessageV2,
47
47
  )
48
+ from cohere.core.api_error import ApiError
48
49
  from cohere.v2.client import OMIT
49
50
  except ImportError as _import_error:
50
51
  raise ImportError(
@@ -154,17 +155,22 @@ class CohereModel(Model):
154
155
  ) -> ChatResponse:
155
156
  tools = self._get_tools(model_request_parameters)
156
157
  cohere_messages = list(chain(*(self._map_message(m) for m in messages)))
157
- return await self.client.chat(
158
- model=self._model_name,
159
- messages=cohere_messages,
160
- tools=tools or OMIT,
161
- max_tokens=model_settings.get('max_tokens', OMIT),
162
- temperature=model_settings.get('temperature', OMIT),
163
- p=model_settings.get('top_p', OMIT),
164
- seed=model_settings.get('seed', OMIT),
165
- presence_penalty=model_settings.get('presence_penalty', OMIT),
166
- frequency_penalty=model_settings.get('frequency_penalty', OMIT),
167
- )
158
+ try:
159
+ return await self.client.chat(
160
+ model=self._model_name,
161
+ messages=cohere_messages,
162
+ tools=tools or OMIT,
163
+ max_tokens=model_settings.get('max_tokens', OMIT),
164
+ temperature=model_settings.get('temperature', OMIT),
165
+ p=model_settings.get('top_p', OMIT),
166
+ seed=model_settings.get('seed', OMIT),
167
+ presence_penalty=model_settings.get('presence_penalty', OMIT),
168
+ frequency_penalty=model_settings.get('frequency_penalty', OMIT),
169
+ )
170
+ except ApiError as e:
171
+ if (status_code := e.status_code) and status_code >= 400:
172
+ raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e
173
+ raise
168
174
 
169
175
  def _process_response(self, response: ChatResponse) -> ModelResponse:
170
176
  """Process a non-streamed response, and prepare a message to return."""
@@ -242,7 +248,10 @@ class CohereModel(Model):
242
248
  if isinstance(part, SystemPromptPart):
243
249
  yield SystemChatMessageV2(role='system', content=part.content)
244
250
  elif isinstance(part, UserPromptPart):
245
- yield UserChatMessageV2(role='user', content=part.content)
251
+ if isinstance(part.content, str):
252
+ yield UserChatMessageV2(role='user', content=part.content)
253
+ else:
254
+ raise RuntimeError('Cohere does not yet support multi-modal inputs.')
246
255
  elif isinstance(part, ToolReturnPart):
247
256
  yield ToolChatMessageV2(
248
257
  role='tool',