pydantic-ai-slim 0.0.12__py3-none-any.whl → 0.0.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

pydantic_ai/exceptions.py CHANGED
@@ -2,7 +2,7 @@ from __future__ import annotations as _annotations
2
2
 
3
3
  import json
4
4
 
5
- __all__ = 'ModelRetry', 'UserError', 'UnexpectedModelBehavior'
5
+ __all__ = 'ModelRetry', 'UserError', 'AgentRunError', 'UnexpectedModelBehavior', 'UsageLimitExceeded'
6
6
 
7
7
 
8
8
  class ModelRetry(Exception):
@@ -30,7 +30,25 @@ class UserError(RuntimeError):
30
30
  super().__init__(message)
31
31
 
32
32
 
33
- class UnexpectedModelBehavior(RuntimeError):
33
+ class AgentRunError(RuntimeError):
34
+ """Base class for errors occurring during an agent run."""
35
+
36
+ message: str
37
+ """The error message."""
38
+
39
+ def __init__(self, message: str):
40
+ self.message = message
41
+ super().__init__(message)
42
+
43
+ def __str__(self) -> str:
44
+ return self.message
45
+
46
+
47
+ class UsageLimitExceeded(AgentRunError):
48
+ """Error raised when a Model's usage exceeds the specified limits."""
49
+
50
+
51
+ class UnexpectedModelBehavior(AgentRunError):
34
52
  """Error caused by unexpected Model behavior, e.g. an unexpected response code."""
35
53
 
36
54
  message: str
pydantic_ai/messages.py CHANGED
@@ -2,18 +2,17 @@ from __future__ import annotations as _annotations
2
2
 
3
3
  from dataclasses import dataclass, field
4
4
  from datetime import datetime
5
- from typing import Annotated, Any, Literal, Union
5
+ from typing import Annotated, Any, Literal, Union, cast
6
6
 
7
7
  import pydantic
8
8
  import pydantic_core
9
- from pydantic import TypeAdapter
9
+ from typing_extensions import Self, assert_never
10
10
 
11
- from . import _pydantic
12
11
  from ._utils import now_utc as _now_utc
13
12
 
14
13
 
15
14
  @dataclass
16
- class SystemPrompt:
15
+ class SystemPromptPart:
17
16
  """A system prompt, generally written by the application developer.
18
17
 
19
18
  This gives the model context and guidance on how to respond.
@@ -21,12 +20,13 @@ class SystemPrompt:
21
20
 
22
21
  content: str
23
22
  """The content of the prompt."""
24
- role: Literal['system'] = 'system'
25
- """Message type identifier, this type is available on all message as a discriminator."""
23
+
24
+ part_kind: Literal['system-prompt'] = 'system-prompt'
25
+ """Part type identifier, this is available on all parts as a discriminator."""
26
26
 
27
27
 
28
28
  @dataclass
29
- class UserPrompt:
29
+ class UserPromptPart:
30
30
  """A user prompt, generally written by the end user.
31
31
 
32
32
  Content comes from the `user_prompt` parameter of [`Agent.run`][pydantic_ai.Agent.run],
@@ -35,29 +35,35 @@ class UserPrompt:
35
35
 
36
36
  content: str
37
37
  """The content of the prompt."""
38
+
38
39
  timestamp: datetime = field(default_factory=_now_utc)
39
40
  """The timestamp of the prompt."""
40
- role: Literal['user'] = 'user'
41
- """Message type identifier, this type is available on all message as a discriminator."""
41
+
42
+ part_kind: Literal['user-prompt'] = 'user-prompt'
43
+ """Part type identifier, this is available on all parts as a discriminator."""
42
44
 
43
45
 
44
- tool_return_ta: TypeAdapter[Any] = TypeAdapter(Any)
46
+ tool_return_ta: pydantic.TypeAdapter[Any] = pydantic.TypeAdapter(Any, config=pydantic.ConfigDict(defer_build=True))
45
47
 
46
48
 
47
49
  @dataclass
48
- class ToolReturn:
50
+ class ToolReturnPart:
49
51
  """A tool return message, this encodes the result of running a tool."""
50
52
 
51
53
  tool_name: str
52
54
  """The name of the "tool" was called."""
