pydantic-ai-slim 0.0.26__tar.gz → 0.0.27__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

Files changed (34) hide show
  1. {pydantic_ai_slim-0.0.26 → pydantic_ai_slim-0.0.27}/PKG-INFO +3 -2
  2. {pydantic_ai_slim-0.0.26 → pydantic_ai_slim-0.0.27}/pydantic_ai/__init__.py +11 -1
  3. {pydantic_ai_slim-0.0.26 → pydantic_ai_slim-0.0.27}/pydantic_ai/agent.py +3 -3
  4. {pydantic_ai_slim-0.0.26 → pydantic_ai_slim-0.0.27}/pydantic_ai/exceptions.py +42 -1
  5. {pydantic_ai_slim-0.0.26 → pydantic_ai_slim-0.0.27}/pydantic_ai/models/anthropic.py +20 -15
  6. {pydantic_ai_slim-0.0.26 → pydantic_ai_slim-0.0.27}/pydantic_ai/models/cohere.py +18 -12
  7. pydantic_ai_slim-0.0.27/pydantic_ai/models/fallback.py +116 -0
  8. {pydantic_ai_slim-0.0.26 → pydantic_ai_slim-0.0.27}/pydantic_ai/models/function.py +16 -7
  9. {pydantic_ai_slim-0.0.26 → pydantic_ai_slim-0.0.27}/pydantic_ai/models/gemini.py +30 -14
  10. {pydantic_ai_slim-0.0.26 → pydantic_ai_slim-0.0.27}/pydantic_ai/models/groq.py +24 -19
  11. {pydantic_ai_slim-0.0.26 → pydantic_ai_slim-0.0.27}/pydantic_ai/models/mistral.py +21 -14
  12. {pydantic_ai_slim-0.0.26 → pydantic_ai_slim-0.0.27}/pydantic_ai/models/openai.py +26 -21
  13. {pydantic_ai_slim-0.0.26 → pydantic_ai_slim-0.0.27}/pyproject.toml +4 -5
  14. {pydantic_ai_slim-0.0.26 → pydantic_ai_slim-0.0.27}/.gitignore +0 -0
  15. {pydantic_ai_slim-0.0.26 → pydantic_ai_slim-0.0.27}/README.md +0 -0
  16. {pydantic_ai_slim-0.0.26 → pydantic_ai_slim-0.0.27}/pydantic_ai/_agent_graph.py +0 -0
  17. {pydantic_ai_slim-0.0.26 → pydantic_ai_slim-0.0.27}/pydantic_ai/_griffe.py +0 -0
  18. {pydantic_ai_slim-0.0.26 → pydantic_ai_slim-0.0.27}/pydantic_ai/_parts_manager.py +0 -0
  19. {pydantic_ai_slim-0.0.26 → pydantic_ai_slim-0.0.27}/pydantic_ai/_pydantic.py +0 -0
  20. {pydantic_ai_slim-0.0.26 → pydantic_ai_slim-0.0.27}/pydantic_ai/_result.py +0 -0
  21. {pydantic_ai_slim-0.0.26 → pydantic_ai_slim-0.0.27}/pydantic_ai/_system_prompt.py +0 -0
  22. {pydantic_ai_slim-0.0.26 → pydantic_ai_slim-0.0.27}/pydantic_ai/_utils.py +0 -0
  23. {pydantic_ai_slim-0.0.26 → pydantic_ai_slim-0.0.27}/pydantic_ai/format_as_xml.py +0 -0
  24. {pydantic_ai_slim-0.0.26 → pydantic_ai_slim-0.0.27}/pydantic_ai/messages.py +0 -0
  25. {pydantic_ai_slim-0.0.26 → pydantic_ai_slim-0.0.27}/pydantic_ai/models/__init__.py +0 -0
  26. {pydantic_ai_slim-0.0.26 → pydantic_ai_slim-0.0.27}/pydantic_ai/models/instrumented.py +0 -0
  27. {pydantic_ai_slim-0.0.26 → pydantic_ai_slim-0.0.27}/pydantic_ai/models/test.py +0 -0
  28. {pydantic_ai_slim-0.0.26 → pydantic_ai_slim-0.0.27}/pydantic_ai/models/vertexai.py +0 -0
  29. {pydantic_ai_slim-0.0.26 → pydantic_ai_slim-0.0.27}/pydantic_ai/models/wrapper.py +0 -0
  30. {pydantic_ai_slim-0.0.26 → pydantic_ai_slim-0.0.27}/pydantic_ai/py.typed +0 -0
  31. {pydantic_ai_slim-0.0.26 → pydantic_ai_slim-0.0.27}/pydantic_ai/result.py +0 -0
  32. {pydantic_ai_slim-0.0.26 → pydantic_ai_slim-0.0.27}/pydantic_ai/settings.py +0 -0
  33. {pydantic_ai_slim-0.0.26 → pydantic_ai_slim-0.0.27}/pydantic_ai/tools.py +0 -0
  34. {pydantic_ai_slim-0.0.26 → pydantic_ai_slim-0.0.27}/pydantic_ai/usage.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pydantic-ai-slim
3
- Version: 0.0.26
3
+ Version: 0.0.27
4
4
  Summary: Agent Framework / shim to use Pydantic with LLMs, slim package
