pydantic-ai-slim 0.0.24__py3-none-any.whl → 0.0.26__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/__init__.py +16 -4
- pydantic_ai/_agent_graph.py +264 -351
- pydantic_ai/_utils.py +1 -1
- pydantic_ai/agent.py +581 -156
- pydantic_ai/messages.py +121 -1
- pydantic_ai/models/__init__.py +12 -1
- pydantic_ai/models/anthropic.py +67 -50
- pydantic_ai/models/cohere.py +5 -2
- pydantic_ai/models/function.py +15 -6
- pydantic_ai/models/gemini.py +73 -5
- pydantic_ai/models/groq.py +35 -8
- pydantic_ai/models/instrumented.py +225 -0
- pydantic_ai/models/mistral.py +29 -4
- pydantic_ai/models/openai.py +59 -13
- pydantic_ai/models/test.py +6 -6
- pydantic_ai/models/wrapper.py +45 -0
- pydantic_ai/result.py +106 -144
- pydantic_ai/tools.py +2 -2
- {pydantic_ai_slim-0.0.24.dist-info → pydantic_ai_slim-0.0.26.dist-info}/METADATA +2 -2
- pydantic_ai_slim-0.0.26.dist-info/RECORD +32 -0
- pydantic_ai_slim-0.0.24.dist-info/RECORD +0 -30
- {pydantic_ai_slim-0.0.24.dist-info → pydantic_ai_slim-0.0.26.dist-info}/WHEEL +0 -0
pydantic_ai/messages.py
CHANGED
|
@@ -1,11 +1,14 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
|
+
import uuid
|
|
4
|
+
from collections.abc import Sequence
|
|
3
5
|
from dataclasses import dataclass, field, replace
|
|
4
6
|
from datetime import datetime
|
|
5
7
|
from typing import Annotated, Any, Literal, Union, cast, overload
|
|
6
8
|
|
|
7
9
|
import pydantic
|
|
8
10
|
import pydantic_core
|
|
11
|
+
from typing_extensions import TypeAlias
|
|
9
12
|
|
|
10
13
|
from ._utils import now_utc as _now_utc
|
|
11
14
|
from .exceptions import UnexpectedModelBehavior
|
|
@@ -31,6 +34,93 @@ class SystemPromptPart:
|
|
|
31
34
|
"""Part type identifier, this is available on all parts as a discriminator."""
|
|
32
35
|
|
|
33
36
|
|
|
37
|
+
@dataclass
|
|
38
|
+
class AudioUrl:
|
|
39
|
+
"""A URL to an audio file."""
|
|
40
|
+
|
|
41
|
+
url: str
|
|
42
|
+
"""The URL of the audio file."""
|
|
43
|
+
|
|
44
|
+
kind: Literal['audio-url'] = 'audio-url'
|
|
45
|
+
"""Type identifier, this is available on all parts as a discriminator."""
|
|
46
|
+
|
|
47
|
+
@property
|
|
48
|
+
def media_type(self) -> AudioMediaType:
|
|
49
|
+
"""Return the media type of the audio file, based on the url."""
|
|
50
|
+
if self.url.endswith('.mp3'):
|
|
51
|
+
return 'audio/mpeg'
|
|
52
|
+
elif self.url.endswith('.wav'):
|
|
53
|
+
return 'audio/wav'
|
|
54
|
+
else:
|
|
55
|
+
raise ValueError(f'Unknown audio file extension: {self.url}')
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
@dataclass
|
|
59
|
+
class ImageUrl:
|
|
60
|
+
"""A URL to an image."""
|
|
61
|
+
|
|
62
|
+
url: str
|
|
63
|
+
"""The URL of the image."""
|
|
64
|
+
|
|
65
|
+
kind: Literal['image-url'] = 'image-url'
|
|
66
|
+
"""Type identifier, this is available on all parts as a discriminator."""
|
|
67
|
+
|
|
68
|
+
@property
|
|
69
|
+
def media_type(self) -> ImageMediaType:
|
|
70
|
+
"""Return the media type of the image, based on the url."""
|
|
71
|
+
if self.url.endswith(('.jpg', '.jpeg')):
|
|
72
|
+
return 'image/jpeg'
|
|
73
|
+
elif self.url.endswith('.png'):
|
|
74
|
+
return 'image/png'
|
|
75
|
+
elif self.url.endswith('.gif'):
|
|
76
|
+
return 'image/gif'
|
|
77
|
+
elif self.url.endswith('.webp'):
|
|
78
|
+
return 'image/webp'
|
|
79
|
+
else:
|
|
80
|
+
raise ValueError(f'Unknown image file extension: {self.url}')
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
AudioMediaType: TypeAlias = Literal['audio/wav', 'audio/mpeg']
|
|
84
|
+
ImageMediaType: TypeAlias = Literal['image/jpeg', 'image/png', 'image/gif', 'image/webp']
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
@dataclass
|
|
88
|
+
class BinaryContent:
|
|
89
|
+
"""Binary content, e.g. an audio or image file."""
|
|
90
|
+
|
|
91
|
+
data: bytes
|
|
92
|
+
"""The binary data."""
|
|
93
|
+
|
|
94
|
+
media_type: AudioMediaType | ImageMediaType | str
|
|
95
|
+
"""The media type of the binary data."""
|
|
96
|
+
|
|
97
|
+
kind: Literal['binary'] = 'binary'
|
|
98
|
+
"""Type identifier, this is available on all parts as a discriminator."""
|
|
99
|
+
|
|
100
|
+
@property
|
|
101
|
+
def is_audio(self) -> bool:
|
|
102
|
+
"""Return `True` if the media type is an audio type."""
|
|
103
|
+
return self.media_type.startswith('audio/')
|
|
104
|
+
|
|
105
|
+
@property
|
|
106
|
+
def is_image(self) -> bool:
|
|
107
|
+
"""Return `True` if the media type is an image type."""
|
|
108
|
+
return self.media_type.startswith('image/')
|
|
109
|
+
|
|
110
|
+
@property
|
|
111
|
+
def audio_format(self) -> Literal['mp3', 'wav']:
|
|
112
|
+
"""Return the audio format given the media type."""
|
|
113
|
+
if self.media_type == 'audio/mpeg':
|
|
114
|
+
return 'mp3'
|
|
115
|
+
elif self.media_type == 'audio/wav':
|
|
116
|
+
return 'wav'
|
|
117
|
+
else:
|
|
118
|
+
raise ValueError(f'Unknown audio media type: {self.media_type}')
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
UserContent: TypeAlias = 'str | ImageUrl | AudioUrl | BinaryContent'
|
|
122
|
+
|
|
123
|
+
|
|
34
124
|
@dataclass
|
|
35
125
|
class UserPromptPart:
|
|
36
126
|
"""A user prompt, generally written by the end user.
|
|
@@ -39,7 +129,7 @@ class UserPromptPart:
|
|
|
39
129
|
[`Agent.run_sync`][pydantic_ai.Agent.run_sync], and [`Agent.run_stream`][pydantic_ai.Agent.run_stream].
|
|
40
130
|
"""
|
|
41
131
|
|
|
42
|
-
content: str
|
|
132
|
+
content: str | Sequence[UserContent]
|
|
43
133
|
"""The content of the prompt."""
|
|
44
134
|
|
|
45
135
|
timestamp: datetime = field(default_factory=_now_utc)
|
|
@@ -445,3 +535,33 @@ class PartDeltaEvent:
|
|
|
445
535
|
|
|
446
536
|
ModelResponseStreamEvent = Annotated[Union[PartStartEvent, PartDeltaEvent], pydantic.Discriminator('event_kind')]
|
|
447
537
|
"""An event in the model response stream, either starting a new part or applying a delta to an existing one."""
|
|
538
|
+
|
|
539
|
+
|
|
540
|
+
@dataclass
|
|
541
|
+
class FunctionToolCallEvent:
|
|
542
|
+
"""An event indicating the start to a call to a function tool."""
|
|
543
|
+
|
|
544
|
+
part: ToolCallPart
|
|
545
|
+
"""The (function) tool call to make."""
|
|
546
|
+
call_id: str = field(init=False)
|
|
547
|
+
"""An ID used for matching details about the call to its result. If present, defaults to the part's tool_call_id."""
|
|
548
|
+
event_kind: Literal['function_tool_call'] = 'function_tool_call'
|
|
549
|
+
"""Event type identifier, used as a discriminator."""
|
|
550
|
+
|
|
551
|
+
def __post_init__(self):
|
|
552
|
+
self.call_id = self.part.tool_call_id or str(uuid.uuid4())
|
|
553
|
+
|
|
554
|
+
|
|
555
|
+
@dataclass
|
|
556
|
+
class FunctionToolResultEvent:
|
|
557
|
+
"""An event indicating the result of a function tool call."""
|
|
558
|
+
|
|
559
|
+
result: ToolReturnPart | RetryPromptPart
|
|
560
|
+
"""The result of the call to the function tool."""
|
|
561
|
+
call_id: str
|
|
562
|
+
"""An ID used to match the result to its original call."""
|
|
563
|
+
event_kind: Literal['function_tool_result'] = 'function_tool_result'
|
|
564
|
+
"""Event type identifier, used as a discriminator."""
|
|
565
|
+
|
|
566
|
+
|
|
567
|
+
HandleResponseEvent = Annotated[Union[FunctionToolCallEvent, FunctionToolResultEvent], pydantic.Discriminator('kind')]
|
pydantic_ai/models/__init__.py
CHANGED
|
@@ -234,6 +234,8 @@ class StreamedResponse(ABC):
|
|
|
234
234
|
|
|
235
235
|
This method should be implemented by subclasses to translate the vendor-specific stream of events into
|
|
236
236
|
pydantic_ai-format events.
|
|
237
|
+
|
|
238
|
+
It should use the `_parts_manager` to handle deltas, and should update the `_usage` attributes as it goes.
|
|
237
239
|
"""
|
|
238
240
|
raise NotImplementedError()
|
|
239
241
|
# noinspection PyUnreachableCode
|
|
@@ -362,7 +364,6 @@ def infer_model(model: Model | KnownModelName) -> Model:
|
|
|
362
364
|
raise UserError(f'Unknown model: {model}')
|
|
363
365
|
|
|
364
366
|
|
|
365
|
-
@cache
|
|
366
367
|
def cached_async_http_client(timeout: int = 600, connect: int = 5) -> httpx.AsyncClient:
|
|
367
368
|
"""Cached HTTPX async client so multiple agents and calls can share the same client.
|
|
368
369
|
|
|
@@ -373,6 +374,16 @@ def cached_async_http_client(timeout: int = 600, connect: int = 5) -> httpx.Asyn
|
|
|
373
374
|
The default timeouts match those of OpenAI,
|
|
374
375
|
see <https://github.com/openai/openai-python/blob/v1.54.4/src/openai/_constants.py#L9>.
|
|
375
376
|
"""
|
|
377
|
+
client = _cached_async_http_client(timeout=timeout, connect=connect)
|
|
378
|
+
if client.is_closed:
|
|
379
|
+
# This happens if the context manager is used, so we need to create a new client.
|
|
380
|
+
_cached_async_http_client.cache_clear()
|
|
381
|
+
client = _cached_async_http_client(timeout=timeout, connect=connect)
|
|
382
|
+
return client
|
|
383
|
+
|
|
384
|
+
|
|
385
|
+
@cache
|
|
386
|
+
def _cached_async_http_client(timeout: int = 600, connect: int = 5) -> httpx.AsyncClient:
|
|
376
387
|
return httpx.AsyncClient(
|
|
377
388
|
timeout=httpx.Timeout(timeout=timeout, connect=connect),
|
|
378
389
|
headers={'User-Agent': get_user_agent()},
|
pydantic_ai/models/anthropic.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
|
-
|
|
3
|
+
import io
|
|
4
|
+
from collections.abc import AsyncGenerator, AsyncIterable, AsyncIterator
|
|
4
5
|
from contextlib import asynccontextmanager
|
|
5
6
|
from dataclasses import dataclass, field
|
|
6
7
|
from datetime import datetime, timezone
|
|
@@ -13,6 +14,8 @@ from typing_extensions import assert_never
|
|
|
13
14
|
from .. import UnexpectedModelBehavior, _utils, usage
|
|
14
15
|
from .._utils import guard_tool_call_id as _guard_tool_call_id
|
|
15
16
|
from ..messages import (
|
|
17
|
+
BinaryContent,
|
|
18
|
+
ImageUrl,
|
|
16
19
|
ModelMessage,
|
|
17
20
|
ModelRequest,
|
|
18
21
|
ModelResponse,
|
|
@@ -38,6 +41,7 @@ from . import (
|
|
|
38
41
|
try:
|
|
39
42
|
from anthropic import NOT_GIVEN, AsyncAnthropic, AsyncStream
|
|
40
43
|
from anthropic.types import (
|
|
44
|
+
ImageBlockParam,
|
|
41
45
|
Message as AnthropicMessage,
|
|
42
46
|
MessageParam,
|
|
43
47
|
MetadataParam,
|
|
@@ -214,7 +218,7 @@ class AnthropicModel(Model):
|
|
|
214
218
|
if (allow_parallel_tool_calls := model_settings.get('parallel_tool_calls')) is not None:
|
|
215
219
|
tool_choice['disable_parallel_tool_use'] = not allow_parallel_tool_calls
|
|
216
220
|
|
|
217
|
-
system_prompt, anthropic_messages = self._map_message(messages)
|
|
221
|
+
system_prompt, anthropic_messages = await self._map_message(messages)
|
|
218
222
|
|
|
219
223
|
return await self.client.messages.create(
|
|
220
224
|
max_tokens=model_settings.get('max_tokens', 1024),
|
|
@@ -266,69 +270,82 @@ class AnthropicModel(Model):
|
|
|
266
270
|
tools += [self._map_tool_definition(r) for r in model_request_parameters.result_tools]
|
|
267
271
|
return tools
|
|
268
272
|
|
|
269
|
-
def _map_message(self, messages: list[ModelMessage]) -> tuple[str, list[MessageParam]]:
|
|
273
|
+
async def _map_message(self, messages: list[ModelMessage]) -> tuple[str, list[MessageParam]]:
|
|
270
274
|
"""Just maps a `pydantic_ai.Message` to a `anthropic.types.MessageParam`."""
|
|
271
275
|
system_prompt: str = ''
|
|
272
276
|
anthropic_messages: list[MessageParam] = []
|
|
273
277
|
for m in messages:
|
|
274
278
|
if isinstance(m, ModelRequest):
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
content=part.model_response_str(),
|
|
289
|
-
is_error=False,
|
|
290
|
-
)
|
|
291
|
-
],
|
|
292
|
-
)
|
|
279
|
+
user_content_params: list[ToolResultBlockParam | TextBlockParam | ImageBlockParam] = []
|
|
280
|
+
for request_part in m.parts:
|
|
281
|
+
if isinstance(request_part, SystemPromptPart):
|
|
282
|
+
system_prompt += request_part.content
|
|
283
|
+
elif isinstance(request_part, UserPromptPart):
|
|
284
|
+
async for content in self._map_user_prompt(request_part):
|
|
285
|
+
user_content_params.append(content)
|
|
286
|
+
elif isinstance(request_part, ToolReturnPart):
|
|
287
|
+
tool_result_block_param = ToolResultBlockParam(
|
|
288
|
+
tool_use_id=_guard_tool_call_id(t=request_part, model_source='Anthropic'),
|
|
289
|
+
type='tool_result',
|
|
290
|
+
content=request_part.model_response_str(),
|
|
291
|
+
is_error=False,
|
|
293
292
|
)
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
293
|
+
user_content_params.append(tool_result_block_param)
|
|
294
|
+
elif isinstance(request_part, RetryPromptPart):
|
|
295
|
+
if request_part.tool_name is None:
|
|
296
|
+
retry_param = TextBlockParam(type='text', text=request_part.model_response())
|
|
297
297
|
else:
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
tool_use_id=_guard_tool_call_id(t=part, model_source='Anthropic'),
|
|
304
|
-
type='tool_result',
|
|
305
|
-
content=part.model_response(),
|
|
306
|
-
is_error=True,
|
|
307
|
-
),
|
|
308
|
-
],
|
|
309
|
-
)
|
|
298
|
+
retry_param = ToolResultBlockParam(
|
|
299
|
+
tool_use_id=_guard_tool_call_id(t=request_part, model_source='Anthropic'),
|
|
300
|
+
type='tool_result',
|
|
301
|
+
content=request_part.model_response(),
|
|
302
|
+
is_error=True,
|
|
310
303
|
)
|
|
304
|
+
user_content_params.append(retry_param)
|
|
305
|
+
anthropic_messages.append(MessageParam(role='user', content=user_content_params))
|
|
311
306
|
elif isinstance(m, ModelResponse):
|
|
312
|
-
|
|
313
|
-
for
|
|
314
|
-
if isinstance(
|
|
315
|
-
|
|
307
|
+
assistant_content_params: list[TextBlockParam | ToolUseBlockParam] = []
|
|
308
|
+
for response_part in m.parts:
|
|
309
|
+
if isinstance(response_part, TextPart):
|
|
310
|
+
assistant_content_params.append(TextBlockParam(text=response_part.content, type='text'))
|
|
316
311
|
else:
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
312
|
+
tool_use_block_param = ToolUseBlockParam(
|
|
313
|
+
id=_guard_tool_call_id(t=response_part, model_source='Anthropic'),
|
|
314
|
+
type='tool_use',
|
|
315
|
+
name=response_part.tool_name,
|
|
316
|
+
input=response_part.args_as_dict(),
|
|
317
|
+
)
|
|
318
|
+
assistant_content_params.append(tool_use_block_param)
|
|
319
|
+
anthropic_messages.append(MessageParam(role='assistant', content=assistant_content_params))
|
|
320
320
|
else:
|
|
321
321
|
assert_never(m)
|
|
322
322
|
return system_prompt, anthropic_messages
|
|
323
323
|
|
|
324
324
|
@staticmethod
|
|
325
|
-
def
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
325
|
+
async def _map_user_prompt(part: UserPromptPart) -> AsyncGenerator[ImageBlockParam | TextBlockParam]:
|
|
326
|
+
if isinstance(part.content, str):
|
|
327
|
+
yield TextBlockParam(text=part.content, type='text')
|
|
328
|
+
else:
|
|
329
|
+
for item in part.content:
|
|
330
|
+
if isinstance(item, str):
|
|
331
|
+
yield TextBlockParam(text=item, type='text')
|
|
332
|
+
elif isinstance(item, BinaryContent):
|
|
333
|
+
if item.is_image:
|
|
334
|
+
yield ImageBlockParam(
|
|
335
|
+
source={'data': io.BytesIO(item.data), 'media_type': item.media_type, 'type': 'base64'}, # type: ignore
|
|
336
|
+
type='image',
|
|
337
|
+
)
|
|
338
|
+
else:
|
|
339
|
+
raise RuntimeError('Only images are supported for binary content')
|
|
340
|
+
elif isinstance(item, ImageUrl):
|
|
341
|
+
response = await cached_async_http_client().get(item.url)
|
|
342
|
+
response.raise_for_status()
|
|
343
|
+
yield ImageBlockParam(
|
|
344
|
+
source={'data': io.BytesIO(response.content), 'media_type': 'image/jpeg', 'type': 'base64'},
|
|
345
|
+
type='image',
|
|
346
|
+
)
|
|
347
|
+
else:
|
|
348
|
+
raise RuntimeError(f'Unsupported content type: {type(item)}')
|
|
332
349
|
|
|
333
350
|
@staticmethod
|
|
334
351
|
def _map_tool_definition(f: ToolDefinition) -> ToolParam:
|
pydantic_ai/models/cohere.py
CHANGED
|
@@ -124,7 +124,7 @@ class CohereModel(Model):
|
|
|
124
124
|
assert api_key is None, 'Cannot provide both `cohere_client` and `api_key`'
|
|
125
125
|
self.client = cohere_client
|
|
126
126
|
else:
|
|
127
|
-
self.client = AsyncClientV2(api_key=api_key, httpx_client=http_client)
|
|
127
|
+
self.client = AsyncClientV2(api_key=api_key, httpx_client=http_client)
|
|
128
128
|
|
|
129
129
|
async def request(
|
|
130
130
|
self,
|
|
@@ -242,7 +242,10 @@ class CohereModel(Model):
|
|
|
242
242
|
if isinstance(part, SystemPromptPart):
|
|
243
243
|
yield SystemChatMessageV2(role='system', content=part.content)
|
|
244
244
|
elif isinstance(part, UserPromptPart):
|
|
245
|
-
|
|
245
|
+
if isinstance(part.content, str):
|
|
246
|
+
yield UserChatMessageV2(role='user', content=part.content)
|
|
247
|
+
else:
|
|
248
|
+
raise RuntimeError('Cohere does not yet support multi-modal inputs.')
|
|
246
249
|
elif isinstance(part, ToolReturnPart):
|
|
247
250
|
yield ToolChatMessageV2(
|
|
248
251
|
role='tool',
|
pydantic_ai/models/function.py
CHANGED
|
@@ -2,7 +2,7 @@ from __future__ import annotations as _annotations
|
|
|
2
2
|
|
|
3
3
|
import inspect
|
|
4
4
|
import re
|
|
5
|
-
from collections.abc import AsyncIterator, Awaitable, Iterable
|
|
5
|
+
from collections.abc import AsyncIterator, Awaitable, Iterable, Sequence
|
|
6
6
|
from contextlib import asynccontextmanager
|
|
7
7
|
from dataclasses import dataclass, field
|
|
8
8
|
from datetime import datetime
|
|
@@ -14,6 +14,9 @@ from typing_extensions import TypeAlias, assert_never, overload
|
|
|
14
14
|
from .. import _utils, usage
|
|
15
15
|
from .._utils import PeekableAsyncStream
|
|
16
16
|
from ..messages import (
|
|
17
|
+
AudioUrl,
|
|
18
|
+
BinaryContent,
|
|
19
|
+
ImageUrl,
|
|
17
20
|
ModelMessage,
|
|
18
21
|
ModelRequest,
|
|
19
22
|
ModelResponse,
|
|
@@ -23,6 +26,7 @@ from ..messages import (
|
|
|
23
26
|
TextPart,
|
|
24
27
|
ToolCallPart,
|
|
25
28
|
ToolReturnPart,
|
|
29
|
+
UserContent,
|
|
26
30
|
UserPromptPart,
|
|
27
31
|
)
|
|
28
32
|
from ..settings import ModelSettings
|
|
@@ -109,9 +113,9 @@ class FunctionModel(Model):
|
|
|
109
113
|
model_settings,
|
|
110
114
|
)
|
|
111
115
|
|
|
112
|
-
assert (
|
|
113
|
-
|
|
114
|
-
)
|
|
116
|
+
assert self.stream_function is not None, (
|
|
117
|
+
'FunctionModel must receive a `stream_function` to support streamed requests'
|
|
118
|
+
)
|
|
115
119
|
|
|
116
120
|
response_stream = PeekableAsyncStream(self.stream_function(messages, agent_info))
|
|
117
121
|
|
|
@@ -262,7 +266,12 @@ def _estimate_usage(messages: Iterable[ModelMessage]) -> usage.Usage:
|
|
|
262
266
|
)
|
|
263
267
|
|
|
264
268
|
|
|
265
|
-
def _estimate_string_tokens(content: str) -> int:
|
|
269
|
+
def _estimate_string_tokens(content: str | Sequence[UserContent]) -> int:
|
|
266
270
|
if not content:
|
|
267
271
|
return 0
|
|
268
|
-
|
|
272
|
+
if isinstance(content, str):
|
|
273
|
+
return len(re.split(r'[\s",.:]+', content.strip()))
|
|
274
|
+
# TODO(Marcelo): We need to study how we can estimate the tokens for these types of content.
|
|
275
|
+
else: # pragma: no cover
|
|
276
|
+
assert isinstance(content, (AudioUrl, ImageUrl, BinaryContent))
|
|
277
|
+
return 0
|
pydantic_ai/models/gemini.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
|
+
import base64
|
|
3
4
|
import os
|
|
4
5
|
import re
|
|
5
6
|
from collections.abc import AsyncIterator, Sequence
|
|
@@ -16,6 +17,9 @@ from typing_extensions import NotRequired, TypedDict, assert_never
|
|
|
16
17
|
|
|
17
18
|
from .. import UnexpectedModelBehavior, _utils, exceptions, usage
|
|
18
19
|
from ..messages import (
|
|
20
|
+
AudioUrl,
|
|
21
|
+
BinaryContent,
|
|
22
|
+
ImageUrl,
|
|
19
23
|
ModelMessage,
|
|
20
24
|
ModelRequest,
|
|
21
25
|
ModelResponse,
|
|
@@ -185,7 +189,7 @@ class GeminiModel(Model):
|
|
|
185
189
|
) -> AsyncIterator[HTTPResponse]:
|
|
186
190
|
tools = self._get_tools(model_request_parameters)
|
|
187
191
|
tool_config = self._get_tool_config(model_request_parameters, tools)
|
|
188
|
-
sys_prompt_parts, contents = self._message_to_gemini_content(messages)
|
|
192
|
+
sys_prompt_parts, contents = await self._message_to_gemini_content(messages)
|
|
189
193
|
|
|
190
194
|
request_data = _GeminiRequest(contents=contents)
|
|
191
195
|
if sys_prompt_parts:
|
|
@@ -254,7 +258,7 @@ class GeminiModel(Model):
|
|
|
254
258
|
async for chunk in aiter_bytes:
|
|
255
259
|
content.extend(chunk)
|
|
256
260
|
responses = _gemini_streamed_response_ta.validate_json(
|
|
257
|
-
content,
|
|
261
|
+
_ensure_decodeable(content),
|
|
258
262
|
experimental_allow_partial='trailing-strings',
|
|
259
263
|
)
|
|
260
264
|
if responses:
|
|
@@ -269,7 +273,7 @@ class GeminiModel(Model):
|
|
|
269
273
|
return GeminiStreamedResponse(_model_name=self._model_name, _content=content, _stream=aiter_bytes)
|
|
270
274
|
|
|
271
275
|
@classmethod
|
|
272
|
-
def _message_to_gemini_content(
|
|
276
|
+
async def _message_to_gemini_content(
|
|
273
277
|
cls, messages: list[ModelMessage]
|
|
274
278
|
) -> tuple[list[_GeminiTextPart], list[_GeminiContent]]:
|
|
275
279
|
sys_prompt_parts: list[_GeminiTextPart] = []
|
|
@@ -282,7 +286,7 @@ class GeminiModel(Model):
|
|
|
282
286
|
if isinstance(part, SystemPromptPart):
|
|
283
287
|
sys_prompt_parts.append(_GeminiTextPart(text=part.content))
|
|
284
288
|
elif isinstance(part, UserPromptPart):
|
|
285
|
-
message_parts.
|
|
289
|
+
message_parts.extend(await cls._map_user_prompt(part))
|
|
286
290
|
elif isinstance(part, ToolReturnPart):
|
|
287
291
|
message_parts.append(_response_part_from_response(part.tool_name, part.model_response_object()))
|
|
288
292
|
elif isinstance(part, RetryPromptPart):
|
|
@@ -303,6 +307,34 @@ class GeminiModel(Model):
|
|
|
303
307
|
|
|
304
308
|
return sys_prompt_parts, contents
|
|
305
309
|
|
|
310
|
+
@staticmethod
|
|
311
|
+
async def _map_user_prompt(part: UserPromptPart) -> list[_GeminiPartUnion]:
|
|
312
|
+
if isinstance(part.content, str):
|
|
313
|
+
return [{'text': part.content}]
|
|
314
|
+
else:
|
|
315
|
+
content: list[_GeminiPartUnion] = []
|
|
316
|
+
for item in part.content:
|
|
317
|
+
if isinstance(item, str):
|
|
318
|
+
content.append({'text': item})
|
|
319
|
+
elif isinstance(item, BinaryContent):
|
|
320
|
+
base64_encoded = base64.b64encode(item.data).decode('utf-8')
|
|
321
|
+
content.append(_GeminiInlineDataPart(data=base64_encoded, mime_type=item.media_type))
|
|
322
|
+
elif isinstance(item, (AudioUrl, ImageUrl)):
|
|
323
|
+
try:
|
|
324
|
+
content.append(_GeminiFileDataData(file_uri=item.url, mime_type=item.media_type))
|
|
325
|
+
except ValueError:
|
|
326
|
+
# Download the file if can't find the mime type.
|
|
327
|
+
client = cached_async_http_client()
|
|
328
|
+
response = await client.get(item.url, follow_redirects=True)
|
|
329
|
+
response.raise_for_status()
|
|
330
|
+
base64_encoded = base64.b64encode(response.content).decode('utf-8')
|
|
331
|
+
content.append(
|
|
332
|
+
_GeminiInlineDataPart(data=base64_encoded, mime_type=response.headers['Content-Type'])
|
|
333
|
+
)
|
|
334
|
+
else:
|
|
335
|
+
assert_never(item)
|
|
336
|
+
return content
|
|
337
|
+
|
|
306
338
|
|
|
307
339
|
class AuthProtocol(Protocol):
|
|
308
340
|
"""Abstract definition for Gemini authentication."""
|
|
@@ -370,7 +402,7 @@ class GeminiStreamedResponse(StreamedResponse):
|
|
|
370
402
|
self._content.extend(chunk)
|
|
371
403
|
|
|
372
404
|
gemini_responses = _gemini_streamed_response_ta.validate_json(
|
|
373
|
-
self._content,
|
|
405
|
+
_ensure_decodeable(self._content),
|
|
374
406
|
experimental_allow_partial='trailing-strings',
|
|
375
407
|
)
|
|
376
408
|
|
|
@@ -494,6 +526,20 @@ class _GeminiTextPart(TypedDict):
|
|
|
494
526
|
text: str
|
|
495
527
|
|
|
496
528
|
|
|
529
|
+
class _GeminiInlineDataPart(TypedDict):
|
|
530
|
+
"""See <https://ai.google.dev/api/caching#Blob>."""
|
|
531
|
+
|
|
532
|
+
data: str
|
|
533
|
+
mime_type: Annotated[str, pydantic.Field(alias='mimeType')]
|
|
534
|
+
|
|
535
|
+
|
|
536
|
+
class _GeminiFileDataData(TypedDict):
|
|
537
|
+
"""See <https://ai.google.dev/api/caching#FileData>."""
|
|
538
|
+
|
|
539
|
+
file_uri: Annotated[str, pydantic.Field(alias='fileUri')]
|
|
540
|
+
mime_type: Annotated[str, pydantic.Field(alias='mimeType')]
|
|
541
|
+
|
|
542
|
+
|
|
497
543
|
class _GeminiFunctionCallPart(TypedDict):
|
|
498
544
|
function_call: Annotated[_GeminiFunctionCall, pydantic.Field(alias='functionCall')]
|
|
499
545
|
|
|
@@ -549,6 +595,10 @@ def _part_discriminator(v: Any) -> str:
|
|
|
549
595
|
if isinstance(v, dict):
|
|
550
596
|
if 'text' in v:
|
|
551
597
|
return 'text'
|
|
598
|
+
elif 'inlineData' in v:
|
|
599
|
+
return 'inline_data'
|
|
600
|
+
elif 'fileData' in v:
|
|
601
|
+
return 'file_data'
|
|
552
602
|
elif 'functionCall' in v or 'function_call' in v:
|
|
553
603
|
return 'function_call'
|
|
554
604
|
elif 'functionResponse' in v or 'function_response' in v:
|
|
@@ -564,6 +614,8 @@ _GeminiPartUnion = Annotated[
|
|
|
564
614
|
Annotated[_GeminiTextPart, pydantic.Tag('text')],
|
|
565
615
|
Annotated[_GeminiFunctionCallPart, pydantic.Tag('function_call')],
|
|
566
616
|
Annotated[_GeminiFunctionResponsePart, pydantic.Tag('function_response')],
|
|
617
|
+
Annotated[_GeminiInlineDataPart, pydantic.Tag('inline_data')],
|
|
618
|
+
Annotated[_GeminiFileDataData, pydantic.Tag('file_data')],
|
|
567
619
|
],
|
|
568
620
|
pydantic.Discriminator(_part_discriminator),
|
|
569
621
|
]
|
|
@@ -774,3 +826,19 @@ class _GeminiJsonSchema:
|
|
|
774
826
|
|
|
775
827
|
if items_schema := schema.get('items'): # pragma: no branch
|
|
776
828
|
self._simplify(items_schema, refs_stack)
|
|
829
|
+
|
|
830
|
+
|
|
831
|
+
def _ensure_decodeable(content: bytearray) -> bytearray:
|
|
832
|
+
"""Trim any invalid unicode point bytes off the end of a bytearray.
|
|
833
|
+
|
|
834
|
+
This is necessary before attempting to parse streaming JSON bytes.
|
|
835
|
+
|
|
836
|
+
This is a temporary workaround until https://github.com/pydantic/pydantic-core/issues/1633 is resolved
|
|
837
|
+
"""
|
|
838
|
+
while True:
|
|
839
|
+
try:
|
|
840
|
+
content.decode()
|
|
841
|
+
except UnicodeDecodeError:
|
|
842
|
+
content = content[:-1] # this will definitely succeed before we run out of bytes
|
|
843
|
+
else:
|
|
844
|
+
return content
|