pydantic-ai-slim 0.0.11__py3-none-any.whl → 0.0.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

pydantic_ai/messages.py CHANGED
@@ -1,20 +1,18 @@
1
1
  from __future__ import annotations as _annotations
2
2
 
3
- import json
4
3
  from dataclasses import dataclass, field
5
4
  from datetime import datetime
6
5
  from typing import Annotated, Any, Literal, Union
7
6
 
8
7
  import pydantic
9
8
  import pydantic_core
10
- from pydantic import TypeAdapter
9
+ from typing_extensions import Self
11
10
 
12
- from . import _pydantic
13
11
  from ._utils import now_utc as _now_utc
14
12
 
15
13
 
16
14
  @dataclass
17
- class SystemPrompt:
15
+ class SystemPromptPart:
18
16
  """A system prompt, generally written by the application developer.
19
17
 
20
18
  This gives the model context and guidance on how to respond.
@@ -22,12 +20,13 @@ class SystemPrompt:
22
20
 
23
21
  content: str
24
22
  """The content of the prompt."""
25
- role: Literal['system'] = 'system'
26
- """Message type identifier, this type is available on all message as a discriminator."""
23
+
24
+ part_kind: Literal['system-prompt'] = 'system-prompt'
25
+ """Part type identifier, this is available on all parts as a discriminator."""
27
26
 
28
27
 
29
28
  @dataclass
30
- class UserPrompt:
29
+ class UserPromptPart:
31
30
  """A user prompt, generally written by the end user.
32
31
 
33
32
  Content comes from the `user_prompt` parameter of [`Agent.run`][pydantic_ai.Agent.run],
@@ -36,29 +35,35 @@ class UserPrompt:
36
35
 
37
36
  content: str
38
37
  """The content of the prompt."""
38
+
39
39
  timestamp: datetime = field(default_factory=_now_utc)
40
40
  """The timestamp of the prompt."""
41
- role: Literal['user'] = 'user'
42
- """Message type identifier, this type is available on all message as a discriminator."""
41
+
42
+ part_kind: Literal['user-prompt'] = 'user-prompt'
43
+ """Part type identifier, this is available on all parts as a discriminator."""
43
44
 
44
45
 
45
- tool_return_ta: TypeAdapter[Any] = TypeAdapter(Any)
46
+ tool_return_ta: pydantic.TypeAdapter[Any] = pydantic.TypeAdapter(Any, config=pydantic.ConfigDict(defer_build=True))
46
47
 
47
48
 
48
49
  @dataclass
49
- class ToolReturn:
50
+ class ToolReturnPart:
50
51
  """A tool return message, this encodes the result of running a tool."""
51
52
 
52
53
  tool_name: str
53
54
  """The name of the "tool" was called."""
55
+
54
56
  content: Any
55
57
  """The return value."""
56
- tool_id: str | None = None
57
- """Optional tool identifier, this is used by some models including OpenAI."""
58
+
59
+ tool_call_id: str | None = None
60
+ """Optional tool call identifier, this is used by some models including OpenAI."""
61
+
58
62
  timestamp: datetime = field(default_factory=_now_utc)
59
63
  """The timestamp, when the tool returned."""
60
- role: Literal['tool-return'] = 'tool-return'
61
- """Message type identifier, this type is available on all message as a discriminator."""
64
+
65
+ part_kind: Literal['tool-return'] = 'tool-return'
66
+ """Part type identifier, this is available on all parts as a discriminator."""
62
67
 
63
68
  def model_response_str(self) -> str:
64
69
  if isinstance(self.content, str):
@@ -74,8 +79,11 @@ class ToolReturn:
74
79
  return {'return_value': tool_return_ta.dump_python(self.content, mode='json')}
75
80
 
76
81
 
82
+ error_details_ta = pydantic.TypeAdapter(list[pydantic_core.ErrorDetails], config=pydantic.ConfigDict(defer_build=True))
83
+
84
+
77
85
  @dataclass
