pydantic-ai-slim 0.0.18__py3-none-any.whl → 0.0.20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

@@ -37,7 +37,8 @@ def format_as_xml(
37
37
  none_str: String to use for `None` values.
38
38
  indent: Indentation string to use for pretty printing.
39
39
 
40
- Returns: XML representation of the object.
40
+ Returns:
41
+ XML representation of the object.
41
42
 
42
43
  Example:
43
44
  ```python {title="format_as_xml_example.py" lint="skip"}
pydantic_ai/messages.py CHANGED
@@ -1,14 +1,15 @@
1
1
  from __future__ import annotations as _annotations
2
2
 
3
- from dataclasses import dataclass, field
3
+ from dataclasses import dataclass, field, replace
4
4
  from datetime import datetime
5
- from typing import Annotated, Any, Literal, Union, cast
5
+ from typing import Annotated, Any, Literal, Union, cast, overload
6
6
 
7
7
  import pydantic
8
8
  import pydantic_core
9
9
  from typing_extensions import Self, assert_never
10
10
 
11
11
  from ._utils import now_utc as _now_utc
12
+ from .exceptions import UnexpectedModelBehavior
12
13
 
13
14
 
14
15
  @dataclass
@@ -72,12 +73,14 @@ class ToolReturnPart:
72
73
  """Part type identifier, this is available on all parts as a discriminator."""
73
74
 
74
75
  def model_response_str(self) -> str:
76
+ """Return a string representation of the content for the model."""
75
77
  if isinstance(self.content, str):
76
78
  return self.content
77
79
  else:
78
80
  return tool_return_ta.dump_json(self.content).decode()
79
81
 
80
82
  def model_response_object(self) -> dict[str, Any]:
83
+ """Return a dictionary representation of the content, wrapping non-dict types appropriately."""
81
84
  # gemini supports JSON dict return values, but no other JSON types, hence we wrap anything else in a dict
82
85
  if isinstance(self.content, dict):
83
86
  return tool_return_ta.dump_python(self.content, mode='json') # pyright: ignore[reportUnknownMemberType]
@@ -124,6 +127,7 @@ class RetryPromptPart:
124
127
  """Part type identifier, this is available on all parts as a discriminator."""
125
128
 
126
129
  def model_response(self) -> str:
130
+ """Return a string message describing why the retry is requested."""
127
131
  if isinstance(self.content, str):
128
132
  description = self.content
129
133
  else:
@@ -159,6 +163,10 @@ class TextPart:
159
163
  part_kind: Literal['text'] = 'text'
160
164
  """Part type identifier, this is available on all parts as a discriminator."""
161
165
 
166
+ def has_content(self) -> bool:
167
+ """Return `True` if the text content is non-empty."""
168
+ return bool(self.content)
169
+
162
170
 
163
171
  @dataclass
164
172
  class ArgsJson:
@@ -197,7 +205,7 @@ class ToolCallPart:
197
205
 
198
206
  @classmethod
199
207
  def from_raw_args(cls, tool_name: str, args: str | dict[str, Any], tool_call_id: str | None = None) -> Self:
200
- """Create a `ToolCallPart` from raw arguments."""
208
+ """Create a `ToolCallPart` from raw arguments, converting them to `ArgsJson` or `ArgsDict`."""
201
209
  if isinstance(args, str):
202
210
  return cls(tool_name, ArgsJson(args), tool_call_id)
203
211
  elif isinstance(args, dict):
@@ -226,6 +234,7 @@ class ToolCallPart:
226
234
  return pydantic_core.to_json(self.args.args_dict).decode()
227
235
 
228
236
  def has_content(self) -> bool:
237
+ """Return `True` if the arguments contain any data."""
229
238
  if isinstance(self.args, ArgsDict):
230
239
  return any(self.args.args_dict.values())
231
240
  else:
@@ -243,6 +252,9 @@ class ModelResponse:
243
252
  parts: list[ModelResponsePart]
244
253
  """The parts of the model message."""
245
254
 
255
+ model_name: str | None = None
256
+ """The name of the model that generated the response."""
257
+
246
258
  timestamp: datetime = field(default_factory=_now_utc)
247
259
  """The timestamp of the response.
248
260
 
@@ -252,19 +264,209 @@ class ModelResponse:
252
264
  kind: Literal['response'] = 'response'
253
265
  """Message type identifier, this is available on all parts as a discriminator."""
254
266
 
255
- @classmethod
256
- def from_text(cls, content: str, timestamp: datetime | None = None) -> Self:
257
- return cls([TextPart(content)], timestamp=timestamp or _now_utc())
258
267
 
259
- @classmethod
260
- def from_tool_call(cls, tool_call: ToolCallPart) -> Self:
261
- return cls([tool_call])
268
+ ModelMessage = Annotated[Union[ModelRequest, ModelResponse], pydantic.Discriminator('kind')]
269
+ """Any message sent to or returned by a model."""
262
270
 
271
+ ModelMessagesTypeAdapter = pydantic.TypeAdapter(list[ModelMessage], config=pydantic.ConfigDict(defer_build=True))
272
+ """Pydantic [`TypeAdapter`][pydantic.type_adapter.TypeAdapter] for (de)serializing messages."""
263
273
 
264
- ModelMessage = Union[ModelRequest, ModelResponse]
265
- """Any message send to or returned by a model."""
266
274
 
267
- ModelMessagesTypeAdapter = pydantic.TypeAdapter(
268
- list[Annotated[ModelMessage, pydantic.Discriminator('kind')]], config=pydantic.ConfigDict(defer_build=True)
269
- )
270
- """Pydantic [`TypeAdapter`][pydantic.type_adapter.TypeAdapter] for (de)serializing messages."""
275
+ @dataclass
276
+ class TextPartDelta:
277
+ """A partial update (delta) for a `TextPart` to append new text content."""
278
+
279
+ content_delta: str
280
+ """The incremental text content to add to the existing `TextPart` content."""
281
+
282
+ part_delta_kind: Literal['text'] = 'text'
283
+ """Part delta type identifier, used as a discriminator."""
284
+
285
+ def apply(self, part: ModelResponsePart) -> TextPart:
286
+ """Apply this text delta to an existing `TextPart`.
287
+
288
+ Args:
289
+ part: The existing model response part, which must be a `TextPart`.
290
+
291
+ Returns:
292
+ A new `TextPart` with updated text content.
293
+
294
+ Raises:
295
+ ValueError: If `part` is not a `TextPart`.
296
+ """
297
+ if not isinstance(part, TextPart):
298
+ raise ValueError('Cannot apply TextPartDeltas to non-TextParts')
299
+ return replace(part, content=part.content + self.content_delta)
300
+
301
+
302
+ @dataclass
303
+ class ToolCallPartDelta:
304
+ """A partial update (delta) for a `ToolCallPart` to modify tool name, arguments, or tool call ID."""
305
+
306
+ tool_name_delta: str | None = None
307
+ """Incremental text to add to the existing tool name, if any."""
308
+
309
+ args_delta: str | dict[str, Any] | None = None
310
+ """Incremental data to add to the tool arguments.
311
+
312
+ If this is a string, it will be appended to existing JSON arguments.
313
+ If this is a dict, it will be merged with existing dict arguments.
314
+ """
315
+
316
+ tool_call_id: str | None = None
317
+ """Optional tool call identifier, this is used by some models including OpenAI.
318
+
319
+ Note this is never treated as a delta — it can replace None, but otherwise if a
320
+ non-matching value is provided an error will be raised."""
321
+
322
+ part_delta_kind: Literal['tool_call'] = 'tool_call'
323
+ """Part delta type identifier, used as a discriminator."""
324
+
325
+ def as_part(self) -> ToolCallPart | None:
326
+ """Convert this delta to a fully formed `ToolCallPart` if possible, otherwise return `None`.
327
+
328
+ Returns:
329
+ A `ToolCallPart` if both `tool_name_delta` and `args_delta` are set, otherwise `None`.
330
+ """
331
+ if self.tool_name_delta is None or self.args_delta is None:
332
+ return None
333
+
334
+ return ToolCallPart.from_raw_args(
335
+ self.tool_name_delta,
336
+ self.args_delta,
337
+ self.tool_call_id,
338
+ )
339
+
340
+ @overload
341
+ def apply(self, part: ModelResponsePart) -> ToolCallPart: ...
342
+
343
+ @overload
344
+ def apply(self, part: ModelResponsePart | ToolCallPartDelta) -> ToolCallPart | ToolCallPartDelta: ...
345
+
346
+ def apply(self, part: ModelResponsePart | ToolCallPartDelta) -> ToolCallPart | ToolCallPartDelta:
347
+ """Apply this delta to a part or delta, returning a new part or delta with the changes applied.
348
+
349
+ Args:
350
+ part: The existing model response part or delta to update.
351
+
352
+ Returns:
353
+ Either a new `ToolCallPart` or an updated `ToolCallPartDelta`.
354
+
355
+ Raises:
356
+ ValueError: If `part` is neither a `ToolCallPart` nor a `ToolCallPartDelta`.
357
+ UnexpectedModelBehavior: If applying JSON deltas to dict arguments or vice versa.
358
+ """
359
+ if isinstance(part, ToolCallPart):
360
+ return self._apply_to_part(part)
361
+
362
+ if isinstance(part, ToolCallPartDelta):
363
+ return self._apply_to_delta(part)
364
+
365
+ raise ValueError(f'Can only apply ToolCallPartDeltas to ToolCallParts or ToolCallPartDeltas, not {part}')
366
+
367
+ def _apply_to_delta(self, delta: ToolCallPartDelta) -> ToolCallPart | ToolCallPartDelta:
368
+ """Internal helper to apply this delta to another delta."""
369
+ if self.tool_name_delta:
370
+ # Append incremental text to the existing tool_name_delta
371
+ updated_tool_name_delta = (delta.tool_name_delta or '') + self.tool_name_delta
372
+ delta = replace(delta, tool_name_delta=updated_tool_name_delta)
373
+
374
+ if isinstance(self.args_delta, str):
375
+ if isinstance(delta.args_delta, dict):
376
+ raise UnexpectedModelBehavior(
377
+ f'Cannot apply JSON deltas to non-JSON tool arguments ({delta=}, {self=})'
378
+ )
379
+ updated_args_delta = (delta.args_delta or '') + self.args_delta
380
+ delta = replace(delta, args_delta=updated_args_delta)
381
+ elif isinstance(self.args_delta, dict):
382
+ if isinstance(delta.args_delta, str):
383
+ raise UnexpectedModelBehavior(
384
+ f'Cannot apply dict deltas to non-dict tool arguments ({delta=}, {self=})'
385
+ )
386
+ updated_args_delta = {**(delta.args_delta or {}), **self.args_delta}
387
+ delta = replace(delta, args_delta=updated_args_delta)
388
+
389
+ if self.tool_call_id:
390
+ # Set the tool_call_id if it wasn't present, otherwise error if it has changed
391
+ if delta.tool_call_id is not None and delta.tool_call_id != self.tool_call_id:
392
+ raise UnexpectedModelBehavior(
393
+ f'Cannot apply a new tool_call_id to a ToolCallPartDelta that already has one ({delta=}, {self=})'
394
+ )
395
+ delta = replace(delta, tool_call_id=self.tool_call_id)
396
+
397
+ # If we now have enough data to create a full ToolCallPart, do so
398
+ if delta.tool_name_delta is not None and delta.args_delta is not None:
399
+ return ToolCallPart.from_raw_args(
400
+ delta.tool_name_delta,
401
+ delta.args_delta,
402
+ delta.tool_call_id,
403
+ )
404
+
405
+ return delta
406
+
407
+ def _apply_to_part(self, part: ToolCallPart) -> ToolCallPart:
408
+ """Internal helper to apply this delta directly to a `ToolCallPart`."""
409
+ if self.tool_name_delta:
410
+ # Append incremental text to the existing tool_name
411
+ tool_name = part.tool_name + self.tool_name_delta
412
+ part = replace(part, tool_name=tool_name)
413
+
414
+ if isinstance(self.args_delta, str):
415
+ if not isinstance(part.args, ArgsJson):
416
+ raise UnexpectedModelBehavior(f'Cannot apply JSON deltas to non-JSON tool arguments ({part=}, {self=})')
417
+ updated_json = part.args.args_json + self.args_delta
418
+ part = replace(part, args=ArgsJson(updated_json))
419
+ elif isinstance(self.args_delta, dict):
420
+ if not isinstance(part.args, ArgsDict):
421
+ raise UnexpectedModelBehavior(f'Cannot apply dict deltas to non-dict tool arguments ({part=}, {self=})')
422
+ updated_dict = {**(part.args.args_dict or {}), **self.args_delta}
423
+ part = replace(part, args=ArgsDict(updated_dict))
424
+
425
+ if self.tool_call_id:
426
+ # Replace the tool_call_id entirely if given
427
+ if part.tool_call_id is not None and part.tool_call_id != self.tool_call_id:
428
+ raise UnexpectedModelBehavior(
429
+ f'Cannot apply a new tool_call_id to a ToolCallPartDelta that already has one ({part=}, {self=})'
430
+ )
431
+ part = replace(part, tool_call_id=self.tool_call_id)
432
+ return part
433
+
434
+
435
+ ModelResponsePartDelta = Annotated[Union[TextPartDelta, ToolCallPartDelta], pydantic.Discriminator('part_delta_kind')]
436
+ """A partial update (delta) for any model response part."""
437
+
438
+
439
+ @dataclass
440
+ class PartStartEvent:
441
+ """An event indicating that a new part has started.
442
+
443
+ If multiple `PartStartEvent`s are received with the same index,
444
+ the new one should fully replace the old one.
445
+ """
446
+
447
+ index: int
448
+ """The index of the part within the overall response parts list."""
449
+
450
+ part: ModelResponsePart
451
+ """The newly started `ModelResponsePart`."""
452
+
453
+ event_kind: Literal['part_start'] = 'part_start'
454
+ """Event type identifier, used as a discriminator."""
455
+
456
+
457
+ @dataclass
458
+ class PartDeltaEvent:
459
+ """An event indicating a delta update for an existing part."""
460
+
461
+ index: int
462
+ """The index of the part within the overall response parts list."""
463
+
464
+ delta: ModelResponsePartDelta
465
+ """The delta to apply to the specified part."""
466
+
467
+ event_kind: Literal['part_delta'] = 'part_delta'
468
+ """Event type identifier, used as a discriminator."""
469
+
470
+
471
+ ModelResponseStreamEvent = Annotated[Union[PartStartEvent, PartDeltaEvent], pydantic.Discriminator('event_kind')]
472
+ """An event in the model response stream, either starting a new part or applying a delta to an existing one."""
@@ -7,20 +7,22 @@ specific LLM being used.
7
7
  from __future__ import annotations as _annotations
8
8
 
9
9
  from abc import ABC, abstractmethod
10
- from collections.abc import AsyncIterator, Iterable, Iterator
10
+ from collections.abc import AsyncIterator, Iterator
11
11
  from contextlib import asynccontextmanager, contextmanager
12
+ from dataclasses import dataclass, field
12
13
  from datetime import datetime
13
14
  from functools import cache
14
- from typing import TYPE_CHECKING, Literal, Union
15
+ from typing import TYPE_CHECKING, Literal
15
16
 
16
17
  import httpx
17
18
 
19
+ from .._parts_manager import ModelResponsePartsManager
18
20
  from ..exceptions import UserError
19
- from ..messages import ModelMessage, ModelResponse
21
+ from ..messages import ModelMessage, ModelResponse, ModelResponseStreamEvent
20
22
  from ..settings import ModelSettings
23
+ from ..usage import Usage
21
24
 
22
25
  if TYPE_CHECKING:
23
- from ..result import Usage
24
26
  from ..tools import ToolDefinition
25
27
 
26
28
 
@@ -59,6 +61,7 @@ KnownModelName = Literal[
59
61
  'mistral:codestral-latest',
60
62
  'mistral:mistral-moderation-latest',
61
63
  'ollama:codellama',
64
+ 'ollama:deepseek-r1',
62
65
  'ollama:gemma',
63
66
  'ollama:gemma2',
64
67
  'ollama:llama3',
@@ -70,6 +73,7 @@ KnownModelName = Literal[
70
73
  'ollama:mistral-nemo',
71
74
  'ollama:mixtral',
72
75
  'ollama:phi3',
76
+ 'ollama:phi4',
73
77
  'ollama:qwq',
74
78
  'ollama:qwen',
75
79
  'ollama:qwen2',
@@ -78,6 +82,22 @@ KnownModelName = Literal[
78
82
  'anthropic:claude-3-5-haiku-latest',
79
83
  'anthropic:claude-3-5-sonnet-latest',
80
84
  'anthropic:claude-3-opus-latest',
85
+ 'claude-3-5-haiku-latest',
86
+ 'claude-3-5-sonnet-latest',
87
+ 'claude-3-opus-latest',
88
+ 'cohere:c4ai-aya-expanse-32b',
89
+ 'cohere:c4ai-aya-expanse-8b',
90
+ 'cohere:command',
91
+ 'cohere:command-light',
92
+ 'cohere:command-light-nightly',
93
+ 'cohere:command-nightly',
94
+ 'cohere:command-r',
95
+ 'cohere:command-r-03-2024',
96
+ 'cohere:command-r-08-2024',
97
+ 'cohere:command-r-plus',
98
+ 'cohere:command-r-plus-04-2024',
99
+ 'cohere:command-r-plus-08-2024',
100
+ 'cohere:command-r7b-12-2024',
81
101
  'test',
82
102
  ]
83
103
  """Known model names that can be used with the `model` parameter of [`Agent`][pydantic_ai.Agent].
@@ -129,88 +149,54 @@ class AgentModel(ABC):
129
149
  @asynccontextmanager
130
150
  async def request_stream(
131
151
  self, messages: list[ModelMessage], model_settings: ModelSettings | None
132
- ) -> AsyncIterator[EitherStreamedResponse]:
152
+ ) -> AsyncIterator[StreamedResponse]:
133
153
  """Make a request to the model and return a streaming response."""
154
+ # This method is not required, but you need to implement it if you want to support streamed responses
134
155
  raise NotImplementedError(f'Streamed requests not supported by this {self.__class__.__name__}')
135
156
  # yield is required to make this a generator for type checking
136
157
  # noinspection PyUnreachableCode
137
158
  yield # pragma: no cover
138
159
 
139
160
 
140
- class StreamTextResponse(ABC):
141
- """Streamed response from an LLM when returning text."""
142
-
143
- def __aiter__(self) -> AsyncIterator[None]:
144
- """Stream the response as an async iterable, building up the text as it goes.
145
-
146
- This is an async iterator that yields `None` to avoid doing the work of validating the input and
147
- extracting the text field when it will often be thrown away.
148
- """
149
- return self
150
-
151
- @abstractmethod
152
- async def __anext__(self) -> None:
153
- """Process the next chunk of the response, see above for why this returns `None`."""
154
- raise NotImplementedError()
155
-
156
- @abstractmethod
157
- def get(self, *, final: bool = False) -> Iterable[str]:
158
- """Returns an iterable of text since the last call to `get()` — e.g. the text delta.
159
-
160
- Args:
161
- final: If True, this is the final call, after iteration is complete, the response should be fully validated
162
- and all text extracted.
163
- """
164
- raise NotImplementedError()
161
+ @dataclass
162
+ class StreamedResponse(ABC):
163
+ """Streamed response from an LLM when calling a tool."""
165
164
 
166
- @abstractmethod
167
- def usage(self) -> Usage:
168
- """Return the usage of the request.
165
+ _model_name: str
166
+ _usage: Usage = field(default_factory=Usage, init=False)
167
+ _parts_manager: ModelResponsePartsManager = field(default_factory=ModelResponsePartsManager, init=False)
168
+ _event_iterator: AsyncIterator[ModelResponseStreamEvent] | None = field(default=None, init=False)
169
169
 
170
- NOTE: this won't return the full usage until the stream is finished.
171
- """
172
- raise NotImplementedError()
170
+ def __aiter__(self) -> AsyncIterator[ModelResponseStreamEvent]:
171
+ """Stream the response as an async iterable of [`ModelResponseStreamEvent`][pydantic_ai.messages.ModelResponseStreamEvent]s."""
172
+ if self._event_iterator is None:
173
+ self._event_iterator = self._get_event_iterator()
174
+ return self._event_iterator
173
175
 
174
176
  @abstractmethod
175
- def timestamp(self) -> datetime:
176
- """Get the timestamp of the response."""
177
- raise NotImplementedError()
177
+ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
178
+ """Return an async iterator of [`ModelResponseStreamEvent`][pydantic_ai.messages.ModelResponseStreamEvent]s.
178
179
 
179
-
180
- class StreamStructuredResponse(ABC):
181
- """Streamed response from an LLM when calling a tool."""
182
-
183
- def __aiter__(self) -> AsyncIterator[None]:
184
- """Stream the response as an async iterable, building up the tool call as it goes.
185
-
186
- This is an async iterator that yields `None` to avoid doing the work of building the final tool call when
187
- it will often be thrown away.
180
+ This method should be implemented by subclasses to translate the vendor-specific stream of events into
181
+ pydantic_ai-format events.
188
182
  """
189
- return self
190
-
191
- @abstractmethod
192
- async def __anext__(self) -> None:
193
- """Process the next chunk of the response, see above for why this returns `None`."""
194
183
  raise NotImplementedError()
184
+ # noinspection PyUnreachableCode
185
+ yield
195
186
 
196
- @abstractmethod
197
- def get(self, *, final: bool = False) -> ModelResponse:
198
- """Get the `ModelResponse` at this point.
199
-
200
- The `ModelResponse` may or may not be complete, depending on whether the stream is finished.
187
+ def get(self) -> ModelResponse:
188
+ """Build a [`ModelResponse`][pydantic_ai.messages.ModelResponse] from the data received from the stream so far."""
189
+ return ModelResponse(
190
+ parts=self._parts_manager.get_parts(), model_name=self._model_name, timestamp=self.timestamp()
191
+ )
201
192
 
202
- Args:
203
- final: If True, this is the final call, after iteration is complete, the response should be fully validated.
204
- """
205
- raise NotImplementedError()
193
+ def model_name(self) -> str:
194
+ """Get the model name of the response."""
195
+ return self._model_name
206
196
 
207
- @abstractmethod
208
197
  def usage(self) -> Usage:
209
- """Get the usage of the request.
210
-
211
- NOTE: this won't return the full usage until the stream is finished.
212
- """
213
- raise NotImplementedError()
198
+ """Get the usage of the response so far. This will not be the final usage until the stream is exhausted."""
199
+ return self._usage
214
200
 
215
201
  @abstractmethod
216
202
  def timestamp(self) -> datetime:
@@ -218,9 +204,6 @@ class StreamStructuredResponse(ABC):
218
204
  raise NotImplementedError()
219
205
 
220
206
 
221
- EitherStreamedResponse = Union[StreamTextResponse, StreamStructuredResponse]
222
-
223
-
224
207
  ALLOW_MODEL_REQUESTS = True
225
208
  """Whether to allow requests to models.
226
209
 
@@ -269,6 +252,10 @@ def infer_model(model: Model | KnownModelName) -> Model:
269
252
  from .test import TestModel
270
253
 
271
254
  return TestModel()
255
+ elif model.startswith('cohere:'):
256
+ from .cohere import CohereModel
257
+
258
+ return CohereModel(model[7:])
272
259
  elif model.startswith('openai:'):
273
260
  from .openai import OpenAIModel
274
261