55
+
53
56
  content: Any
54
57
  """The return value."""
55
- tool_id: str | None = None
56
- """Optional tool identifier, this is used by some models including OpenAI."""
58
+
59
+ tool_call_id: str | None = None
60
+ """Optional tool call identifier, this is used by some models including OpenAI."""
61
+
57
62
  timestamp: datetime = field(default_factory=_now_utc)
58
63
  """The timestamp, when the tool returned."""
59
- role: Literal['tool-return'] = 'tool-return'
60
- """Message type identifier, this type is available on all message as a discriminator."""
64
+
65
+ part_kind: Literal['tool-return'] = 'tool-return'
66
+ """Part type identifier, this is available on all parts as a discriminator."""
61
67
 
62
68
  def model_response_str(self) -> str:
63
69
  if isinstance(self.content, str):
@@ -73,11 +79,11 @@ class ToolReturn:
73
79
  return {'return_value': tool_return_ta.dump_python(self.content, mode='json')}
74
80
 
75
81
 
76
- ErrorDetailsTa = _pydantic.LazyTypeAdapter(list[pydantic_core.ErrorDetails])
82
+ error_details_ta = pydantic.TypeAdapter(list[pydantic_core.ErrorDetails], config=pydantic.ConfigDict(defer_build=True))
77
83
 
78
84
 
79
85
  @dataclass
80
- class RetryPrompt:
86
+ class RetryPromptPart:
81
87
  """A message back to a model asking it to try again.
82
88
 
83
89
  This can be sent for a number of reasons:
@@ -98,37 +104,54 @@ class RetryPrompt:
98
104
  If the retry was triggered by a [`ValidationError`][pydantic_core.ValidationError], this will be a list of
99
105
  error details.
100
106
  """
107
+
101
108
  tool_name: str | None = None
102
109
  """The name of the tool that was called, if any."""
103
- tool_id: str | None = None
104
- """The tool identifier, if any."""
110
+
111
+ tool_call_id: str | None = None
112
+ """Optional tool call identifier, this is used by some models including OpenAI."""
113
+
105
114
  timestamp: datetime = field(default_factory=_now_utc)
106
115
  """The timestamp, when the retry was triggered."""
107
- role: Literal['retry-prompt'] = 'retry-prompt'
108
- """Message type identifier, this type is available on all message as a discriminator."""
116
+
117
+ part_kind: Literal['retry-prompt'] = 'retry-prompt'
118
+ """Part type identifier, this is available on all parts as a discriminator."""
109
119
 
110
120
  def model_response(self) -> str:
111
121
  if isinstance(self.content, str):
112
122
  description = self.content
113
123
  else:
114
- json_errors = ErrorDetailsTa.dump_json(self.content, exclude={'__all__': {'ctx'}}, indent=2)
124
+ json_errors = error_details_ta.dump_json(self.content, exclude={'__all__': {'ctx'}}, indent=2)
115
125
  description = f'{len(self.content)} validation errors: {json_errors.decode()}'
116
126
  return f'{description}\n\nFix the errors and try again.'
117
127
 
118
128
 
129
+ ModelRequestPart = Annotated[
130
+ Union[SystemPromptPart, UserPromptPart, ToolReturnPart, RetryPromptPart], pydantic.Discriminator('part_kind')
131
+ ]
132
+ """A message part sent by PydanticAI to a model."""
133
+
134
+
135
+ @dataclass
136
+ class ModelRequest:
137
+ """A request generated by PydanticAI and sent to a model, e.g. a message from the PydanticAI app to the model."""
138
+
139
+ parts: list[ModelRequestPart]
140
+ """The parts of the user message."""
141
+
142
+ kind: Literal['request'] = 'request'
143
+ """Message type identifier, this is available on all parts as a discriminator."""
144
+
145
+
119
146
  @dataclass
120
- class ModelTextResponse:
147
+ class TextPart:
121
148
  """A plain text response from a model."""
122
149
 
123
150
  content: str
124
151
  """The text content of the response."""
