pydantic-ai-slim 0.0.25__py3-none-any.whl → 0.0.27__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/__init__.py +22 -4
- pydantic_ai/_agent_graph.py +15 -12
- pydantic_ai/agent.py +13 -13
- pydantic_ai/exceptions.py +42 -1
- pydantic_ai/messages.py +90 -1
- pydantic_ai/models/anthropic.py +58 -28
- pydantic_ai/models/cohere.py +22 -13
- pydantic_ai/models/fallback.py +116 -0
- pydantic_ai/models/function.py +28 -10
- pydantic_ai/models/gemini.py +78 -10
- pydantic_ai/models/groq.py +59 -27
- pydantic_ai/models/mistral.py +50 -15
- pydantic_ai/models/openai.py +84 -30
- pydantic_ai/tools.py +2 -2
- {pydantic_ai_slim-0.0.25.dist-info → pydantic_ai_slim-0.0.27.dist-info}/METADATA +3 -2
- pydantic_ai_slim-0.0.27.dist-info/RECORD +33 -0
- pydantic_ai_slim-0.0.25.dist-info/RECORD +0 -32
- {pydantic_ai_slim-0.0.25.dist-info → pydantic_ai_slim-0.0.27.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
from __future__ import annotations as _annotations
|
|
2
|
+
|
|
3
|
+
from collections.abc import AsyncIterator
|
|
4
|
+
from contextlib import AsyncExitStack, asynccontextmanager
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
from typing import TYPE_CHECKING, Callable
|
|
7
|
+
|
|
8
|
+
from ..exceptions import FallbackExceptionGroup, ModelHTTPError
|
|
9
|
+
from . import KnownModelName, Model, ModelRequestParameters, StreamedResponse, infer_model
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
from ..messages import ModelMessage, ModelResponse
|
|
13
|
+
from ..settings import ModelSettings
|
|
14
|
+
from ..usage import Usage
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass(init=False)
|
|
18
|
+
class FallbackModel(Model):
|
|
19
|
+
"""A model that uses one or more fallback models upon failure.
|
|
20
|
+
|
|
21
|
+
Apart from `__init__`, all methods are private or match those of the base class.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
models: list[Model]
|
|
25
|
+
|
|
26
|
+
_model_name: str = field(repr=False)
|
|
27
|
+
_fallback_on: Callable[[Exception], bool]
|
|
28
|
+
|
|
29
|
+
def __init__(
|
|
30
|
+
self,
|
|
31
|
+
default_model: Model | KnownModelName,
|
|
32
|
+
*fallback_models: Model | KnownModelName,
|
|
33
|
+
fallback_on: Callable[[Exception], bool] | tuple[type[Exception], ...] = (ModelHTTPError,),
|
|
34
|
+
):
|
|
35
|
+
"""Initialize a fallback model instance.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
default_model: The name or instance of the default model to use.
|
|
39
|
+
fallback_models: The names or instances of the fallback models to use upon failure.
|
|
40
|
+
fallback_on: A callable or tuple of exceptions that should trigger a fallback.
|
|
41
|
+
"""
|
|
42
|
+
self.models = [infer_model(default_model), *[infer_model(m) for m in fallback_models]]
|
|
43
|
+
self._model_name = f'FallBackModel[{", ".join(model.model_name for model in self.models)}]'
|
|
44
|
+
|
|
45
|
+
if isinstance(fallback_on, tuple):
|
|
46
|
+
self._fallback_on = _default_fallback_condition_factory(fallback_on)
|
|
47
|
+
else:
|
|
48
|
+
self._fallback_on = fallback_on
|
|
49
|
+
|
|
50
|
+
async def request(
|
|
51
|
+
self,
|
|
52
|
+
messages: list[ModelMessage],
|
|
53
|
+
model_settings: ModelSettings | None,
|
|
54
|
+
model_request_parameters: ModelRequestParameters,
|
|
55
|
+
) -> tuple[ModelResponse, Usage]:
|
|
56
|
+
"""Try each model in sequence until one succeeds.
|
|
57
|
+
|
|
58
|
+
In case of failure, raise a FallbackExceptionGroup with all exceptions.
|
|
59
|
+
"""
|
|
60
|
+
exceptions: list[Exception] = []
|
|
61
|
+
|
|
62
|
+
for model in self.models:
|
|
63
|
+
try:
|
|
64
|
+
return await model.request(messages, model_settings, model_request_parameters)
|
|
65
|
+
except Exception as exc:
|
|
66
|
+
if self._fallback_on(exc):
|
|
67
|
+
exceptions.append(exc)
|
|
68
|
+
continue
|
|
69
|
+
raise exc
|
|
70
|
+
|
|
71
|
+
raise FallbackExceptionGroup('All models from FallbackModel failed', exceptions)
|
|
72
|
+
|
|
73
|
+
@asynccontextmanager
|
|
74
|
+
async def request_stream(
|
|
75
|
+
self,
|
|
76
|
+
messages: list[ModelMessage],
|
|
77
|
+
model_settings: ModelSettings | None,
|
|
78
|
+
model_request_parameters: ModelRequestParameters,
|
|
79
|
+
) -> AsyncIterator[StreamedResponse]:
|
|
80
|
+
"""Try each model in sequence until one succeeds."""
|
|
81
|
+
exceptions: list[Exception] = []
|
|
82
|
+
|
|
83
|
+
for model in self.models:
|
|
84
|
+
async with AsyncExitStack() as stack:
|
|
85
|
+
try:
|
|
86
|
+
response = await stack.enter_async_context(
|
|
87
|
+
model.request_stream(messages, model_settings, model_request_parameters)
|
|
88
|
+
)
|
|
89
|
+
except Exception as exc:
|
|
90
|
+
if self._fallback_on(exc):
|
|
91
|
+
exceptions.append(exc)
|
|
92
|
+
continue
|
|
93
|
+
raise exc
|
|
94
|
+
yield response
|
|
95
|
+
return
|
|
96
|
+
|
|
97
|
+
raise FallbackExceptionGroup('All models from FallbackModel failed', exceptions)
|
|
98
|
+
|
|
99
|
+
@property
|
|
100
|
+
def model_name(self) -> str:
|
|
101
|
+
"""The model name."""
|
|
102
|
+
return self._model_name
|
|
103
|
+
|
|
104
|
+
@property
|
|
105
|
+
def system(self) -> str | None:
|
|
106
|
+
"""The system / model provider, n/a for fallback models."""
|
|
107
|
+
return None
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def _default_fallback_condition_factory(exceptions: tuple[type[Exception], ...]) -> Callable[[Exception], bool]:
|
|
111
|
+
"""Create a default fallback condition for the given exceptions."""
|
|
112
|
+
|
|
113
|
+
def fallback_condition(exception: Exception) -> bool:
|
|
114
|
+
return isinstance(exception, exceptions)
|
|
115
|
+
|
|
116
|
+
return fallback_condition
|
pydantic_ai/models/function.py
CHANGED
|
@@ -2,7 +2,7 @@ from __future__ import annotations as _annotations
|
|
|
2
2
|
|
|
3
3
|
import inspect
|
|
4
4
|
import re
|
|
5
|
-
from collections.abc import AsyncIterator, Awaitable, Iterable
|
|
5
|
+
from collections.abc import AsyncIterator, Awaitable, Iterable, Sequence
|
|
6
6
|
from contextlib import asynccontextmanager
|
|
7
7
|
from dataclasses import dataclass, field
|
|
8
8
|
from datetime import datetime
|
|
@@ -14,6 +14,9 @@ from typing_extensions import TypeAlias, assert_never, overload
|
|
|
14
14
|
from .. import _utils, usage
|
|
15
15
|
from .._utils import PeekableAsyncStream
|
|
16
16
|
from ..messages import (
|
|
17
|
+
AudioUrl,
|
|
18
|
+
BinaryContent,
|
|
19
|
+
ImageUrl,
|
|
17
20
|
ModelMessage,
|
|
18
21
|
ModelRequest,
|
|
19
22
|
ModelResponse,
|
|
@@ -23,6 +26,7 @@ from ..messages import (
|
|
|
23
26
|
TextPart,
|
|
24
27
|
ToolCallPart,
|
|
25
28
|
ToolReturnPart,
|
|
29
|
+
UserContent,
|
|
26
30
|
UserPromptPart,
|
|
27
31
|
)
|
|
28
32
|
from ..settings import ModelSettings
|
|
@@ -44,15 +48,23 @@ class FunctionModel(Model):
|
|
|
44
48
|
_system: str | None = field(default=None, repr=False)
|
|
45
49
|
|
|
46
50
|
@overload
|
|
47
|
-
def __init__(self, function: FunctionDef) -> None: ...
|
|
51
|
+
def __init__(self, function: FunctionDef, *, model_name: str | None = None) -> None: ...
|
|
48
52
|
|
|
49
53
|
@overload
|
|
50
|
-
def __init__(self, *, stream_function: StreamFunctionDef) -> None: ...
|
|
54
|
+
def __init__(self, *, stream_function: StreamFunctionDef, model_name: str | None = None) -> None: ...
|
|
51
55
|
|
|
52
56
|
@overload
|
|
53
|
-
def __init__(
|
|
57
|
+
def __init__(
|
|
58
|
+
self, function: FunctionDef, *, stream_function: StreamFunctionDef, model_name: str | None = None
|
|
59
|
+
) -> None: ...
|
|
54
60
|
|
|
55
|
-
def __init__(
|
|
61
|
+
def __init__(
|
|
62
|
+
self,
|
|
63
|
+
function: FunctionDef | None = None,
|
|
64
|
+
*,
|
|
65
|
+
stream_function: StreamFunctionDef | None = None,
|
|
66
|
+
model_name: str | None = None,
|
|
67
|
+
):
|
|
56
68
|
"""Initialize a `FunctionModel`.
|
|
57
69
|
|
|
58
70
|
Either `function` or `stream_function` must be provided, providing both is allowed.
|
|
@@ -60,6 +72,7 @@ class FunctionModel(Model):
|
|
|
60
72
|
Args:
|
|
61
73
|
function: The function to call for non-streamed requests.
|
|
62
74
|
stream_function: The function to call for streamed requests.
|
|
75
|
+
model_name: The name of the model. If not provided, a name is generated from the function names.
|
|
63
76
|
"""
|
|
64
77
|
if function is None and stream_function is None:
|
|
65
78
|
raise TypeError('Either `function` or `stream_function` must be provided')
|
|
@@ -68,7 +81,7 @@ class FunctionModel(Model):
|
|
|
68
81
|
|
|
69
82
|
function_name = self.function.__name__ if self.function is not None else ''
|
|
70
83
|
stream_function_name = self.stream_function.__name__ if self.stream_function is not None else ''
|
|
71
|
-
self._model_name = f'function:{function_name}:{stream_function_name}'
|
|
84
|
+
self._model_name = model_name or f'function:{function_name}:{stream_function_name}'
|
|
72
85
|
|
|
73
86
|
async def request(
|
|
74
87
|
self,
|
|
@@ -91,7 +104,7 @@ class FunctionModel(Model):
|
|
|
91
104
|
response_ = await _utils.run_in_executor(self.function, messages, agent_info)
|
|
92
105
|
assert isinstance(response_, ModelResponse), response_
|
|
93
106
|
response = response_
|
|
94
|
-
response.model_name =
|
|
107
|
+
response.model_name = self._model_name
|
|
95
108
|
# TODO is `messages` right here? Should it just be new messages?
|
|
96
109
|
return response, _estimate_usage(chain(messages, [response]))
|
|
97
110
|
|
|
@@ -119,7 +132,7 @@ class FunctionModel(Model):
|
|
|
119
132
|
if isinstance(first, _utils.Unset):
|
|
120
133
|
raise ValueError('Stream function must return at least one item')
|
|
121
134
|
|
|
122
|
-
yield FunctionStreamedResponse(_model_name=
|
|
135
|
+
yield FunctionStreamedResponse(_model_name=self._model_name, _iter=response_stream)
|
|
123
136
|
|
|
124
137
|
@property
|
|
125
138
|
def model_name(self) -> str:
|
|
@@ -262,7 +275,12 @@ def _estimate_usage(messages: Iterable[ModelMessage]) -> usage.Usage:
|
|
|
262
275
|
)
|
|
263
276
|
|
|
264
277
|
|
|
265
|
-
def _estimate_string_tokens(content: str) -> int:
|
|
278
|
+
def _estimate_string_tokens(content: str | Sequence[UserContent]) -> int:
|
|
266
279
|
if not content:
|
|
267
280
|
return 0
|
|
268
|
-
|
|
281
|
+
if isinstance(content, str):
|
|
282
|
+
return len(re.split(r'[\s",.:]+', content.strip()))
|
|
283
|
+
# TODO(Marcelo): We need to study how we can estimate the tokens for these types of content.
|
|
284
|
+
else: # pragma: no cover
|
|
285
|
+
assert isinstance(content, (AudioUrl, ImageUrl, BinaryContent))
|
|
286
|
+
return 0
|
pydantic_ai/models/gemini.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
|
+
import base64
|
|
3
4
|
import os
|
|
4
5
|
import re
|
|
5
6
|
from collections.abc import AsyncIterator, Sequence
|
|
@@ -14,8 +15,11 @@ import pydantic
|
|
|
14
15
|
from httpx import USE_CLIENT_DEFAULT, AsyncClient as AsyncHTTPClient, Response as HTTPResponse
|
|
15
16
|
from typing_extensions import NotRequired, TypedDict, assert_never
|
|
16
17
|
|
|
17
|
-
from .. import UnexpectedModelBehavior,
|
|
18
|
+
from .. import ModelHTTPError, UnexpectedModelBehavior, UserError, _utils, usage
|
|
18
19
|
from ..messages import (
|
|
20
|
+
AudioUrl,
|
|
21
|
+
BinaryContent,
|
|
22
|
+
ImageUrl,
|
|
19
23
|
ModelMessage,
|
|
20
24
|
ModelRequest,
|
|
21
25
|
ModelResponse,
|
|
@@ -108,7 +112,7 @@ class GeminiModel(Model):
|
|
|
108
112
|
if env_api_key := os.getenv('GEMINI_API_KEY'):
|
|
109
113
|
api_key = env_api_key
|
|
110
114
|
else:
|
|
111
|
-
raise
|
|
115
|
+
raise UserError('API key must be provided or set in the GEMINI_API_KEY environment variable')
|
|
112
116
|
self.http_client = http_client or cached_async_http_client()
|
|
113
117
|
self._auth = ApiKeyAuth(api_key)
|
|
114
118
|
self._url = url_template.format(model=model_name)
|
|
@@ -185,7 +189,7 @@ class GeminiModel(Model):
|
|
|
185
189
|
) -> AsyncIterator[HTTPResponse]:
|
|
186
190
|
tools = self._get_tools(model_request_parameters)
|
|
187
191
|
tool_config = self._get_tool_config(model_request_parameters, tools)
|
|
188
|
-
sys_prompt_parts, contents = self._message_to_gemini_content(messages)
|
|
192
|
+
sys_prompt_parts, contents = await self._message_to_gemini_content(messages)
|
|
189
193
|
|
|
190
194
|
request_data = _GeminiRequest(contents=contents)
|
|
191
195
|
if sys_prompt_parts:
|
|
@@ -229,9 +233,11 @@ class GeminiModel(Model):
|
|
|
229
233
|
headers=headers,
|
|
230
234
|
timeout=model_settings.get('timeout', USE_CLIENT_DEFAULT),
|
|
231
235
|
) as r:
|
|
232
|
-
if r.status_code != 200:
|
|
236
|
+
if (status_code := r.status_code) != 200:
|
|
233
237
|
await r.aread()
|
|
234
|
-
|
|
238
|
+
if status_code >= 400:
|
|
239
|
+
raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=r.text)
|
|
240
|
+
raise UnexpectedModelBehavior(f'Unexpected response from gemini {status_code}', r.text)
|
|
235
241
|
yield r
|
|
236
242
|
|
|
237
243
|
def _process_response(self, response: _GeminiResponse) -> ModelResponse:
|
|
@@ -269,7 +275,7 @@ class GeminiModel(Model):
|
|
|
269
275
|
return GeminiStreamedResponse(_model_name=self._model_name, _content=content, _stream=aiter_bytes)
|
|
270
276
|
|
|
271
277
|
@classmethod
|
|
272
|
-
def _message_to_gemini_content(
|
|
278
|
+
async def _message_to_gemini_content(
|
|
273
279
|
cls, messages: list[ModelMessage]
|
|
274
280
|
) -> tuple[list[_GeminiTextPart], list[_GeminiContent]]:
|
|
275
281
|
sys_prompt_parts: list[_GeminiTextPart] = []
|
|
@@ -282,7 +288,7 @@ class GeminiModel(Model):
|
|
|
282
288
|
if isinstance(part, SystemPromptPart):
|
|
283
289
|
sys_prompt_parts.append(_GeminiTextPart(text=part.content))
|
|
284
290
|
elif isinstance(part, UserPromptPart):
|
|
285
|
-
message_parts.
|
|
291
|
+
message_parts.extend(await cls._map_user_prompt(part))
|
|
286
292
|
elif isinstance(part, ToolReturnPart):
|
|
287
293
|
message_parts.append(_response_part_from_response(part.tool_name, part.model_response_object()))
|
|
288
294
|
elif isinstance(part, RetryPromptPart):
|
|
@@ -303,6 +309,40 @@ class GeminiModel(Model):
|
|
|
303
309
|
|
|
304
310
|
return sys_prompt_parts, contents
|
|
305
311
|
|
|
312
|
+
@staticmethod
|
|
313
|
+
async def _map_user_prompt(part: UserPromptPart) -> list[_GeminiPartUnion]:
|
|
314
|
+
if isinstance(part.content, str):
|
|
315
|
+
return [{'text': part.content}]
|
|
316
|
+
else:
|
|
317
|
+
content: list[_GeminiPartUnion] = []
|
|
318
|
+
for item in part.content:
|
|
319
|
+
if isinstance(item, str):
|
|
320
|
+
content.append({'text': item})
|
|
321
|
+
elif isinstance(item, BinaryContent):
|
|
322
|
+
base64_encoded = base64.b64encode(item.data).decode('utf-8')
|
|
323
|
+
content.append(
|
|
324
|
+
_GeminiInlineDataPart(inline_data={'data': base64_encoded, 'mime_type': item.media_type})
|
|
325
|
+
)
|
|
326
|
+
elif isinstance(item, (AudioUrl, ImageUrl)):
|
|
327
|
+
try:
|
|
328
|
+
content.append(
|
|
329
|
+
_GeminiFileDataPart(file_data={'file_uri': item.url, 'mime_type': item.media_type})
|
|
330
|
+
)
|
|
331
|
+
except ValueError:
|
|
332
|
+
# Download the file if can't find the mime type.
|
|
333
|
+
client = cached_async_http_client()
|
|
334
|
+
response = await client.get(item.url, follow_redirects=True)
|
|
335
|
+
response.raise_for_status()
|
|
336
|
+
base64_encoded = base64.b64encode(response.content).decode('utf-8')
|
|
337
|
+
content.append(
|
|
338
|
+
_GeminiInlineDataPart(
|
|
339
|
+
inline_data={'data': base64_encoded, 'mime_type': response.headers['Content-Type']}
|
|
340
|
+
)
|
|
341
|
+
)
|
|
342
|
+
else:
|
|
343
|
+
assert_never(item)
|
|
344
|
+
return content
|
|
345
|
+
|
|
306
346
|
|
|
307
347
|
class AuthProtocol(Protocol):
|
|
308
348
|
"""Abstract definition for Gemini authentication."""
|
|
@@ -494,6 +534,28 @@ class _GeminiTextPart(TypedDict):
|
|
|
494
534
|
text: str
|
|
495
535
|
|
|
496
536
|
|
|
537
|
+
class _GeminiInlineData(TypedDict):
|
|
538
|
+
data: str
|
|
539
|
+
mime_type: Annotated[str, pydantic.Field(alias='mimeType')]
|
|
540
|
+
|
|
541
|
+
|
|
542
|
+
class _GeminiInlineDataPart(TypedDict):
|
|
543
|
+
"""See <https://ai.google.dev/api/caching#Blob>."""
|
|
544
|
+
|
|
545
|
+
inline_data: Annotated[_GeminiInlineData, pydantic.Field(alias='inlineData')]
|
|
546
|
+
|
|
547
|
+
|
|
548
|
+
class _GeminiFileData(TypedDict):
|
|
549
|
+
"""See <https://ai.google.dev/api/caching#FileData>."""
|
|
550
|
+
|
|
551
|
+
file_uri: Annotated[str, pydantic.Field(alias='fileUri')]
|
|
552
|
+
mime_type: Annotated[str, pydantic.Field(alias='mimeType')]
|
|
553
|
+
|
|
554
|
+
|
|
555
|
+
class _GeminiFileDataPart(TypedDict):
|
|
556
|
+
file_data: Annotated[_GeminiFileData, pydantic.Field(alias='fileData')]
|
|
557
|
+
|
|
558
|
+
|
|
497
559
|
class _GeminiFunctionCallPart(TypedDict):
|
|
498
560
|
function_call: Annotated[_GeminiFunctionCall, pydantic.Field(alias='functionCall')]
|
|
499
561
|
|
|
@@ -517,7 +579,7 @@ def _process_response_from_parts(
|
|
|
517
579
|
)
|
|
518
580
|
)
|
|
519
581
|
elif 'function_response' in part:
|
|
520
|
-
raise
|
|
582
|
+
raise UnexpectedModelBehavior(
|
|
521
583
|
f'Unsupported response from Gemini, expected all parts to be function calls or text, got: {part!r}'
|
|
522
584
|
)
|
|
523
585
|
return ModelResponse(parts=items, model_name=model_name, timestamp=timestamp or _utils.now_utc())
|
|
@@ -549,6 +611,10 @@ def _part_discriminator(v: Any) -> str:
|
|
|
549
611
|
if isinstance(v, dict):
|
|
550
612
|
if 'text' in v:
|
|
551
613
|
return 'text'
|
|
614
|
+
elif 'inlineData' in v:
|
|
615
|
+
return 'inline_data'
|
|
616
|
+
elif 'fileData' in v:
|
|
617
|
+
return 'file_data'
|
|
552
618
|
elif 'functionCall' in v or 'function_call' in v:
|
|
553
619
|
return 'function_call'
|
|
554
620
|
elif 'functionResponse' in v or 'function_response' in v:
|
|
@@ -564,6 +630,8 @@ _GeminiPartUnion = Annotated[
|
|
|
564
630
|
Annotated[_GeminiTextPart, pydantic.Tag('text')],
|
|
565
631
|
Annotated[_GeminiFunctionCallPart, pydantic.Tag('function_call')],
|
|
566
632
|
Annotated[_GeminiFunctionResponsePart, pydantic.Tag('function_response')],
|
|
633
|
+
Annotated[_GeminiInlineDataPart, pydantic.Tag('inline_data')],
|
|
634
|
+
Annotated[_GeminiFileDataPart, pydantic.Tag('file_data')],
|
|
567
635
|
],
|
|
568
636
|
pydantic.Discriminator(_part_discriminator),
|
|
569
637
|
]
|
|
@@ -726,7 +794,7 @@ class _GeminiJsonSchema:
|
|
|
726
794
|
# noinspection PyTypeChecker
|
|
727
795
|
key = re.sub(r'^#/\$defs/', '', ref)
|
|
728
796
|
if key in refs_stack:
|
|
729
|
-
raise
|
|
797
|
+
raise UserError('Recursive `$ref`s in JSON Schema are not supported by Gemini')
|
|
730
798
|
refs_stack += (key,)
|
|
731
799
|
schema_def = self.defs[key]
|
|
732
800
|
self._simplify(schema_def, refs_stack)
|
|
@@ -760,7 +828,7 @@ class _GeminiJsonSchema:
|
|
|
760
828
|
def _object(self, schema: dict[str, Any], refs_stack: tuple[str, ...]) -> None:
|
|
761
829
|
ad_props = schema.pop('additionalProperties', None)
|
|
762
830
|
if ad_props:
|
|
763
|
-
raise
|
|
831
|
+
raise UserError('Additional properties in JSON Schema are not supported by Gemini')
|
|
764
832
|
|
|
765
833
|
if properties := schema.get('properties'): # pragma: no branch
|
|
766
834
|
for value in properties.values():
|
pydantic_ai/models/groq.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
|
+
import base64
|
|
3
4
|
from collections.abc import AsyncIterable, AsyncIterator, Iterable
|
|
4
5
|
from contextlib import asynccontextmanager
|
|
5
6
|
from dataclasses import dataclass, field
|
|
@@ -10,9 +11,11 @@ from typing import Literal, Union, cast, overload
|
|
|
10
11
|
from httpx import AsyncClient as AsyncHTTPClient
|
|
11
12
|
from typing_extensions import assert_never
|
|
12
13
|
|
|
13
|
-
from .. import UnexpectedModelBehavior, _utils, usage
|
|
14
|
+
from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage
|
|
14
15
|
from .._utils import guard_tool_call_id as _guard_tool_call_id
|
|
15
16
|
from ..messages import (
|
|
17
|
+
BinaryContent,
|
|
18
|
+
ImageUrl,
|
|
16
19
|
ModelMessage,
|
|
17
20
|
ModelRequest,
|
|
18
21
|
ModelResponse,
|
|
@@ -36,9 +39,9 @@ from . import (
|
|
|
36
39
|
)
|
|
37
40
|
|
|
38
41
|
try:
|
|
39
|
-
from groq import NOT_GIVEN, AsyncGroq, AsyncStream
|
|
42
|
+
from groq import NOT_GIVEN, APIStatusError, AsyncGroq, AsyncStream
|
|
40
43
|
from groq.types import chat
|
|
41
|
-
from groq.types.chat import
|
|
44
|
+
from groq.types.chat.chat_completion_content_part_image_param import ImageURL
|
|
42
45
|
except ImportError as _import_error:
|
|
43
46
|
raise ImportError(
|
|
44
47
|
'Please install `groq` to use the Groq model, '
|
|
@@ -163,7 +166,7 @@ class GroqModel(Model):
|
|
|
163
166
|
stream: Literal[True],
|
|
164
167
|
model_settings: GroqModelSettings,
|
|
165
168
|
model_request_parameters: ModelRequestParameters,
|
|
166
|
-
) -> AsyncStream[ChatCompletionChunk]:
|
|
169
|
+
) -> AsyncStream[chat.ChatCompletionChunk]:
|
|
167
170
|
pass
|
|
168
171
|
|
|
169
172
|
@overload
|
|
@@ -182,7 +185,7 @@ class GroqModel(Model):
|
|
|
182
185
|
stream: bool,
|
|
183
186
|
model_settings: GroqModelSettings,
|
|
184
187
|
model_request_parameters: ModelRequestParameters,
|
|
185
|
-
) -> chat.ChatCompletion | AsyncStream[ChatCompletionChunk]:
|
|
188
|
+
) -> chat.ChatCompletion | AsyncStream[chat.ChatCompletionChunk]:
|
|
186
189
|
tools = self._get_tools(model_request_parameters)
|
|
187
190
|
# standalone function to make it easier to override
|
|
188
191
|
if not tools:
|
|
@@ -194,23 +197,28 @@ class GroqModel(Model):
|
|
|
194
197
|
|
|
195
198
|
groq_messages = list(chain(*(self._map_message(m) for m in messages)))
|
|
196
199
|
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
200
|
+
try:
|
|
201
|
+
return await self.client.chat.completions.create(
|
|
202
|
+
model=str(self._model_name),
|
|
203
|
+
messages=groq_messages,
|
|
204
|
+
n=1,
|
|
205
|
+
parallel_tool_calls=model_settings.get('parallel_tool_calls', NOT_GIVEN),
|
|
206
|
+
tools=tools or NOT_GIVEN,
|
|
207
|
+
tool_choice=tool_choice or NOT_GIVEN,
|
|
208
|
+
stream=stream,
|
|
209
|
+
max_tokens=model_settings.get('max_tokens', NOT_GIVEN),
|
|
210
|
+
temperature=model_settings.get('temperature', NOT_GIVEN),
|
|
211
|
+
top_p=model_settings.get('top_p', NOT_GIVEN),
|
|
212
|
+
timeout=model_settings.get('timeout', NOT_GIVEN),
|
|
213
|
+
seed=model_settings.get('seed', NOT_GIVEN),
|
|
214
|
+
presence_penalty=model_settings.get('presence_penalty', NOT_GIVEN),
|
|
215
|
+
frequency_penalty=model_settings.get('frequency_penalty', NOT_GIVEN),
|
|
216
|
+
logit_bias=model_settings.get('logit_bias', NOT_GIVEN),
|
|
217
|
+
)
|
|
218
|
+
except APIStatusError as e:
|
|
219
|
+
if (status_code := e.status_code) >= 400:
|
|
220
|
+
raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e
|
|
221
|
+
raise
|
|
214
222
|
|
|
215
223
|
def _process_response(self, response: chat.ChatCompletion) -> ModelResponse:
|
|
216
224
|
"""Process a non-streamed response, and prepare a message to return."""
|
|
@@ -224,7 +232,7 @@ class GroqModel(Model):
|
|
|
224
232
|
items.append(ToolCallPart(tool_name=c.function.name, args=c.function.arguments, tool_call_id=c.id))
|
|
225
233
|
return ModelResponse(items, model_name=response.model, timestamp=timestamp)
|
|
226
234
|
|
|
227
|
-
async def _process_streamed_response(self, response: AsyncStream[ChatCompletionChunk]) -> GroqStreamedResponse:
|
|
235
|
+
async def _process_streamed_response(self, response: AsyncStream[chat.ChatCompletionChunk]) -> GroqStreamedResponse:
|
|
228
236
|
"""Process a streamed response, and prepare a streaming response to return."""
|
|
229
237
|
peekable_response = _utils.PeekableAsyncStream(response)
|
|
230
238
|
first_chunk = await peekable_response.peek()
|
|
@@ -293,7 +301,7 @@ class GroqModel(Model):
|
|
|
293
301
|
if isinstance(part, SystemPromptPart):
|
|
294
302
|
yield chat.ChatCompletionSystemMessageParam(role='system', content=part.content)
|
|
295
303
|
elif isinstance(part, UserPromptPart):
|
|
296
|
-
yield
|
|
304
|
+
yield cls._map_user_prompt(part)
|
|
297
305
|
elif isinstance(part, ToolReturnPart):
|
|
298
306
|
yield chat.ChatCompletionToolMessageParam(
|
|
299
307
|
role='tool',
|
|
@@ -310,13 +318,37 @@ class GroqModel(Model):
|
|
|
310
318
|
content=part.model_response(),
|
|
311
319
|
)
|
|
312
320
|
|
|
321
|
+
@staticmethod
|
|
322
|
+
def _map_user_prompt(part: UserPromptPart) -> chat.ChatCompletionUserMessageParam:
|
|
323
|
+
content: str | list[chat.ChatCompletionContentPartParam]
|
|
324
|
+
if isinstance(part.content, str):
|
|
325
|
+
content = part.content
|
|
326
|
+
else:
|
|
327
|
+
content = []
|
|
328
|
+
for item in part.content:
|
|
329
|
+
if isinstance(item, str):
|
|
330
|
+
content.append(chat.ChatCompletionContentPartTextParam(text=item, type='text'))
|
|
331
|
+
elif isinstance(item, ImageUrl):
|
|
332
|
+
image_url = ImageURL(url=item.url)
|
|
333
|
+
content.append(chat.ChatCompletionContentPartImageParam(image_url=image_url, type='image_url'))
|
|
334
|
+
elif isinstance(item, BinaryContent):
|
|
335
|
+
base64_encoded = base64.b64encode(item.data).decode('utf-8')
|
|
336
|
+
if item.is_image:
|
|
337
|
+
image_url = ImageURL(url=f'data:{item.media_type};base64,{base64_encoded}')
|
|
338
|
+
content.append(chat.ChatCompletionContentPartImageParam(image_url=image_url, type='image_url'))
|
|
339
|
+
else:
|
|
340
|
+
raise RuntimeError('Only images are supported for binary content in Groq.')
|
|
341
|
+
else: # pragma: no cover
|
|
342
|
+
raise RuntimeError(f'Unsupported content type: {type(item)}')
|
|
343
|
+
return chat.ChatCompletionUserMessageParam(role='user', content=content)
|
|
344
|
+
|
|
313
345
|
|
|
314
346
|
@dataclass
|
|
315
347
|
class GroqStreamedResponse(StreamedResponse):
|
|
316
348
|
"""Implementation of `StreamedResponse` for Groq models."""
|
|
317
349
|
|
|
318
350
|
_model_name: GroqModelName
|
|
319
|
-
_response: AsyncIterable[ChatCompletionChunk]
|
|
351
|
+
_response: AsyncIterable[chat.ChatCompletionChunk]
|
|
320
352
|
_timestamp: datetime
|
|
321
353
|
|
|
322
354
|
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
|
|
@@ -355,9 +387,9 @@ class GroqStreamedResponse(StreamedResponse):
|
|
|
355
387
|
return self._timestamp
|
|
356
388
|
|
|
357
389
|
|
|
358
|
-
def _map_usage(completion: ChatCompletionChunk | ChatCompletion) -> usage.Usage:
|
|
390
|
+
def _map_usage(completion: chat.ChatCompletionChunk | chat.ChatCompletion) -> usage.Usage:
|
|
359
391
|
response_usage = None
|
|
360
|
-
if isinstance(completion, ChatCompletion):
|
|
392
|
+
if isinstance(completion, chat.ChatCompletion):
|
|
361
393
|
response_usage = completion.usage
|
|
362
394
|
elif completion.x_groq is not None:
|
|
363
395
|
response_usage = completion.x_groq.usage
|
pydantic_ai/models/mistral.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
|
+
import base64
|
|
3
4
|
import os
|
|
4
5
|
from collections.abc import AsyncIterable, AsyncIterator, Iterable
|
|
5
6
|
from contextlib import asynccontextmanager
|
|
@@ -12,9 +13,11 @@ import pydantic_core
|
|
|
12
13
|
from httpx import AsyncClient as AsyncHTTPClient, Timeout
|
|
13
14
|
from typing_extensions import assert_never
|
|
14
15
|
|
|
15
|
-
from .. import UnexpectedModelBehavior, _utils
|
|
16
|
+
from .. import ModelHTTPError, UnexpectedModelBehavior, _utils
|
|
16
17
|
from .._utils import now_utc as _now_utc
|
|
17
18
|
from ..messages import (
|
|
19
|
+
BinaryContent,
|
|
20
|
+
ImageUrl,
|
|
18
21
|
ModelMessage,
|
|
19
22
|
ModelRequest,
|
|
20
23
|
ModelResponse,
|
|
@@ -45,6 +48,8 @@ try:
|
|
|
45
48
|
Content as MistralContent,
|
|
46
49
|
ContentChunk as MistralContentChunk,
|
|
47
50
|
FunctionCall as MistralFunctionCall,
|
|
51
|
+
ImageURL as MistralImageURL,
|
|
52
|
+
ImageURLChunk as MistralImageURLChunk,
|
|
48
53
|
Mistral,
|
|
49
54
|
OptionalNullable as MistralOptionalNullable,
|
|
50
55
|
TextChunk as MistralTextChunk,
|
|
@@ -54,6 +59,7 @@ try:
|
|
|
54
59
|
ChatCompletionResponse as MistralChatCompletionResponse,
|
|
55
60
|
CompletionEvent as MistralCompletionEvent,
|
|
56
61
|
Messages as MistralMessages,
|
|
62
|
+
SDKError,
|
|
57
63
|
Tool as MistralTool,
|
|
58
64
|
ToolCall as MistralToolCall,
|
|
59
65
|
)
|
|
@@ -179,19 +185,25 @@ class MistralModel(Model):
|
|
|
179
185
|
model_request_parameters: ModelRequestParameters,
|
|
180
186
|
) -> MistralChatCompletionResponse:
|
|
181
187
|
"""Make a non-streaming request to the model."""
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
188
|
+
try:
|
|
189
|
+
response = await self.client.chat.complete_async(
|
|
190
|
+
model=str(self._model_name),
|
|
191
|
+
messages=list(chain(*(self._map_message(m) for m in messages))),
|
|
192
|
+
n=1,
|
|
193
|
+
tools=self._map_function_and_result_tools_definition(model_request_parameters) or UNSET,
|
|
194
|
+
tool_choice=self._get_tool_choice(model_request_parameters),
|
|
195
|
+
stream=False,
|
|
196
|
+
max_tokens=model_settings.get('max_tokens', UNSET),
|
|
197
|
+
temperature=model_settings.get('temperature', UNSET),
|
|
198
|
+
top_p=model_settings.get('top_p', 1),
|
|
199
|
+
timeout_ms=self._get_timeout_ms(model_settings.get('timeout')),
|
|
200
|
+
random_seed=model_settings.get('seed', UNSET),
|
|
201
|
+
)
|
|
202
|
+
except SDKError as e:
|
|
203
|
+
if (status_code := e.status_code) >= 400:
|
|
204
|
+
raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e
|
|
205
|
+
raise
|
|
206
|
+
|
|
195
207
|
assert response, 'A unexpected empty response from Mistral.'
|
|
196
208
|
return response
|
|
197
209
|
|
|
@@ -423,7 +435,7 @@ class MistralModel(Model):
|
|
|
423
435
|
if isinstance(part, SystemPromptPart):
|
|
424
436
|
yield MistralSystemMessage(content=part.content)
|
|
425
437
|
elif isinstance(part, UserPromptPart):
|
|
426
|
-
yield
|
|
438
|
+
yield cls._map_user_prompt(part)
|
|
427
439
|
elif isinstance(part, ToolReturnPart):
|
|
428
440
|
yield MistralToolMessage(
|
|
429
441
|
tool_call_id=part.tool_call_id,
|
|
@@ -460,6 +472,29 @@ class MistralModel(Model):
|
|
|
460
472
|
else:
|
|
461
473
|
assert_never(message)
|
|
462
474
|
|
|
475
|
+
@staticmethod
|
|
476
|
+
def _map_user_prompt(part: UserPromptPart) -> MistralUserMessage:
|
|
477
|
+
content: str | list[MistralContentChunk]
|
|
478
|
+
if isinstance(part.content, str):
|
|
479
|
+
content = part.content
|
|
480
|
+
else:
|
|
481
|
+
content = []
|
|
482
|
+
for item in part.content:
|
|
483
|
+
if isinstance(item, str):
|
|
484
|
+
content.append(MistralTextChunk(text=item))
|
|
485
|
+
elif isinstance(item, ImageUrl):
|
|
486
|
+
content.append(MistralImageURLChunk(image_url=MistralImageURL(url=item.url)))
|
|
487
|
+
elif isinstance(item, BinaryContent):
|
|
488
|
+
base64_encoded = base64.b64encode(item.data).decode('utf-8')
|
|
489
|
+
if item.is_image:
|
|
490
|
+
image_url = MistralImageURL(url=f'data:{item.media_type};base64,{base64_encoded}')
|
|
491
|
+
content.append(MistralImageURLChunk(image_url=image_url, type='image_url'))
|
|
492
|
+
else:
|
|
493
|
+
raise RuntimeError('Only image binary content is supported for Mistral.')
|
|
494
|
+
else: # pragma: no cover
|
|
495
|
+
raise RuntimeError(f'Unsupported content type: {type(item)}')
|
|
496
|
+
return MistralUserMessage(content=content)
|
|
497
|
+
|
|
463
498
|
|
|
464
499
|
MistralToolCallId = Union[str, None]
|
|
465
500
|
|