pydantic-ai-slim 0.2.4__py3-none-any.whl → 0.2.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_agent_graph.py +16 -7
- pydantic_ai/_cli.py +11 -12
- pydantic_ai/_output.py +7 -7
- pydantic_ai/_parts_manager.py +1 -1
- pydantic_ai/agent.py +30 -18
- pydantic_ai/direct.py +2 -0
- pydantic_ai/exceptions.py +2 -2
- pydantic_ai/messages.py +29 -11
- pydantic_ai/models/__init__.py +43 -6
- pydantic_ai/models/anthropic.py +17 -12
- pydantic_ai/models/bedrock.py +10 -9
- pydantic_ai/models/cohere.py +4 -4
- pydantic_ai/models/fallback.py +2 -2
- pydantic_ai/models/function.py +1 -1
- pydantic_ai/models/gemini.py +26 -22
- pydantic_ai/models/google.py +569 -0
- pydantic_ai/models/groq.py +12 -6
- pydantic_ai/models/instrumented.py +43 -33
- pydantic_ai/models/mistral.py +15 -9
- pydantic_ai/models/openai.py +46 -8
- pydantic_ai/models/test.py +1 -1
- pydantic_ai/models/wrapper.py +1 -1
- pydantic_ai/providers/__init__.py +4 -0
- pydantic_ai/providers/google.py +143 -0
- pydantic_ai/providers/google_vertex.py +3 -3
- pydantic_ai/providers/openrouter.py +69 -0
- pydantic_ai/result.py +13 -21
- pydantic_ai/tools.py +34 -2
- pydantic_ai/usage.py +1 -1
- {pydantic_ai_slim-0.2.4.dist-info → pydantic_ai_slim-0.2.6.dist-info}/METADATA +7 -4
- pydantic_ai_slim-0.2.6.dist-info/RECORD +59 -0
- pydantic_ai_slim-0.2.6.dist-info/licenses/LICENSE +21 -0
- pydantic_ai_slim-0.2.4.dist-info/RECORD +0 -55
- {pydantic_ai_slim-0.2.4.dist-info → pydantic_ai_slim-0.2.6.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-0.2.4.dist-info → pydantic_ai_slim-0.2.6.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,569 @@
|
|
|
1
|
+
from __future__ import annotations as _annotations
|
|
2
|
+
|
|
3
|
+
import base64
|
|
4
|
+
import warnings
|
|
5
|
+
from collections.abc import AsyncIterator, Awaitable
|
|
6
|
+
from contextlib import asynccontextmanager
|
|
7
|
+
from dataclasses import dataclass, field, replace
|
|
8
|
+
from datetime import datetime
|
|
9
|
+
from typing import Literal, Union, cast, overload
|
|
10
|
+
from uuid import uuid4
|
|
11
|
+
|
|
12
|
+
from typing_extensions import assert_never
|
|
13
|
+
|
|
14
|
+
from pydantic_ai.providers import Provider
|
|
15
|
+
|
|
16
|
+
from .. import UnexpectedModelBehavior, UserError, _utils, usage
|
|
17
|
+
from ..messages import (
|
|
18
|
+
AudioUrl,
|
|
19
|
+
BinaryContent,
|
|
20
|
+
DocumentUrl,
|
|
21
|
+
ImageUrl,
|
|
22
|
+
ModelMessage,
|
|
23
|
+
ModelRequest,
|
|
24
|
+
ModelResponse,
|
|
25
|
+
ModelResponsePart,
|
|
26
|
+
ModelResponseStreamEvent,
|
|
27
|
+
RetryPromptPart,
|
|
28
|
+
SystemPromptPart,
|
|
29
|
+
TextPart,
|
|
30
|
+
ToolCallPart,
|
|
31
|
+
ToolReturnPart,
|
|
32
|
+
UserPromptPart,
|
|
33
|
+
VideoUrl,
|
|
34
|
+
)
|
|
35
|
+
from ..settings import ModelSettings
|
|
36
|
+
from ..tools import ToolDefinition
|
|
37
|
+
from . import (
|
|
38
|
+
Model,
|
|
39
|
+
ModelRequestParameters,
|
|
40
|
+
StreamedResponse,
|
|
41
|
+
cached_async_http_client,
|
|
42
|
+
check_allow_model_requests,
|
|
43
|
+
get_user_agent,
|
|
44
|
+
)
|
|
45
|
+
from ._json_schema import JsonSchema, WalkJsonSchema
|
|
46
|
+
|
|
47
|
+
try:
|
|
48
|
+
from google import genai
|
|
49
|
+
from google.genai.types import (
|
|
50
|
+
ContentDict,
|
|
51
|
+
ContentUnionDict,
|
|
52
|
+
FunctionCallDict,
|
|
53
|
+
FunctionCallingConfigDict,
|
|
54
|
+
FunctionCallingConfigMode,
|
|
55
|
+
FunctionDeclarationDict,
|
|
56
|
+
GenerateContentConfigDict,
|
|
57
|
+
GenerateContentResponse,
|
|
58
|
+
Part,
|
|
59
|
+
PartDict,
|
|
60
|
+
SafetySettingDict,
|
|
61
|
+
ThinkingConfigDict,
|
|
62
|
+
ToolConfigDict,
|
|
63
|
+
ToolDict,
|
|
64
|
+
ToolListUnionDict,
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
from ..providers.google import GoogleProvider
|
|
68
|
+
except ImportError as _import_error:
|
|
69
|
+
raise ImportError(
|
|
70
|
+
'Please install `google-genai` to use the Google model, '
|
|
71
|
+
'you can use the `google` optional group — `pip install "pydantic-ai-slim[google]"`'
|
|
72
|
+
) from _import_error
|
|
73
|
+
|
|
74
|
+
LatestGoogleModelNames = Literal[
|
|
75
|
+
'gemini-1.5-flash',
|
|
76
|
+
'gemini-1.5-flash-8b',
|
|
77
|
+
'gemini-1.5-pro',
|
|
78
|
+
'gemini-1.0-pro',
|
|
79
|
+
'gemini-2.0-flash-exp',
|
|
80
|
+
'gemini-2.0-flash-thinking-exp-01-21',
|
|
81
|
+
'gemini-exp-1206',
|
|
82
|
+
'gemini-2.0-flash',
|
|
83
|
+
'gemini-2.0-flash-lite-preview-02-05',
|
|
84
|
+
'gemini-2.0-pro-exp-02-05',
|
|
85
|
+
'gemini-2.5-flash-preview-04-17',
|
|
86
|
+
'gemini-2.5-pro-exp-03-25',
|
|
87
|
+
'gemini-2.5-pro-preview-03-25',
|
|
88
|
+
]
|
|
89
|
+
"""Latest Gemini models."""
|
|
90
|
+
|
|
91
|
+
GoogleModelName = Union[str, LatestGoogleModelNames]
|
|
92
|
+
"""Possible Gemini model names.
|
|
93
|
+
|
|
94
|
+
Since Gemini supports a variety of date-stamped models, we explicitly list the latest models but
|
|
95
|
+
allow any name in the type hints.
|
|
96
|
+
See [the Gemini API docs](https://ai.google.dev/gemini-api/docs/models/gemini#model-variations) for a full list.
|
|
97
|
+
"""
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
class GoogleModelSettings(ModelSettings, total=False):
|
|
101
|
+
"""Settings used for a Gemini model request.
|
|
102
|
+
|
|
103
|
+
ALL FIELDS MUST BE `gemini_` PREFIXED SO YOU CAN MERGE THEM WITH OTHER MODELS.
|
|
104
|
+
"""
|
|
105
|
+
|
|
106
|
+
google_safety_settings: list[SafetySettingDict]
|
|
107
|
+
"""The safety settings to use for the model.
|
|
108
|
+
|
|
109
|
+
See <https://ai.google.dev/gemini-api/docs/safety-settings> for more information.
|
|
110
|
+
"""
|
|
111
|
+
|
|
112
|
+
google_thinking_config: ThinkingConfigDict
|
|
113
|
+
"""The thinking configuration to use for the model.
|
|
114
|
+
|
|
115
|
+
See <https://ai.google.dev/gemini-api/docs/thinking> for more information.
|
|
116
|
+
"""
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
@dataclass(init=False)
|
|
120
|
+
class GoogleModel(Model):
|
|
121
|
+
"""A model that uses Gemini via `generativelanguage.googleapis.com` API.
|
|
122
|
+
|
|
123
|
+
This is implemented from scratch rather than using a dedicated SDK, good API documentation is
|
|
124
|
+
available [here](https://ai.google.dev/api).
|
|
125
|
+
|
|
126
|
+
Apart from `__init__`, all methods are private or match those of the base class.
|
|
127
|
+
"""
|
|
128
|
+
|
|
129
|
+
client: genai.Client = field(repr=False)
|
|
130
|
+
|
|
131
|
+
_model_name: GoogleModelName = field(repr=False)
|
|
132
|
+
_provider: Provider[genai.Client] = field(repr=False)
|
|
133
|
+
_url: str | None = field(repr=False)
|
|
134
|
+
_system: str = field(default='google', repr=False)
|
|
135
|
+
|
|
136
|
+
def __init__(
|
|
137
|
+
self,
|
|
138
|
+
model_name: GoogleModelName,
|
|
139
|
+
*,
|
|
140
|
+
provider: Literal['google-gla', 'google-vertex'] | Provider[genai.Client] = 'google-gla',
|
|
141
|
+
):
|
|
142
|
+
"""Initialize a Gemini model.
|
|
143
|
+
|
|
144
|
+
Args:
|
|
145
|
+
model_name: The name of the model to use.
|
|
146
|
+
provider: The provider to use for authentication and API access. Can be either the string
|
|
147
|
+
'google-gla' or 'google-vertex' or an instance of `Provider[httpx.AsyncClient]`.
|
|
148
|
+
If not provided, a new provider will be created using the other parameters.
|
|
149
|
+
"""
|
|
150
|
+
self._model_name = model_name
|
|
151
|
+
|
|
152
|
+
if isinstance(provider, str):
|
|
153
|
+
provider = GoogleProvider(vertexai=provider == 'google-vertex') # pragma: lax no cover
|
|
154
|
+
|
|
155
|
+
self._provider = provider
|
|
156
|
+
self._system = provider.name
|
|
157
|
+
self.client = provider.client
|
|
158
|
+
|
|
159
|
+
@property
|
|
160
|
+
def base_url(self) -> str:
|
|
161
|
+
return self._provider.base_url
|
|
162
|
+
|
|
163
|
+
async def request(
|
|
164
|
+
self,
|
|
165
|
+
messages: list[ModelMessage],
|
|
166
|
+
model_settings: ModelSettings | None,
|
|
167
|
+
model_request_parameters: ModelRequestParameters,
|
|
168
|
+
) -> ModelResponse:
|
|
169
|
+
check_allow_model_requests()
|
|
170
|
+
model_settings = cast(GoogleModelSettings, model_settings or {})
|
|
171
|
+
response = await self._generate_content(messages, False, model_settings, model_request_parameters)
|
|
172
|
+
return self._process_response(response)
|
|
173
|
+
|
|
174
|
+
@asynccontextmanager
|
|
175
|
+
async def request_stream(
|
|
176
|
+
self,
|
|
177
|
+
messages: list[ModelMessage],
|
|
178
|
+
model_settings: ModelSettings | None,
|
|
179
|
+
model_request_parameters: ModelRequestParameters,
|
|
180
|
+
) -> AsyncIterator[StreamedResponse]:
|
|
181
|
+
check_allow_model_requests()
|
|
182
|
+
model_settings = cast(GoogleModelSettings, model_settings or {})
|
|
183
|
+
response = await self._generate_content(messages, True, model_settings, model_request_parameters)
|
|
184
|
+
yield await self._process_streamed_response(response) # type: ignore
|
|
185
|
+
|
|
186
|
+
def customize_request_parameters(self, model_request_parameters: ModelRequestParameters) -> ModelRequestParameters:
|
|
187
|
+
def _customize_tool_def(t: ToolDefinition):
|
|
188
|
+
return replace(t, parameters_json_schema=_GeminiJsonSchema(t.parameters_json_schema).walk())
|
|
189
|
+
|
|
190
|
+
return ModelRequestParameters(
|
|
191
|
+
function_tools=[_customize_tool_def(tool) for tool in model_request_parameters.function_tools],
|
|
192
|
+
allow_text_output=model_request_parameters.allow_text_output,
|
|
193
|
+
output_tools=[_customize_tool_def(tool) for tool in model_request_parameters.output_tools],
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
@property
|
|
197
|
+
def model_name(self) -> GoogleModelName:
|
|
198
|
+
"""The model name."""
|
|
199
|
+
return self._model_name
|
|
200
|
+
|
|
201
|
+
@property
|
|
202
|
+
def system(self) -> str:
|
|
203
|
+
"""The system / model provider."""
|
|
204
|
+
return self._system
|
|
205
|
+
|
|
206
|
+
def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[ToolDict] | None:
|
|
207
|
+
tools: list[ToolDict] = [
|
|
208
|
+
ToolDict(function_declarations=[_function_declaration_from_tool(t)])
|
|
209
|
+
for t in model_request_parameters.function_tools
|
|
210
|
+
]
|
|
211
|
+
if model_request_parameters.output_tools:
|
|
212
|
+
tools += [
|
|
213
|
+
ToolDict(function_declarations=[_function_declaration_from_tool(t)])
|
|
214
|
+
for t in model_request_parameters.output_tools
|
|
215
|
+
]
|
|
216
|
+
return tools or None
|
|
217
|
+
|
|
218
|
+
def _get_tool_config(
|
|
219
|
+
self, model_request_parameters: ModelRequestParameters, tools: list[ToolDict] | None
|
|
220
|
+
) -> ToolConfigDict | None:
|
|
221
|
+
if model_request_parameters.allow_text_output:
|
|
222
|
+
return None
|
|
223
|
+
elif tools:
|
|
224
|
+
names: list[str] = []
|
|
225
|
+
for tool in tools:
|
|
226
|
+
for function_declaration in tool.get('function_declarations') or []:
|
|
227
|
+
if name := function_declaration.get('name'): # pragma: no branch
|
|
228
|
+
names.append(name)
|
|
229
|
+
return _tool_config(names)
|
|
230
|
+
else:
|
|
231
|
+
return _tool_config([]) # pragma: no cover
|
|
232
|
+
|
|
233
|
+
@overload
|
|
234
|
+
async def _generate_content(
|
|
235
|
+
self,
|
|
236
|
+
messages: list[ModelMessage],
|
|
237
|
+
stream: Literal[False],
|
|
238
|
+
model_settings: GoogleModelSettings,
|
|
239
|
+
model_request_parameters: ModelRequestParameters,
|
|
240
|
+
) -> GenerateContentResponse: ...
|
|
241
|
+
|
|
242
|
+
@overload
|
|
243
|
+
async def _generate_content(
|
|
244
|
+
self,
|
|
245
|
+
messages: list[ModelMessage],
|
|
246
|
+
stream: Literal[True],
|
|
247
|
+
model_settings: GoogleModelSettings,
|
|
248
|
+
model_request_parameters: ModelRequestParameters,
|
|
249
|
+
) -> Awaitable[AsyncIterator[GenerateContentResponse]]: ...
|
|
250
|
+
|
|
251
|
+
async def _generate_content(
|
|
252
|
+
self,
|
|
253
|
+
messages: list[ModelMessage],
|
|
254
|
+
stream: bool,
|
|
255
|
+
model_settings: GoogleModelSettings,
|
|
256
|
+
model_request_parameters: ModelRequestParameters,
|
|
257
|
+
) -> GenerateContentResponse | Awaitable[AsyncIterator[GenerateContentResponse]]:
|
|
258
|
+
tools = self._get_tools(model_request_parameters)
|
|
259
|
+
tool_config = self._get_tool_config(model_request_parameters, tools)
|
|
260
|
+
system_instruction, contents = await self._map_messages(messages)
|
|
261
|
+
|
|
262
|
+
config = GenerateContentConfigDict(
|
|
263
|
+
http_options={'headers': {'Content-Type': 'application/json', 'User-Agent': get_user_agent()}},
|
|
264
|
+
system_instruction=system_instruction,
|
|
265
|
+
temperature=model_settings.get('temperature'),
|
|
266
|
+
top_p=model_settings.get('top_p'),
|
|
267
|
+
max_output_tokens=model_settings.get('max_tokens'),
|
|
268
|
+
presence_penalty=model_settings.get('presence_penalty'),
|
|
269
|
+
frequency_penalty=model_settings.get('frequency_penalty'),
|
|
270
|
+
safety_settings=model_settings.get('google_safety_settings'),
|
|
271
|
+
thinking_config=model_settings.get('google_thinking_config'),
|
|
272
|
+
tools=cast(ToolListUnionDict, tools),
|
|
273
|
+
tool_config=tool_config,
|
|
274
|
+
)
|
|
275
|
+
|
|
276
|
+
func = self.client.aio.models.generate_content_stream if stream else self.client.aio.models.generate_content
|
|
277
|
+
return await func(model=self._model_name, contents=contents, config=config) # type: ignore
|
|
278
|
+
|
|
279
|
+
def _process_response(self, response: GenerateContentResponse) -> ModelResponse:
|
|
280
|
+
if not response.candidates or len(response.candidates) != 1:
|
|
281
|
+
raise UnexpectedModelBehavior('Expected exactly one candidate in Gemini response') # pragma: no cover
|
|
282
|
+
if response.candidates[0].content is None or response.candidates[0].content.parts is None:
|
|
283
|
+
if response.candidates[0].finish_reason == 'SAFETY':
|
|
284
|
+
raise UnexpectedModelBehavior('Safety settings triggered', str(response))
|
|
285
|
+
else:
|
|
286
|
+
raise UnexpectedModelBehavior(
|
|
287
|
+
'Content field missing from Gemini response', str(response)
|
|
288
|
+
) # pragma: no cover
|
|
289
|
+
parts = response.candidates[0].content.parts or []
|
|
290
|
+
usage = _metadata_as_usage(response)
|
|
291
|
+
usage.requests = 1
|
|
292
|
+
return _process_response_from_parts(parts, response.model_version or self._model_name, usage)
|
|
293
|
+
|
|
294
|
+
async def _process_streamed_response(self, response: AsyncIterator[GenerateContentResponse]) -> StreamedResponse:
|
|
295
|
+
"""Process a streamed response, and prepare a streaming response to return."""
|
|
296
|
+
peekable_response = _utils.PeekableAsyncStream(response)
|
|
297
|
+
first_chunk = await peekable_response.peek()
|
|
298
|
+
if isinstance(first_chunk, _utils.Unset):
|
|
299
|
+
raise UnexpectedModelBehavior('Streamed response ended without content or tool calls') # pragma: no cover
|
|
300
|
+
|
|
301
|
+
return GeminiStreamedResponse(
|
|
302
|
+
_model_name=self._model_name,
|
|
303
|
+
_response=peekable_response,
|
|
304
|
+
_timestamp=first_chunk.create_time or _utils.now_utc(),
|
|
305
|
+
)
|
|
306
|
+
|
|
307
|
+
async def _map_messages(self, messages: list[ModelMessage]) -> tuple[ContentDict | None, list[ContentUnionDict]]:
|
|
308
|
+
contents: list[ContentUnionDict] = []
|
|
309
|
+
system_parts: list[PartDict] = []
|
|
310
|
+
|
|
311
|
+
for m in messages:
|
|
312
|
+
if isinstance(m, ModelRequest):
|
|
313
|
+
message_parts: list[PartDict] = []
|
|
314
|
+
|
|
315
|
+
for part in m.parts:
|
|
316
|
+
if isinstance(part, SystemPromptPart):
|
|
317
|
+
system_parts.append({'text': part.content})
|
|
318
|
+
elif isinstance(part, UserPromptPart):
|
|
319
|
+
message_parts.extend(await self._map_user_prompt(part))
|
|
320
|
+
elif isinstance(part, ToolReturnPart):
|
|
321
|
+
message_parts.append(
|
|
322
|
+
{
|
|
323
|
+
'function_response': {
|
|
324
|
+
'name': part.tool_name,
|
|
325
|
+
'response': part.model_response_object(),
|
|
326
|
+
'id': part.tool_call_id,
|
|
327
|
+
}
|
|
328
|
+
}
|
|
329
|
+
)
|
|
330
|
+
elif isinstance(part, RetryPromptPart):
|
|
331
|
+
if part.tool_name is None:
|
|
332
|
+
message_parts.append({'text': part.model_response()}) # pragma: no cover
|
|
333
|
+
else:
|
|
334
|
+
message_parts.append(
|
|
335
|
+
{
|
|
336
|
+
'function_response': {
|
|
337
|
+
'name': part.tool_name,
|
|
338
|
+
'response': {'call_error': part.model_response()},
|
|
339
|
+
'id': part.tool_call_id,
|
|
340
|
+
}
|
|
341
|
+
}
|
|
342
|
+
)
|
|
343
|
+
else:
|
|
344
|
+
assert_never(part)
|
|
345
|
+
|
|
346
|
+
if message_parts: # pragma: no branch
|
|
347
|
+
contents.append({'role': 'user', 'parts': message_parts})
|
|
348
|
+
elif isinstance(m, ModelResponse):
|
|
349
|
+
contents.append(_content_model_response(m))
|
|
350
|
+
else:
|
|
351
|
+
assert_never(m)
|
|
352
|
+
if instructions := self._get_instructions(messages):
|
|
353
|
+
system_parts.insert(0, {'text': instructions})
|
|
354
|
+
system_instruction = ContentDict(role='user', parts=system_parts) if system_parts else None
|
|
355
|
+
return system_instruction, contents
|
|
356
|
+
|
|
357
|
+
async def _map_user_prompt(self, part: UserPromptPart) -> list[PartDict]:
|
|
358
|
+
if isinstance(part.content, str):
|
|
359
|
+
return [{'text': part.content}]
|
|
360
|
+
else:
|
|
361
|
+
content: list[PartDict] = []
|
|
362
|
+
for item in part.content:
|
|
363
|
+
if isinstance(item, str):
|
|
364
|
+
content.append({'text': item})
|
|
365
|
+
elif isinstance(item, BinaryContent):
|
|
366
|
+
# NOTE: The type from Google GenAI is incorrect, it should be `str`, not `bytes`.
|
|
367
|
+
base64_encoded = base64.b64encode(item.data).decode('utf-8')
|
|
368
|
+
content.append({'inline_data': {'data': base64_encoded, 'mime_type': item.media_type}}) # type: ignore
|
|
369
|
+
elif isinstance(item, (AudioUrl, ImageUrl, DocumentUrl, VideoUrl)):
|
|
370
|
+
client = cached_async_http_client()
|
|
371
|
+
response = await client.get(item.url, follow_redirects=True)
|
|
372
|
+
response.raise_for_status()
|
|
373
|
+
# NOTE: The type from Google GenAI is incorrect, it should be `str`, not `bytes`.
|
|
374
|
+
base64_encoded = base64.b64encode(response.content).decode('utf-8')
|
|
375
|
+
content.append({'inline_data': {'data': base64_encoded, 'mime_type': item.media_type}}) # type: ignore
|
|
376
|
+
else:
|
|
377
|
+
assert_never(item)
|
|
378
|
+
return content
|
|
379
|
+
|
|
380
|
+
|
|
381
|
+
@dataclass
|
|
382
|
+
class GeminiStreamedResponse(StreamedResponse):
|
|
383
|
+
"""Implementation of `StreamedResponse` for the Gemini model."""
|
|
384
|
+
|
|
385
|
+
_model_name: GoogleModelName
|
|
386
|
+
_response: AsyncIterator[GenerateContentResponse]
|
|
387
|
+
_timestamp: datetime
|
|
388
|
+
|
|
389
|
+
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
|
|
390
|
+
async for chunk in self._response:
|
|
391
|
+
self._usage += _metadata_as_usage(chunk)
|
|
392
|
+
|
|
393
|
+
assert chunk.candidates is not None
|
|
394
|
+
candidate = chunk.candidates[0]
|
|
395
|
+
if candidate.content is None:
|
|
396
|
+
raise UnexpectedModelBehavior('Streamed response has no content field') # pragma: no cover
|
|
397
|
+
assert candidate.content.parts is not None
|
|
398
|
+
for part in candidate.content.parts:
|
|
399
|
+
if part.text:
|
|
400
|
+
yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=part.text)
|
|
401
|
+
elif part.function_call:
|
|
402
|
+
maybe_event = self._parts_manager.handle_tool_call_delta(
|
|
403
|
+
vendor_part_id=uuid4(),
|
|
404
|
+
tool_name=part.function_call.name,
|
|
405
|
+
args=part.function_call.args,
|
|
406
|
+
tool_call_id=part.function_call.id,
|
|
407
|
+
)
|
|
408
|
+
if maybe_event is not None: # pragma: no branch
|
|
409
|
+
yield maybe_event
|
|
410
|
+
else:
|
|
411
|
+
assert part.function_response is not None, f'Unexpected part: {part}' # pragma: no cover
|
|
412
|
+
|
|
413
|
+
@property
|
|
414
|
+
def model_name(self) -> GoogleModelName:
|
|
415
|
+
"""Get the model name of the response."""
|
|
416
|
+
return self._model_name
|
|
417
|
+
|
|
418
|
+
@property
|
|
419
|
+
def timestamp(self) -> datetime:
|
|
420
|
+
"""Get the timestamp of the response."""
|
|
421
|
+
return self._timestamp
|
|
422
|
+
|
|
423
|
+
|
|
424
|
+
def _content_model_response(m: ModelResponse) -> ContentDict:
|
|
425
|
+
parts: list[PartDict] = []
|
|
426
|
+
for item in m.parts:
|
|
427
|
+
if isinstance(item, ToolCallPart):
|
|
428
|
+
function_call = FunctionCallDict(name=item.tool_name, args=item.args_as_dict(), id=item.tool_call_id)
|
|
429
|
+
parts.append({'function_call': function_call})
|
|
430
|
+
elif isinstance(item, TextPart):
|
|
431
|
+
if item.content: # pragma: no branch
|
|
432
|
+
parts.append({'text': item.content})
|
|
433
|
+
else:
|
|
434
|
+
assert_never(item)
|
|
435
|
+
return ContentDict(role='model', parts=parts)
|
|
436
|
+
|
|
437
|
+
|
|
438
|
+
def _process_response_from_parts(parts: list[Part], model_name: GoogleModelName, usage: usage.Usage) -> ModelResponse:
|
|
439
|
+
items: list[ModelResponsePart] = []
|
|
440
|
+
for part in parts:
|
|
441
|
+
if part.text:
|
|
442
|
+
items.append(TextPart(content=part.text))
|
|
443
|
+
elif part.function_call:
|
|
444
|
+
assert part.function_call.name is not None
|
|
445
|
+
tool_call_part = ToolCallPart(tool_name=part.function_call.name, args=part.function_call.args or {})
|
|
446
|
+
if part.function_call.id is not None:
|
|
447
|
+
tool_call_part.tool_call_id = part.function_call.id # pragma: no cover
|
|
448
|
+
items.append(tool_call_part)
|
|
449
|
+
elif part.function_response: # pragma: no cover
|
|
450
|
+
raise UnexpectedModelBehavior(
|
|
451
|
+
f'Unsupported response from Gemini, expected all parts to be function calls or text, got: {part!r}'
|
|
452
|
+
)
|
|
453
|
+
return ModelResponse(parts=items, model_name=model_name, usage=usage)
|
|
454
|
+
|
|
455
|
+
|
|
456
|
+
def _function_declaration_from_tool(tool: ToolDefinition) -> FunctionDeclarationDict:
|
|
457
|
+
json_schema = tool.parameters_json_schema
|
|
458
|
+
f = FunctionDeclarationDict(name=tool.name, description=tool.description)
|
|
459
|
+
if json_schema.get('properties'): # pragma: no branch
|
|
460
|
+
f['parameters'] = json_schema # type: ignore
|
|
461
|
+
return f
|
|
462
|
+
|
|
463
|
+
|
|
464
|
+
def _tool_config(function_names: list[str]) -> ToolConfigDict:
|
|
465
|
+
mode = FunctionCallingConfigMode.ANY
|
|
466
|
+
function_calling_config = FunctionCallingConfigDict(mode=mode, allowed_function_names=function_names)
|
|
467
|
+
return ToolConfigDict(function_calling_config=function_calling_config)
|
|
468
|
+
|
|
469
|
+
|
|
470
|
+
def _metadata_as_usage(response: GenerateContentResponse) -> usage.Usage:
|
|
471
|
+
metadata = response.usage_metadata
|
|
472
|
+
if metadata is None:
|
|
473
|
+
return usage.Usage() # pragma: no cover
|
|
474
|
+
# TODO(Marcelo): We exclude the `prompt_tokens_details` and `candidate_token_details` fields because on
|
|
475
|
+
# `usage.Usage.incr``, it will try to sum non-integer values with integers, which will fail. We should probably
|
|
476
|
+
# handle this in the `Usage` class.
|
|
477
|
+
details = metadata.model_dump(
|
|
478
|
+
exclude={'prompt_tokens_details', 'candidates_tokens_details', 'traffic_type'},
|
|
479
|
+
exclude_defaults=True,
|
|
480
|
+
)
|
|
481
|
+
return usage.Usage(
|
|
482
|
+
request_tokens=details.pop('prompt_token_count', 0),
|
|
483
|
+
response_tokens=details.pop('candidates_token_count', 0),
|
|
484
|
+
total_tokens=details.pop('total_token_count', 0),
|
|
485
|
+
details=details,
|
|
486
|
+
)
|
|
487
|
+
|
|
488
|
+
|
|
489
|
+
class _GeminiJsonSchema(WalkJsonSchema):
|
|
490
|
+
"""Transforms the JSON Schema from Pydantic to be suitable for Gemini.
|
|
491
|
+
|
|
492
|
+
Gemini which [supports](https://ai.google.dev/gemini-api/docs/function-calling#function_declarations)
|
|
493
|
+
a subset of OpenAPI v3.0.3.
|
|
494
|
+
|
|
495
|
+
Specifically:
|
|
496
|
+
* gemini doesn't allow the `title` keyword to be set
|
|
497
|
+
* gemini doesn't allow `$defs` — we need to inline the definitions where possible
|
|
498
|
+
"""
|
|
499
|
+
|
|
500
|
+
def __init__(self, schema: JsonSchema):
|
|
501
|
+
super().__init__(schema, prefer_inlined_defs=True, simplify_nullable_unions=True)
|
|
502
|
+
|
|
503
|
+
def transform(self, schema: JsonSchema) -> JsonSchema:
|
|
504
|
+
# Note: we need to remove `additionalProperties: False` since it is currently mishandled by Gemini
|
|
505
|
+
additional_properties = schema.pop(
|
|
506
|
+
'additionalProperties', None
|
|
507
|
+
) # don't pop yet so it's included in the warning
|
|
508
|
+
if additional_properties: # pragma: no cover
|
|
509
|
+
original_schema = {**schema, 'additionalProperties': additional_properties}
|
|
510
|
+
warnings.warn(
|
|
511
|
+
'`additionalProperties` is not supported by Gemini; it will be removed from the tool JSON schema.'
|
|
512
|
+
f' Full schema: {self.schema}\n\n'
|
|
513
|
+
f'Source of additionalProperties within the full schema: {original_schema}\n\n'
|
|
514
|
+
'If this came from a field with a type like `dict[str, MyType]`, that field will always be empty.\n\n'
|
|
515
|
+
"If Google's APIs are updated to support this properly, please create an issue on the PydanticAI GitHub"
|
|
516
|
+
' and we will fix this behavior.',
|
|
517
|
+
UserWarning,
|
|
518
|
+
)
|
|
519
|
+
|
|
520
|
+
schema.pop('title', None)
|
|
521
|
+
schema.pop('default', None)
|
|
522
|
+
schema.pop('$schema', None)
|
|
523
|
+
if (const := schema.pop('const', None)) is not None: # pragma: no cover
|
|
524
|
+
# Gemini doesn't support const, but it does support enum with a single value
|
|
525
|
+
schema['enum'] = [const]
|
|
526
|
+
schema.pop('discriminator', None)
|
|
527
|
+
schema.pop('examples', None)
|
|
528
|
+
|
|
529
|
+
# TODO: Should we use the trick from pydantic_ai.models.openai._OpenAIJsonSchema
|
|
530
|
+
# where we add notes about these properties to the field description?
|
|
531
|
+
schema.pop('exclusiveMaximum', None)
|
|
532
|
+
schema.pop('exclusiveMinimum', None)
|
|
533
|
+
|
|
534
|
+
type_ = schema.get('type')
|
|
535
|
+
if 'oneOf' in schema and 'type' not in schema: # pragma: no cover
|
|
536
|
+
# This gets hit when we have a discriminated union
|
|
537
|
+
# Gemini returns an API error in this case even though it says in its error message it shouldn't...
|
|
538
|
+
# Changing the oneOf to an anyOf prevents the API error and I think is functionally equivalent
|
|
539
|
+
schema['anyOf'] = schema.pop('oneOf')
|
|
540
|
+
|
|
541
|
+
if type_ == 'string' and (fmt := schema.pop('format', None)):
|
|
542
|
+
description = schema.get('description')
|
|
543
|
+
if description:
|
|
544
|
+
schema['description'] = f'{description} (format: {fmt})'
|
|
545
|
+
else:
|
|
546
|
+
schema['description'] = f'Format: {fmt}'
|
|
547
|
+
|
|
548
|
+
if '$ref' in schema:
|
|
549
|
+
raise UserError( # pragma: no cover
|
|
550
|
+
f'Recursive `$ref`s in JSON Schema are not supported by Gemini: {schema["$ref"]}'
|
|
551
|
+
)
|
|
552
|
+
|
|
553
|
+
if 'prefixItems' in schema: # pragma: lax no cover
|
|
554
|
+
# prefixItems is not currently supported in Gemini, so we convert it to items for best compatibility
|
|
555
|
+
prefix_items = schema.pop('prefixItems')
|
|
556
|
+
items = schema.get('items')
|
|
557
|
+
unique_items = [items] if items is not None else []
|
|
558
|
+
for item in prefix_items:
|
|
559
|
+
if item not in unique_items:
|
|
560
|
+
unique_items.append(item)
|
|
561
|
+
if len(unique_items) > 1: # pragma: no cover
|
|
562
|
+
schema['items'] = {'anyOf': unique_items}
|
|
563
|
+
elif len(unique_items) == 1:
|
|
564
|
+
schema['items'] = unique_items[0]
|
|
565
|
+
schema.setdefault('minItems', len(prefix_items))
|
|
566
|
+
if items is None:
|
|
567
|
+
schema.setdefault('maxItems', len(prefix_items))
|
|
568
|
+
|
|
569
|
+
return schema
|
pydantic_ai/models/groq.py
CHANGED
|
@@ -227,7 +227,7 @@ class GroqModel(Model):
|
|
|
227
227
|
except APIStatusError as e:
|
|
228
228
|
if (status_code := e.status_code) >= 400:
|
|
229
229
|
raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e
|
|
230
|
-
raise
|
|
230
|
+
raise # pragma: lax no cover
|
|
231
231
|
|
|
232
232
|
def _process_response(self, response: chat.ChatCompletion) -> ModelResponse:
|
|
233
233
|
"""Process a non-streamed response, and prepare a message to return."""
|
|
@@ -239,14 +239,18 @@ class GroqModel(Model):
|
|
|
239
239
|
if choice.message.tool_calls is not None:
|
|
240
240
|
for c in choice.message.tool_calls:
|
|
241
241
|
items.append(ToolCallPart(tool_name=c.function.name, args=c.function.arguments, tool_call_id=c.id))
|
|
242
|
-
return ModelResponse(
|
|
242
|
+
return ModelResponse(
|
|
243
|
+
items, usage=_map_usage(response), model_name=response.model, timestamp=timestamp, vendor_id=response.id
|
|
244
|
+
)
|
|
243
245
|
|
|
244
246
|
async def _process_streamed_response(self, response: AsyncStream[chat.ChatCompletionChunk]) -> GroqStreamedResponse:
|
|
245
247
|
"""Process a streamed response, and prepare a streaming response to return."""
|
|
246
248
|
peekable_response = _utils.PeekableAsyncStream(response)
|
|
247
249
|
first_chunk = await peekable_response.peek()
|
|
248
250
|
if isinstance(first_chunk, _utils.Unset):
|
|
249
|
-
raise UnexpectedModelBehavior(
|
|
251
|
+
raise UnexpectedModelBehavior( # pragma: no cover
|
|
252
|
+
'Streamed response ended without content or tool calls'
|
|
253
|
+
)
|
|
250
254
|
|
|
251
255
|
return GroqStreamedResponse(
|
|
252
256
|
_response=peekable_response,
|
|
@@ -322,9 +326,11 @@ class GroqModel(Model):
|
|
|
322
326
|
tool_call_id=_guard_tool_call_id(t=part),
|
|
323
327
|
content=part.model_response_str(),
|
|
324
328
|
)
|
|
325
|
-
elif isinstance(part, RetryPromptPart):
|
|
329
|
+
elif isinstance(part, RetryPromptPart): # pragma: no branch
|
|
326
330
|
if part.tool_name is None:
|
|
327
|
-
yield chat.ChatCompletionUserMessageParam(
|
|
331
|
+
yield chat.ChatCompletionUserMessageParam( # pragma: no cover
|
|
332
|
+
role='user', content=part.model_response()
|
|
333
|
+
)
|
|
328
334
|
else:
|
|
329
335
|
yield chat.ChatCompletionToolMessageParam(
|
|
330
336
|
role='tool',
|
|
@@ -409,7 +415,7 @@ def _map_usage(completion: chat.ChatCompletionChunk | chat.ChatCompletion) -> us
|
|
|
409
415
|
if isinstance(completion, chat.ChatCompletion):
|
|
410
416
|
response_usage = completion.usage
|
|
411
417
|
elif completion.x_groq is not None:
|
|
412
|
-
response_usage = completion.x_groq.usage
|
|
418
|
+
response_usage = completion.x_groq.usage # pragma: no cover
|
|
413
419
|
|
|
414
420
|
if response_usage is None:
|
|
415
421
|
return usage.Usage()
|