125
- timestamp: datetime = field(default_factory=_now_utc)
126
- """The timestamp of the response.
127
152
 
128
- If the model provides a timestamp in the response (as OpenAI does) that will be used.
129
- """
130
- role: Literal['model-text-response'] = 'model-text-response'
131
- """Message type identifier, this type is available on all message as a discriminator."""
153
+ part_kind: Literal['text'] = 'text'
154
+ """Part type identifier, this is available on all parts as a discriminator."""
132
155
 
133
156
 
134
157
  @dataclass
@@ -148,26 +171,53 @@ class ArgsDict:
148
171
 
149
172
 
150
173
  @dataclass
151
- class ToolCall:
152
- """Either a tool call from the agent."""
174
+ class ToolCallPart:
175
+ """A tool call from a model."""
153
176
 
154
177
  tool_name: str
155
178
  """The name of the tool to call."""
179
+
156
180
  args: ArgsJson | ArgsDict
157
181
  """The arguments to pass to the tool.
158
182
 
159
183
  Either as JSON or a Python dictionary depending on how data was returned.
160
184
  """
161
- tool_id: str | None = None
162
- """Optional tool identifier, this is used by some models including OpenAI."""
163
185
 
164
- @classmethod
165
- def from_json(cls, tool_name: str, args_json: str, tool_id: str | None = None) -> ToolCall:
166
- return cls(tool_name, ArgsJson(args_json), tool_id)
186
+ tool_call_id: str | None = None
187
+ """Optional tool call identifier, this is used by some models including OpenAI."""
188
+
189
+ part_kind: Literal['tool-call'] = 'tool-call'
190
+ """Part type identifier, this is available on all parts as a discriminator."""
167
191
 
168
192
  @classmethod
169
- def from_dict(cls, tool_name: str, args_dict: dict[str, Any]) -> ToolCall:
170
- return cls(tool_name, ArgsDict(args_dict))
193
+ def from_raw_args(cls, tool_name: str, args: str | dict[str, Any], tool_call_id: str | None = None) -> Self:
194
+ """Create a `ToolCallPart` from raw arguments."""
195
+ if isinstance(args, str):
196
+ return cls(tool_name, ArgsJson(args), tool_call_id)
197
+ elif isinstance(args, dict):
198
+ return cls(tool_name, ArgsDict(args), tool_call_id)
199
+ else:
200
+ assert_never(args)
201
+
202
+ def args_as_dict(self) -> dict[str, Any]:
203
+ """Return the arguments as a Python dictionary.
204
+
205
+ This is just for convenience with models that require dicts as input.
206
+ """
207
+ if isinstance(self.args, ArgsDict):
208
+ return self.args.args_dict
209
+ args = pydantic_core.from_json(self.args.args_json)
210
+ assert isinstance(args, dict), 'args should be a dict'
211
+ return cast(dict[str, Any], args)
212
+
213
+ def args_as_json_str(self) -> str:
214
+ """Return the arguments as a JSON string.
215
+
216
+ This is just for convenience with models that require JSON strings as input.
217
+ """
218
+ if isinstance(self.args, ArgsJson):
219
+ return self.args.args_json
220
+ return pydantic_core.to_json(self.args.args_dict).decode()
171
221
 
172
222
  def has_content(self) -> bool:
173
223
  if isinstance(self.args, ArgsDict):
@@ -176,28 +226,39 @@ class ToolCall:
176
226
  return bool(self.args.args_json)
177
227
 
178
228
 
229
+ ModelResponsePart = Annotated[Union[TextPart, ToolCallPart], pydantic.Discriminator('part_kind')]
230
+ """A message part returned by a model."""
231
+
232
+
179
233
  @dataclass
180
- class ModelStructuredResponse:
181
- """A structured response from a model.
234
+ class ModelResponse:
235
+ """A response from a model, e.g. a message from the model to the PydanticAI app."""
182
236
 
183
- This is used either to call a tool or to return a structured response from an agent run.
184
- """
237
+ parts: list[ModelResponsePart]
238
+ """The parts of the model message."""
185
239
 
186
- calls: list[ToolCall]
187
- """The tool calls being made."""
188
240
  timestamp: datetime = field(default_factory=_now_utc)
189
241
  """The timestamp of the response.
190
242
 
191
243
  If the model provides a timestamp in the response (as OpenAI does) that will be used.
192
244
  """