78
- class RetryPrompt:
86
+ class RetryPromptPart:
79
87
  """A message back to a model asking it to try again.
80
88
 
81
89
  This can be sent for a number of reasons:
@@ -96,36 +104,54 @@ class RetryPrompt:
96
104
  If the retry was triggered by a [`ValidationError`][pydantic_core.ValidationError], this will be a list of
97
105
  error details.
98
106
  """
107
+
99
108
  tool_name: str | None = None
100
109
  """The name of the tool that was called, if any."""
101
- tool_id: str | None = None
102
- """The tool identifier, if any."""
110
+
111
+ tool_call_id: str | None = None
112
+ """Optional tool call identifier, this is used by some models including OpenAI."""
113
+
103
114
  timestamp: datetime = field(default_factory=_now_utc)
104
115
  """The timestamp, when the retry was triggered."""
105
- role: Literal['retry-prompt'] = 'retry-prompt'
106
- """Message type identifier, this type is available on all message as a discriminator."""
116
+
117
+ part_kind: Literal['retry-prompt'] = 'retry-prompt'
118
+ """Part type identifier, this is available on all parts as a discriminator."""
107
119
 
108
120
  def model_response(self) -> str:
109
121
  if isinstance(self.content, str):
110
122
  description = self.content
111
123
  else:
112
- description = f'{len(self.content)} validation errors: {json.dumps(self.content, indent=2)}'
124
+ json_errors = error_details_ta.dump_json(self.content, exclude={'__all__': {'ctx'}}, indent=2)
125
+ description = f'{len(self.content)} validation errors: {json_errors.decode()}'
113
126
  return f'{description}\n\nFix the errors and try again.'
114
127
 
115
128
 
129
+ ModelRequestPart = Annotated[
130
+ Union[SystemPromptPart, UserPromptPart, ToolReturnPart, RetryPromptPart], pydantic.Discriminator('part_kind')
131
+ ]
132
+ """A message part sent by PydanticAI to a model."""
133
+
134
+
116
135
  @dataclass
117
- class ModelTextResponse:
136
+ class ModelRequest:
137
+ """A request generated by PydanticAI and sent to a model, e.g. a message from the PydanticAI app to the model."""
138
+
139
+ parts: list[ModelRequestPart]
140
+ """The parts of the user message."""
141
+
142
+ kind: Literal['request'] = 'request'
143
+ """Message type identifier, this is available on all parts as a discriminator."""
144
+
145
+
146
+ @dataclass
147
+ class TextPart:
118
148
  """A plain text response from a model."""
119
149
 
120
150
  content: str
121
151
  """The text content of the response."""
122
- timestamp: datetime = field(default_factory=_now_utc)
123
- """The timestamp of the response.
124
152
 
125
- If the model provides a timestamp in the response (as OpenAI does) that will be used.
126
- """
127
- role: Literal['model-text-response'] = 'model-text-response'
128
- """Message type identifier, this type is available on all message as a discriminator."""
153
+ part_kind: Literal['text'] = 'text'
154
+ """Part type identifier, this is available on all parts as a discriminator."""
129
155
 
130
156
 
131
157
  @dataclass
@@ -145,26 +171,31 @@ class ArgsDict:
145
171
 
146
172
 
147
173
  @dataclass
148
- class ToolCall:
149
- """Either a tool call from the agent."""
174
+ class ToolCallPart:
175
+ """A tool call from a model."""
150
176
 
151
177
  tool_name: str
152
178
  """The name of the tool to call."""
179
+
153
180
  args: ArgsJson | ArgsDict
154
181
  """The arguments to pass to the tool.
155
182
 
156
183
  Either as JSON or a Python dictionary depending on how data was returned.
157
184
  """
158
- tool_id: str | None = None
159
- """Optional tool identifier, this is used by some models including OpenAI."""
185
+
186
+ tool_call_id: str | None = None
187
+ """Optional tool call identifier, this is used by some models including OpenAI."""
188
+
189
+ part_kind: Literal['tool-call'] = 'tool-call'
190
+ """Part type identifier, this is available on all parts as a discriminator."""
160
191
 