5
5
  Author-email: Samuel Colvin <samuel@pydantic.dev>
6
6
  License-Expression: MIT
@@ -25,10 +25,11 @@ Classifier: Topic :: Internet
25
25
  Classifier: Topic :: Software Development :: Libraries :: Python Modules
26
26
  Requires-Python: >=3.9
27
27
  Requires-Dist: eval-type-backport>=0.2.0
28
+ Requires-Dist: exceptiongroup; python_version < '3.11'
28
29
  Requires-Dist: griffe>=1.3.2
29
30
  Requires-Dist: httpx>=0.27
30
31
  Requires-Dist: logfire-api>=1.2.0
31
- Requires-Dist: pydantic-graph==0.0.26
32
+ Requires-Dist: pydantic-graph==0.0.27
32
33
  Requires-Dist: pydantic>=2.10
33
34
  Provides-Extra: anthropic
34
35
  Requires-Dist: anthropic>=0.40.0; extra == 'anthropic'
@@ -1,7 +1,15 @@
1
1
  from importlib.metadata import version
2
2
 
3
3
  from .agent import Agent, EndStrategy, HandleResponseNode, ModelRequestNode, UserPromptNode, capture_run_messages
4
- from .exceptions import AgentRunError, ModelRetry, UnexpectedModelBehavior, UsageLimitExceeded, UserError
4
+ from .exceptions import (
5
+ AgentRunError,
6
+ FallbackExceptionGroup,
7
+ ModelHTTPError,
8
+ ModelRetry,
9
+ UnexpectedModelBehavior,
10
+ UsageLimitExceeded,
11
+ UserError,
12
+ )
5
13
  from .messages import AudioUrl, BinaryContent, ImageUrl
6
14
  from .tools import RunContext, Tool
7
15
 