193
- role: Literal['model-structured-response'] = 'model-structured-response'
194
- """Message type identifier, this type is available on all message as a discriminator."""
245
+
246
+ kind: Literal['response'] = 'response'
247
+ """Message type identifier, this is available on all parts as a discriminator."""
248
+
249
+ @classmethod
250
+ def from_text(cls, content: str, timestamp: datetime | None = None) -> Self:
251
+ return cls([TextPart(content)], timestamp=timestamp or _now_utc())
252
+
253
+ @classmethod
254
+ def from_tool_call(cls, tool_call: ToolCallPart) -> Self:
255
+ return cls([tool_call])
195
256
 
196
257
 
197
- ModelAnyResponse = Union[ModelTextResponse, ModelStructuredResponse]
198
- """Any response from a model."""
199
- Message = Union[SystemPrompt, UserPrompt, ToolReturn, RetryPrompt, ModelTextResponse, ModelStructuredResponse]
258
+ ModelMessage = Union[ModelRequest, ModelResponse]
200
259
  """Any message send to or returned by a model."""
201
260
 
202
- MessagesTypeAdapter = _pydantic.LazyTypeAdapter(list[Annotated[Message, pydantic.Field(discriminator='role')]])
261
+ ModelMessagesTypeAdapter = pydantic.TypeAdapter(
262
+ list[Annotated[ModelMessage, pydantic.Discriminator('kind')]], config=pydantic.ConfigDict(defer_build=True)
263
+ )
203
264
  """Pydantic [`TypeAdapter`][pydantic.type_adapter.TypeAdapter] for (de)serializing messages."""
@@ -16,10 +16,11 @@ from typing import TYPE_CHECKING, Literal, Union
16
16
  import httpx
17
17
 
18
18
  from ..exceptions import UserError
19
- from ..messages import Message, ModelAnyResponse, ModelStructuredResponse
19
+ from ..messages import ModelMessage, ModelResponse
20
+ from ..settings import ModelSettings
20
21
 
21
22
  if TYPE_CHECKING:
22
- from ..result import Cost
23
+ from ..result import Usage
23
24
  from ..tools import ToolDefinition
24
25
 
25
26
 