161
192
  @classmethod
162
- def from_json(cls, tool_name: str, args_json: str, tool_id: str | None = None) -> ToolCall:
163
- return cls(tool_name, ArgsJson(args_json), tool_id)
193
+ def from_json(cls, tool_name: str, args_json: str, tool_call_id: str | None = None) -> Self:
194
+ return cls(tool_name, ArgsJson(args_json), tool_call_id)
164
195
 
165
196
  @classmethod
166
- def from_dict(cls, tool_name: str, args_dict: dict[str, Any]) -> ToolCall:
167
- return cls(tool_name, ArgsDict(args_dict))
197
+ def from_dict(cls, tool_name: str, args_dict: dict[str, Any], tool_call_id: str | None = None) -> Self:
198
+ return cls(tool_name, ArgsDict(args_dict), tool_call_id)
168
199
 
169
200
  def has_content(self) -> bool:
170
201
  if isinstance(self.args, ArgsDict):
@@ -173,28 +204,39 @@ class ToolCall:
173
204
  return bool(self.args.args_json)
174
205
 
175
206
 
207
+ ModelResponsePart = Annotated[Union[TextPart, ToolCallPart], pydantic.Discriminator('part_kind')]
208
+ """A message part returned by a model."""
209
+
210
+
176
211
  @dataclass
177
- class ModelStructuredResponse:
178
- """A structured response from a model.
212
+ class ModelResponse:
213
+ """A response from a model, e.g. a message from the model to the PydanticAI app."""
179
214
 
180
- This is used either to call a tool or to return a structured response from an agent run.
181
- """
215
+ parts: list[ModelResponsePart]
216
+ """The parts of the model message."""
182
217
 
183
- calls: list[ToolCall]
184
- """The tool calls being made."""
185
218
  timestamp: datetime = field(default_factory=_now_utc)
186
219
  """The timestamp of the response.
187
220
 
188
221
  If the model provides a timestamp in the response (as OpenAI does) that will be used.
189
222
  """
190
- role: Literal['model-structured-response'] = 'model-structured-response'
191
- """Message type identifier, this type is available on all message as a discriminator."""
223
+
224
+ kind: Literal['response'] = 'response'
225
+ """Message type identifier, this is available on all parts as a discriminator."""
226
+
227
+ @classmethod
228
+ def from_text(cls, content: str, timestamp: datetime | None = None) -> Self:
229
+ return cls([TextPart(content)], timestamp=timestamp or _now_utc())
230
+
231
+ @classmethod
232
+ def from_tool_call(cls, tool_call: ToolCallPart) -> Self:
233
+ return cls([tool_call])
192
234
 
193
235
 
194
- ModelAnyResponse = Union[ModelTextResponse, ModelStructuredResponse]
195
- """Any response from a model."""
196
- Message = Union[SystemPrompt, UserPrompt, ToolReturn, RetryPrompt, ModelTextResponse, ModelStructuredResponse]
236
+ ModelMessage = Union[ModelRequest, ModelResponse]
197
237
  """Any message send to or returned by a model."""
198
238
 
199
- MessagesTypeAdapter = _pydantic.LazyTypeAdapter(list[Annotated[Message, pydantic.Field(discriminator='role')]])
239
+ ModelMessagesTypeAdapter = pydantic.TypeAdapter(
240
+ list[Annotated[ModelMessage, pydantic.Discriminator('kind')]], config=pydantic.ConfigDict(defer_build=True)
241
+ )
200
242
  """Pydantic [`TypeAdapter`][pydantic.type_adapter.TypeAdapter] for (de)serializing messages."""
@@ -7,20 +7,21 @@ specific LLM being used.
7
7
  from __future__ import annotations as _annotations
8
8
 
9
9
  from abc import ABC, abstractmethod