@@ -17,6 +25,8 @@ __all__ = (
17
25
  # exceptions
18
26
  'AgentRunError',
19
27
  'ModelRetry',
28
+ 'ModelHTTPError',
29
+ 'FallbackExceptionGroup',
20
30
  'UnexpectedModelBehavior',
21
31
  'UsageLimitExceeded',
22
32
  'UserError',
@@ -365,7 +365,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
365
365
  HandleResponseNode(
366
366
  model_response=ModelResponse(
367
367
  parts=[TextPart(content='Paris', part_kind='text')],
368
- model_name='function:model_logic',
368
+ model_name='gpt-4o',
369
369
  timestamp=datetime.datetime(...),
370
370
  kind='response',
371
371
  )
@@ -1214,7 +1214,7 @@ class AgentRun(Generic[AgentDepsT, ResultDataT]):
1214
1214
  HandleResponseNode(
1215
1215
  model_response=ModelResponse(
1216
1216
  parts=[TextPart(content='Paris', part_kind='text')],
1217
- model_name='function:model_logic',
1217
+ model_name='gpt-4o',
1218
1218
  timestamp=datetime.datetime(...),
1219
1219
  kind='response',
1220
1220
  )
@@ -1357,7 +1357,7 @@ class AgentRun(Generic[AgentDepsT, ResultDataT]):
1357
1357
  HandleResponseNode(
1358
1358
  model_response=ModelResponse(
1359
1359
  parts=[TextPart(content='Paris', part_kind='text')],
1360
- model_name='function:model_logic',
1360
+ model_name='gpt-4o',
1361
1361
  timestamp=datetime.datetime(...),
1362
1362
  kind='response',
1363
1363
  )
@@ -1,8 +1,22 @@
1
1
  from __future__ import annotations as _annotations
2
2
 
3
3
  import json
4
+ import sys
4
5
 
5
- __all__ = 'ModelRetry', 'UserError', 'AgentRunError', 'UnexpectedModelBehavior', 'UsageLimitExceeded'
6
+ if sys.version_info < (3, 11):
7
+ from exceptiongroup import ExceptionGroup
8
+ else:
9
+ ExceptionGroup = ExceptionGroup
10
+
11
+ __all__ = (
12
+ 'ModelRetry',
13
+ 'UserError',
14
+ 'AgentRunError',
15
+ 'UnexpectedModelBehavior',
16
+ 'UsageLimitExceeded',
17
+ 'ModelHTTPError',
18
+ 'FallbackExceptionGroup',
19
+ )
6
20
 
7
21
 
8
22
  class ModelRetry(Exception):
@@ -72,3 +86,30 @@ class UnexpectedModelBehavior(AgentRunError):
72
86
  return f'{self.message}, body:\n{self.body}'
73
87
  else:
74
88
  return self.message
89
+
90
+
91
+ class ModelHTTPError(AgentRunError):
92
+ """Raised when an model provider response has a status code of 4xx or 5xx."""
93
+
94
+ status_code: int
95
+ """The HTTP status code returned by the API."""
96
+
97
+ model_name: str
98
+ """The name of the model associated with the error."""
99
+
100
+ body: object | None
101
+ """The body of the response, if available."""
102
+
103
+ message: str
104
+ """The error message with the status code and response body, if available."""
105
+
106
+ def __init__(self, status_code: int, model_name: str, body: object | None = None):
107
+ self.status_code = status_code
108
+ self.model_name = model_name
109
+ self.body = body
110
+ message = f'status_code: {status_code}, model_name: {model_name}, body: {body}'
111
+ super().__init__(message)
112
+
113
+
114
+ class FallbackExceptionGroup(ExceptionGroup):
115
+ """A group of exceptions that can be raised when all fallback models fail."""
@@ -11,7 +11,7 @@ from typing import Any, Literal, Union, cast, overload
11
11
  from httpx import AsyncClient as AsyncHTTPClient
12
12
  from typing_extensions import assert_never
13
13
 
14
- from .. import UnexpectedModelBehavior, _utils, usage
14
+ from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage
15
15
  from .._utils import guard_tool_call_id as _guard_tool_call_id
16
16
  from ..messages import (
17
17
  BinaryContent,
@@ -39,7 +39,7 @@ from . import (
39
39
  )
40
40
 
41
41
  try:
42
- from anthropic import NOT_GIVEN, AsyncAnthropic, AsyncStream
42
+ from anthropic import NOT_GIVEN, APIStatusError, AsyncAnthropic, AsyncStream
43
43
  from anthropic.types import (
44
44
  ImageBlockParam,
45
45
  Message as AnthropicMessage,
@@ -220,19 +220,24 @@ class AnthropicModel(Model):
220
220
 
221
221
  system_prompt, anthropic_messages = await self._map_message(messages)
222
222
 
223
- return await self.client.messages.create(
224
- max_tokens=model_settings.get('max_tokens', 1024),
225
- system=system_prompt or NOT_GIVEN,
226
- messages=anthropic_messages,
227
- model=self._model_name,
228
- tools=tools or NOT_GIVEN,
229
- tool_choice=tool_choice or NOT_GIVEN,
230
- stream=stream,
231
- temperature=model_settings.get('temperature', NOT_GIVEN),
232
- top_p=model_settings.get('top_p', NOT_GIVEN),
233
- timeout=model_settings.get('timeout', NOT_GIVEN),
234
- metadata=model_settings.get('anthropic_metadata', NOT_GIVEN),
235
- )
223
+ try:
224
+ return await self.client.messages.create(
225
+ max_tokens=model_settings.get('max_tokens', 1024),
226
+ system=system_prompt or NOT_GIVEN,
227
+ messages=anthropic_messages,
228
+ model=self._model_name,
229
+ tools=tools or NOT_GIVEN,
230
+ tool_choice=tool_choice or NOT_GIVEN,
231
+ stream=stream,
232
+ temperature=model_settings.get('temperature', NOT_GIVEN),
233
+ top_p=model_settings.get('top_p', NOT_GIVEN),
234
+ timeout=model_settings.get('timeout', NOT_GIVEN),
235
+ metadata=model_settings.get('anthropic_metadata', NOT_GIVEN),
236
+ )
237
+ except APIStatusError as e:
238
+ if (status_code := e.status_code) >= 400:
239
+ raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e
240
+ raise
236
241
 
237
242
  def _process_response(self, response: AnthropicMessage) -> ModelResponse:
238
243
  """Process a non-streamed response, and prepare a message to return."""
@@ -9,7 +9,7 @@ from cohere import TextAssistantMessageContentItem
9
9
  from httpx import AsyncClient as AsyncHTTPClient
10
10
  from typing_extensions import assert_never
11
11
 
12
- from .. import result
12
+ from .. import ModelHTTPError, result
13
13
  from .._utils import guard_tool_call_id as _guard_tool_call_id
14
14
  from ..messages import (
15
15
  ModelMessage,
@@ -45,6 +45,7 @@ try:
45
45
  ToolV2Function,
46
46
  UserChatMessageV2,
47
47
  )
48
+ from cohere.core.api_error import ApiError
48
49
  from cohere.v2.client import OMIT
49
50
  except ImportError as _import_error:
50
51
  raise ImportError(
@@ -154,17 +155,22 @@ class CohereModel(Model):
154
155
  ) -> ChatResponse:
155
156
  tools = self._get_tools(model_request_parameters)
156
157
  cohere_messages = list(chain(*(self._map_message(m) for m in messages)))
157
- return await self.client.chat(
158
- model=self._model_name,
159
- messages=cohere_messages,
160
- tools=tools or OMIT,
161
- max_tokens=model_settings.get('max_tokens', OMIT),
162
- temperature=model_settings.get('temperature', OMIT),
163
- p=model_settings.get('top_p', OMIT),
164
- seed=model_settings.get('seed', OMIT),
165
- presence_penalty=model_settings.get('presence_penalty', OMIT),
166
- frequency_penalty=model_settings.get('frequency_penalty', OMIT),
167
- )
158
+ try:
159
+ return await self.client.chat(
160
+ model=self._model_name,
161
+ messages=cohere_messages,
162
+ tools=tools or OMIT,
163
+ max_tokens=model_settings.get('max_tokens', OMIT),
164
+ temperature=model_settings.get('temperature', OMIT),
165
+ p=model_settings.get('top_p', OMIT),
166
+ seed=model_settings.get('seed', OMIT),
167
+ presence_penalty=model_settings.get('presence_penalty', OMIT),
168
+ frequency_penalty=model_settings.get('frequency_penalty', OMIT),
169
+ )
170
+ except ApiError as e:
171
+ if (status_code := e.status_code) and status_code >= 400:
172
+ raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e
173
+ raise
168
174
 
169
175
  def _process_response(self, response: ChatResponse) -> ModelResponse:
170
176
  """Process a non-streamed response, and prepare a message to return."""
@@ -0,0 +1,116 @@
1
+ from __future__ import annotations as _annotations
2
+
3
+ from collections.abc import AsyncIterator
4
+ from contextlib import AsyncExitStack, asynccontextmanager
5
+ from dataclasses import dataclass, field
6
+ from typing import TYPE_CHECKING, Callable
7
+
8
+ from ..exceptions import FallbackExceptionGroup, ModelHTTPError
9
+ from . import KnownModelName, Model, ModelRequestParameters, StreamedResponse, infer_model
10
+
11
+ if TYPE_CHECKING:
12
+ from ..messages import ModelMessage, ModelResponse
13
+ from ..settings import ModelSettings
14
+ from ..usage import Usage
15
+
16
+
17
+ @dataclass(init=False)
18
+ class FallbackModel(Model):
19
+ """A model that uses one or more fallback models upon failure.
20
+
21
+ Apart from `__init__`, all methods are private or match those of the base class.
22
+ """
23
+
24
+ models: list[Model]
25
+
26
+ _model_name: str = field(repr=False)
27
+ _fallback_on: Callable[[Exception], bool]
28
+
29
+ def __init__(
30
+ self,
31
+ default_model: Model | KnownModelName,
32
+ *fallback_models: Model | KnownModelName,
33
+ fallback_on: Callable[[Exception], bool] | tuple[type[Exception], ...] = (ModelHTTPError,),
34
+ ):
35
+ """Initialize a fallback model instance.
36
+
37
+ Args:
38
+ default_model: The name or instance of the default model to use.
39
+ fallback_models: The names or instances of the fallback models to use upon failure.
40
+ fallback_on: A callable or tuple of exceptions that should trigger a fallback.
41
+ """
42
+ self.models = [infer_model(default_model), *[infer_model(m) for m in fallback_models]]
43
+ self._model_name = f'FallBackModel[{", ".join(model.model_name for model in self.models)}]'
44
+
45
+ if isinstance(fallback_on, tuple):
46
+ self._fallback_on = _default_fallback_condition_factory(fallback_on)
47
+ else:
48
+ self._fallback_on = fallback_on
49
+
50
+ async def request(
51
+ self,
52
+ messages: list[ModelMessage],
53
+ model_settings: ModelSettings | None,
54
+ model_request_parameters: ModelRequestParameters,
55
+ ) -> tuple[ModelResponse, Usage]:
56
+ """Try each model in sequence until one succeeds.
57
+
58
+ In case of failure, raise a FallbackExceptionGroup with all exceptions.
59
+ """
60
+ exceptions: list[Exception] = []
61
+
62
+ for model in self.models:
63
+ try:
64
+ return await model.request(messages, model_settings, model_request_parameters)
65
+ except Exception as exc:
66
+ if self._fallback_on(exc):
67
+ exceptions.append(exc)
68
+ continue
69
+ raise exc
70
+
71
+ raise FallbackExceptionGroup('All models from FallbackModel failed', exceptions)
72
+
73
+ @asynccontextmanager
74
+ async def request_stream(
75
+ self,
76
+ messages: list[ModelMessage],
77
+ model_settings: ModelSettings | None,
78
+ model_request_parameters: ModelRequestParameters,
79
+ ) -> AsyncIterator[StreamedResponse]:
80
+ """Try each model in sequence until one succeeds."""
81
+ exceptions: list[Exception] = []
82
+
83
+ for model in self.models:
84
+ async with AsyncExitStack() as stack:
85
+ try:
86
+ response = await stack.enter_async_context(
87
+ model.request_stream(messages, model_settings, model_request_parameters)
88
+ )
89
+ except Exception as exc:
90
+ if self._fallback_on(exc):
91
+ exceptions.append(exc)
92
+ continue
93
+ raise exc
94
+ yield response
95
+ return
96
+
97
+ raise FallbackExceptionGroup('All models from FallbackModel failed', exceptions)
98
+
99
+ @property
100
+ def model_name(self) -> str:
101
+ """The model name."""
102
+ return self._model_name
103
+
104
+ @property
105
+ def system(self) -> str | None:
106
+ """The system / model provider, n/a for fallback models."""
107
+ return None
108
+
109
+
110
+ def _default_fallback_condition_factory(exceptions: tuple[type[Exception], ...]) -> Callable[[Exception], bool]:
111
+ """Create a default fallback condition for the given exceptions."""
112
+
113
+ def fallback_condition(exception: Exception) -> bool:
114
+ return isinstance(exception, exceptions)
115
+
116
+ return fallback_condition
@@ -48,15 +48,23 @@ class FunctionModel(Model):
48
48
  _system: str | None = field(default=None, repr=False)
49
49
 
50
50
  @overload
51
- def __init__(self, function: FunctionDef) -> None: ...
51
+ def __init__(self, function: FunctionDef, *, model_name: str | None = None) -> None: ...
52
52
 
53
53
  @overload
54
- def __init__(self, *, stream_function: StreamFunctionDef) -> None: ...
54
+ def __init__(self, *, stream_function: StreamFunctionDef, model_name: str | None = None) -> None: ...
55
55
 
56
56
  @overload
57
- def __init__(self, function: FunctionDef, *, stream_function: StreamFunctionDef) -> None: ...
57
+ def __init__(
58
+ self, function: FunctionDef, *, stream_function: StreamFunctionDef, model_name: str | None = None
59
+ ) -> None: ...
58
60
 
59
- def __init__(self, function: FunctionDef | None = None, *, stream_function: StreamFunctionDef | None = None):
61
+ def __init__(
62
+ self,
63
+ function: FunctionDef | None = None,
64
+ *,
65
+ stream_function: StreamFunctionDef | None = None,
66
+ model_name: str | None = None,
67
+ ):
60
68
  """Initialize a `FunctionModel`.
61
69
 
62
70
  Either `function` or `stream_function` must be provided, providing both is allowed.
@@ -64,6 +72,7 @@ class FunctionModel(Model):
64
72
  Args:
65
73
  function: The function to call for non-streamed requests.
66
74
  stream_function: The function to call for streamed requests.
75
+ model_name: The name of the model. If not provided, a name is generated from the function names.
67
76
  """
68
77
  if function is None and stream_function is None:
69
78
  raise TypeError('Either `function` or `stream_function` must be provided')
@@ -72,7 +81,7 @@ class FunctionModel(Model):
72
81
 
73
82
  function_name = self.function.__name__ if self.function is not None else ''
74
83
  stream_function_name = self.stream_function.__name__ if self.stream_function is not None else ''
75
- self._model_name = f'function:{function_name}:{stream_function_name}'
84
+ self._model_name = model_name or f'function:{function_name}:{stream_function_name}'
76
85
 
77
86
  async def request(
78
87
  self,
@@ -95,7 +104,7 @@ class FunctionModel(Model):
95
104
  response_ = await _utils.run_in_executor(self.function, messages, agent_info)
96
105
  assert isinstance(response_, ModelResponse), response_
97
106
  response = response_
98
- response.model_name = f'function:{self.function.__name__}'
107
+ response.model_name = self._model_name
99
108
  # TODO is `messages` right here? Should it just be new messages?
100
109
  return response, _estimate_usage(chain(messages, [response]))
101
110
 
@@ -123,7 +132,7 @@ class FunctionModel(Model):
123
132
  if isinstance(first, _utils.Unset):
124
133
  raise ValueError('Stream function must return at least one item')
125
134
 
126
- yield FunctionStreamedResponse(_model_name=f'function:{self.stream_function.__name__}', _iter=response_stream)
135
+ yield FunctionStreamedResponse(_model_name=self._model_name, _iter=response_stream)
127
136
 
128
137
  @property
129
138
  def model_name(self) -> str:
@@ -15,7 +15,7 @@ import pydantic
15
15
  from httpx import USE_CLIENT_DEFAULT, AsyncClient as AsyncHTTPClient, Response as HTTPResponse
16
16
  from typing_extensions import NotRequired, TypedDict, assert_never
17
17
 
18
- from .. import UnexpectedModelBehavior, _utils, exceptions, usage
18
+ from .. import ModelHTTPError, UnexpectedModelBehavior, UserError, _utils, usage
19
19
  from ..messages import (
20
20
  AudioUrl,
21
21
  BinaryContent,
@@ -112,7 +112,7 @@ class GeminiModel(Model):
112
112
  if env_api_key := os.getenv('GEMINI_API_KEY'):
113
113
  api_key = env_api_key
114
114
  else:
115
- raise exceptions.UserError('API key must be provided or set in the GEMINI_API_KEY environment variable')
115
+ raise UserError('API key must be provided or set in the GEMINI_API_KEY environment variable')
116
116
  self.http_client = http_client or cached_async_http_client()
117
117
  self._auth = ApiKeyAuth(api_key)
118
118
  self._url = url_template.format(model=model_name)
@@ -233,9 +233,11 @@ class GeminiModel(Model):
233
233
  headers=headers,
234
234
  timeout=model_settings.get('timeout', USE_CLIENT_DEFAULT),
235
235
  ) as r:
236
- if r.status_code != 200:
236
+ if (status_code := r.status_code) != 200:
237
237
  await r.aread()
238
- raise exceptions.UnexpectedModelBehavior(f'Unexpected response from gemini {r.status_code}', r.text)
238
+ if status_code >= 400:
239
+ raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=r.text)
240
+ raise UnexpectedModelBehavior(f'Unexpected response from gemini {status_code}', r.text)
239
241
  yield r
240
242
 
241
243
  def _process_response(self, response: _GeminiResponse) -> ModelResponse:
@@ -318,10 +320,14 @@ class GeminiModel(Model):
318
320
  content.append({'text': item})
319
321
  elif isinstance(item, BinaryContent):
320
322
  base64_encoded = base64.b64encode(item.data).decode('utf-8')
321
- content.append(_GeminiInlineDataPart(data=base64_encoded, mime_type=item.media_type))
323
+ content.append(
324
+ _GeminiInlineDataPart(inline_data={'data': base64_encoded, 'mime_type': item.media_type})
325
+ )
322
326
  elif isinstance(item, (AudioUrl, ImageUrl)):
323
327
  try:
324
- content.append(_GeminiFileDataData(file_uri=item.url, mime_type=item.media_type))
328
+ content.append(
329
+ _GeminiFileDataPart(file_data={'file_uri': item.url, 'mime_type': item.media_type})
330
+ )
325
331
  except ValueError:
326
332
  # Download the file if can't find the mime type.
327
333
  client = cached_async_http_client()
@@ -329,7 +335,9 @@ class GeminiModel(Model):
329
335
  response.raise_for_status()
330
336
  base64_encoded = base64.b64encode(response.content).decode('utf-8')
331
337
  content.append(
332
- _GeminiInlineDataPart(data=base64_encoded, mime_type=response.headers['Content-Type'])
338
+ _GeminiInlineDataPart(
339
+ inline_data={'data': base64_encoded, 'mime_type': response.headers['Content-Type']}
340
+ )
333
341
  )
334
342
  else:
335
343
  assert_never(item)
@@ -526,20 +534,28 @@ class _GeminiTextPart(TypedDict):
526
534
  text: str
527
535
 
528
536
 
537
+ class _GeminiInlineData(TypedDict):
538
+ data: str
539
+ mime_type: Annotated[str, pydantic.Field(alias='mimeType')]
540
+
541
+
529
542
  class _GeminiInlineDataPart(TypedDict):
530
543
  """See <https://ai.google.dev/api/caching#Blob>."""
531
544
 
532
- data: str
533
- mime_type: Annotated[str, pydantic.Field(alias='mimeType')]
545
+ inline_data: Annotated[_GeminiInlineData, pydantic.Field(alias='inlineData')]
534
546
 
535
547
 
536
- class _GeminiFileDataData(TypedDict):
548
+ class _GeminiFileData(TypedDict):
537
549
  """See <https://ai.google.dev/api/caching#FileData>."""
538
550
 
539
551
  file_uri: Annotated[str, pydantic.Field(alias='fileUri')]
540
552
  mime_type: Annotated[str, pydantic.Field(alias='mimeType')]
541
553
 
542
554
 
555
+ class _GeminiFileDataPart(TypedDict):
556
+ file_data: Annotated[_GeminiFileData, pydantic.Field(alias='fileData')]
557
+
558
+
543
559
  class _GeminiFunctionCallPart(TypedDict):
544
560
  function_call: Annotated[_GeminiFunctionCall, pydantic.Field(alias='functionCall')]
545
561
 
@@ -563,7 +579,7 @@ def _process_response_from_parts(
563
579
  )
564
580
  )
565
581
  elif 'function_response' in part:
566
- raise exceptions.UnexpectedModelBehavior(
582
+ raise UnexpectedModelBehavior(
567
583
  f'Unsupported response from Gemini, expected all parts to be function calls or text, got: {part!r}'
568
584
  )
569
585
  return ModelResponse(parts=items, model_name=model_name, timestamp=timestamp or _utils.now_utc())
@@ -615,7 +631,7 @@ _GeminiPartUnion = Annotated[
615
631
  Annotated[_GeminiFunctionCallPart, pydantic.Tag('function_call')],
616
632
  Annotated[_GeminiFunctionResponsePart, pydantic.Tag('function_response')],
617
633
  Annotated[_GeminiInlineDataPart, pydantic.Tag('inline_data')],
618
- Annotated[_GeminiFileDataData, pydantic.Tag('file_data')],
634
+ Annotated[_GeminiFileDataPart, pydantic.Tag('file_data')],
619
635
  ],
620
636
  pydantic.Discriminator(_part_discriminator),
621
637
  ]
@@ -778,7 +794,7 @@ class _GeminiJsonSchema:
778
794
  # noinspection PyTypeChecker
779
795
  key = re.sub(r'^#/\$defs/', '', ref)
780
796
  if key in refs_stack:
781
- raise exceptions.UserError('Recursive `$ref`s in JSON Schema are not supported by Gemini')
797
+ raise UserError('Recursive `$ref`s in JSON Schema are not supported by Gemini')
782
798
  refs_stack += (key,)
783
799
  schema_def = self.defs[key]
784
800
  self._simplify(schema_def, refs_stack)
@@ -812,7 +828,7 @@ class _GeminiJsonSchema:
812
828
  def _object(self, schema: dict[str, Any], refs_stack: tuple[str, ...]) -> None:
813
829
  ad_props = schema.pop('additionalProperties', None)
814
830
  if ad_props:
815
- raise exceptions.UserError('Additional properties in JSON Schema are not supported by Gemini')
831
+ raise UserError('Additional properties in JSON Schema are not supported by Gemini')
816
832
 
817
833
  if properties := schema.get('properties'): # pragma: no branch
818
834
  for value in properties.values():
@@ -11,7 +11,7 @@ from typing import Literal, Union, cast, overload
11
11
  from httpx import AsyncClient as AsyncHTTPClient
12
12
  from typing_extensions import assert_never
13
13
 
14
- from .. import UnexpectedModelBehavior, _utils, usage
14
+ from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage
15
15
  from .._utils import guard_tool_call_id as _guard_tool_call_id
16
16
  from ..messages import (
17
17
  BinaryContent,
@@ -39,7 +39,7 @@ from . import (
39
39
  )
40
40
 
41
41
  try:
42
- from groq import NOT_GIVEN, AsyncGroq, AsyncStream
42
+ from groq import NOT_GIVEN, APIStatusError, AsyncGroq, AsyncStream
43
43
  from groq.types import chat
44
44
  from groq.types.chat.chat_completion_content_part_image_param import ImageURL
45
45
  except ImportError as _import_error:
@@ -197,23 +197,28 @@ class GroqModel(Model):
197
197
 
198
198
  groq_messages = list(chain(*(self._map_message(m) for m in messages)))
199
199
 
200
- return await self.client.chat.completions.create(
201
- model=str(self._model_name),
202
- messages=groq_messages,
203
- n=1,
204
- parallel_tool_calls=model_settings.get('parallel_tool_calls', NOT_GIVEN),
205
- tools=tools or NOT_GIVEN,
206
- tool_choice=tool_choice or NOT_GIVEN,
207
- stream=stream,
208
- max_tokens=model_settings.get('max_tokens', NOT_GIVEN),
209
- temperature=model_settings.get('temperature', NOT_GIVEN),
210
- top_p=model_settings.get('top_p', NOT_GIVEN),
211
- timeout=model_settings.get('timeout', NOT_GIVEN),
212
- seed=model_settings.get('seed', NOT_GIVEN),
213
- presence_penalty=model_settings.get('presence_penalty', NOT_GIVEN),
214
- frequency_penalty=model_settings.get('frequency_penalty', NOT_GIVEN),
215
- logit_bias=model_settings.get('logit_bias', NOT_GIVEN),
216
- )
200
+ try:
201
+ return await self.client.chat.completions.create(
202
+ model=str(self._model_name),
203
+ messages=groq_messages,
204
+ n=1,
205
+ parallel_tool_calls=model_settings.get('parallel_tool_calls', NOT_GIVEN),
206
+ tools=tools or NOT_GIVEN,
207
+ tool_choice=tool_choice or NOT_GIVEN,
208
+ stream=stream,
209
+ max_tokens=model_settings.get('max_tokens', NOT_GIVEN),
210
+ temperature=model_settings.get('temperature', NOT_GIVEN),
211
+ top_p=model_settings.get('top_p', NOT_GIVEN),
212
+ timeout=model_settings.get('timeout', NOT_GIVEN),
213
+ seed=model_settings.get('seed', NOT_GIVEN),
214
+ presence_penalty=model_settings.get('presence_penalty', NOT_GIVEN),
215
+ frequency_penalty=model_settings.get('frequency_penalty', NOT_GIVEN),
216
+ logit_bias=model_settings.get('logit_bias', NOT_GIVEN),
217
+ )
218
+ except APIStatusError as e:
219
+ if (status_code := e.status_code) >= 400:
220
+ raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e
221
+ raise
217
222
 
218
223
  def _process_response(self, response: chat.ChatCompletion) -> ModelResponse:
219
224
  """Process a non-streamed response, and prepare a message to return."""
@@ -13,7 +13,7 @@ import pydantic_core
13
13
  from httpx import AsyncClient as AsyncHTTPClient, Timeout
14
14
  from typing_extensions import assert_never
15
15
 
16
- from .. import UnexpectedModelBehavior, _utils
16
+ from .. import ModelHTTPError, UnexpectedModelBehavior, _utils
17
17
  from .._utils import now_utc as _now_utc
18
18
  from ..messages import (
19
19
  BinaryContent,
@@ -59,6 +59,7 @@ try:
59
59
  ChatCompletionResponse as MistralChatCompletionResponse,
60
60
  CompletionEvent as MistralCompletionEvent,
61
61
  Messages as MistralMessages,
62
+ SDKError,
62
63
  Tool as MistralTool,
63
64
  ToolCall as MistralToolCall,
64
65
  )
@@ -184,19 +185,25 @@ class MistralModel(Model):
184
185
  model_request_parameters: ModelRequestParameters,
185
186
  ) -> MistralChatCompletionResponse:
186
187
  """Make a non-streaming request to the model."""
187
- response = await self.client.chat.complete_async(
188
- model=str(self._model_name),
189
- messages=list(chain(*(self._map_message(m) for m in messages))),
190
- n=1,
191
- tools=self._map_function_and_result_tools_definition(model_request_parameters) or UNSET,
192
- tool_choice=self._get_tool_choice(model_request_parameters),
193
- stream=False,
194
- max_tokens=model_settings.get('max_tokens', UNSET),
195
- temperature=model_settings.get('temperature', UNSET),
196
- top_p=model_settings.get('top_p', 1),
197
- timeout_ms=self._get_timeout_ms(model_settings.get('timeout')),
198
- random_seed=model_settings.get('seed', UNSET),
199
- )
188
+ try:
189
+ response = await self.client.chat.complete_async(
190
+ model=str(self._model_name),
191
+ messages=list(chain(*(self._map_message(m) for m in messages))),
192
+ n=1,
193
+ tools=self._map_function_and_result_tools_definition(model_request_parameters) or UNSET,
194
+ tool_choice=self._get_tool_choice(model_request_parameters),
195
+ stream=False,
196
+ max_tokens=model_settings.get('max_tokens', UNSET),
197
+ temperature=model_settings.get('temperature', UNSET),
198
+ top_p=model_settings.get('top_p', 1),
199
+ timeout_ms=self._get_timeout_ms(model_settings.get('timeout')),
200
+ random_seed=model_settings.get('seed', UNSET),
201
+ )
202
+ except SDKError as e:
203
+ if (status_code := e.status_code) >= 400:
204
+ raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e
205
+ raise
206
+
200
207
  assert response, 'A unexpected empty response from Mistral.'
201
208
  return response
202
209
 
@@ -11,7 +11,7 @@ from typing import Literal, Union, cast, overload
11
11
  from httpx import AsyncClient as AsyncHTTPClient
12
12
  from typing_extensions import assert_never
13
13
 
14
- from .. import UnexpectedModelBehavior, _utils, usage
14
+ from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage
15
15
  from .._utils import guard_tool_call_id as _guard_tool_call_id
16
16
  from ..messages import (
17
17
  AudioUrl,
@@ -40,7 +40,7 @@ from . import (
40
40
  )
41
41
 
42
42
  try:
43
- from openai import NOT_GIVEN, AsyncOpenAI, AsyncStream
43
+ from openai import NOT_GIVEN, APIStatusError, AsyncOpenAI, AsyncStream
44
44
  from openai.types import ChatModel, chat
45
45
  from openai.types.chat import (
46
46
  ChatCompletionChunk,
@@ -224,25 +224,30 @@ class OpenAIModel(Model):
224
224
  async for msg in self._map_message(m):
225
225
  openai_messages.append(msg)
226
226
 
227
- return await self.client.chat.completions.create(
228
- model=self._model_name,
229
- messages=openai_messages,
230
- n=1,
231
- parallel_tool_calls=model_settings.get('parallel_tool_calls', NOT_GIVEN),
232
- tools=tools or NOT_GIVEN,
233
- tool_choice=tool_choice or NOT_GIVEN,
234
- stream=stream,
235
- stream_options={'include_usage': True} if stream else NOT_GIVEN,
236
- max_tokens=model_settings.get('max_tokens', NOT_GIVEN),
237
- temperature=model_settings.get('temperature', NOT_GIVEN),
238
- top_p=model_settings.get('top_p', NOT_GIVEN),
239
- timeout=model_settings.get('timeout', NOT_GIVEN),
240
- seed=model_settings.get('seed', NOT_GIVEN),
241
- presence_penalty=model_settings.get('presence_penalty', NOT_GIVEN),
242
- frequency_penalty=model_settings.get('frequency_penalty', NOT_GIVEN),
243
- logit_bias=model_settings.get('logit_bias', NOT_GIVEN),
244
- reasoning_effort=model_settings.get('openai_reasoning_effort', NOT_GIVEN),
245
- )
227
+ try:
228
+ return await self.client.chat.completions.create(
229
+ model=self._model_name,
230
+ messages=openai_messages,
231
+ n=1,
232
+ parallel_tool_calls=model_settings.get('parallel_tool_calls', NOT_GIVEN),
233
+ tools=tools or NOT_GIVEN,
234
+ tool_choice=tool_choice or NOT_GIVEN,
235
+ stream=stream,
236
+ stream_options={'include_usage': True} if stream else NOT_GIVEN,
237
+ max_tokens=model_settings.get('max_tokens', NOT_GIVEN),
238
+ temperature=model_settings.get('temperature', NOT_GIVEN),
239
+ top_p=model_settings.get('top_p', NOT_GIVEN),
240
+ timeout=model_settings.get('timeout', NOT_GIVEN),
241
+ seed=model_settings.get('seed', NOT_GIVEN),
242
+ presence_penalty=model_settings.get('presence_penalty', NOT_GIVEN),
243
+ frequency_penalty=model_settings.get('frequency_penalty', NOT_GIVEN),
244
+ logit_bias=model_settings.get('logit_bias', NOT_GIVEN),
245
+ reasoning_effort=model_settings.get('openai_reasoning_effort', NOT_GIVEN),
246
+ )
247
+ except APIStatusError as e:
248
+ if (status_code := e.status_code) >= 400:
249
+ raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e
250
+ raise
246
251
 
247
252
  def _process_response(self, response: chat.ChatCompletion) -> ModelResponse:
248
253
  """Process a non-streamed response, and prepare a message to return."""
@@ -4,11 +4,9 @@ build-backend = "hatchling.build"
4
4
 
5
5
  [project]
6
6
  name = "pydantic-ai-slim"
7
- version = "0.0.26"
7
+ version = "0.0.27"
8
8
  description = "Agent Framework / shim to use Pydantic with LLMs, slim package"
9
- authors = [
10
- { name = "Samuel Colvin", email = "samuel@pydantic.dev" },
11
- ]
9
+ authors = [{ name = "Samuel Colvin", email = "samuel@pydantic.dev" }]
12
10
  license = "MIT"
13
11
  readme = "README.md"
14
12
  classifiers = [
@@ -39,7 +37,8 @@ dependencies = [
39
37
  "httpx>=0.27",
40
38
  "logfire-api>=1.2.0",
41
39
  "pydantic>=2.10",
42
- "pydantic-graph==0.0.26",
40
+ "pydantic-graph==0.0.27",
41
+ "exceptiongroup; python_version < '3.11'",
43
42
  ]
44
43
 
45
44
  [project.optional-dependencies]