@@ -30,7 +31,9 @@ KnownModelName = Literal[
30
31
  'openai:gpt-4',
31
32
  'openai:o1-preview',
32
33
  'openai:o1-mini',
34
+ 'openai:o1',
33
35
  'openai:gpt-3.5-turbo',
36
+ 'groq:llama-3.3-70b-versatile',
34
37
  'groq:llama-3.1-70b-versatile',
35
38
  'groq:llama3-groq-70b-8192-tool-use-preview',
36
39
  'groq:llama3-groq-8b-8192-tool-use-preview',
@@ -47,8 +50,15 @@ KnownModelName = Literal[
47
50
  'groq:gemma-7b-it',
48
51
  'gemini-1.5-flash',
49
52
  'gemini-1.5-pro',
53
+ 'gemini-2.0-flash-exp',
50
54
  'vertexai:gemini-1.5-flash',
51
55
  'vertexai:gemini-1.5-pro',
56
+ # since mistral models are supported by other providers (e.g. ollama), and some of their models (e.g. "codestral")
57
+ # don't start with "mistral", we add the "mistral:" prefix to all to be explicit
58
+ 'mistral:mistral-small-latest',
59
+ 'mistral:mistral-large-latest',
60
+ 'mistral:codestral-latest',
61
+ 'mistral:mistral-moderation-latest',
52
62
  'ollama:codellama',
53
63
  'ollama:gemma',
54
64
  'ollama:gemma2',
@@ -66,6 +76,9 @@ KnownModelName = Literal[
66
76
  'ollama:qwen2',
67
77
  'ollama:qwen2.5',
68
78
  'ollama:starcoder2',
79
+ 'claude-3-5-haiku-latest',
80
+ 'claude-3-5-sonnet-latest',
81
+ 'claude-3-opus-latest',
69
82
  'test',
70
83
  ]
71
84
  """Known model names that can be used with the `model` parameter of [`Agent`][pydantic_ai.Agent].
@@ -108,12 +121,16 @@ class AgentModel(ABC):
108
121
  """Model configured for each step of an Agent run."""
109
122
 
110
123
  @abstractmethod
111
- async def request(self, messages: list[Message]) -> tuple[ModelAnyResponse, Cost]:
124
+ async def request(
125
+ self, messages: list[ModelMessage], model_settings: ModelSettings | None
126
+ ) -> tuple[ModelResponse, Usage]:
112
127
  """Make a request to the model."""
113
128
  raise NotImplementedError()
114
129
 
115
130
  @asynccontextmanager
116
- async def request_stream(self, messages: list[Message]) -> AsyncIterator[EitherStreamedResponse]:
131
+ async def request_stream(
132
+ self, messages: list[ModelMessage], model_settings: ModelSettings | None
133
+ ) -> AsyncIterator[EitherStreamedResponse]:
117
134
  """Make a request to the model and return a streaming response."""
118
135
  raise NotImplementedError(f'Streamed requests not supported by this {self.__class__.__name__}')
119
136
  # yield is required to make this a generator for type checking
@@ -148,10 +165,10 @@ class StreamTextResponse(ABC):
148
165
  raise NotImplementedError()
149
166
 
150
167
  @abstractmethod
151
- def cost(self) -> Cost:
152
- """Return the cost of the request.
168
+ def usage(self) -> Usage:
169
+ """Return the usage of the request.
153
170
 
154
- NOTE: this won't return the ful cost until the stream is finished.
171
+ NOTE: this won't return the full usage until the stream is finished.
155
172
  """
156
173
  raise NotImplementedError()
157
174
 
@@ -178,10 +195,10 @@ class StreamStructuredResponse(ABC):
178
195
  raise NotImplementedError()
179
196
 
180
197
  @abstractmethod
181
- def get(self, *, final: bool = False) -> ModelStructuredResponse:
182
- """Get the `ModelStructuredResponse` at this point.
198
+ def get(self, *, final: bool = False) -> ModelResponse:
199
+ """Get the `ModelResponse` at this point.
183
200
 
184
- The `ModelStructuredResponse` may or may not be complete, depending on whether the stream is finished.
201
+ The `ModelResponse` may or may not be complete, depending on whether the stream is finished.
185
202
 
186
203
  Args:
187
204
  final: If True, this is the final call, after iteration is complete, the response should be fully validated.
@@ -189,10 +206,10 @@ class StreamStructuredResponse(ABC):
189
206
  raise NotImplementedError()
190
207
 
191
208
  @abstractmethod
192
- def cost(self) -> Cost:
193
- """Get the cost of the request.
209
+ def usage(self) -> Usage:
210
+ """Get the usage of the request.
194
211
 
195
- NOTE: this won't return the full cost until the stream is finished.
212
+ NOTE: this won't return the full usage until the stream is finished.
196
213
  """
197
214
  raise NotImplementedError()
198
215
 
@@ -219,7 +236,7 @@ The testing models [`TestModel`][pydantic_ai.models.test.TestModel] and
219
236
  def check_allow_model_requests() -> None:
220
237
  """Check if model requests are allowed.
221
238
 
222
- If you're defining your own models that have cost or latency associated with their use, you should call this in
239
+ If you're defining your own models that have costs or latency associated with their use, you should call this in
223
240
  [`Model.agent_model`][pydantic_ai.models.Model.agent_model].
224
241
 
225
242
  Raises:
@@ -270,10 +287,18 @@ def infer_model(model: Model | KnownModelName) -> Model:
270
287
  from .vertexai import VertexAIModel
271
288
 
272
289
  return VertexAIModel(model[9:]) # pyright: ignore[reportArgumentType]
290
+ elif model.startswith('mistral:'):
291
+ from .mistral import MistralModel
292
+
293
+ return MistralModel(model[8:])
273
294
  elif model.startswith('ollama:'):
274
295
  from .ollama import OllamaModel
275
296
 
276
297
  return OllamaModel(model[7:])
298
+ elif model.startswith('claude'):
299
+ from .anthropic import AnthropicModel
300
+
301
+ return AnthropicModel(model)
277
302
  else:
278
303
  raise UserError(f'Unknown model: {model}')
279
304