10
- from collections.abc import AsyncIterator, Iterable, Iterator, Mapping, Sequence
10
+ from collections.abc import AsyncIterator, Iterable, Iterator
11
11
  from contextlib import asynccontextmanager, contextmanager
12
12
  from datetime import datetime
13
13
  from functools import cache
14
- from typing import TYPE_CHECKING, Literal, Protocol, Union
14
+ from typing import TYPE_CHECKING, Literal, Union
15
15
 
16
16
  import httpx
17
17
 
18
18
  from ..exceptions import UserError
19
- from ..messages import Message, ModelAnyResponse, ModelStructuredResponse
19
+ from ..messages import ModelMessage, ModelResponse
20
+ from ..settings import ModelSettings
20
21
 
21
22
  if TYPE_CHECKING:
22
- from .._utils import ObjectJsonSchema
23
23
  from ..result import Cost
24
+ from ..tools import ToolDefinition
24
25
 
25
26
 
26
27
  KnownModelName = Literal[
@@ -31,6 +32,7 @@ KnownModelName = Literal[
31
32
  'openai:o1-preview',
32
33
  'openai:o1-mini',
33
34
  'openai:gpt-3.5-turbo',
35
+ 'groq:llama-3.3-70b-versatile',
34
36
  'groq:llama-3.1-70b-versatile',
35
37
  'groq:llama3-groq-70b-8192-tool-use-preview',
36
38
  'groq:llama3-groq-8b-8192-tool-use-preview',
@@ -47,8 +49,35 @@ KnownModelName = Literal[
47
49
  'groq:gemma-7b-it',
48
50
  'gemini-1.5-flash',
49
51
  'gemini-1.5-pro',
52
+ 'gemini-2.0-flash-exp',
50
53
  'vertexai:gemini-1.5-flash',
51
54
  'vertexai:gemini-1.5-pro',
55
+ # since mistral models are supported by other providers (e.g. ollama), and some of their models (e.g. "codestral")
56
+ # don't start with "mistral", we add the "mistral:" prefix to all to be explicit
57
+ 'mistral:mistral-small-latest',
58
+ 'mistral:mistral-large-latest',
59
+ 'mistral:codestral-latest',
60
+ 'mistral:mistral-moderation-latest',
61
+ 'ollama:codellama',
62
+ 'ollama:gemma',
63
+ 'ollama:gemma2',
64
+ 'ollama:llama3',
65
+ 'ollama:llama3.1',
66
+ 'ollama:llama3.2',
67
+ 'ollama:llama3.2-vision',
68
+ 'ollama:llama3.3',
69
+ 'ollama:mistral',
70
+ 'ollama:mistral-nemo',
71
+ 'ollama:mixtral',
72
+ 'ollama:phi3',
73
+ 'ollama:qwq',
74
+ 'ollama:qwen',
75
+ 'ollama:qwen2',
76
+ 'ollama:qwen2.5',
77
+ 'ollama:starcoder2',
78
+ 'claude-3-5-haiku-latest',
79
+ 'claude-3-5-sonnet-latest',
80
+ 'claude-3-opus-latest',
52
81
  'test',
53
82
  ]
54
83
  """Known model names that can be used with the `model` parameter of [`Agent`][pydantic_ai.Agent].
@@ -63,11 +92,12 @@ class Model(ABC):
63
92
  @abstractmethod
64
93
  async def agent_model(
65
94
  self,
66
- function_tools: Mapping[str, AbstractToolDefinition],
95
+ *,
96
+ function_tools: list[ToolDefinition],
67
97
  allow_text_result: bool,
68
- result_tools: Sequence[AbstractToolDefinition] | None,
98
+ result_tools: list[ToolDefinition],
69
99
  ) -> AgentModel:
70
- """Create an agent model.
100
+ """Create an agent model, this is called for each step of an agent run.
71
101
 
72
102
  This is async in case slow/async config checks need to be performed that can't be done in `__init__`.
73
103
 
@@ -87,15 +117,19 @@ class Model(ABC):
87
117
 
88
118
 
89
119
  class AgentModel(ABC):
90
- """Model configured for a specific agent."""
120
+ """Model configured for each step of an Agent run."""
91
121
 
92
122
  @abstractmethod
93
- async def request(self, messages: list[Message]) -> tuple[ModelAnyResponse, Cost]:
123
+ async def request(
124
+ self, messages: list[ModelMessage], model_settings: ModelSettings | None
125
+ ) -> tuple[ModelResponse, Cost]:
94
126
  """Make a request to the model."""
95
127
  raise NotImplementedError()
96
128
 
97
129
  @asynccontextmanager
98
- async def request_stream(self, messages: list[Message]) -> AsyncIterator[EitherStreamedResponse]:
130
+ async def request_stream(
131
+ self, messages: list[ModelMessage], model_settings: ModelSettings | None
132
+ ) -> AsyncIterator[EitherStreamedResponse]:
99
133
  """Make a request to the model and return a streaming response."""
100
134
  raise NotImplementedError(f'Streamed requests not supported by this {self.__class__.__name__}')
101
135
  # yield is required to make this a generator for type checking
@@ -160,10 +194,10 @@ class StreamStructuredResponse(ABC):
160
194
  raise NotImplementedError()
161
195
 
162
196
  @abstractmethod
163
- def get(self, *, final: bool = False) -> ModelStructuredResponse:
164
- """Get the `ModelStructuredResponse` at this point.
197
+ def get(self, *, final: bool = False) -> ModelResponse:
198
+ """Get the `ModelResponse` at this point.
165
199
 
166
- The `ModelStructuredResponse` may or may not be complete, depending on whether the stream is finished.
200
+ The `ModelResponse` may or may not be complete, depending on whether the stream is finished.
167
201
 
168
202
  Args:
169
203
  final: If True, this is the final call, after iteration is complete, the response should be fully validated.
@@ -238,7 +272,7 @@ def infer_model(model: Model | KnownModelName) -> Model:
238
272
  elif model.startswith('openai:'):
239
273
  from .openai import OpenAIModel
240
274
 
241
- return OpenAIModel(model[7:]) # pyright: ignore[reportArgumentType]
275
+ return OpenAIModel(model[7:])
242
276
  elif model.startswith('gemini'):
243
277
  from .gemini import GeminiModel
244
278
 
@@ -252,37 +286,20 @@ def infer_model(model: Model | KnownModelName) -> Model:
252
286
  from .vertexai import VertexAIModel
253
287
 
254
288
  return VertexAIModel(model[9:]) # pyright: ignore[reportArgumentType]
255
- else:
256
- raise UserError(f'Unknown model: {model}')
257
-
258
-
259
- class AbstractToolDefinition(Protocol):
260
- """Abstract definition of a function/tool.
261
-
262
- This is used for both function tools and result tools.
263
- """
289
+ elif model.startswith('mistral:'):
290
+ from .mistral import MistralModel
264
291
 
265
- @property
266
- def name(self) -> str:
267
- """The name of the tool."""
268
- ...
269
-
270
- @property
271
- def description(self) -> str:
272
- """The description of the tool."""
273
- ...
274
-
275
- @property
276
- def json_schema(self) -> ObjectJsonSchema:
277
- """The JSON schema for the tool's arguments."""
278
- ...
292
+ return MistralModel(model[8:])
293
+ elif model.startswith('ollama:'):
294
+ from .ollama import OllamaModel
279
295
 
280
- @property
281
- def outer_typed_dict_key(self) -> str | None:
282
- """The key in the outer [TypedDict] that wraps a result tool.
296
+ return OllamaModel(model[7:])
297
+ elif model.startswith('claude'):
298
+ from .anthropic import AnthropicModel
283
299
 
284
- This will only be set for result tools which don't have an `object` JSON schema.
285
- """
300
+ return AnthropicModel(model)
301
+ else:
302
+ raise UserError(f'Unknown model: {model}')
286
303
 
287
304
 
288
305
  @cache