pydantic-ai-slim 0.0.17__py3-none-any.whl → 0.0.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_griffe.py +23 -4
- pydantic_ai/_parts_manager.py +239 -0
- pydantic_ai/_pydantic.py +16 -3
- pydantic_ai/_system_prompt.py +1 -0
- pydantic_ai/_utils.py +80 -17
- pydantic_ai/agent.py +332 -124
- pydantic_ai/format_as_xml.py +2 -1
- pydantic_ai/messages.py +224 -9
- pydantic_ai/models/__init__.py +59 -82
- pydantic_ai/models/anthropic.py +22 -22
- pydantic_ai/models/function.py +47 -79
- pydantic_ai/models/gemini.py +86 -125
- pydantic_ai/models/groq.py +53 -125
- pydantic_ai/models/mistral.py +75 -137
- pydantic_ai/models/ollama.py +1 -0
- pydantic_ai/models/openai.py +50 -125
- pydantic_ai/models/test.py +40 -73
- pydantic_ai/models/vertexai.py +1 -1
- pydantic_ai/result.py +91 -92
- pydantic_ai/tools.py +24 -5
- {pydantic_ai_slim-0.0.17.dist-info → pydantic_ai_slim-0.0.19.dist-info}/METADATA +3 -1
- pydantic_ai_slim-0.0.19.dist-info/RECORD +29 -0
- pydantic_ai_slim-0.0.17.dist-info/RECORD +0 -28
- {pydantic_ai_slim-0.0.17.dist-info → pydantic_ai_slim-0.0.19.dist-info}/WHEEL +0 -0
pydantic_ai/format_as_xml.py
CHANGED
|
@@ -37,7 +37,8 @@ def format_as_xml(
|
|
|
37
37
|
none_str: String to use for `None` values.
|
|
38
38
|
indent: Indentation string to use for pretty printing.
|
|
39
39
|
|
|
40
|
-
Returns:
|
|
40
|
+
Returns:
|
|
41
|
+
XML representation of the object.
|
|
41
42
|
|
|
42
43
|
Example:
|
|
43
44
|
```python {title="format_as_xml_example.py" lint="skip"}
|
pydantic_ai/messages.py
CHANGED
|
@@ -1,14 +1,15 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
|
-
from dataclasses import dataclass, field
|
|
3
|
+
from dataclasses import dataclass, field, replace
|
|
4
4
|
from datetime import datetime
|
|
5
|
-
from typing import Annotated, Any, Literal, Union, cast
|
|
5
|
+
from typing import Annotated, Any, Literal, Union, cast, overload
|
|
6
6
|
|
|
7
7
|
import pydantic
|
|
8
8
|
import pydantic_core
|
|
9
9
|
from typing_extensions import Self, assert_never
|
|
10
10
|
|
|
11
11
|
from ._utils import now_utc as _now_utc
|
|
12
|
+
from .exceptions import UnexpectedModelBehavior
|
|
12
13
|
|
|
13
14
|
|
|
14
15
|
@dataclass
|
|
@@ -21,6 +22,12 @@ class SystemPromptPart:
|
|
|
21
22
|
content: str
|
|
22
23
|
"""The content of the prompt."""
|
|
23
24
|
|
|
25
|
+
dynamic_ref: str | None = None
|
|
26
|
+
"""The ref of the dynamic system prompt function that generated this part.
|
|
27
|
+
|
|
28
|
+
Only set if system prompt is dynamic, see [`system_prompt`][pydantic_ai.Agent.system_prompt] for more information.
|
|
29
|
+
"""
|
|
30
|
+
|
|
24
31
|
part_kind: Literal['system-prompt'] = 'system-prompt'
|
|
25
32
|
"""Part type identifier, this is available on all parts as a discriminator."""
|
|
26
33
|
|
|
@@ -66,12 +73,14 @@ class ToolReturnPart:
|
|
|
66
73
|
"""Part type identifier, this is available on all parts as a discriminator."""
|
|
67
74
|
|
|
68
75
|
def model_response_str(self) -> str:
|
|
76
|
+
"""Return a string representation of the content for the model."""
|
|
69
77
|
if isinstance(self.content, str):
|
|
70
78
|
return self.content
|
|
71
79
|
else:
|
|
72
80
|
return tool_return_ta.dump_json(self.content).decode()
|
|
73
81
|
|
|
74
82
|
def model_response_object(self) -> dict[str, Any]:
|
|
83
|
+
"""Return a dictionary representation of the content, wrapping non-dict types appropriately."""
|
|
75
84
|
# gemini supports JSON dict return values, but no other JSON types, hence we wrap anything else in a dict
|
|
76
85
|
if isinstance(self.content, dict):
|
|
77
86
|
return tool_return_ta.dump_python(self.content, mode='json') # pyright: ignore[reportUnknownMemberType]
|
|
@@ -118,6 +127,7 @@ class RetryPromptPart:
|
|
|
118
127
|
"""Part type identifier, this is available on all parts as a discriminator."""
|
|
119
128
|
|
|
120
129
|
def model_response(self) -> str:
|
|
130
|
+
"""Return a string message describing why the retry is requested."""
|
|
121
131
|
if isinstance(self.content, str):
|
|
122
132
|
description = self.content
|
|
123
133
|
else:
|
|
@@ -153,6 +163,10 @@ class TextPart:
|
|
|
153
163
|
part_kind: Literal['text'] = 'text'
|
|
154
164
|
"""Part type identifier, this is available on all parts as a discriminator."""
|
|
155
165
|
|
|
166
|
+
def has_content(self) -> bool:
|
|
167
|
+
"""Return `True` if the text content is non-empty."""
|
|
168
|
+
return bool(self.content)
|
|
169
|
+
|
|
156
170
|
|
|
157
171
|
@dataclass
|
|
158
172
|
class ArgsJson:
|
|
@@ -191,7 +205,7 @@ class ToolCallPart:
|
|
|
191
205
|
|
|
192
206
|
@classmethod
|
|
193
207
|
def from_raw_args(cls, tool_name: str, args: str | dict[str, Any], tool_call_id: str | None = None) -> Self:
|
|
194
|
-
"""Create a `ToolCallPart` from raw arguments
|
|
208
|
+
"""Create a `ToolCallPart` from raw arguments, converting them to `ArgsJson` or `ArgsDict`."""
|
|
195
209
|
if isinstance(args, str):
|
|
196
210
|
return cls(tool_name, ArgsJson(args), tool_call_id)
|
|
197
211
|
elif isinstance(args, dict):
|
|
@@ -220,6 +234,7 @@ class ToolCallPart:
|
|
|
220
234
|
return pydantic_core.to_json(self.args.args_dict).decode()
|
|
221
235
|
|
|
222
236
|
def has_content(self) -> bool:
|
|
237
|
+
"""Return `True` if the arguments contain any data."""
|
|
223
238
|
if isinstance(self.args, ArgsDict):
|
|
224
239
|
return any(self.args.args_dict.values())
|
|
225
240
|
else:
|
|
@@ -248,17 +263,217 @@ class ModelResponse:
|
|
|
248
263
|
|
|
249
264
|
@classmethod
|
|
250
265
|
def from_text(cls, content: str, timestamp: datetime | None = None) -> Self:
|
|
251
|
-
|
|
266
|
+
"""Create a `ModelResponse` containing a single `TextPart`."""
|
|
267
|
+
return cls([TextPart(content=content)], timestamp=timestamp or _now_utc())
|
|
252
268
|
|
|
253
269
|
@classmethod
|
|
254
270
|
def from_tool_call(cls, tool_call: ToolCallPart) -> Self:
|
|
271
|
+
"""Create a `ModelResponse` containing a single `ToolCallPart`."""
|
|
255
272
|
return cls([tool_call])
|
|
256
273
|
|
|
257
274
|
|
|
258
|
-
ModelMessage = Union[ModelRequest, ModelResponse]
|
|
259
|
-
"""Any message
|
|
275
|
+
ModelMessage = Annotated[Union[ModelRequest, ModelResponse], pydantic.Discriminator('kind')]
|
|
276
|
+
"""Any message sent to or returned by a model."""
|
|
260
277
|
|
|
261
|
-
ModelMessagesTypeAdapter = pydantic.TypeAdapter(
|
|
262
|
-
list[Annotated[ModelMessage, pydantic.Discriminator('kind')]], config=pydantic.ConfigDict(defer_build=True)
|
|
263
|
-
)
|
|
278
|
+
ModelMessagesTypeAdapter = pydantic.TypeAdapter(list[ModelMessage], config=pydantic.ConfigDict(defer_build=True))
|
|
264
279
|
"""Pydantic [`TypeAdapter`][pydantic.type_adapter.TypeAdapter] for (de)serializing messages."""
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
@dataclass
|
|
283
|
+
class TextPartDelta:
|
|
284
|
+
"""A partial update (delta) for a `TextPart` to append new text content."""
|
|
285
|
+
|
|
286
|
+
content_delta: str
|
|
287
|
+
"""The incremental text content to add to the existing `TextPart` content."""
|
|
288
|
+
|
|
289
|
+
part_delta_kind: Literal['text'] = 'text'
|
|
290
|
+
"""Part delta type identifier, used as a discriminator."""
|
|
291
|
+
|
|
292
|
+
def apply(self, part: ModelResponsePart) -> TextPart:
|
|
293
|
+
"""Apply this text delta to an existing `TextPart`.
|
|
294
|
+
|
|
295
|
+
Args:
|
|
296
|
+
part: The existing model response part, which must be a `TextPart`.
|
|
297
|
+
|
|
298
|
+
Returns:
|
|
299
|
+
A new `TextPart` with updated text content.
|
|
300
|
+
|
|
301
|
+
Raises:
|
|
302
|
+
ValueError: If `part` is not a `TextPart`.
|
|
303
|
+
"""
|
|
304
|
+
if not isinstance(part, TextPart):
|
|
305
|
+
raise ValueError('Cannot apply TextPartDeltas to non-TextParts')
|
|
306
|
+
return replace(part, content=part.content + self.content_delta)
|
|
307
|
+
|
|
308
|
+
|
|
309
|
+
@dataclass
|
|
310
|
+
class ToolCallPartDelta:
|
|
311
|
+
"""A partial update (delta) for a `ToolCallPart` to modify tool name, arguments, or tool call ID."""
|
|
312
|
+
|
|
313
|
+
tool_name_delta: str | None = None
|
|
314
|
+
"""Incremental text to add to the existing tool name, if any."""
|
|
315
|
+
|
|
316
|
+
args_delta: str | dict[str, Any] | None = None
|
|
317
|
+
"""Incremental data to add to the tool arguments.
|
|
318
|
+
|
|
319
|
+
If this is a string, it will be appended to existing JSON arguments.
|
|
320
|
+
If this is a dict, it will be merged with existing dict arguments.
|
|
321
|
+
"""
|
|
322
|
+
|
|
323
|
+
tool_call_id: str | None = None
|
|
324
|
+
"""Optional tool call identifier, this is used by some models including OpenAI.
|
|
325
|
+
|
|
326
|
+
Note this is never treated as a delta — it can replace None, but otherwise if a
|
|
327
|
+
non-matching value is provided an error will be raised."""
|
|
328
|
+
|
|
329
|
+
part_delta_kind: Literal['tool_call'] = 'tool_call'
|
|
330
|
+
"""Part delta type identifier, used as a discriminator."""
|
|
331
|
+
|
|
332
|
+
def as_part(self) -> ToolCallPart | None:
|
|
333
|
+
"""Convert this delta to a fully formed `ToolCallPart` if possible, otherwise return `None`.
|
|
334
|
+
|
|
335
|
+
Returns:
|
|
336
|
+
A `ToolCallPart` if both `tool_name_delta` and `args_delta` are set, otherwise `None`.
|
|
337
|
+
"""
|
|
338
|
+
if self.tool_name_delta is None or self.args_delta is None:
|
|
339
|
+
return None
|
|
340
|
+
|
|
341
|
+
return ToolCallPart.from_raw_args(
|
|
342
|
+
self.tool_name_delta,
|
|
343
|
+
self.args_delta,
|
|
344
|
+
self.tool_call_id,
|
|
345
|
+
)
|
|
346
|
+
|
|
347
|
+
@overload
|
|
348
|
+
def apply(self, part: ModelResponsePart) -> ToolCallPart: ...
|
|
349
|
+
|
|
350
|
+
@overload
|
|
351
|
+
def apply(self, part: ModelResponsePart | ToolCallPartDelta) -> ToolCallPart | ToolCallPartDelta: ...
|
|
352
|
+
|
|
353
|
+
def apply(self, part: ModelResponsePart | ToolCallPartDelta) -> ToolCallPart | ToolCallPartDelta:
|
|
354
|
+
"""Apply this delta to a part or delta, returning a new part or delta with the changes applied.
|
|
355
|
+
|
|
356
|
+
Args:
|
|
357
|
+
part: The existing model response part or delta to update.
|
|
358
|
+
|
|
359
|
+
Returns:
|
|
360
|
+
Either a new `ToolCallPart` or an updated `ToolCallPartDelta`.
|
|
361
|
+
|
|
362
|
+
Raises:
|
|
363
|
+
ValueError: If `part` is neither a `ToolCallPart` nor a `ToolCallPartDelta`.
|
|
364
|
+
UnexpectedModelBehavior: If applying JSON deltas to dict arguments or vice versa.
|
|
365
|
+
"""
|
|
366
|
+
if isinstance(part, ToolCallPart):
|
|
367
|
+
return self._apply_to_part(part)
|
|
368
|
+
|
|
369
|
+
if isinstance(part, ToolCallPartDelta):
|
|
370
|
+
return self._apply_to_delta(part)
|
|
371
|
+
|
|
372
|
+
raise ValueError(f'Can only apply ToolCallPartDeltas to ToolCallParts or ToolCallPartDeltas, not {part}')
|
|
373
|
+
|
|
374
|
+
def _apply_to_delta(self, delta: ToolCallPartDelta) -> ToolCallPart | ToolCallPartDelta:
|
|
375
|
+
"""Internal helper to apply this delta to another delta."""
|
|
376
|
+
if self.tool_name_delta:
|
|
377
|
+
# Append incremental text to the existing tool_name_delta
|
|
378
|
+
updated_tool_name_delta = (delta.tool_name_delta or '') + self.tool_name_delta
|
|
379
|
+
delta = replace(delta, tool_name_delta=updated_tool_name_delta)
|
|
380
|
+
|
|
381
|
+
if isinstance(self.args_delta, str):
|
|
382
|
+
if isinstance(delta.args_delta, dict):
|
|
383
|
+
raise UnexpectedModelBehavior(
|
|
384
|
+
f'Cannot apply JSON deltas to non-JSON tool arguments ({delta=}, {self=})'
|
|
385
|
+
)
|
|
386
|
+
updated_args_delta = (delta.args_delta or '') + self.args_delta
|
|
387
|
+
delta = replace(delta, args_delta=updated_args_delta)
|
|
388
|
+
elif isinstance(self.args_delta, dict):
|
|
389
|
+
if isinstance(delta.args_delta, str):
|
|
390
|
+
raise UnexpectedModelBehavior(
|
|
391
|
+
f'Cannot apply dict deltas to non-dict tool arguments ({delta=}, {self=})'
|
|
392
|
+
)
|
|
393
|
+
updated_args_delta = {**(delta.args_delta or {}), **self.args_delta}
|
|
394
|
+
delta = replace(delta, args_delta=updated_args_delta)
|
|
395
|
+
|
|
396
|
+
if self.tool_call_id:
|
|
397
|
+
# Set the tool_call_id if it wasn't present, otherwise error if it has changed
|
|
398
|
+
if delta.tool_call_id is not None and delta.tool_call_id != self.tool_call_id:
|
|
399
|
+
raise UnexpectedModelBehavior(
|
|
400
|
+
f'Cannot apply a new tool_call_id to a ToolCallPartDelta that already has one ({delta=}, {self=})'
|
|
401
|
+
)
|
|
402
|
+
delta = replace(delta, tool_call_id=self.tool_call_id)
|
|
403
|
+
|
|
404
|
+
# If we now have enough data to create a full ToolCallPart, do so
|
|
405
|
+
if delta.tool_name_delta is not None and delta.args_delta is not None:
|
|
406
|
+
return ToolCallPart.from_raw_args(
|
|
407
|
+
delta.tool_name_delta,
|
|
408
|
+
delta.args_delta,
|
|
409
|
+
delta.tool_call_id,
|
|
410
|
+
)
|
|
411
|
+
|
|
412
|
+
return delta
|
|
413
|
+
|
|
414
|
+
def _apply_to_part(self, part: ToolCallPart) -> ToolCallPart:
|
|
415
|
+
"""Internal helper to apply this delta directly to a `ToolCallPart`."""
|
|
416
|
+
if self.tool_name_delta:
|
|
417
|
+
# Append incremental text to the existing tool_name
|
|
418
|
+
tool_name = part.tool_name + self.tool_name_delta
|
|
419
|
+
part = replace(part, tool_name=tool_name)
|
|
420
|
+
|
|
421
|
+
if isinstance(self.args_delta, str):
|
|
422
|
+
if not isinstance(part.args, ArgsJson):
|
|
423
|
+
raise UnexpectedModelBehavior(f'Cannot apply JSON deltas to non-JSON tool arguments ({part=}, {self=})')
|
|
424
|
+
updated_json = part.args.args_json + self.args_delta
|
|
425
|
+
part = replace(part, args=ArgsJson(updated_json))
|
|
426
|
+
elif isinstance(self.args_delta, dict):
|
|
427
|
+
if not isinstance(part.args, ArgsDict):
|
|
428
|
+
raise UnexpectedModelBehavior(f'Cannot apply dict deltas to non-dict tool arguments ({part=}, {self=})')
|
|
429
|
+
updated_dict = {**(part.args.args_dict or {}), **self.args_delta}
|
|
430
|
+
part = replace(part, args=ArgsDict(updated_dict))
|
|
431
|
+
|
|
432
|
+
if self.tool_call_id:
|
|
433
|
+
# Replace the tool_call_id entirely if given
|
|
434
|
+
if part.tool_call_id is not None and part.tool_call_id != self.tool_call_id:
|
|
435
|
+
raise UnexpectedModelBehavior(
|
|
436
|
+
f'Cannot apply a new tool_call_id to a ToolCallPartDelta that already has one ({part=}, {self=})'
|
|
437
|
+
)
|
|
438
|
+
part = replace(part, tool_call_id=self.tool_call_id)
|
|
439
|
+
return part
|
|
440
|
+
|
|
441
|
+
|
|
442
|
+
ModelResponsePartDelta = Annotated[Union[TextPartDelta, ToolCallPartDelta], pydantic.Discriminator('part_delta_kind')]
|
|
443
|
+
"""A partial update (delta) for any model response part."""
|
|
444
|
+
|
|
445
|
+
|
|
446
|
+
@dataclass
|
|
447
|
+
class PartStartEvent:
|
|
448
|
+
"""An event indicating that a new part has started.
|
|
449
|
+
|
|
450
|
+
If multiple `PartStartEvent`s are received with the same index,
|
|
451
|
+
the new one should fully replace the old one.
|
|
452
|
+
"""
|
|
453
|
+
|
|
454
|
+
index: int
|
|
455
|
+
"""The index of the part within the overall response parts list."""
|
|
456
|
+
|
|
457
|
+
part: ModelResponsePart
|
|
458
|
+
"""The newly started `ModelResponsePart`."""
|
|
459
|
+
|
|
460
|
+
event_kind: Literal['part_start'] = 'part_start'
|
|
461
|
+
"""Event type identifier, used as a discriminator."""
|
|
462
|
+
|
|
463
|
+
|
|
464
|
+
@dataclass
|
|
465
|
+
class PartDeltaEvent:
|
|
466
|
+
"""An event indicating a delta update for an existing part."""
|
|
467
|
+
|
|
468
|
+
index: int
|
|
469
|
+
"""The index of the part within the overall response parts list."""
|
|
470
|
+
|
|
471
|
+
delta: ModelResponsePartDelta
|
|
472
|
+
"""The delta to apply to the specified part."""
|
|
473
|
+
|
|
474
|
+
event_kind: Literal['part_delta'] = 'part_delta'
|
|
475
|
+
"""Event type identifier, used as a discriminator."""
|
|
476
|
+
|
|
477
|
+
|
|
478
|
+
ModelResponseStreamEvent = Annotated[Union[PartStartEvent, PartDeltaEvent], pydantic.Discriminator('event_kind')]
|
|
479
|
+
"""An event in the model response stream, either starting a new part or applying a delta to an existing one."""
|
pydantic_ai/models/__init__.py
CHANGED
|
@@ -7,20 +7,22 @@ specific LLM being used.
|
|
|
7
7
|
from __future__ import annotations as _annotations
|
|
8
8
|
|
|
9
9
|
from abc import ABC, abstractmethod
|
|
10
|
-
from collections.abc import AsyncIterator,
|
|
10
|
+
from collections.abc import AsyncIterator, Iterator
|
|
11
11
|
from contextlib import asynccontextmanager, contextmanager
|
|
12
|
+
from dataclasses import dataclass, field
|
|
12
13
|
from datetime import datetime
|
|
13
14
|
from functools import cache
|
|
14
|
-
from typing import TYPE_CHECKING, Literal
|
|
15
|
+
from typing import TYPE_CHECKING, Literal
|
|
15
16
|
|
|
16
17
|
import httpx
|
|
17
18
|
|
|
19
|
+
from .._parts_manager import ModelResponsePartsManager
|
|
18
20
|
from ..exceptions import UserError
|
|
19
|
-
from ..messages import ModelMessage, ModelResponse
|
|
21
|
+
from ..messages import ModelMessage, ModelResponse, ModelResponseStreamEvent
|
|
20
22
|
from ..settings import ModelSettings
|
|
23
|
+
from ..usage import Usage
|
|
21
24
|
|
|
22
25
|
if TYPE_CHECKING:
|
|
23
|
-
from ..result import Usage
|
|
24
26
|
from ..tools import ToolDefinition
|
|
25
27
|
|
|
26
28
|
|
|
@@ -48,13 +50,12 @@ KnownModelName = Literal[
|
|
|
48
50
|
'groq:mixtral-8x7b-32768',
|
|
49
51
|
'groq:gemma2-9b-it',
|
|
50
52
|
'groq:gemma-7b-it',
|
|
51
|
-
'gemini-1.5-flash',
|
|
52
|
-
'gemini-1.5-pro',
|
|
53
|
-
'gemini-2.0-flash-exp',
|
|
54
|
-
'
|
|
55
|
-
'
|
|
56
|
-
|
|
57
|
-
# don't start with "mistral", we add the "mistral:" prefix to all to be explicit
|
|
53
|
+
'google-gla:gemini-1.5-flash',
|
|
54
|
+
'google-gla:gemini-1.5-pro',
|
|
55
|
+
'google-gla:gemini-2.0-flash-exp',
|
|
56
|
+
'google-vertex:gemini-1.5-flash',
|
|
57
|
+
'google-vertex:gemini-1.5-pro',
|
|
58
|
+
'google-vertex:gemini-2.0-flash-exp',
|
|
58
59
|
'mistral:mistral-small-latest',
|
|
59
60
|
'mistral:mistral-large-latest',
|
|
60
61
|
'mistral:codestral-latest',
|
|
@@ -71,14 +72,15 @@ KnownModelName = Literal[
|
|
|
71
72
|
'ollama:mistral-nemo',
|
|
72
73
|
'ollama:mixtral',
|
|
73
74
|
'ollama:phi3',
|
|
75
|
+
'ollama:phi4',
|
|
74
76
|
'ollama:qwq',
|
|
75
77
|
'ollama:qwen',
|
|
76
78
|
'ollama:qwen2',
|
|
77
79
|
'ollama:qwen2.5',
|
|
78
80
|
'ollama:starcoder2',
|
|
79
|
-
'claude-3-5-haiku-latest',
|
|
80
|
-
'claude-3-5-sonnet-latest',
|
|
81
|
-
'claude-3-opus-latest',
|
|
81
|
+
'anthropic:claude-3-5-haiku-latest',
|
|
82
|
+
'anthropic:claude-3-5-sonnet-latest',
|
|
83
|
+
'anthropic:claude-3-opus-latest',
|
|
82
84
|
'test',
|
|
83
85
|
]
|
|
84
86
|
"""Known model names that can be used with the `model` parameter of [`Agent`][pydantic_ai.Agent].
|
|
@@ -130,88 +132,47 @@ class AgentModel(ABC):
|
|
|
130
132
|
@asynccontextmanager
|
|
131
133
|
async def request_stream(
|
|
132
134
|
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
133
|
-
) -> AsyncIterator[
|
|
135
|
+
) -> AsyncIterator[StreamedResponse]:
|
|
134
136
|
"""Make a request to the model and return a streaming response."""
|
|
137
|
+
# This method is not required, but you need to implement it if you want to support streamed responses
|
|
135
138
|
raise NotImplementedError(f'Streamed requests not supported by this {self.__class__.__name__}')
|
|
136
139
|
# yield is required to make this a generator for type checking
|
|
137
140
|
# noinspection PyUnreachableCode
|
|
138
141
|
yield # pragma: no cover
|
|
139
142
|
|
|
140
143
|
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
def __aiter__(self) -> AsyncIterator[None]:
|
|
145
|
-
"""Stream the response as an async iterable, building up the text as it goes.
|
|
146
|
-
|
|
147
|
-
This is an async iterator that yields `None` to avoid doing the work of validating the input and
|
|
148
|
-
extracting the text field when it will often be thrown away.
|
|
149
|
-
"""
|
|
150
|
-
return self
|
|
151
|
-
|
|
152
|
-
@abstractmethod
|
|
153
|
-
async def __anext__(self) -> None:
|
|
154
|
-
"""Process the next chunk of the response, see above for why this returns `None`."""
|
|
155
|
-
raise NotImplementedError()
|
|
156
|
-
|
|
157
|
-
@abstractmethod
|
|
158
|
-
def get(self, *, final: bool = False) -> Iterable[str]:
|
|
159
|
-
"""Returns an iterable of text since the last call to `get()` — e.g. the text delta.
|
|
160
|
-
|
|
161
|
-
Args:
|
|
162
|
-
final: If True, this is the final call, after iteration is complete, the response should be fully validated
|
|
163
|
-
and all text extracted.
|
|
164
|
-
"""
|
|
165
|
-
raise NotImplementedError()
|
|
166
|
-
|
|
167
|
-
@abstractmethod
|
|
168
|
-
def usage(self) -> Usage:
|
|
169
|
-
"""Return the usage of the request.
|
|
170
|
-
|
|
171
|
-
NOTE: this won't return the full usage until the stream is finished.
|
|
172
|
-
"""
|
|
173
|
-
raise NotImplementedError()
|
|
174
|
-
|
|
175
|
-
@abstractmethod
|
|
176
|
-
def timestamp(self) -> datetime:
|
|
177
|
-
"""Get the timestamp of the response."""
|
|
178
|
-
raise NotImplementedError()
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
class StreamStructuredResponse(ABC):
|
|
144
|
+
@dataclass
|
|
145
|
+
class StreamedResponse(ABC):
|
|
182
146
|
"""Streamed response from an LLM when calling a tool."""
|
|
183
147
|
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
This is an async iterator that yields `None` to avoid doing the work of building the final tool call when
|
|
188
|
-
it will often be thrown away.
|
|
189
|
-
"""
|
|
190
|
-
return self
|
|
148
|
+
_usage: Usage = field(default_factory=Usage, init=False)
|
|
149
|
+
_parts_manager: ModelResponsePartsManager = field(default_factory=ModelResponsePartsManager, init=False)
|
|
150
|
+
_event_iterator: AsyncIterator[ModelResponseStreamEvent] | None = field(default=None, init=False)
|
|
191
151
|
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
152
|
+
def __aiter__(self) -> AsyncIterator[ModelResponseStreamEvent]:
|
|
153
|
+
"""Stream the response as an async iterable of [`ModelResponseStreamEvent`][pydantic_ai.messages.ModelResponseStreamEvent]s."""
|
|
154
|
+
if self._event_iterator is None:
|
|
155
|
+
self._event_iterator = self._get_event_iterator()
|
|
156
|
+
return self._event_iterator
|
|
196
157
|
|
|
197
158
|
@abstractmethod
|
|
198
|
-
def
|
|
199
|
-
"""
|
|
200
|
-
|
|
201
|
-
The `ModelResponse` may or may not be complete, depending on whether the stream is finished.
|
|
159
|
+
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
|
|
160
|
+
"""Return an async iterator of [`ModelResponseStreamEvent`][pydantic_ai.messages.ModelResponseStreamEvent]s.
|
|
202
161
|
|
|
203
|
-
|
|
204
|
-
|
|
162
|
+
This method should be implemented by subclasses to translate the vendor-specific stream of events into
|
|
163
|
+
pydantic_ai-format events.
|
|
205
164
|
"""
|
|
206
165
|
raise NotImplementedError()
|
|
166
|
+
# noinspection PyUnreachableCode
|
|
167
|
+
yield
|
|
207
168
|
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
169
|
+
def get(self) -> ModelResponse:
|
|
170
|
+
"""Build a [`ModelResponse`][pydantic_ai.messages.ModelResponse] from the data received from the stream so far."""
|
|
171
|
+
return ModelResponse(parts=self._parts_manager.get_parts(), timestamp=self.timestamp())
|
|
211
172
|
|
|
212
|
-
|
|
213
|
-
"""
|
|
214
|
-
|
|
173
|
+
def usage(self) -> Usage:
|
|
174
|
+
"""Get the usage of the response so far. This will not be the final usage until the stream is exhausted."""
|
|
175
|
+
return self._usage
|
|
215
176
|
|
|
216
177
|
@abstractmethod
|
|
217
178
|
def timestamp(self) -> datetime:
|
|
@@ -219,9 +180,6 @@ class StreamStructuredResponse(ABC):
|
|
|
219
180
|
raise NotImplementedError()
|
|
220
181
|
|
|
221
182
|
|
|
222
|
-
EitherStreamedResponse = Union[StreamTextResponse, StreamStructuredResponse]
|
|
223
|
-
|
|
224
|
-
|
|
225
183
|
ALLOW_MODEL_REQUESTS = True
|
|
226
184
|
"""Whether to allow requests to models.
|
|
227
185
|
|
|
@@ -274,6 +232,15 @@ def infer_model(model: Model | KnownModelName) -> Model:
|
|
|
274
232
|
from .openai import OpenAIModel
|
|
275
233
|
|
|
276
234
|
return OpenAIModel(model[7:])
|
|
235
|
+
elif model.startswith(('gpt', 'o1')):
|
|
236
|
+
from .openai import OpenAIModel
|
|
237
|
+
|
|
238
|
+
return OpenAIModel(model)
|
|
239
|
+
elif model.startswith('google-gla'):
|
|
240
|
+
from .gemini import GeminiModel
|
|
241
|
+
|
|
242
|
+
return GeminiModel(model[11:]) # pyright: ignore[reportArgumentType]
|
|
243
|
+
# backwards compatibility with old model names (ex, gemini-1.5-flash -> google-gla:gemini-1.5-flash)
|
|
277
244
|
elif model.startswith('gemini'):
|
|
278
245
|
from .gemini import GeminiModel
|
|
279
246
|
|
|
@@ -283,6 +250,11 @@ def infer_model(model: Model | KnownModelName) -> Model:
|
|
|
283
250
|
from .groq import GroqModel
|
|
284
251
|
|
|
285
252
|
return GroqModel(model[5:]) # pyright: ignore[reportArgumentType]
|
|
253
|
+
elif model.startswith('google-vertex'):
|
|
254
|
+
from .vertexai import VertexAIModel
|
|
255
|
+
|
|
256
|
+
return VertexAIModel(model[14:]) # pyright: ignore[reportArgumentType]
|
|
257
|
+
# backwards compatibility with old model names (ex, vertexai:gemini-1.5-flash -> google-vertex:gemini-1.5-flash)
|
|
286
258
|
elif model.startswith('vertexai:'):
|
|
287
259
|
from .vertexai import VertexAIModel
|
|
288
260
|
|
|
@@ -295,6 +267,11 @@ def infer_model(model: Model | KnownModelName) -> Model:
|
|
|
295
267
|
from .ollama import OllamaModel
|
|
296
268
|
|
|
297
269
|
return OllamaModel(model[7:])
|
|
270
|
+
elif model.startswith('anthropic'):
|
|
271
|
+
from .anthropic import AnthropicModel
|
|
272
|
+
|
|
273
|
+
return AnthropicModel(model[10:])
|
|
274
|
+
# backwards compatibility with old model names (ex, claude-3-5-sonnet-latest -> anthropic:claude-3-5-sonnet-latest)
|
|
298
275
|
elif model.startswith('claude'):
|
|
299
276
|
from .anthropic import AnthropicModel
|
|
300
277
|
|
pydantic_ai/models/anthropic.py
CHANGED
|
@@ -8,7 +8,7 @@ from typing import Any, Literal, Union, cast, overload
|
|
|
8
8
|
from httpx import AsyncClient as AsyncHTTPClient
|
|
9
9
|
from typing_extensions import assert_never
|
|
10
10
|
|
|
11
|
-
from .. import
|
|
11
|
+
from .. import usage
|
|
12
12
|
from .._utils import guard_tool_call_id as _guard_tool_call_id
|
|
13
13
|
from ..messages import (
|
|
14
14
|
ArgsDict,
|
|
@@ -27,8 +27,8 @@ from ..settings import ModelSettings
|
|
|
27
27
|
from ..tools import ToolDefinition
|
|
28
28
|
from . import (
|
|
29
29
|
AgentModel,
|
|
30
|
-
EitherStreamedResponse,
|
|
31
30
|
Model,
|
|
31
|
+
StreamedResponse,
|
|
32
32
|
cached_async_http_client,
|
|
33
33
|
check_allow_model_requests,
|
|
34
34
|
)
|
|
@@ -136,7 +136,7 @@ class AnthropicModel(Model):
|
|
|
136
136
|
)
|
|
137
137
|
|
|
138
138
|
def name(self) -> str:
|
|
139
|
-
return self.model_name
|
|
139
|
+
return f'anthropic:{self.model_name}'
|
|
140
140
|
|
|
141
141
|
@staticmethod
|
|
142
142
|
def _map_tool_definition(f: ToolDefinition) -> ToolParam:
|
|
@@ -158,14 +158,14 @@ class AnthropicAgentModel(AgentModel):
|
|
|
158
158
|
|
|
159
159
|
async def request(
|
|
160
160
|
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
161
|
-
) -> tuple[ModelResponse,
|
|
161
|
+
) -> tuple[ModelResponse, usage.Usage]:
|
|
162
162
|
response = await self._messages_create(messages, False, model_settings)
|
|
163
163
|
return self._process_response(response), _map_usage(response)
|
|
164
164
|
|
|
165
165
|
@asynccontextmanager
|
|
166
166
|
async def request_stream(
|
|
167
167
|
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
168
|
-
) -> AsyncIterator[
|
|
168
|
+
) -> AsyncIterator[StreamedResponse]:
|
|
169
169
|
response = await self._messages_create(messages, True, model_settings)
|
|
170
170
|
async with response:
|
|
171
171
|
yield await self._process_streamed_response(response)
|
|
@@ -216,28 +216,28 @@ class AnthropicAgentModel(AgentModel):
|
|
|
216
216
|
items: list[ModelResponsePart] = []
|
|
217
217
|
for item in response.content:
|
|
218
218
|
if isinstance(item, TextBlock):
|
|
219
|
-
items.append(TextPart(item.text))
|
|
219
|
+
items.append(TextPart(content=item.text))
|
|
220
220
|
else:
|
|
221
221
|
assert isinstance(item, ToolUseBlock), 'unexpected item type'
|
|
222
222
|
items.append(
|
|
223
223
|
ToolCallPart.from_raw_args(
|
|
224
|
-
item.name,
|
|
225
|
-
cast(dict[str, Any], item.input),
|
|
226
|
-
item.id,
|
|
224
|
+
tool_name=item.name,
|
|
225
|
+
args=cast(dict[str, Any], item.input),
|
|
226
|
+
tool_call_id=item.id,
|
|
227
227
|
)
|
|
228
228
|
)
|
|
229
229
|
|
|
230
230
|
return ModelResponse(items)
|
|
231
231
|
|
|
232
232
|
@staticmethod
|
|
233
|
-
async def _process_streamed_response(response: AsyncStream[RawMessageStreamEvent]) ->
|
|
233
|
+
async def _process_streamed_response(response: AsyncStream[RawMessageStreamEvent]) -> StreamedResponse:
|
|
234
234
|
"""TODO: Process a streamed response, and prepare a streaming response to return."""
|
|
235
235
|
# We don't yet support streamed responses from Anthropic, so we raise an error here for now.
|
|
236
236
|
# Streamed responses will be supported in a future release.
|
|
237
237
|
|
|
238
238
|
raise RuntimeError('Streamed responses are not yet supported for Anthropic models.')
|
|
239
239
|
|
|
240
|
-
# Should be returning some sort of AnthropicStreamTextResponse or
|
|
240
|
+
# Should be returning some sort of AnthropicStreamTextResponse or AnthropicStreamedResponse
|
|
241
241
|
# depending on the type of chunk we get, but we need to establish how we handle (and when we get) the following:
|
|
242
242
|
# RawMessageStartEvent
|
|
243
243
|
# RawMessageDeltaEvent
|
|
@@ -315,30 +315,30 @@ def _map_tool_call(t: ToolCallPart) -> ToolUseBlockParam:
|
|
|
315
315
|
)
|
|
316
316
|
|
|
317
317
|
|
|
318
|
-
def _map_usage(message: AnthropicMessage | RawMessageStreamEvent) ->
|
|
318
|
+
def _map_usage(message: AnthropicMessage | RawMessageStreamEvent) -> usage.Usage:
|
|
319
319
|
if isinstance(message, AnthropicMessage):
|
|
320
|
-
|
|
320
|
+
response_usage = message.usage
|
|
321
321
|
else:
|
|
322
322
|
if isinstance(message, RawMessageStartEvent):
|
|
323
|
-
|
|
323
|
+
response_usage = message.message.usage
|
|
324
324
|
elif isinstance(message, RawMessageDeltaEvent):
|
|
325
|
-
|
|
325
|
+
response_usage = message.usage
|
|
326
326
|
else:
|
|
327
327
|
# No usage information provided in:
|
|
328
328
|
# - RawMessageStopEvent
|
|
329
329
|
# - RawContentBlockStartEvent
|
|
330
330
|
# - RawContentBlockDeltaEvent
|
|
331
331
|
# - RawContentBlockStopEvent
|
|
332
|
-
|
|
332
|
+
response_usage = None
|
|
333
333
|
|
|
334
|
-
if
|
|
335
|
-
return
|
|
334
|
+
if response_usage is None:
|
|
335
|
+
return usage.Usage()
|
|
336
336
|
|
|
337
|
-
request_tokens = getattr(
|
|
337
|
+
request_tokens = getattr(response_usage, 'input_tokens', None)
|
|
338
338
|
|
|
339
|
-
return
|
|
339
|
+
return usage.Usage(
|
|
340
340
|
# Usage coming from the RawMessageDeltaEvent doesn't have input token data, hence this getattr
|
|
341
341
|
request_tokens=request_tokens,
|
|
342
|
-
response_tokens=
|
|
343
|
-
total_tokens=(request_tokens or 0) +
|
|
342
|
+
response_tokens=response_usage.output_tokens,
|
|
343
|
+
total_tokens=(request_tokens or 0) + response_usage.output_tokens,
|
|
344
344
|
)
|