pydantic-ai-slim 0.4.5__py3-none-any.whl → 0.4.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_function_schema.py +13 -4
- pydantic_ai/_output.py +41 -25
- pydantic_ai/_parts_manager.py +31 -5
- pydantic_ai/ag_ui.py +68 -78
- pydantic_ai/agent.py +9 -29
- pydantic_ai/mcp.py +79 -19
- pydantic_ai/messages.py +74 -16
- pydantic_ai/models/__init__.py +12 -1
- pydantic_ai/models/anthropic.py +11 -3
- pydantic_ai/models/bedrock.py +4 -2
- pydantic_ai/models/cohere.py +6 -6
- pydantic_ai/models/function.py +19 -18
- pydantic_ai/models/gemini.py +5 -1
- pydantic_ai/models/google.py +9 -2
- pydantic_ai/models/groq.py +6 -2
- pydantic_ai/models/huggingface.py +6 -2
- pydantic_ai/models/mistral.py +15 -3
- pydantic_ai/models/openai.py +34 -7
- pydantic_ai/models/test.py +6 -2
- pydantic_ai/profiles/openai.py +8 -0
- pydantic_ai/providers/__init__.py +8 -0
- pydantic_ai/providers/moonshotai.py +97 -0
- pydantic_ai/providers/vercel.py +107 -0
- pydantic_ai/result.py +115 -151
- {pydantic_ai_slim-0.4.5.dist-info → pydantic_ai_slim-0.4.7.dist-info}/METADATA +7 -7
- {pydantic_ai_slim-0.4.5.dist-info → pydantic_ai_slim-0.4.7.dist-info}/RECORD +29 -27
- {pydantic_ai_slim-0.4.5.dist-info → pydantic_ai_slim-0.4.7.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-0.4.5.dist-info → pydantic_ai_slim-0.4.7.dist-info}/entry_points.txt +0 -0
- {pydantic_ai_slim-0.4.5.dist-info → pydantic_ai_slim-0.4.7.dist-info}/licenses/LICENSE +0 -0
pydantic_ai/mcp.py
CHANGED
|
@@ -2,11 +2,13 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
import base64
|
|
4
4
|
import functools
|
|
5
|
+
import warnings
|
|
5
6
|
from abc import ABC, abstractmethod
|
|
6
7
|
from asyncio import Lock
|
|
7
8
|
from collections.abc import AsyncIterator, Awaitable, Sequence
|
|
8
9
|
from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager
|
|
9
10
|
from dataclasses import dataclass, field, replace
|
|
11
|
+
from datetime import timedelta
|
|
10
12
|
from pathlib import Path
|
|
11
13
|
from typing import Any, Callable
|
|
12
14
|
|
|
@@ -37,7 +39,7 @@ except ImportError as _import_error:
|
|
|
37
39
|
) from _import_error
|
|
38
40
|
|
|
39
41
|
# after mcp imports so any import error maps to this file, not _mcp.py
|
|
40
|
-
from . import _mcp, exceptions, messages, models
|
|
42
|
+
from . import _mcp, _utils, exceptions, messages, models
|
|
41
43
|
|
|
42
44
|
__all__ = 'MCPServer', 'MCPServerStdio', 'MCPServerHTTP', 'MCPServerSSE', 'MCPServerStreamableHTTP'
|
|
43
45
|
|
|
@@ -59,6 +61,7 @@ class MCPServer(AbstractToolset[Any], ABC):
|
|
|
59
61
|
log_level: mcp_types.LoggingLevel | None = None
|
|
60
62
|
log_handler: LoggingFnT | None = None
|
|
61
63
|
timeout: float = 5
|
|
64
|
+
read_timeout: float = 5 * 60
|
|
62
65
|
process_tool_call: ProcessToolCallback | None = None
|
|
63
66
|
allow_sampling: bool = True
|
|
64
67
|
max_retries: int = 1
|
|
@@ -148,7 +151,7 @@ class MCPServer(AbstractToolset[Any], ABC):
|
|
|
148
151
|
except McpError as e:
|
|
149
152
|
raise exceptions.ModelRetry(e.error.message)
|
|
150
153
|
|
|
151
|
-
content = [self._map_tool_result_part(part) for part in result.content]
|
|
154
|
+
content = [await self._map_tool_result_part(part) for part in result.content]
|
|
152
155
|
|
|
153
156
|
if result.isError:
|
|
154
157
|
text = '\n'.join(str(part) for part in content)
|
|
@@ -208,6 +211,7 @@ class MCPServer(AbstractToolset[Any], ABC):
|
|
|
208
211
|
write_stream=self._write_stream,
|
|
209
212
|
sampling_callback=self._sampling_callback if self.allow_sampling else None,
|
|
210
213
|
logging_callback=self.log_handler,
|
|
214
|
+
read_timeout_seconds=timedelta(seconds=self.read_timeout),
|
|
211
215
|
)
|
|
212
216
|
self._client = await self._exit_stack.enter_async_context(client)
|
|
213
217
|
|
|
@@ -258,8 +262,8 @@ class MCPServer(AbstractToolset[Any], ABC):
|
|
|
258
262
|
model=self.sampling_model.model_name,
|
|
259
263
|
)
|
|
260
264
|
|
|
261
|
-
def _map_tool_result_part(
|
|
262
|
-
self, part: mcp_types.
|
|
265
|
+
async def _map_tool_result_part(
|
|
266
|
+
self, part: mcp_types.ContentBlock
|
|
263
267
|
) -> str | messages.BinaryContent | dict[str, Any] | list[Any]:
|
|
264
268
|
# See https://github.com/jlowin/fastmcp/blob/main/docs/servers/tools.mdx#return-values
|
|
265
269
|
|
|
@@ -281,18 +285,29 @@ class MCPServer(AbstractToolset[Any], ABC):
|
|
|
281
285
|
) # pragma: no cover
|
|
282
286
|
elif isinstance(part, mcp_types.EmbeddedResource):
|
|
283
287
|
resource = part.resource
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
)
|
|
291
|
-
|
|
292
|
-
assert_never(resource)
|
|
288
|
+
return self._get_content(resource)
|
|
289
|
+
elif isinstance(part, mcp_types.ResourceLink):
|
|
290
|
+
resource_result: mcp_types.ReadResourceResult = await self._client.read_resource(part.uri)
|
|
291
|
+
return (
|
|
292
|
+
self._get_content(resource_result.contents[0])
|
|
293
|
+
if len(resource_result.contents) == 1
|
|
294
|
+
else [self._get_content(resource) for resource in resource_result.contents]
|
|
295
|
+
)
|
|
293
296
|
else:
|
|
294
297
|
assert_never(part)
|
|
295
298
|
|
|
299
|
+
def _get_content(
|
|
300
|
+
self, resource: mcp_types.TextResourceContents | mcp_types.BlobResourceContents
|
|
301
|
+
) -> str | messages.BinaryContent:
|
|
302
|
+
if isinstance(resource, mcp_types.TextResourceContents):
|
|
303
|
+
return resource.text
|
|
304
|
+
elif isinstance(resource, mcp_types.BlobResourceContents):
|
|
305
|
+
return messages.BinaryContent(
|
|
306
|
+
data=base64.b64decode(resource.blob), media_type=resource.mimeType or 'application/octet-stream'
|
|
307
|
+
)
|
|
308
|
+
else:
|
|
309
|
+
assert_never(resource)
|
|
310
|
+
|
|
296
311
|
|
|
297
312
|
@dataclass
|
|
298
313
|
class MCPServerStdio(MCPServer):
|
|
@@ -401,7 +416,7 @@ class MCPServerStdio(MCPServer):
|
|
|
401
416
|
return f'MCPServerStdio(command={self.command!r}, args={self.args!r}, tool_prefix={self.tool_prefix!r})'
|
|
402
417
|
|
|
403
418
|
|
|
404
|
-
@dataclass
|
|
419
|
+
@dataclass(init=False)
|
|
405
420
|
class _MCPServerHTTP(MCPServer):
|
|
406
421
|
url: str
|
|
407
422
|
"""The URL of the endpoint on the MCP server."""
|
|
@@ -438,10 +453,10 @@ class _MCPServerHTTP(MCPServer):
|
|
|
438
453
|
```
|
|
439
454
|
"""
|
|
440
455
|
|
|
441
|
-
|
|
442
|
-
"""Maximum time in seconds to wait for new
|
|
456
|
+
read_timeout: float = 5 * 60
|
|
457
|
+
"""Maximum time in seconds to wait for new messages before timing out.
|
|
443
458
|
|
|
444
|
-
This timeout applies to the long-lived
|
|
459
|
+
This timeout applies to the long-lived connection after it's established.
|
|
445
460
|
If no new messages are received within this time, the connection will be considered stale
|
|
446
461
|
and may be closed. Defaults to 5 minutes (300 seconds).
|
|
447
462
|
"""
|
|
@@ -485,6 +500,51 @@ class _MCPServerHTTP(MCPServer):
|
|
|
485
500
|
sampling_model: models.Model | None = None
|
|
486
501
|
"""The model to use for sampling."""
|
|
487
502
|
|
|
503
|
+
def __init__(
|
|
504
|
+
self,
|
|
505
|
+
*,
|
|
506
|
+
url: str,
|
|
507
|
+
headers: dict[str, str] | None = None,
|
|
508
|
+
http_client: httpx.AsyncClient | None = None,
|
|
509
|
+
read_timeout: float | None = None,
|
|
510
|
+
tool_prefix: str | None = None,
|
|
511
|
+
log_level: mcp_types.LoggingLevel | None = None,
|
|
512
|
+
log_handler: LoggingFnT | None = None,
|
|
513
|
+
timeout: float = 5,
|
|
514
|
+
process_tool_call: ProcessToolCallback | None = None,
|
|
515
|
+
allow_sampling: bool = True,
|
|
516
|
+
max_retries: int = 1,
|
|
517
|
+
sampling_model: models.Model | None = None,
|
|
518
|
+
**kwargs: Any,
|
|
519
|
+
):
|
|
520
|
+
# Handle deprecated sse_read_timeout parameter
|
|
521
|
+
if 'sse_read_timeout' in kwargs:
|
|
522
|
+
if read_timeout is not None:
|
|
523
|
+
raise TypeError("'read_timeout' and 'sse_read_timeout' cannot be set at the same time.")
|
|
524
|
+
|
|
525
|
+
warnings.warn(
|
|
526
|
+
"'sse_read_timeout' is deprecated, use 'read_timeout' instead.", DeprecationWarning, stacklevel=2
|
|
527
|
+
)
|
|
528
|
+
read_timeout = kwargs.pop('sse_read_timeout')
|
|
529
|
+
|
|
530
|
+
_utils.validate_empty_kwargs(kwargs)
|
|
531
|
+
|
|
532
|
+
if read_timeout is None:
|
|
533
|
+
read_timeout = 5 * 60
|
|
534
|
+
|
|
535
|
+
self.url = url
|
|
536
|
+
self.headers = headers
|
|
537
|
+
self.http_client = http_client
|
|
538
|
+
self.tool_prefix = tool_prefix
|
|
539
|
+
self.log_level = log_level
|
|
540
|
+
self.log_handler = log_handler
|
|
541
|
+
self.timeout = timeout
|
|
542
|
+
self.process_tool_call = process_tool_call
|
|
543
|
+
self.allow_sampling = allow_sampling
|
|
544
|
+
self.max_retries = max_retries
|
|
545
|
+
self.sampling_model = sampling_model
|
|
546
|
+
self.read_timeout = read_timeout
|
|
547
|
+
|
|
488
548
|
@property
|
|
489
549
|
@abstractmethod
|
|
490
550
|
def _transport_client(
|
|
@@ -522,7 +582,7 @@ class _MCPServerHTTP(MCPServer):
|
|
|
522
582
|
self._transport_client,
|
|
523
583
|
url=self.url,
|
|
524
584
|
timeout=self.timeout,
|
|
525
|
-
sse_read_timeout=self.
|
|
585
|
+
sse_read_timeout=self.read_timeout,
|
|
526
586
|
)
|
|
527
587
|
|
|
528
588
|
if self.http_client is not None:
|
|
@@ -549,7 +609,7 @@ class _MCPServerHTTP(MCPServer):
|
|
|
549
609
|
return f'{self.__class__.__name__}(url={self.url!r}, tool_prefix={self.tool_prefix!r})'
|
|
550
610
|
|
|
551
611
|
|
|
552
|
-
@dataclass
|
|
612
|
+
@dataclass(init=False)
|
|
553
613
|
class MCPServerSSE(_MCPServerHTTP):
|
|
554
614
|
"""An MCP server that connects over streamable HTTP connections.
|
|
555
615
|
|
pydantic_ai/messages.py
CHANGED
|
@@ -85,7 +85,7 @@ class SystemPromptPart:
|
|
|
85
85
|
__repr__ = _utils.dataclasses_no_defaults_repr
|
|
86
86
|
|
|
87
87
|
|
|
88
|
-
@dataclass(repr=False)
|
|
88
|
+
@dataclass(init=False, repr=False)
|
|
89
89
|
class FileUrl(ABC):
|
|
90
90
|
"""Abstract base class for any URL-based file."""
|
|
91
91
|
|
|
@@ -106,11 +106,29 @@ class FileUrl(ABC):
|
|
|
106
106
|
- `GoogleModel`: `VideoUrl.vendor_metadata` is used as `video_metadata`: https://ai.google.dev/gemini-api/docs/video-understanding#customize-video-processing
|
|
107
107
|
"""
|
|
108
108
|
|
|
109
|
-
|
|
109
|
+
_media_type: str | None = field(init=False, repr=False)
|
|
110
|
+
|
|
111
|
+
def __init__(
|
|
112
|
+
self,
|
|
113
|
+
url: str,
|
|
114
|
+
force_download: bool = False,
|
|
115
|
+
vendor_metadata: dict[str, Any] | None = None,
|
|
116
|
+
media_type: str | None = None,
|
|
117
|
+
) -> None:
|
|
118
|
+
self.url = url
|
|
119
|
+
self.vendor_metadata = vendor_metadata
|
|
120
|
+
self.force_download = force_download
|
|
121
|
+
self._media_type = media_type
|
|
122
|
+
|
|
110
123
|
@abstractmethod
|
|
111
|
-
def
|
|
124
|
+
def _infer_media_type(self) -> str:
|
|
112
125
|
"""Return the media type of the file, based on the url."""
|
|
113
126
|
|
|
127
|
+
@property
|
|
128
|
+
def media_type(self) -> str:
|
|
129
|
+
"""Return the media type of the file, based on the url or the provided `_media_type`."""
|
|
130
|
+
return self._media_type or self._infer_media_type()
|
|
131
|
+
|
|
114
132
|
@property
|
|
115
133
|
@abstractmethod
|
|
116
134
|
def format(self) -> str:
|
|
@@ -119,7 +137,7 @@ class FileUrl(ABC):
|
|
|
119
137
|
__repr__ = _utils.dataclasses_no_defaults_repr
|
|
120
138
|
|
|
121
139
|
|
|
122
|
-
@dataclass(repr=False)
|
|
140
|
+
@dataclass(init=False, repr=False)
|
|
123
141
|
class VideoUrl(FileUrl):
|
|
124
142
|
"""A URL to a video."""
|
|
125
143
|
|
|
@@ -129,8 +147,18 @@ class VideoUrl(FileUrl):
|
|
|
129
147
|
kind: Literal['video-url'] = 'video-url'
|
|
130
148
|
"""Type identifier, this is available on all parts as a discriminator."""
|
|
131
149
|
|
|
132
|
-
|
|
133
|
-
|
|
150
|
+
def __init__(
|
|
151
|
+
self,
|
|
152
|
+
url: str,
|
|
153
|
+
force_download: bool = False,
|
|
154
|
+
vendor_metadata: dict[str, Any] | None = None,
|
|
155
|
+
media_type: str | None = None,
|
|
156
|
+
kind: Literal['video-url'] = 'video-url',
|
|
157
|
+
) -> None:
|
|
158
|
+
super().__init__(url=url, force_download=force_download, vendor_metadata=vendor_metadata, media_type=media_type)
|
|
159
|
+
self.kind = kind
|
|
160
|
+
|
|
161
|
+
def _infer_media_type(self) -> VideoMediaType:
|
|
134
162
|
"""Return the media type of the video, based on the url."""
|
|
135
163
|
if self.url.endswith('.mkv'):
|
|
136
164
|
return 'video/x-matroska'
|
|
@@ -170,7 +198,7 @@ class VideoUrl(FileUrl):
|
|
|
170
198
|
return _video_format_lookup[self.media_type]
|
|
171
199
|
|
|
172
200
|
|
|
173
|
-
@dataclass(repr=False)
|
|
201
|
+
@dataclass(init=False, repr=False)
|
|
174
202
|
class AudioUrl(FileUrl):
|
|
175
203
|
"""A URL to an audio file."""
|
|
176
204
|
|
|
@@ -180,8 +208,18 @@ class AudioUrl(FileUrl):
|
|
|
180
208
|
kind: Literal['audio-url'] = 'audio-url'
|
|
181
209
|
"""Type identifier, this is available on all parts as a discriminator."""
|
|
182
210
|
|
|
183
|
-
|
|
184
|
-
|
|
211
|
+
def __init__(
|
|
212
|
+
self,
|
|
213
|
+
url: str,
|
|
214
|
+
force_download: bool = False,
|
|
215
|
+
vendor_metadata: dict[str, Any] | None = None,
|
|
216
|
+
media_type: str | None = None,
|
|
217
|
+
kind: Literal['audio-url'] = 'audio-url',
|
|
218
|
+
) -> None:
|
|
219
|
+
super().__init__(url=url, force_download=force_download, vendor_metadata=vendor_metadata, media_type=media_type)
|
|
220
|
+
self.kind = kind
|
|
221
|
+
|
|
222
|
+
def _infer_media_type(self) -> AudioMediaType:
|
|
185
223
|
"""Return the media type of the audio file, based on the url.
|
|
186
224
|
|
|
187
225
|
References:
|
|
@@ -208,7 +246,7 @@ class AudioUrl(FileUrl):
|
|
|
208
246
|
return _audio_format_lookup[self.media_type]
|
|
209
247
|
|
|
210
248
|
|
|
211
|
-
@dataclass(repr=False)
|
|
249
|
+
@dataclass(init=False, repr=False)
|
|
212
250
|
class ImageUrl(FileUrl):
|
|
213
251
|
"""A URL to an image."""
|
|
214
252
|
|
|
@@ -218,8 +256,18 @@ class ImageUrl(FileUrl):
|
|
|
218
256
|
kind: Literal['image-url'] = 'image-url'
|
|
219
257
|
"""Type identifier, this is available on all parts as a discriminator."""
|
|
220
258
|
|
|
221
|
-
|
|
222
|
-
|
|
259
|
+
def __init__(
|
|
260
|
+
self,
|
|
261
|
+
url: str,
|
|
262
|
+
force_download: bool = False,
|
|
263
|
+
vendor_metadata: dict[str, Any] | None = None,
|
|
264
|
+
media_type: str | None = None,
|
|
265
|
+
kind: Literal['image-url'] = 'image-url',
|
|
266
|
+
) -> None:
|
|
267
|
+
super().__init__(url=url, force_download=force_download, vendor_metadata=vendor_metadata, media_type=media_type)
|
|
268
|
+
self.kind = kind
|
|
269
|
+
|
|
270
|
+
def _infer_media_type(self) -> ImageMediaType:
|
|
223
271
|
"""Return the media type of the image, based on the url."""
|
|
224
272
|
if self.url.endswith(('.jpg', '.jpeg')):
|
|
225
273
|
return 'image/jpeg'
|
|
@@ -241,7 +289,7 @@ class ImageUrl(FileUrl):
|
|
|
241
289
|
return _image_format_lookup[self.media_type]
|
|
242
290
|
|
|
243
291
|
|
|
244
|
-
@dataclass(repr=False)
|
|
292
|
+
@dataclass(init=False, repr=False)
|
|
245
293
|
class DocumentUrl(FileUrl):
|
|
246
294
|
"""The URL of the document."""
|
|
247
295
|
|
|
@@ -251,8 +299,18 @@ class DocumentUrl(FileUrl):
|
|
|
251
299
|
kind: Literal['document-url'] = 'document-url'
|
|
252
300
|
"""Type identifier, this is available on all parts as a discriminator."""
|
|
253
301
|
|
|
254
|
-
|
|
255
|
-
|
|
302
|
+
def __init__(
|
|
303
|
+
self,
|
|
304
|
+
url: str,
|
|
305
|
+
force_download: bool = False,
|
|
306
|
+
vendor_metadata: dict[str, Any] | None = None,
|
|
307
|
+
media_type: str | None = None,
|
|
308
|
+
kind: Literal['document-url'] = 'document-url',
|
|
309
|
+
) -> None:
|
|
310
|
+
super().__init__(url=url, force_download=force_download, vendor_metadata=vendor_metadata, media_type=media_type)
|
|
311
|
+
self.kind = kind
|
|
312
|
+
|
|
313
|
+
def _infer_media_type(self) -> str:
|
|
256
314
|
"""Return the media type of the document, based on the url."""
|
|
257
315
|
type_, _ = guess_type(self.url)
|
|
258
316
|
if type_ is None:
|
|
@@ -632,7 +690,7 @@ class ThinkingPart:
|
|
|
632
690
|
|
|
633
691
|
def has_content(self) -> bool:
|
|
634
692
|
"""Return `True` if the thinking content is non-empty."""
|
|
635
|
-
return bool(self.content)
|
|
693
|
+
return bool(self.content)
|
|
636
694
|
|
|
637
695
|
__repr__ = _utils.dataclasses_no_defaults_repr
|
|
638
696
|
|
pydantic_ai/models/__init__.py
CHANGED
|
@@ -233,6 +233,15 @@ KnownModelName = TypeAliasType(
|
|
|
233
233
|
'mistral:mistral-large-latest',
|
|
234
234
|
'mistral:mistral-moderation-latest',
|
|
235
235
|
'mistral:mistral-small-latest',
|
|
236
|
+
'moonshotai:moonshot-v1-8k',
|
|
237
|
+
'moonshotai:moonshot-v1-32k',
|
|
238
|
+
'moonshotai:moonshot-v1-128k',
|
|
239
|
+
'moonshotai:moonshot-v1-8k-vision-preview',
|
|
240
|
+
'moonshotai:moonshot-v1-32k-vision-preview',
|
|
241
|
+
'moonshotai:moonshot-v1-128k-vision-preview',
|
|
242
|
+
'moonshotai:kimi-latest',
|
|
243
|
+
'moonshotai:kimi-thinking-preview',
|
|
244
|
+
'moonshotai:kimi-k2-0711-preview',
|
|
236
245
|
'o1',
|
|
237
246
|
'o1-2024-12-17',
|
|
238
247
|
'o1-mini',
|
|
@@ -615,7 +624,9 @@ def infer_model(model: Model | KnownModelName | str) -> Model: # noqa: C901
|
|
|
615
624
|
'deepseek',
|
|
616
625
|
'azure',
|
|
617
626
|
'openrouter',
|
|
627
|
+
'vercel',
|
|
618
628
|
'grok',
|
|
629
|
+
'moonshotai',
|
|
619
630
|
'fireworks',
|
|
620
631
|
'together',
|
|
621
632
|
'heroku',
|
|
@@ -758,7 +769,7 @@ async def download_item(
|
|
|
758
769
|
|
|
759
770
|
data_type = media_type
|
|
760
771
|
if type_format == 'extension':
|
|
761
|
-
data_type =
|
|
772
|
+
data_type = item.format
|
|
762
773
|
|
|
763
774
|
data = response.content
|
|
764
775
|
if data_format in ('base64', 'base64_uri'):
|
pydantic_ai/models/anthropic.py
CHANGED
|
@@ -470,7 +470,7 @@ class AnthropicStreamedResponse(StreamedResponse):
|
|
|
470
470
|
_response: AsyncIterable[BetaRawMessageStreamEvent]
|
|
471
471
|
_timestamp: datetime
|
|
472
472
|
|
|
473
|
-
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
|
|
473
|
+
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: # noqa: C901
|
|
474
474
|
current_block: BetaContentBlock | None = None
|
|
475
475
|
|
|
476
476
|
async for event in self._response:
|
|
@@ -479,7 +479,11 @@ class AnthropicStreamedResponse(StreamedResponse):
|
|
|
479
479
|
if isinstance(event, BetaRawContentBlockStartEvent):
|
|
480
480
|
current_block = event.content_block
|
|
481
481
|
if isinstance(current_block, BetaTextBlock) and current_block.text:
|
|
482
|
-
|
|
482
|
+
maybe_event = self._parts_manager.handle_text_delta(
|
|
483
|
+
vendor_part_id='content', content=current_block.text
|
|
484
|
+
)
|
|
485
|
+
if maybe_event is not None: # pragma: no branch
|
|
486
|
+
yield maybe_event
|
|
483
487
|
elif isinstance(current_block, BetaThinkingBlock):
|
|
484
488
|
yield self._parts_manager.handle_thinking_delta(
|
|
485
489
|
vendor_part_id='thinking',
|
|
@@ -498,7 +502,11 @@ class AnthropicStreamedResponse(StreamedResponse):
|
|
|
498
502
|
|
|
499
503
|
elif isinstance(event, BetaRawContentBlockDeltaEvent):
|
|
500
504
|
if isinstance(event.delta, BetaTextDelta):
|
|
501
|
-
|
|
505
|
+
maybe_event = self._parts_manager.handle_text_delta(
|
|
506
|
+
vendor_part_id='content', content=event.delta.text
|
|
507
|
+
)
|
|
508
|
+
if maybe_event is not None: # pragma: no branch
|
|
509
|
+
yield maybe_event
|
|
502
510
|
elif isinstance(event.delta, BetaThinkingDelta):
|
|
503
511
|
yield self._parts_manager.handle_thinking_delta(
|
|
504
512
|
vendor_part_id='thinking', content=event.delta.thinking
|
pydantic_ai/models/bedrock.py
CHANGED
|
@@ -572,7 +572,7 @@ class BedrockStreamedResponse(StreamedResponse):
|
|
|
572
572
|
_event_stream: EventStream[ConverseStreamOutputTypeDef]
|
|
573
573
|
_timestamp: datetime = field(default_factory=_utils.now_utc)
|
|
574
574
|
|
|
575
|
-
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
|
|
575
|
+
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: # noqa: C901
|
|
576
576
|
"""Return an async iterator of [`ModelResponseStreamEvent`][pydantic_ai.messages.ModelResponseStreamEvent]s.
|
|
577
577
|
|
|
578
578
|
This method should be implemented by subclasses to translate the vendor-specific stream of events into
|
|
@@ -618,7 +618,9 @@ class BedrockStreamedResponse(StreamedResponse):
|
|
|
618
618
|
UserWarning,
|
|
619
619
|
)
|
|
620
620
|
if 'text' in delta:
|
|
621
|
-
|
|
621
|
+
maybe_event = self._parts_manager.handle_text_delta(vendor_part_id=index, content=delta['text'])
|
|
622
|
+
if maybe_event is not None:
|
|
623
|
+
yield maybe_event
|
|
622
624
|
if 'toolUse' in delta:
|
|
623
625
|
tool_use = delta['toolUse']
|
|
624
626
|
maybe_event = self._parts_manager.handle_tool_call_delta(
|
pydantic_ai/models/cohere.py
CHANGED
|
@@ -38,15 +38,15 @@ try:
|
|
|
38
38
|
AssistantChatMessageV2,
|
|
39
39
|
AsyncClientV2,
|
|
40
40
|
ChatMessageV2,
|
|
41
|
-
ChatResponse,
|
|
42
41
|
SystemChatMessageV2,
|
|
43
|
-
|
|
42
|
+
TextAssistantMessageV2ContentItem,
|
|
44
43
|
ToolCallV2,
|
|
45
44
|
ToolCallV2Function,
|
|
46
45
|
ToolChatMessageV2,
|
|
47
46
|
ToolV2,
|
|
48
47
|
ToolV2Function,
|
|
49
48
|
UserChatMessageV2,
|
|
49
|
+
V2ChatResponse,
|
|
50
50
|
)
|
|
51
51
|
from cohere.core.api_error import ApiError
|
|
52
52
|
from cohere.v2.client import OMIT
|
|
@@ -164,7 +164,7 @@ class CohereModel(Model):
|
|
|
164
164
|
messages: list[ModelMessage],
|
|
165
165
|
model_settings: CohereModelSettings,
|
|
166
166
|
model_request_parameters: ModelRequestParameters,
|
|
167
|
-
) ->
|
|
167
|
+
) -> V2ChatResponse:
|
|
168
168
|
tools = self._get_tools(model_request_parameters)
|
|
169
169
|
cohere_messages = self._map_messages(messages)
|
|
170
170
|
try:
|
|
@@ -185,7 +185,7 @@ class CohereModel(Model):
|
|
|
185
185
|
raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e
|
|
186
186
|
raise # pragma: no cover
|
|
187
187
|
|
|
188
|
-
def _process_response(self, response:
|
|
188
|
+
def _process_response(self, response: V2ChatResponse) -> ModelResponse:
|
|
189
189
|
"""Process a non-streamed response, and prepare a message to return."""
|
|
190
190
|
parts: list[ModelResponsePart] = []
|
|
191
191
|
if response.message.content is not None and len(response.message.content) > 0:
|
|
@@ -227,7 +227,7 @@ class CohereModel(Model):
|
|
|
227
227
|
assert_never(item)
|
|
228
228
|
message_param = AssistantChatMessageV2(role='assistant')
|
|
229
229
|
if texts:
|
|
230
|
-
message_param.content = [
|
|
230
|
+
message_param.content = [TextAssistantMessageV2ContentItem(text='\n\n'.join(texts))]
|
|
231
231
|
if tool_calls:
|
|
232
232
|
message_param.tool_calls = tool_calls
|
|
233
233
|
cohere_messages.append(message_param)
|
|
@@ -294,7 +294,7 @@ class CohereModel(Model):
|
|
|
294
294
|
assert_never(part)
|
|
295
295
|
|
|
296
296
|
|
|
297
|
-
def _map_usage(response:
|
|
297
|
+
def _map_usage(response: V2ChatResponse) -> usage.Usage:
|
|
298
298
|
u = response.usage
|
|
299
299
|
if u is None:
|
|
300
300
|
return usage.Usage()
|
pydantic_ai/models/function.py
CHANGED
|
@@ -16,9 +16,7 @@ from pydantic_ai.profiles import ModelProfileSpec
|
|
|
16
16
|
from .. import _utils, usage
|
|
17
17
|
from .._utils import PeekableAsyncStream
|
|
18
18
|
from ..messages import (
|
|
19
|
-
AudioUrl,
|
|
20
19
|
BinaryContent,
|
|
21
|
-
ImageUrl,
|
|
22
20
|
ModelMessage,
|
|
23
21
|
ModelRequest,
|
|
24
22
|
ModelResponse,
|
|
@@ -266,7 +264,9 @@ class FunctionStreamedResponse(StreamedResponse):
|
|
|
266
264
|
if isinstance(item, str):
|
|
267
265
|
response_tokens = _estimate_string_tokens(item)
|
|
268
266
|
self._usage += usage.Usage(response_tokens=response_tokens, total_tokens=response_tokens)
|
|
269
|
-
|
|
267
|
+
maybe_event = self._parts_manager.handle_text_delta(vendor_part_id='content', content=item)
|
|
268
|
+
if maybe_event is not None: # pragma: no branch
|
|
269
|
+
yield maybe_event
|
|
270
270
|
elif isinstance(item, dict) and item:
|
|
271
271
|
for dtc_index, delta in item.items():
|
|
272
272
|
if isinstance(delta, DeltaThinkingPart):
|
|
@@ -288,7 +288,7 @@ class FunctionStreamedResponse(StreamedResponse):
|
|
|
288
288
|
args=delta.json_args,
|
|
289
289
|
tool_call_id=delta.tool_call_id,
|
|
290
290
|
)
|
|
291
|
-
if maybe_event is not None:
|
|
291
|
+
if maybe_event is not None: # pragma: no branch
|
|
292
292
|
yield maybe_event
|
|
293
293
|
else:
|
|
294
294
|
assert_never(delta)
|
|
@@ -345,18 +345,19 @@ def _estimate_usage(messages: Iterable[ModelMessage]) -> usage.Usage:
|
|
|
345
345
|
def _estimate_string_tokens(content: str | Sequence[UserContent]) -> int:
|
|
346
346
|
if not content:
|
|
347
347
|
return 0
|
|
348
|
+
|
|
348
349
|
if isinstance(content, str):
|
|
349
|
-
return len(
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
350
|
+
return len(_TOKEN_SPLIT_RE.split(content.strip()))
|
|
351
|
+
|
|
352
|
+
tokens = 0
|
|
353
|
+
for part in content:
|
|
354
|
+
if isinstance(part, str):
|
|
355
|
+
tokens += len(_TOKEN_SPLIT_RE.split(part.strip()))
|
|
356
|
+
elif isinstance(part, BinaryContent):
|
|
357
|
+
tokens += len(part.data)
|
|
358
|
+
# TODO(Marcelo): We need to study how we can estimate the tokens for AudioUrl or ImageUrl.
|
|
359
|
+
|
|
360
|
+
return tokens
|
|
361
|
+
|
|
362
|
+
|
|
363
|
+
_TOKEN_SPLIT_RE = re.compile(r'[\s",.:]+')
|
pydantic_ai/models/gemini.py
CHANGED
|
@@ -438,7 +438,11 @@ class GeminiStreamedResponse(StreamedResponse):
|
|
|
438
438
|
if 'text' in gemini_part:
|
|
439
439
|
# Using vendor_part_id=None means we can produce multiple text parts if their deltas are sprinkled
|
|
440
440
|
# amongst the tool call deltas
|
|
441
|
-
|
|
441
|
+
maybe_event = self._parts_manager.handle_text_delta(
|
|
442
|
+
vendor_part_id=None, content=gemini_part['text']
|
|
443
|
+
)
|
|
444
|
+
if maybe_event is not None: # pragma: no branch
|
|
445
|
+
yield maybe_event
|
|
442
446
|
|
|
443
447
|
elif 'function_call' in gemini_part:
|
|
444
448
|
# Here, we assume all function_call parts are complete and don't have deltas.
|
pydantic_ai/models/google.py
CHANGED
|
@@ -411,7 +411,12 @@ class GoogleModel(Model):
|
|
|
411
411
|
file_data_dict['video_metadata'] = item.vendor_metadata
|
|
412
412
|
content.append(file_data_dict) # type: ignore
|
|
413
413
|
elif isinstance(item, FileUrl):
|
|
414
|
-
if
|
|
414
|
+
if item.force_download or (
|
|
415
|
+
# google-gla does not support passing file urls directly, except for youtube videos
|
|
416
|
+
# (see above) and files uploaded to the file API (which cannot be downloaded anyway)
|
|
417
|
+
self.system == 'google-gla'
|
|
418
|
+
and not item.url.startswith(r'https://generativelanguage.googleapis.com/v1beta/files')
|
|
419
|
+
):
|
|
415
420
|
downloaded_item = await download_item(item, data_format='base64')
|
|
416
421
|
inline_data = {'data': downloaded_item['data'], 'mime_type': downloaded_item['data_type']}
|
|
417
422
|
content.append({'inline_data': inline_data}) # type: ignore
|
|
@@ -453,7 +458,9 @@ class GeminiStreamedResponse(StreamedResponse):
|
|
|
453
458
|
if part.thought:
|
|
454
459
|
yield self._parts_manager.handle_thinking_delta(vendor_part_id='thinking', content=part.text)
|
|
455
460
|
else:
|
|
456
|
-
|
|
461
|
+
maybe_event = self._parts_manager.handle_text_delta(vendor_part_id='content', content=part.text)
|
|
462
|
+
if maybe_event is not None: # pragma: no branch
|
|
463
|
+
yield maybe_event
|
|
457
464
|
elif part.function_call:
|
|
458
465
|
maybe_event = self._parts_manager.handle_tool_call_delta(
|
|
459
466
|
vendor_part_id=uuid4(),
|
pydantic_ai/models/groq.py
CHANGED
|
@@ -415,7 +415,11 @@ class GroqStreamedResponse(StreamedResponse):
|
|
|
415
415
|
# Handle the text part of the response
|
|
416
416
|
content = choice.delta.content
|
|
417
417
|
if content is not None:
|
|
418
|
-
|
|
418
|
+
maybe_event = self._parts_manager.handle_text_delta(
|
|
419
|
+
vendor_part_id='content', content=content, extract_think_tags=True
|
|
420
|
+
)
|
|
421
|
+
if maybe_event is not None: # pragma: no branch
|
|
422
|
+
yield maybe_event
|
|
419
423
|
|
|
420
424
|
# Handle the tool calls
|
|
421
425
|
for dtc in choice.delta.tool_calls or []:
|
|
@@ -444,7 +448,7 @@ def _map_usage(completion: chat.ChatCompletionChunk | chat.ChatCompletion) -> us
|
|
|
444
448
|
if isinstance(completion, chat.ChatCompletion):
|
|
445
449
|
response_usage = completion.usage
|
|
446
450
|
elif completion.x_groq is not None:
|
|
447
|
-
response_usage = completion.x_groq.usage
|
|
451
|
+
response_usage = completion.x_groq.usage
|
|
448
452
|
|
|
449
453
|
if response_usage is None:
|
|
450
454
|
return usage.Usage()
|
|
@@ -426,8 +426,12 @@ class HuggingFaceStreamedResponse(StreamedResponse):
|
|
|
426
426
|
|
|
427
427
|
# Handle the text part of the response
|
|
428
428
|
content = choice.delta.content
|
|
429
|
-
if content
|
|
430
|
-
|
|
429
|
+
if content:
|
|
430
|
+
maybe_event = self._parts_manager.handle_text_delta(
|
|
431
|
+
vendor_part_id='content', content=content, extract_think_tags=True
|
|
432
|
+
)
|
|
433
|
+
if maybe_event is not None: # pragma: no branch
|
|
434
|
+
yield maybe_event
|
|
431
435
|
|
|
432
436
|
for dtc in choice.delta.tool_calls or []:
|
|
433
437
|
maybe_event = self._parts_manager.handle_tool_call_delta(
|