pydantic-ai-slim 0.0.18__py3-none-any.whl → 0.0.20__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_griffe.py +10 -3
- pydantic_ai/_parts_manager.py +239 -0
- pydantic_ai/_pydantic.py +17 -3
- pydantic_ai/_result.py +26 -21
- pydantic_ai/_system_prompt.py +4 -4
- pydantic_ai/_utils.py +80 -17
- pydantic_ai/agent.py +187 -159
- pydantic_ai/format_as_xml.py +2 -1
- pydantic_ai/messages.py +217 -15
- pydantic_ai/models/__init__.py +58 -71
- pydantic_ai/models/anthropic.py +112 -48
- pydantic_ai/models/cohere.py +278 -0
- pydantic_ai/models/function.py +57 -85
- pydantic_ai/models/gemini.py +83 -129
- pydantic_ai/models/groq.py +60 -130
- pydantic_ai/models/mistral.py +86 -142
- pydantic_ai/models/ollama.py +4 -0
- pydantic_ai/models/openai.py +75 -136
- pydantic_ai/models/test.py +55 -80
- pydantic_ai/models/vertexai.py +2 -1
- pydantic_ai/result.py +132 -114
- pydantic_ai/settings.py +18 -1
- pydantic_ai/tools.py +42 -23
- {pydantic_ai_slim-0.0.18.dist-info → pydantic_ai_slim-0.0.20.dist-info}/METADATA +7 -3
- pydantic_ai_slim-0.0.20.dist-info/RECORD +30 -0
- pydantic_ai_slim-0.0.18.dist-info/RECORD +0 -28
- {pydantic_ai_slim-0.0.18.dist-info → pydantic_ai_slim-0.0.20.dist-info}/WHEEL +0 -0
pydantic_ai/format_as_xml.py
CHANGED
|
@@ -37,7 +37,8 @@ def format_as_xml(
|
|
|
37
37
|
none_str: String to use for `None` values.
|
|
38
38
|
indent: Indentation string to use for pretty printing.
|
|
39
39
|
|
|
40
|
-
Returns:
|
|
40
|
+
Returns:
|
|
41
|
+
XML representation of the object.
|
|
41
42
|
|
|
42
43
|
Example:
|
|
43
44
|
```python {title="format_as_xml_example.py" lint="skip"}
|
pydantic_ai/messages.py
CHANGED
|
@@ -1,14 +1,15 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
|
-
from dataclasses import dataclass, field
|
|
3
|
+
from dataclasses import dataclass, field, replace
|
|
4
4
|
from datetime import datetime
|
|
5
|
-
from typing import Annotated, Any, Literal, Union, cast
|
|
5
|
+
from typing import Annotated, Any, Literal, Union, cast, overload
|
|
6
6
|
|
|
7
7
|
import pydantic
|
|
8
8
|
import pydantic_core
|
|
9
9
|
from typing_extensions import Self, assert_never
|
|
10
10
|
|
|
11
11
|
from ._utils import now_utc as _now_utc
|
|
12
|
+
from .exceptions import UnexpectedModelBehavior
|
|
12
13
|
|
|
13
14
|
|
|
14
15
|
@dataclass
|
|
@@ -72,12 +73,14 @@ class ToolReturnPart:
|
|
|
72
73
|
"""Part type identifier, this is available on all parts as a discriminator."""
|
|
73
74
|
|
|
74
75
|
def model_response_str(self) -> str:
|
|
76
|
+
"""Return a string representation of the content for the model."""
|
|
75
77
|
if isinstance(self.content, str):
|
|
76
78
|
return self.content
|
|
77
79
|
else:
|
|
78
80
|
return tool_return_ta.dump_json(self.content).decode()
|
|
79
81
|
|
|
80
82
|
def model_response_object(self) -> dict[str, Any]:
|
|
83
|
+
"""Return a dictionary representation of the content, wrapping non-dict types appropriately."""
|
|
81
84
|
# gemini supports JSON dict return values, but no other JSON types, hence we wrap anything else in a dict
|
|
82
85
|
if isinstance(self.content, dict):
|
|
83
86
|
return tool_return_ta.dump_python(self.content, mode='json') # pyright: ignore[reportUnknownMemberType]
|
|
@@ -124,6 +127,7 @@ class RetryPromptPart:
|
|
|
124
127
|
"""Part type identifier, this is available on all parts as a discriminator."""
|
|
125
128
|
|
|
126
129
|
def model_response(self) -> str:
|
|
130
|
+
"""Return a string message describing why the retry is requested."""
|
|
127
131
|
if isinstance(self.content, str):
|
|
128
132
|
description = self.content
|
|
129
133
|
else:
|
|
@@ -159,6 +163,10 @@ class TextPart:
|
|
|
159
163
|
part_kind: Literal['text'] = 'text'
|
|
160
164
|
"""Part type identifier, this is available on all parts as a discriminator."""
|
|
161
165
|
|
|
166
|
+
def has_content(self) -> bool:
|
|
167
|
+
"""Return `True` if the text content is non-empty."""
|
|
168
|
+
return bool(self.content)
|
|
169
|
+
|
|
162
170
|
|
|
163
171
|
@dataclass
|
|
164
172
|
class ArgsJson:
|
|
@@ -197,7 +205,7 @@ class ToolCallPart:
|
|
|
197
205
|
|
|
198
206
|
@classmethod
|
|
199
207
|
def from_raw_args(cls, tool_name: str, args: str | dict[str, Any], tool_call_id: str | None = None) -> Self:
|
|
200
|
-
"""Create a `ToolCallPart` from raw arguments
|
|
208
|
+
"""Create a `ToolCallPart` from raw arguments, converting them to `ArgsJson` or `ArgsDict`."""
|
|
201
209
|
if isinstance(args, str):
|
|
202
210
|
return cls(tool_name, ArgsJson(args), tool_call_id)
|
|
203
211
|
elif isinstance(args, dict):
|
|
@@ -226,6 +234,7 @@ class ToolCallPart:
|
|
|
226
234
|
return pydantic_core.to_json(self.args.args_dict).decode()
|
|
227
235
|
|
|
228
236
|
def has_content(self) -> bool:
|
|
237
|
+
"""Return `True` if the arguments contain any data."""
|
|
229
238
|
if isinstance(self.args, ArgsDict):
|
|
230
239
|
return any(self.args.args_dict.values())
|
|
231
240
|
else:
|
|
@@ -243,6 +252,9 @@ class ModelResponse:
|
|
|
243
252
|
parts: list[ModelResponsePart]
|
|
244
253
|
"""The parts of the model message."""
|
|
245
254
|
|
|
255
|
+
model_name: str | None = None
|
|
256
|
+
"""The name of the model that generated the response."""
|
|
257
|
+
|
|
246
258
|
timestamp: datetime = field(default_factory=_now_utc)
|
|
247
259
|
"""The timestamp of the response.
|
|
248
260
|
|
|
@@ -252,19 +264,209 @@ class ModelResponse:
|
|
|
252
264
|
kind: Literal['response'] = 'response'
|
|
253
265
|
"""Message type identifier, this is available on all parts as a discriminator."""
|
|
254
266
|
|
|
255
|
-
@classmethod
|
|
256
|
-
def from_text(cls, content: str, timestamp: datetime | None = None) -> Self:
|
|
257
|
-
return cls([TextPart(content)], timestamp=timestamp or _now_utc())
|
|
258
267
|
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
return cls([tool_call])
|
|
268
|
+
ModelMessage = Annotated[Union[ModelRequest, ModelResponse], pydantic.Discriminator('kind')]
|
|
269
|
+
"""Any message sent to or returned by a model."""
|
|
262
270
|
|
|
271
|
+
ModelMessagesTypeAdapter = pydantic.TypeAdapter(list[ModelMessage], config=pydantic.ConfigDict(defer_build=True))
|
|
272
|
+
"""Pydantic [`TypeAdapter`][pydantic.type_adapter.TypeAdapter] for (de)serializing messages."""
|
|
263
273
|
|
|
264
|
-
ModelMessage = Union[ModelRequest, ModelResponse]
|
|
265
|
-
"""Any message send to or returned by a model."""
|
|
266
274
|
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
)
|
|
270
|
-
|
|
275
|
+
@dataclass
|
|
276
|
+
class TextPartDelta:
|
|
277
|
+
"""A partial update (delta) for a `TextPart` to append new text content."""
|
|
278
|
+
|
|
279
|
+
content_delta: str
|
|
280
|
+
"""The incremental text content to add to the existing `TextPart` content."""
|
|
281
|
+
|
|
282
|
+
part_delta_kind: Literal['text'] = 'text'
|
|
283
|
+
"""Part delta type identifier, used as a discriminator."""
|
|
284
|
+
|
|
285
|
+
def apply(self, part: ModelResponsePart) -> TextPart:
|
|
286
|
+
"""Apply this text delta to an existing `TextPart`.
|
|
287
|
+
|
|
288
|
+
Args:
|
|
289
|
+
part: The existing model response part, which must be a `TextPart`.
|
|
290
|
+
|
|
291
|
+
Returns:
|
|
292
|
+
A new `TextPart` with updated text content.
|
|
293
|
+
|
|
294
|
+
Raises:
|
|
295
|
+
ValueError: If `part` is not a `TextPart`.
|
|
296
|
+
"""
|
|
297
|
+
if not isinstance(part, TextPart):
|
|
298
|
+
raise ValueError('Cannot apply TextPartDeltas to non-TextParts')
|
|
299
|
+
return replace(part, content=part.content + self.content_delta)
|
|
300
|
+
|
|
301
|
+
|
|
302
|
+
@dataclass
|
|
303
|
+
class ToolCallPartDelta:
|
|
304
|
+
"""A partial update (delta) for a `ToolCallPart` to modify tool name, arguments, or tool call ID."""
|
|
305
|
+
|
|
306
|
+
tool_name_delta: str | None = None
|
|
307
|
+
"""Incremental text to add to the existing tool name, if any."""
|
|
308
|
+
|
|
309
|
+
args_delta: str | dict[str, Any] | None = None
|
|
310
|
+
"""Incremental data to add to the tool arguments.
|
|
311
|
+
|
|
312
|
+
If this is a string, it will be appended to existing JSON arguments.
|
|
313
|
+
If this is a dict, it will be merged with existing dict arguments.
|
|
314
|
+
"""
|
|
315
|
+
|
|
316
|
+
tool_call_id: str | None = None
|
|
317
|
+
"""Optional tool call identifier, this is used by some models including OpenAI.
|
|
318
|
+
|
|
319
|
+
Note this is never treated as a delta — it can replace None, but otherwise if a
|
|
320
|
+
non-matching value is provided an error will be raised."""
|
|
321
|
+
|
|
322
|
+
part_delta_kind: Literal['tool_call'] = 'tool_call'
|
|
323
|
+
"""Part delta type identifier, used as a discriminator."""
|
|
324
|
+
|
|
325
|
+
def as_part(self) -> ToolCallPart | None:
|
|
326
|
+
"""Convert this delta to a fully formed `ToolCallPart` if possible, otherwise return `None`.
|
|
327
|
+
|
|
328
|
+
Returns:
|
|
329
|
+
A `ToolCallPart` if both `tool_name_delta` and `args_delta` are set, otherwise `None`.
|
|
330
|
+
"""
|
|
331
|
+
if self.tool_name_delta is None or self.args_delta is None:
|
|
332
|
+
return None
|
|
333
|
+
|
|
334
|
+
return ToolCallPart.from_raw_args(
|
|
335
|
+
self.tool_name_delta,
|
|
336
|
+
self.args_delta,
|
|
337
|
+
self.tool_call_id,
|
|
338
|
+
)
|
|
339
|
+
|
|
340
|
+
@overload
|
|
341
|
+
def apply(self, part: ModelResponsePart) -> ToolCallPart: ...
|
|
342
|
+
|
|
343
|
+
@overload
|
|
344
|
+
def apply(self, part: ModelResponsePart | ToolCallPartDelta) -> ToolCallPart | ToolCallPartDelta: ...
|
|
345
|
+
|
|
346
|
+
def apply(self, part: ModelResponsePart | ToolCallPartDelta) -> ToolCallPart | ToolCallPartDelta:
|
|
347
|
+
"""Apply this delta to a part or delta, returning a new part or delta with the changes applied.
|
|
348
|
+
|
|
349
|
+
Args:
|
|
350
|
+
part: The existing model response part or delta to update.
|
|
351
|
+
|
|
352
|
+
Returns:
|
|
353
|
+
Either a new `ToolCallPart` or an updated `ToolCallPartDelta`.
|
|
354
|
+
|
|
355
|
+
Raises:
|
|
356
|
+
ValueError: If `part` is neither a `ToolCallPart` nor a `ToolCallPartDelta`.
|
|
357
|
+
UnexpectedModelBehavior: If applying JSON deltas to dict arguments or vice versa.
|
|
358
|
+
"""
|
|
359
|
+
if isinstance(part, ToolCallPart):
|
|
360
|
+
return self._apply_to_part(part)
|
|
361
|
+
|
|
362
|
+
if isinstance(part, ToolCallPartDelta):
|
|
363
|
+
return self._apply_to_delta(part)
|
|
364
|
+
|
|
365
|
+
raise ValueError(f'Can only apply ToolCallPartDeltas to ToolCallParts or ToolCallPartDeltas, not {part}')
|
|
366
|
+
|
|
367
|
+
def _apply_to_delta(self, delta: ToolCallPartDelta) -> ToolCallPart | ToolCallPartDelta:
|
|
368
|
+
"""Internal helper to apply this delta to another delta."""
|
|
369
|
+
if self.tool_name_delta:
|
|
370
|
+
# Append incremental text to the existing tool_name_delta
|
|
371
|
+
updated_tool_name_delta = (delta.tool_name_delta or '') + self.tool_name_delta
|
|
372
|
+
delta = replace(delta, tool_name_delta=updated_tool_name_delta)
|
|
373
|
+
|
|
374
|
+
if isinstance(self.args_delta, str):
|
|
375
|
+
if isinstance(delta.args_delta, dict):
|
|
376
|
+
raise UnexpectedModelBehavior(
|
|
377
|
+
f'Cannot apply JSON deltas to non-JSON tool arguments ({delta=}, {self=})'
|
|
378
|
+
)
|
|
379
|
+
updated_args_delta = (delta.args_delta or '') + self.args_delta
|
|
380
|
+
delta = replace(delta, args_delta=updated_args_delta)
|
|
381
|
+
elif isinstance(self.args_delta, dict):
|
|
382
|
+
if isinstance(delta.args_delta, str):
|
|
383
|
+
raise UnexpectedModelBehavior(
|
|
384
|
+
f'Cannot apply dict deltas to non-dict tool arguments ({delta=}, {self=})'
|
|
385
|
+
)
|
|
386
|
+
updated_args_delta = {**(delta.args_delta or {}), **self.args_delta}
|
|
387
|
+
delta = replace(delta, args_delta=updated_args_delta)
|
|
388
|
+
|
|
389
|
+
if self.tool_call_id:
|
|
390
|
+
# Set the tool_call_id if it wasn't present, otherwise error if it has changed
|
|
391
|
+
if delta.tool_call_id is not None and delta.tool_call_id != self.tool_call_id:
|
|
392
|
+
raise UnexpectedModelBehavior(
|
|
393
|
+
f'Cannot apply a new tool_call_id to a ToolCallPartDelta that already has one ({delta=}, {self=})'
|
|
394
|
+
)
|
|
395
|
+
delta = replace(delta, tool_call_id=self.tool_call_id)
|
|
396
|
+
|
|
397
|
+
# If we now have enough data to create a full ToolCallPart, do so
|
|
398
|
+
if delta.tool_name_delta is not None and delta.args_delta is not None:
|
|
399
|
+
return ToolCallPart.from_raw_args(
|
|
400
|
+
delta.tool_name_delta,
|
|
401
|
+
delta.args_delta,
|
|
402
|
+
delta.tool_call_id,
|
|
403
|
+
)
|
|
404
|
+
|
|
405
|
+
return delta
|
|
406
|
+
|
|
407
|
+
def _apply_to_part(self, part: ToolCallPart) -> ToolCallPart:
|
|
408
|
+
"""Internal helper to apply this delta directly to a `ToolCallPart`."""
|
|
409
|
+
if self.tool_name_delta:
|
|
410
|
+
# Append incremental text to the existing tool_name
|
|
411
|
+
tool_name = part.tool_name + self.tool_name_delta
|
|
412
|
+
part = replace(part, tool_name=tool_name)
|
|
413
|
+
|
|
414
|
+
if isinstance(self.args_delta, str):
|
|
415
|
+
if not isinstance(part.args, ArgsJson):
|
|
416
|
+
raise UnexpectedModelBehavior(f'Cannot apply JSON deltas to non-JSON tool arguments ({part=}, {self=})')
|
|
417
|
+
updated_json = part.args.args_json + self.args_delta
|
|
418
|
+
part = replace(part, args=ArgsJson(updated_json))
|
|
419
|
+
elif isinstance(self.args_delta, dict):
|
|
420
|
+
if not isinstance(part.args, ArgsDict):
|
|
421
|
+
raise UnexpectedModelBehavior(f'Cannot apply dict deltas to non-dict tool arguments ({part=}, {self=})')
|
|
422
|
+
updated_dict = {**(part.args.args_dict or {}), **self.args_delta}
|
|
423
|
+
part = replace(part, args=ArgsDict(updated_dict))
|
|
424
|
+
|
|
425
|
+
if self.tool_call_id:
|
|
426
|
+
# Replace the tool_call_id entirely if given
|
|
427
|
+
if part.tool_call_id is not None and part.tool_call_id != self.tool_call_id:
|
|
428
|
+
raise UnexpectedModelBehavior(
|
|
429
|
+
f'Cannot apply a new tool_call_id to a ToolCallPartDelta that already has one ({part=}, {self=})'
|
|
430
|
+
)
|
|
431
|
+
part = replace(part, tool_call_id=self.tool_call_id)
|
|
432
|
+
return part
|
|
433
|
+
|
|
434
|
+
|
|
435
|
+
ModelResponsePartDelta = Annotated[Union[TextPartDelta, ToolCallPartDelta], pydantic.Discriminator('part_delta_kind')]
|
|
436
|
+
"""A partial update (delta) for any model response part."""
|
|
437
|
+
|
|
438
|
+
|
|
439
|
+
@dataclass
|
|
440
|
+
class PartStartEvent:
|
|
441
|
+
"""An event indicating that a new part has started.
|
|
442
|
+
|
|
443
|
+
If multiple `PartStartEvent`s are received with the same index,
|
|
444
|
+
the new one should fully replace the old one.
|
|
445
|
+
"""
|
|
446
|
+
|
|
447
|
+
index: int
|
|
448
|
+
"""The index of the part within the overall response parts list."""
|
|
449
|
+
|
|
450
|
+
part: ModelResponsePart
|
|
451
|
+
"""The newly started `ModelResponsePart`."""
|
|
452
|
+
|
|
453
|
+
event_kind: Literal['part_start'] = 'part_start'
|
|
454
|
+
"""Event type identifier, used as a discriminator."""
|
|
455
|
+
|
|
456
|
+
|
|
457
|
+
@dataclass
|
|
458
|
+
class PartDeltaEvent:
|
|
459
|
+
"""An event indicating a delta update for an existing part."""
|
|
460
|
+
|
|
461
|
+
index: int
|
|
462
|
+
"""The index of the part within the overall response parts list."""
|
|
463
|
+
|
|
464
|
+
delta: ModelResponsePartDelta
|
|
465
|
+
"""The delta to apply to the specified part."""
|
|
466
|
+
|
|
467
|
+
event_kind: Literal['part_delta'] = 'part_delta'
|
|
468
|
+
"""Event type identifier, used as a discriminator."""
|
|
469
|
+
|
|
470
|
+
|
|
471
|
+
ModelResponseStreamEvent = Annotated[Union[PartStartEvent, PartDeltaEvent], pydantic.Discriminator('event_kind')]
|
|
472
|
+
"""An event in the model response stream, either starting a new part or applying a delta to an existing one."""
|
pydantic_ai/models/__init__.py
CHANGED
|
@@ -7,20 +7,22 @@ specific LLM being used.
|
|
|
7
7
|
from __future__ import annotations as _annotations
|
|
8
8
|
|
|
9
9
|
from abc import ABC, abstractmethod
|
|
10
|
-
from collections.abc import AsyncIterator,
|
|
10
|
+
from collections.abc import AsyncIterator, Iterator
|
|
11
11
|
from contextlib import asynccontextmanager, contextmanager
|
|
12
|
+
from dataclasses import dataclass, field
|
|
12
13
|
from datetime import datetime
|
|
13
14
|
from functools import cache
|
|
14
|
-
from typing import TYPE_CHECKING, Literal
|
|
15
|
+
from typing import TYPE_CHECKING, Literal
|
|
15
16
|
|
|
16
17
|
import httpx
|
|
17
18
|
|
|
19
|
+
from .._parts_manager import ModelResponsePartsManager
|
|
18
20
|
from ..exceptions import UserError
|
|
19
|
-
from ..messages import ModelMessage, ModelResponse
|
|
21
|
+
from ..messages import ModelMessage, ModelResponse, ModelResponseStreamEvent
|
|
20
22
|
from ..settings import ModelSettings
|
|
23
|
+
from ..usage import Usage
|
|
21
24
|
|
|
22
25
|
if TYPE_CHECKING:
|
|
23
|
-
from ..result import Usage
|
|
24
26
|
from ..tools import ToolDefinition
|
|
25
27
|
|
|
26
28
|
|
|
@@ -59,6 +61,7 @@ KnownModelName = Literal[
|
|
|
59
61
|
'mistral:codestral-latest',
|
|
60
62
|
'mistral:mistral-moderation-latest',
|
|
61
63
|
'ollama:codellama',
|
|
64
|
+
'ollama:deepseek-r1',
|
|
62
65
|
'ollama:gemma',
|
|
63
66
|
'ollama:gemma2',
|
|
64
67
|
'ollama:llama3',
|
|
@@ -70,6 +73,7 @@ KnownModelName = Literal[
|
|
|
70
73
|
'ollama:mistral-nemo',
|
|
71
74
|
'ollama:mixtral',
|
|
72
75
|
'ollama:phi3',
|
|
76
|
+
'ollama:phi4',
|
|
73
77
|
'ollama:qwq',
|
|
74
78
|
'ollama:qwen',
|
|
75
79
|
'ollama:qwen2',
|
|
@@ -78,6 +82,22 @@ KnownModelName = Literal[
|
|
|
78
82
|
'anthropic:claude-3-5-haiku-latest',
|
|
79
83
|
'anthropic:claude-3-5-sonnet-latest',
|
|
80
84
|
'anthropic:claude-3-opus-latest',
|
|
85
|
+
'claude-3-5-haiku-latest',
|
|
86
|
+
'claude-3-5-sonnet-latest',
|
|
87
|
+
'claude-3-opus-latest',
|
|
88
|
+
'cohere:c4ai-aya-expanse-32b',
|
|
89
|
+
'cohere:c4ai-aya-expanse-8b',
|
|
90
|
+
'cohere:command',
|
|
91
|
+
'cohere:command-light',
|
|
92
|
+
'cohere:command-light-nightly',
|
|
93
|
+
'cohere:command-nightly',
|
|
94
|
+
'cohere:command-r',
|
|
95
|
+
'cohere:command-r-03-2024',
|
|
96
|
+
'cohere:command-r-08-2024',
|
|
97
|
+
'cohere:command-r-plus',
|
|
98
|
+
'cohere:command-r-plus-04-2024',
|
|
99
|
+
'cohere:command-r-plus-08-2024',
|
|
100
|
+
'cohere:command-r7b-12-2024',
|
|
81
101
|
'test',
|
|
82
102
|
]
|
|
83
103
|
"""Known model names that can be used with the `model` parameter of [`Agent`][pydantic_ai.Agent].
|
|
@@ -129,88 +149,54 @@ class AgentModel(ABC):
|
|
|
129
149
|
@asynccontextmanager
|
|
130
150
|
async def request_stream(
|
|
131
151
|
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
132
|
-
) -> AsyncIterator[
|
|
152
|
+
) -> AsyncIterator[StreamedResponse]:
|
|
133
153
|
"""Make a request to the model and return a streaming response."""
|
|
154
|
+
# This method is not required, but you need to implement it if you want to support streamed responses
|
|
134
155
|
raise NotImplementedError(f'Streamed requests not supported by this {self.__class__.__name__}')
|
|
135
156
|
# yield is required to make this a generator for type checking
|
|
136
157
|
# noinspection PyUnreachableCode
|
|
137
158
|
yield # pragma: no cover
|
|
138
159
|
|
|
139
160
|
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
def __aiter__(self) -> AsyncIterator[None]:
|
|
144
|
-
"""Stream the response as an async iterable, building up the text as it goes.
|
|
145
|
-
|
|
146
|
-
This is an async iterator that yields `None` to avoid doing the work of validating the input and
|
|
147
|
-
extracting the text field when it will often be thrown away.
|
|
148
|
-
"""
|
|
149
|
-
return self
|
|
150
|
-
|
|
151
|
-
@abstractmethod
|
|
152
|
-
async def __anext__(self) -> None:
|
|
153
|
-
"""Process the next chunk of the response, see above for why this returns `None`."""
|
|
154
|
-
raise NotImplementedError()
|
|
155
|
-
|
|
156
|
-
@abstractmethod
|
|
157
|
-
def get(self, *, final: bool = False) -> Iterable[str]:
|
|
158
|
-
"""Returns an iterable of text since the last call to `get()` — e.g. the text delta.
|
|
159
|
-
|
|
160
|
-
Args:
|
|
161
|
-
final: If True, this is the final call, after iteration is complete, the response should be fully validated
|
|
162
|
-
and all text extracted.
|
|
163
|
-
"""
|
|
164
|
-
raise NotImplementedError()
|
|
161
|
+
@dataclass
|
|
162
|
+
class StreamedResponse(ABC):
|
|
163
|
+
"""Streamed response from an LLM when calling a tool."""
|
|
165
164
|
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
165
|
+
_model_name: str
|
|
166
|
+
_usage: Usage = field(default_factory=Usage, init=False)
|
|
167
|
+
_parts_manager: ModelResponsePartsManager = field(default_factory=ModelResponsePartsManager, init=False)
|
|
168
|
+
_event_iterator: AsyncIterator[ModelResponseStreamEvent] | None = field(default=None, init=False)
|
|
169
169
|
|
|
170
|
-
|
|
171
|
-
"""
|
|
172
|
-
|
|
170
|
+
def __aiter__(self) -> AsyncIterator[ModelResponseStreamEvent]:
|
|
171
|
+
"""Stream the response as an async iterable of [`ModelResponseStreamEvent`][pydantic_ai.messages.ModelResponseStreamEvent]s."""
|
|
172
|
+
if self._event_iterator is None:
|
|
173
|
+
self._event_iterator = self._get_event_iterator()
|
|
174
|
+
return self._event_iterator
|
|
173
175
|
|
|
174
176
|
@abstractmethod
|
|
175
|
-
def
|
|
176
|
-
"""
|
|
177
|
-
raise NotImplementedError()
|
|
177
|
+
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
|
|
178
|
+
"""Return an async iterator of [`ModelResponseStreamEvent`][pydantic_ai.messages.ModelResponseStreamEvent]s.
|
|
178
179
|
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
"""Streamed response from an LLM when calling a tool."""
|
|
182
|
-
|
|
183
|
-
def __aiter__(self) -> AsyncIterator[None]:
|
|
184
|
-
"""Stream the response as an async iterable, building up the tool call as it goes.
|
|
185
|
-
|
|
186
|
-
This is an async iterator that yields `None` to avoid doing the work of building the final tool call when
|
|
187
|
-
it will often be thrown away.
|
|
180
|
+
This method should be implemented by subclasses to translate the vendor-specific stream of events into
|
|
181
|
+
pydantic_ai-format events.
|
|
188
182
|
"""
|
|
189
|
-
return self
|
|
190
|
-
|
|
191
|
-
@abstractmethod
|
|
192
|
-
async def __anext__(self) -> None:
|
|
193
|
-
"""Process the next chunk of the response, see above for why this returns `None`."""
|
|
194
183
|
raise NotImplementedError()
|
|
184
|
+
# noinspection PyUnreachableCode
|
|
185
|
+
yield
|
|
195
186
|
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
187
|
+
def get(self) -> ModelResponse:
|
|
188
|
+
"""Build a [`ModelResponse`][pydantic_ai.messages.ModelResponse] from the data received from the stream so far."""
|
|
189
|
+
return ModelResponse(
|
|
190
|
+
parts=self._parts_manager.get_parts(), model_name=self._model_name, timestamp=self.timestamp()
|
|
191
|
+
)
|
|
201
192
|
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
raise NotImplementedError()
|
|
193
|
+
def model_name(self) -> str:
|
|
194
|
+
"""Get the model name of the response."""
|
|
195
|
+
return self._model_name
|
|
206
196
|
|
|
207
|
-
@abstractmethod
|
|
208
197
|
def usage(self) -> Usage:
|
|
209
|
-
"""Get the usage of the
|
|
210
|
-
|
|
211
|
-
NOTE: this won't return the full usage until the stream is finished.
|
|
212
|
-
"""
|
|
213
|
-
raise NotImplementedError()
|
|
198
|
+
"""Get the usage of the response so far. This will not be the final usage until the stream is exhausted."""
|
|
199
|
+
return self._usage
|
|
214
200
|
|
|
215
201
|
@abstractmethod
|
|
216
202
|
def timestamp(self) -> datetime:
|
|
@@ -218,9 +204,6 @@ class StreamStructuredResponse(ABC):
|
|
|
218
204
|
raise NotImplementedError()
|
|
219
205
|
|
|
220
206
|
|
|
221
|
-
EitherStreamedResponse = Union[StreamTextResponse, StreamStructuredResponse]
|
|
222
|
-
|
|
223
|
-
|
|
224
207
|
ALLOW_MODEL_REQUESTS = True
|
|
225
208
|
"""Whether to allow requests to models.
|
|
226
209
|
|
|
@@ -269,6 +252,10 @@ def infer_model(model: Model | KnownModelName) -> Model:
|
|
|
269
252
|
from .test import TestModel
|
|
270
253
|
|
|
271
254
|
return TestModel()
|
|
255
|
+
elif model.startswith('cohere:'):
|
|
256
|
+
from .cohere import CohereModel
|
|
257
|
+
|
|
258
|
+
return CohereModel(model[7:])
|
|
272
259
|
elif model.startswith('openai:'):
|
|
273
260
|
from .openai import OpenAIModel
|
|
274
261
|
|