pydantic-ai-slim 0.0.12__py3-none-any.whl → 0.0.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/__init__.py +12 -2
- pydantic_ai/_pydantic.py +7 -25
- pydantic_ai/_result.py +33 -18
- pydantic_ai/_system_prompt.py +2 -2
- pydantic_ai/_utils.py +9 -2
- pydantic_ai/agent.py +366 -171
- pydantic_ai/exceptions.py +20 -2
- pydantic_ai/messages.py +111 -50
- pydantic_ai/models/__init__.py +39 -14
- pydantic_ai/models/anthropic.py +344 -0
- pydantic_ai/models/function.py +62 -40
- pydantic_ai/models/gemini.py +164 -124
- pydantic_ai/models/groq.py +112 -94
- pydantic_ai/models/mistral.py +668 -0
- pydantic_ai/models/ollama.py +1 -1
- pydantic_ai/models/openai.py +120 -96
- pydantic_ai/models/test.py +78 -61
- pydantic_ai/models/vertexai.py +7 -3
- pydantic_ai/result.py +96 -68
- pydantic_ai/settings.py +137 -0
- pydantic_ai/tools.py +46 -26
- {pydantic_ai_slim-0.0.12.dist-info → pydantic_ai_slim-0.0.14.dist-info}/METADATA +8 -3
- pydantic_ai_slim-0.0.14.dist-info/RECORD +26 -0
- {pydantic_ai_slim-0.0.12.dist-info → pydantic_ai_slim-0.0.14.dist-info}/WHEEL +1 -1
- pydantic_ai_slim-0.0.12.dist-info/RECORD +0 -23
pydantic_ai/models/gemini.py
CHANGED
|
@@ -2,29 +2,32 @@ from __future__ import annotations as _annotations
|
|
|
2
2
|
|
|
3
3
|
import os
|
|
4
4
|
import re
|
|
5
|
-
from collections.abc import AsyncIterator, Iterable
|
|
5
|
+
from collections.abc import AsyncIterator, Iterable, Sequence
|
|
6
6
|
from contextlib import asynccontextmanager
|
|
7
7
|
from copy import deepcopy
|
|
8
8
|
from dataclasses import dataclass, field
|
|
9
9
|
from datetime import datetime
|
|
10
10
|
from typing import Annotated, Any, Literal, Protocol, Union
|
|
11
11
|
|
|
12
|
+
import pydantic
|
|
12
13
|
import pydantic_core
|
|
13
|
-
from httpx import AsyncClient as AsyncHTTPClient, Response as HTTPResponse
|
|
14
|
-
from pydantic import Discriminator, Field, Tag
|
|
14
|
+
from httpx import USE_CLIENT_DEFAULT, AsyncClient as AsyncHTTPClient, Response as HTTPResponse
|
|
15
15
|
from typing_extensions import NotRequired, TypedDict, TypeGuard, assert_never
|
|
16
16
|
|
|
17
|
-
from .. import UnexpectedModelBehavior,
|
|
17
|
+
from .. import UnexpectedModelBehavior, _utils, exceptions, result
|
|
18
18
|
from ..messages import (
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
19
|
+
ModelMessage,
|
|
20
|
+
ModelRequest,
|
|
21
|
+
ModelResponse,
|
|
22
|
+
ModelResponsePart,
|
|
23
|
+
RetryPromptPart,
|
|
24
|
+
SystemPromptPart,
|
|
25
|
+
TextPart,
|
|
26
|
+
ToolCallPart,
|
|
27
|
+
ToolReturnPart,
|
|
28
|
+
UserPromptPart,
|
|
27
29
|
)
|
|
30
|
+
from ..settings import ModelSettings
|
|
28
31
|
from ..tools import ToolDefinition
|
|
29
32
|
from . import (
|
|
30
33
|
AgentModel,
|
|
@@ -37,7 +40,9 @@ from . import (
|
|
|
37
40
|
get_user_agent,
|
|
38
41
|
)
|
|
39
42
|
|
|
40
|
-
GeminiModelName = Literal[
|
|
43
|
+
GeminiModelName = Literal[
|
|
44
|
+
'gemini-1.5-flash', 'gemini-1.5-flash-8b', 'gemini-1.5-pro', 'gemini-1.0-pro', 'gemini-2.0-flash-exp'
|
|
45
|
+
]
|
|
41
46
|
"""Named Gemini models.
|
|
42
47
|
|
|
43
48
|
See [the Gemini API docs](https://ai.google.dev/gemini-api/docs/models/gemini#model-variations) for a full list.
|
|
@@ -164,26 +169,25 @@ class GeminiAgentModel(AgentModel):
|
|
|
164
169
|
self.tool_config = tool_config
|
|
165
170
|
self.url = url
|
|
166
171
|
|
|
167
|
-
async def request(
|
|
168
|
-
|
|
172
|
+
async def request(
|
|
173
|
+
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
174
|
+
) -> tuple[ModelResponse, result.Usage]:
|
|
175
|
+
async with self._make_request(messages, False, model_settings) as http_response:
|
|
169
176
|
response = _gemini_response_ta.validate_json(await http_response.aread())
|
|
170
|
-
return self._process_response(response),
|
|
177
|
+
return self._process_response(response), _metadata_as_usage(response)
|
|
171
178
|
|
|
172
179
|
@asynccontextmanager
|
|
173
|
-
async def request_stream(
|
|
174
|
-
|
|
180
|
+
async def request_stream(
|
|
181
|
+
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
182
|
+
) -> AsyncIterator[EitherStreamedResponse]:
|
|
183
|
+
async with self._make_request(messages, True, model_settings) as http_response:
|
|
175
184
|
yield await self._process_streamed_response(http_response)
|
|
176
185
|
|
|
177
186
|
@asynccontextmanager
|
|
178
|
-
async def _make_request(
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
either_content = self._message_to_gemini(m)
|
|
183
|
-
if left := either_content.left:
|
|
184
|
-
sys_prompt_parts.append(left.value)
|
|
185
|
-
else:
|
|
186
|
-
contents.append(either_content.right)
|
|
187
|
+
async def _make_request(
|
|
188
|
+
self, messages: list[ModelMessage], streamed: bool, model_settings: ModelSettings | None
|
|
189
|
+
) -> AsyncIterator[HTTPResponse]:
|
|
190
|
+
sys_prompt_parts, contents = self._message_to_gemini_content(messages)
|
|
187
191
|
|
|
188
192
|
request_data = _GeminiRequest(contents=contents)
|
|
189
193
|
if sys_prompt_parts:
|
|
@@ -193,6 +197,17 @@ class GeminiAgentModel(AgentModel):
|
|
|
193
197
|
if self.tool_config is not None:
|
|
194
198
|
request_data['tool_config'] = self.tool_config
|
|
195
199
|
|
|
200
|
+
generation_config: _GeminiGenerationConfig = {}
|
|
201
|
+
if model_settings:
|
|
202
|
+
if (max_tokens := model_settings.get('max_tokens')) is not None:
|
|
203
|
+
generation_config['max_output_tokens'] = max_tokens
|
|
204
|
+
if (temperature := model_settings.get('temperature')) is not None:
|
|
205
|
+
generation_config['temperature'] = temperature
|
|
206
|
+
if (top_p := model_settings.get('top_p')) is not None:
|
|
207
|
+
generation_config['top_p'] = top_p
|
|
208
|
+
if generation_config:
|
|
209
|
+
request_data['generation_config'] = generation_config
|
|
210
|
+
|
|
196
211
|
url = self.url + ('streamGenerateContent' if streamed else 'generateContent')
|
|
197
212
|
|
|
198
213
|
headers = {
|
|
@@ -203,19 +218,24 @@ class GeminiAgentModel(AgentModel):
|
|
|
203
218
|
|
|
204
219
|
request_json = _gemini_request_ta.dump_json(request_data, by_alias=True)
|
|
205
220
|
|
|
206
|
-
async with self.http_client.stream(
|
|
221
|
+
async with self.http_client.stream(
|
|
222
|
+
'POST',
|
|
223
|
+
url,
|
|
224
|
+
content=request_json,
|
|
225
|
+
headers=headers,
|
|
226
|
+
timeout=(model_settings or {}).get('timeout', USE_CLIENT_DEFAULT),
|
|
227
|
+
) as r:
|
|
207
228
|
if r.status_code != 200:
|
|
208
229
|
await r.aread()
|
|
209
230
|
raise exceptions.UnexpectedModelBehavior(f'Unexpected response from gemini {r.status_code}', r.text)
|
|
210
231
|
yield r
|
|
211
232
|
|
|
212
233
|
@staticmethod
|
|
213
|
-
def _process_response(response: _GeminiResponse) ->
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
return ModelTextResponse(content=''.join(part['text'] for part in either.right))
|
|
234
|
+
def _process_response(response: _GeminiResponse) -> ModelResponse:
|
|
235
|
+
if len(response['candidates']) != 1:
|
|
236
|
+
raise UnexpectedModelBehavior('Expected exactly one candidate in Gemini response')
|
|
237
|
+
parts = response['candidates'][0]['content']['parts']
|
|
238
|
+
return _process_response_from_parts(parts)
|
|
219
239
|
|
|
220
240
|
@staticmethod
|
|
221
241
|
async def _process_streamed_response(http_response: HTTPResponse) -> EitherStreamedResponse:
|
|
@@ -239,34 +259,37 @@ class GeminiAgentModel(AgentModel):
|
|
|
239
259
|
if start_response is None:
|
|
240
260
|
raise UnexpectedModelBehavior('Streamed response ended without content or tool calls')
|
|
241
261
|
|
|
262
|
+
# TODO: Update this once we rework stream responses to be more flexible
|
|
242
263
|
if _extract_response_parts(start_response).is_left():
|
|
243
264
|
return GeminiStreamStructuredResponse(_content=content, _stream=aiter_bytes)
|
|
244
265
|
else:
|
|
245
266
|
return GeminiStreamTextResponse(_json_content=content, _stream=aiter_bytes)
|
|
246
267
|
|
|
247
|
-
@
|
|
248
|
-
def
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
268
|
+
@classmethod
|
|
269
|
+
def _message_to_gemini_content(
|
|
270
|
+
cls, messages: list[ModelMessage]
|
|
271
|
+
) -> tuple[list[_GeminiTextPart], list[_GeminiContent]]:
|
|
272
|
+
sys_prompt_parts: list[_GeminiTextPart] = []
|
|
273
|
+
contents: list[_GeminiContent] = []
|
|
274
|
+
for m in messages:
|
|
275
|
+
if isinstance(m, ModelRequest):
|
|
276
|
+
for part in m.parts:
|
|
277
|
+
if isinstance(part, SystemPromptPart):
|
|
278
|
+
sys_prompt_parts.append(_GeminiTextPart(text=part.content))
|
|
279
|
+
elif isinstance(part, UserPromptPart):
|
|
280
|
+
contents.append(_content_user_prompt(part))
|
|
281
|
+
elif isinstance(part, ToolReturnPart):
|
|
282
|
+
contents.append(_content_tool_return(part))
|
|
283
|
+
elif isinstance(part, RetryPromptPart):
|
|
284
|
+
contents.append(_content_retry_prompt(part))
|
|
285
|
+
else:
|
|
286
|
+
assert_never(part)
|
|
287
|
+
elif isinstance(m, ModelResponse):
|
|
288
|
+
contents.append(_content_model_response(m))
|
|
289
|
+
else:
|
|
290
|
+
assert_never(m)
|
|
291
|
+
|
|
292
|
+
return sys_prompt_parts, contents
|
|
270
293
|
|
|
271
294
|
|
|
272
295
|
@dataclass
|
|
@@ -277,7 +300,7 @@ class GeminiStreamTextResponse(StreamTextResponse):
|
|
|
277
300
|
_stream: AsyncIterator[bytes]
|
|
278
301
|
_position: int = 0
|
|
279
302
|
_timestamp: datetime = field(default_factory=_utils.now_utc, init=False)
|
|
280
|
-
|
|
303
|
+
_usage: result.Usage = field(default_factory=result.Usage, init=False)
|
|
281
304
|
|
|
282
305
|
async def __anext__(self) -> None:
|
|
283
306
|
chunk = await self._stream.__anext__()
|
|
@@ -297,7 +320,7 @@ class GeminiStreamTextResponse(StreamTextResponse):
|
|
|
297
320
|
new_items, experimental_allow_partial='trailing-strings'
|
|
298
321
|
)
|
|
299
322
|
for r in new_responses:
|
|
300
|
-
self.
|
|
323
|
+
self._usage += _metadata_as_usage(r)
|
|
301
324
|
parts = r['candidates'][0]['content']['parts']
|
|
302
325
|
if _all_text_parts(parts):
|
|
303
326
|
for part in parts:
|
|
@@ -307,8 +330,8 @@ class GeminiStreamTextResponse(StreamTextResponse):
|
|
|
307
330
|
'Streamed response with unexpected content, expected all parts to be text'
|
|
308
331
|
)
|
|
309
332
|
|
|
310
|
-
def
|
|
311
|
-
return self.
|
|
333
|
+
def usage(self) -> result.Usage:
|
|
334
|
+
return self._usage
|
|
312
335
|
|
|
313
336
|
def timestamp(self) -> datetime:
|
|
314
337
|
return self._timestamp
|
|
@@ -321,14 +344,14 @@ class GeminiStreamStructuredResponse(StreamStructuredResponse):
|
|
|
321
344
|
_content: bytearray
|
|
322
345
|
_stream: AsyncIterator[bytes]
|
|
323
346
|
_timestamp: datetime = field(default_factory=_utils.now_utc, init=False)
|
|
324
|
-
|
|
347
|
+
_usage: result.Usage = field(default_factory=result.Usage, init=False)
|
|
325
348
|
|
|
326
349
|
async def __anext__(self) -> None:
|
|
327
350
|
chunk = await self._stream.__anext__()
|
|
328
351
|
self._content.extend(chunk)
|
|
329
352
|
|
|
330
|
-
def get(self, *, final: bool = False) ->
|
|
331
|
-
"""Get the `
|
|
353
|
+
def get(self, *, final: bool = False) -> ModelResponse:
|
|
354
|
+
"""Get the `ModelResponse` at this point.
|
|
332
355
|
|
|
333
356
|
NOTE: It's not clear how the stream of responses should be combined because Gemini seems to always
|
|
334
357
|
reply with a single response, when returning a structured data.
|
|
@@ -340,23 +363,16 @@ class GeminiStreamStructuredResponse(StreamStructuredResponse):
|
|
|
340
363
|
self._content,
|
|
341
364
|
experimental_allow_partial='off' if final else 'trailing-strings',
|
|
342
365
|
)
|
|
343
|
-
combined_parts: list[
|
|
344
|
-
self.
|
|
366
|
+
combined_parts: list[_GeminiPartUnion] = []
|
|
367
|
+
self._usage = result.Usage()
|
|
345
368
|
for r in responses:
|
|
346
|
-
self.
|
|
369
|
+
self._usage += _metadata_as_usage(r)
|
|
347
370
|
candidate = r['candidates'][0]
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
combined_parts.extend(parts)
|
|
351
|
-
elif not candidate.get('finish_reason'):
|
|
352
|
-
# you can get an empty text part along with the finish_reason, so we ignore that case
|
|
353
|
-
raise UnexpectedModelBehavior(
|
|
354
|
-
'Streamed response with unexpected content, expected all parts to be function calls'
|
|
355
|
-
)
|
|
356
|
-
return _structured_response_from_parts(combined_parts, timestamp=self._timestamp)
|
|
371
|
+
combined_parts.extend(candidate['content']['parts'])
|
|
372
|
+
return _process_response_from_parts(combined_parts, timestamp=self._timestamp)
|
|
357
373
|
|
|
358
|
-
def
|
|
359
|
-
return self.
|
|
374
|
+
def usage(self) -> result.Usage:
|
|
375
|
+
return self._usage
|
|
360
376
|
|
|
361
377
|
def timestamp(self) -> datetime:
|
|
362
378
|
return self._timestamp
|
|
@@ -367,6 +383,7 @@ class GeminiStreamStructuredResponse(StreamStructuredResponse):
|
|
|
367
383
|
# TypeAdapters take care of validation and serialization
|
|
368
384
|
|
|
369
385
|
|
|
386
|
+
@pydantic.with_config(pydantic.ConfigDict(defer_build=True))
|
|
370
387
|
class _GeminiRequest(TypedDict):
|
|
371
388
|
"""Schema for an API request to the Gemini API.
|
|
372
389
|
|
|
@@ -382,32 +399,37 @@ class _GeminiRequest(TypedDict):
|
|
|
382
399
|
Developer generated system instructions, see
|
|
383
400
|
<https://ai.google.dev/gemini-api/docs/system-instructions?lang=rest>
|
|
384
401
|
"""
|
|
402
|
+
generation_config: NotRequired[_GeminiGenerationConfig]
|
|
385
403
|
|
|
386
404
|
|
|
387
|
-
class
|
|
388
|
-
|
|
389
|
-
parts: list[_GeminiPartUnion]
|
|
405
|
+
class _GeminiGenerationConfig(TypedDict, total=False):
|
|
406
|
+
"""Schema for an API request to the Gemini API.
|
|
390
407
|
|
|
408
|
+
Note there are many additional fields available that have not been added yet.
|
|
391
409
|
|
|
392
|
-
|
|
393
|
-
|
|
410
|
+
See <https://ai.google.dev/api/generate-content#generationconfig> for API docs.
|
|
411
|
+
"""
|
|
412
|
+
|
|
413
|
+
max_output_tokens: int
|
|
414
|
+
temperature: float
|
|
415
|
+
top_p: float
|
|
394
416
|
|
|
395
417
|
|
|
396
|
-
|
|
397
|
-
|
|
418
|
+
class _GeminiContent(TypedDict):
|
|
419
|
+
role: Literal['user', 'model']
|
|
420
|
+
parts: list[_GeminiPartUnion]
|
|
398
421
|
|
|
399
422
|
|
|
400
|
-
def
|
|
401
|
-
|
|
402
|
-
return _GeminiContent(role='model', parts=parts)
|
|
423
|
+
def _content_user_prompt(m: UserPromptPart) -> _GeminiContent:
|
|
424
|
+
return _GeminiContent(role='user', parts=[_GeminiTextPart(text=m.content)])
|
|
403
425
|
|
|
404
426
|
|
|
405
|
-
def
|
|
427
|
+
def _content_tool_return(m: ToolReturnPart) -> _GeminiContent:
|
|
406
428
|
f_response = _response_part_from_response(m.tool_name, m.model_response_object())
|
|
407
429
|
return _GeminiContent(role='user', parts=[f_response])
|
|
408
430
|
|
|
409
431
|
|
|
410
|
-
def
|
|
432
|
+
def _content_retry_prompt(m: RetryPromptPart) -> _GeminiContent:
|
|
411
433
|
if m.tool_name is None:
|
|
412
434
|
part = _GeminiTextPart(text=m.model_response())
|
|
413
435
|
else:
|
|
@@ -416,26 +438,42 @@ def _content_function_retry(m: RetryPrompt) -> _GeminiContent:
|
|
|
416
438
|
return _GeminiContent(role='user', parts=[part])
|
|
417
439
|
|
|
418
440
|
|
|
441
|
+
def _content_model_response(m: ModelResponse) -> _GeminiContent:
|
|
442
|
+
parts: list[_GeminiPartUnion] = []
|
|
443
|
+
for item in m.parts:
|
|
444
|
+
if isinstance(item, ToolCallPart):
|
|
445
|
+
parts.append(_function_call_part_from_call(item))
|
|
446
|
+
elif isinstance(item, TextPart):
|
|
447
|
+
parts.append(_GeminiTextPart(text=item.content))
|
|
448
|
+
else:
|
|
449
|
+
assert_never(item)
|
|
450
|
+
return _GeminiContent(role='model', parts=parts)
|
|
451
|
+
|
|
452
|
+
|
|
419
453
|
class _GeminiTextPart(TypedDict):
|
|
420
454
|
text: str
|
|
421
455
|
|
|
422
456
|
|
|
423
457
|
class _GeminiFunctionCallPart(TypedDict):
|
|
424
|
-
function_call: Annotated[_GeminiFunctionCall, Field(alias='functionCall')]
|
|
458
|
+
function_call: Annotated[_GeminiFunctionCall, pydantic.Field(alias='functionCall')]
|
|
425
459
|
|
|
426
460
|
|
|
427
|
-
def _function_call_part_from_call(tool:
|
|
428
|
-
|
|
429
|
-
return _GeminiFunctionCallPart(function_call=_GeminiFunctionCall(name=tool.tool_name, args=tool.args.args_dict))
|
|
461
|
+
def _function_call_part_from_call(tool: ToolCallPart) -> _GeminiFunctionCallPart:
|
|
462
|
+
return _GeminiFunctionCallPart(function_call=_GeminiFunctionCall(name=tool.tool_name, args=tool.args_as_dict()))
|
|
430
463
|
|
|
431
464
|
|
|
432
|
-
def
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
465
|
+
def _process_response_from_parts(parts: Sequence[_GeminiPartUnion], timestamp: datetime | None = None) -> ModelResponse:
|
|
466
|
+
items: list[ModelResponsePart] = []
|
|
467
|
+
for part in parts:
|
|
468
|
+
if 'text' in part:
|
|
469
|
+
items.append(TextPart(part['text']))
|
|
470
|
+
elif 'function_call' in part:
|
|
471
|
+
items.append(ToolCallPart.from_raw_args(part['function_call']['name'], part['function_call']['args']))
|
|
472
|
+
elif 'function_response' in part:
|
|
473
|
+
raise exceptions.UnexpectedModelBehavior(
|
|
474
|
+
f'Unsupported response from Gemini, expected all parts to be function calls or text, got: {part!r}'
|
|
475
|
+
)
|
|
476
|
+
return ModelResponse(items, timestamp=timestamp or _utils.now_utc())
|
|
439
477
|
|
|
440
478
|
|
|
441
479
|
class _GeminiFunctionCall(TypedDict):
|
|
@@ -446,7 +484,7 @@ class _GeminiFunctionCall(TypedDict):
|
|
|
446
484
|
|
|
447
485
|
|
|
448
486
|
class _GeminiFunctionResponsePart(TypedDict):
|
|
449
|
-
function_response: Annotated[_GeminiFunctionResponse, Field(alias='functionResponse')]
|
|
487
|
+
function_response: Annotated[_GeminiFunctionResponse, pydantic.Field(alias='functionResponse')]
|
|
450
488
|
|
|
451
489
|
|
|
452
490
|
def _response_part_from_response(name: str, response: dict[str, Any]) -> _GeminiFunctionResponsePart:
|
|
@@ -476,11 +514,11 @@ def _part_discriminator(v: Any) -> str:
|
|
|
476
514
|
# TODO discriminator
|
|
477
515
|
_GeminiPartUnion = Annotated[
|
|
478
516
|
Union[
|
|
479
|
-
Annotated[_GeminiTextPart, Tag('text')],
|
|
480
|
-
Annotated[_GeminiFunctionCallPart, Tag('function_call')],
|
|
481
|
-
Annotated[_GeminiFunctionResponsePart, Tag('function_response')],
|
|
517
|
+
Annotated[_GeminiTextPart, pydantic.Tag('text')],
|
|
518
|
+
Annotated[_GeminiFunctionCallPart, pydantic.Tag('function_call')],
|
|
519
|
+
Annotated[_GeminiFunctionResponsePart, pydantic.Tag('function_response')],
|
|
482
520
|
],
|
|
483
|
-
Discriminator(_part_discriminator),
|
|
521
|
+
pydantic.Discriminator(_part_discriminator),
|
|
484
522
|
]
|
|
485
523
|
|
|
486
524
|
|
|
@@ -490,7 +528,7 @@ class _GeminiTextContent(TypedDict):
|
|
|
490
528
|
|
|
491
529
|
|
|
492
530
|
class _GeminiTools(TypedDict):
|
|
493
|
-
function_declarations: list[Annotated[_GeminiFunction, Field(alias='functionDeclarations')]]
|
|
531
|
+
function_declarations: list[Annotated[_GeminiFunction, pydantic.Field(alias='functionDeclarations')]]
|
|
494
532
|
|
|
495
533
|
|
|
496
534
|
class _GeminiFunction(TypedDict):
|
|
@@ -531,6 +569,7 @@ class _GeminiFunctionCallingConfig(TypedDict):
|
|
|
531
569
|
allowed_function_names: list[str]
|
|
532
570
|
|
|
533
571
|
|
|
572
|
+
@pydantic.with_config(pydantic.ConfigDict(defer_build=True))
|
|
534
573
|
class _GeminiResponse(TypedDict):
|
|
535
574
|
"""Schema for the response from the Gemini API.
|
|
536
575
|
|
|
@@ -540,10 +579,11 @@ class _GeminiResponse(TypedDict):
|
|
|
540
579
|
|
|
541
580
|
candidates: list[_GeminiCandidates]
|
|
542
581
|
# usageMetadata appears to be required by both APIs but is omitted when streaming responses until the last response
|
|
543
|
-
usage_metadata: NotRequired[Annotated[_GeminiUsageMetaData, Field(alias='usageMetadata')]]
|
|
544
|
-
prompt_feedback: NotRequired[Annotated[_GeminiPromptFeedback, Field(alias='promptFeedback')]]
|
|
582
|
+
usage_metadata: NotRequired[Annotated[_GeminiUsageMetaData, pydantic.Field(alias='usageMetadata')]]
|
|
583
|
+
prompt_feedback: NotRequired[Annotated[_GeminiPromptFeedback, pydantic.Field(alias='promptFeedback')]]
|
|
545
584
|
|
|
546
585
|
|
|
586
|
+
# TODO: Delete the next three functions once we've reworked streams to be more flexible
|
|
547
587
|
def _extract_response_parts(
|
|
548
588
|
response: _GeminiResponse,
|
|
549
589
|
) -> _utils.Either[list[_GeminiFunctionCallPart], list[_GeminiTextPart]]:
|
|
@@ -576,14 +616,14 @@ class _GeminiCandidates(TypedDict):
|
|
|
576
616
|
"""See <https://ai.google.dev/api/generate-content#v1beta.Candidate>."""
|
|
577
617
|
|
|
578
618
|
content: _GeminiContent
|
|
579
|
-
finish_reason: NotRequired[Annotated[Literal['STOP'], Field(alias='finishReason')]]
|
|
619
|
+
finish_reason: NotRequired[Annotated[Literal['STOP', 'MAX_TOKENS'], pydantic.Field(alias='finishReason')]]
|
|
580
620
|
"""
|
|
581
621
|
See <https://ai.google.dev/api/generate-content#FinishReason>, lots of other values are possible,
|
|
582
622
|
but let's wait until we see them and know what they mean to add them here.
|
|
583
623
|
"""
|
|
584
|
-
avg_log_probs: NotRequired[Annotated[float, Field(alias='avgLogProbs')]]
|
|
624
|
+
avg_log_probs: NotRequired[Annotated[float, pydantic.Field(alias='avgLogProbs')]]
|
|
585
625
|
index: NotRequired[int]
|
|
586
|
-
safety_ratings: NotRequired[Annotated[list[_GeminiSafetyRating], Field(alias='safetyRatings')]]
|
|
626
|
+
safety_ratings: NotRequired[Annotated[list[_GeminiSafetyRating], pydantic.Field(alias='safetyRatings')]]
|
|
587
627
|
|
|
588
628
|
|
|
589
629
|
class _GeminiUsageMetaData(TypedDict, total=False):
|
|
@@ -592,20 +632,20 @@ class _GeminiUsageMetaData(TypedDict, total=False):
|
|
|
592
632
|
The docs suggest all fields are required, but some are actually not required, so we assume they are all optional.
|
|
593
633
|
"""
|
|
594
634
|
|
|
595
|
-
prompt_token_count: Annotated[int, Field(alias='promptTokenCount')]
|
|
596
|
-
candidates_token_count: NotRequired[Annotated[int, Field(alias='candidatesTokenCount')]]
|
|
597
|
-
total_token_count: Annotated[int, Field(alias='totalTokenCount')]
|
|
598
|
-
cached_content_token_count: NotRequired[Annotated[int, Field(alias='cachedContentTokenCount')]]
|
|
635
|
+
prompt_token_count: Annotated[int, pydantic.Field(alias='promptTokenCount')]
|
|
636
|
+
candidates_token_count: NotRequired[Annotated[int, pydantic.Field(alias='candidatesTokenCount')]]
|
|
637
|
+
total_token_count: Annotated[int, pydantic.Field(alias='totalTokenCount')]
|
|
638
|
+
cached_content_token_count: NotRequired[Annotated[int, pydantic.Field(alias='cachedContentTokenCount')]]
|
|
599
639
|
|
|
600
640
|
|
|
601
|
-
def
|
|
641
|
+
def _metadata_as_usage(response: _GeminiResponse) -> result.Usage:
|
|
602
642
|
metadata = response.get('usage_metadata')
|
|
603
643
|
if metadata is None:
|
|
604
|
-
return result.
|
|
644
|
+
return result.Usage()
|
|
605
645
|
details: dict[str, int] = {}
|
|
606
646
|
if cached_content_token_count := metadata.get('cached_content_token_count'):
|
|
607
647
|
details['cached_content_token_count'] = cached_content_token_count
|
|
608
|
-
return result.
|
|
648
|
+
return result.Usage(
|
|
609
649
|
request_tokens=metadata.get('prompt_token_count', 0),
|
|
610
650
|
response_tokens=metadata.get('candidates_token_count', 0),
|
|
611
651
|
total_tokens=metadata.get('total_token_count', 0),
|
|
@@ -629,15 +669,15 @@ class _GeminiSafetyRating(TypedDict):
|
|
|
629
669
|
class _GeminiPromptFeedback(TypedDict):
|
|
630
670
|
"""See <https://ai.google.dev/api/generate-content#v1beta.GenerateContentResponse>."""
|
|
631
671
|
|
|
632
|
-
block_reason: Annotated[str, Field(alias='blockReason')]
|
|
633
|
-
safety_ratings: Annotated[list[_GeminiSafetyRating], Field(alias='safetyRatings')]
|
|
672
|
+
block_reason: Annotated[str, pydantic.Field(alias='blockReason')]
|
|
673
|
+
safety_ratings: Annotated[list[_GeminiSafetyRating], pydantic.Field(alias='safetyRatings')]
|
|
634
674
|
|
|
635
675
|
|
|
636
|
-
_gemini_request_ta =
|
|
637
|
-
_gemini_response_ta =
|
|
676
|
+
_gemini_request_ta = pydantic.TypeAdapter(_GeminiRequest)
|
|
677
|
+
_gemini_response_ta = pydantic.TypeAdapter(_GeminiResponse)
|
|
638
678
|
|
|
639
679
|
# steam requests return a list of https://ai.google.dev/api/generate-content#method:-models.streamgeneratecontent
|
|
640
|
-
_gemini_streamed_response_ta =
|
|
680
|
+
_gemini_streamed_response_ta = pydantic.TypeAdapter(list[_GeminiResponse], config=pydantic.ConfigDict(defer_build=True))
|
|
641
681
|
|
|
642
682
|
|
|
643
683
|
class _GeminiJsonSchema:
|