pydantic-ai-slim 0.0.11__py3-none-any.whl → 0.0.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_pydantic.py +13 -29
- pydantic_ai/_result.py +52 -38
- pydantic_ai/_system_prompt.py +1 -1
- pydantic_ai/_utils.py +20 -8
- pydantic_ai/agent.py +431 -167
- pydantic_ai/messages.py +90 -48
- pydantic_ai/models/__init__.py +59 -42
- pydantic_ai/models/anthropic.py +344 -0
- pydantic_ai/models/function.py +66 -44
- pydantic_ai/models/gemini.py +160 -117
- pydantic_ai/models/groq.py +125 -108
- pydantic_ai/models/mistral.py +680 -0
- pydantic_ai/models/ollama.py +116 -0
- pydantic_ai/models/openai.py +145 -114
- pydantic_ai/models/test.py +109 -77
- pydantic_ai/models/vertexai.py +14 -9
- pydantic_ai/result.py +35 -37
- pydantic_ai/settings.py +72 -0
- pydantic_ai/tools.py +140 -45
- {pydantic_ai_slim-0.0.11.dist-info → pydantic_ai_slim-0.0.13.dist-info}/METADATA +8 -3
- pydantic_ai_slim-0.0.13.dist-info/RECORD +26 -0
- {pydantic_ai_slim-0.0.11.dist-info → pydantic_ai_slim-0.0.13.dist-info}/WHEEL +1 -1
- pydantic_ai_slim-0.0.11.dist-info/RECORD +0 -22
pydantic_ai/models/gemini.py
CHANGED
|
@@ -2,31 +2,35 @@ from __future__ import annotations as _annotations
|
|
|
2
2
|
|
|
3
3
|
import os
|
|
4
4
|
import re
|
|
5
|
-
from collections.abc import AsyncIterator, Iterable,
|
|
5
|
+
from collections.abc import AsyncIterator, Iterable, Sequence
|
|
6
6
|
from contextlib import asynccontextmanager
|
|
7
7
|
from copy import deepcopy
|
|
8
8
|
from dataclasses import dataclass, field
|
|
9
9
|
from datetime import datetime
|
|
10
10
|
from typing import Annotated, Any, Literal, Protocol, Union
|
|
11
11
|
|
|
12
|
+
import pydantic
|
|
12
13
|
import pydantic_core
|
|
13
|
-
from httpx import AsyncClient as AsyncHTTPClient, Response as HTTPResponse
|
|
14
|
-
from pydantic import Discriminator, Field, Tag
|
|
14
|
+
from httpx import USE_CLIENT_DEFAULT, AsyncClient as AsyncHTTPClient, Response as HTTPResponse
|
|
15
15
|
from typing_extensions import NotRequired, TypedDict, TypeGuard, assert_never
|
|
16
16
|
|
|
17
|
-
from .. import UnexpectedModelBehavior,
|
|
17
|
+
from .. import UnexpectedModelBehavior, _utils, exceptions, result
|
|
18
18
|
from ..messages import (
|
|
19
19
|
ArgsDict,
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
20
|
+
ModelMessage,
|
|
21
|
+
ModelRequest,
|
|
22
|
+
ModelResponse,
|
|
23
|
+
ModelResponsePart,
|
|
24
|
+
RetryPromptPart,
|
|
25
|
+
SystemPromptPart,
|
|
26
|
+
TextPart,
|
|
27
|
+
ToolCallPart,
|
|
28
|
+
ToolReturnPart,
|
|
29
|
+
UserPromptPart,
|
|
27
30
|
)
|
|
31
|
+
from ..settings import ModelSettings
|
|
32
|
+
from ..tools import ToolDefinition
|
|
28
33
|
from . import (
|
|
29
|
-
AbstractToolDefinition,
|
|
30
34
|
AgentModel,
|
|
31
35
|
EitherStreamedResponse,
|
|
32
36
|
Model,
|
|
@@ -37,7 +41,9 @@ from . import (
|
|
|
37
41
|
get_user_agent,
|
|
38
42
|
)
|
|
39
43
|
|
|
40
|
-
GeminiModelName = Literal[
|
|
44
|
+
GeminiModelName = Literal[
|
|
45
|
+
'gemini-1.5-flash', 'gemini-1.5-flash-8b', 'gemini-1.5-pro', 'gemini-1.0-pro', 'gemini-2.0-flash-exp'
|
|
46
|
+
]
|
|
41
47
|
"""Named Gemini models.
|
|
42
48
|
|
|
43
49
|
See [the Gemini API docs](https://ai.google.dev/gemini-api/docs/models/gemini#model-variations) for a full list.
|
|
@@ -90,9 +96,10 @@ class GeminiModel(Model):
|
|
|
90
96
|
|
|
91
97
|
async def agent_model(
|
|
92
98
|
self,
|
|
93
|
-
|
|
99
|
+
*,
|
|
100
|
+
function_tools: list[ToolDefinition],
|
|
94
101
|
allow_text_result: bool,
|
|
95
|
-
result_tools:
|
|
102
|
+
result_tools: list[ToolDefinition],
|
|
96
103
|
) -> GeminiAgentModel:
|
|
97
104
|
return GeminiAgentModel(
|
|
98
105
|
http_client=self.http_client,
|
|
@@ -142,13 +149,13 @@ class GeminiAgentModel(AgentModel):
|
|
|
142
149
|
model_name: GeminiModelName,
|
|
143
150
|
auth: AuthProtocol,
|
|
144
151
|
url: str,
|
|
145
|
-
function_tools:
|
|
152
|
+
function_tools: list[ToolDefinition],
|
|
146
153
|
allow_text_result: bool,
|
|
147
|
-
result_tools:
|
|
154
|
+
result_tools: list[ToolDefinition],
|
|
148
155
|
):
|
|
149
156
|
check_allow_model_requests()
|
|
150
|
-
tools = [_function_from_abstract_tool(t) for t in function_tools
|
|
151
|
-
if result_tools
|
|
157
|
+
tools = [_function_from_abstract_tool(t) for t in function_tools]
|
|
158
|
+
if result_tools:
|
|
152
159
|
tools += [_function_from_abstract_tool(t) for t in result_tools]
|
|
153
160
|
|
|
154
161
|
if allow_text_result:
|
|
@@ -163,26 +170,25 @@ class GeminiAgentModel(AgentModel):
|
|
|
163
170
|
self.tool_config = tool_config
|
|
164
171
|
self.url = url
|
|
165
172
|
|
|
166
|
-
async def request(
|
|
167
|
-
|
|
173
|
+
async def request(
|
|
174
|
+
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
175
|
+
) -> tuple[ModelResponse, result.Cost]:
|
|
176
|
+
async with self._make_request(messages, False, model_settings) as http_response:
|
|
168
177
|
response = _gemini_response_ta.validate_json(await http_response.aread())
|
|
169
178
|
return self._process_response(response), _metadata_as_cost(response)
|
|
170
179
|
|
|
171
180
|
@asynccontextmanager
|
|
172
|
-
async def request_stream(
|
|
173
|
-
|
|
181
|
+
async def request_stream(
|
|
182
|
+
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
183
|
+
) -> AsyncIterator[EitherStreamedResponse]:
|
|
184
|
+
async with self._make_request(messages, True, model_settings) as http_response:
|
|
174
185
|
yield await self._process_streamed_response(http_response)
|
|
175
186
|
|
|
176
187
|
@asynccontextmanager
|
|
177
|
-
async def _make_request(
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
either_content = self._message_to_gemini(m)
|
|
182
|
-
if left := either_content.left:
|
|
183
|
-
sys_prompt_parts.append(left.value)
|
|
184
|
-
else:
|
|
185
|
-
contents.append(either_content.right)
|
|
188
|
+
async def _make_request(
|
|
189
|
+
self, messages: list[ModelMessage], streamed: bool, model_settings: ModelSettings | None
|
|
190
|
+
) -> AsyncIterator[HTTPResponse]:
|
|
191
|
+
sys_prompt_parts, contents = self._message_to_gemini_content(messages)
|
|
186
192
|
|
|
187
193
|
request_data = _GeminiRequest(contents=contents)
|
|
188
194
|
if sys_prompt_parts:
|
|
@@ -192,6 +198,17 @@ class GeminiAgentModel(AgentModel):
|
|
|
192
198
|
if self.tool_config is not None:
|
|
193
199
|
request_data['tool_config'] = self.tool_config
|
|
194
200
|
|
|
201
|
+
generation_config: _GeminiGenerationConfig = {}
|
|
202
|
+
if model_settings:
|
|
203
|
+
if (max_tokens := model_settings.get('max_tokens')) is not None:
|
|
204
|
+
generation_config['max_output_tokens'] = max_tokens
|
|
205
|
+
if (temperature := model_settings.get('temperature')) is not None:
|
|
206
|
+
generation_config['temperature'] = temperature
|
|
207
|
+
if (top_p := model_settings.get('top_p')) is not None:
|
|
208
|
+
generation_config['top_p'] = top_p
|
|
209
|
+
if generation_config:
|
|
210
|
+
request_data['generation_config'] = generation_config
|
|
211
|
+
|
|
195
212
|
url = self.url + ('streamGenerateContent' if streamed else 'generateContent')
|
|
196
213
|
|
|
197
214
|
headers = {
|
|
@@ -202,19 +219,24 @@ class GeminiAgentModel(AgentModel):
|
|
|
202
219
|
|
|
203
220
|
request_json = _gemini_request_ta.dump_json(request_data, by_alias=True)
|
|
204
221
|
|
|
205
|
-
async with self.http_client.stream(
|
|
222
|
+
async with self.http_client.stream(
|
|
223
|
+
'POST',
|
|
224
|
+
url,
|
|
225
|
+
content=request_json,
|
|
226
|
+
headers=headers,
|
|
227
|
+
timeout=(model_settings or {}).get('timeout', USE_CLIENT_DEFAULT),
|
|
228
|
+
) as r:
|
|
206
229
|
if r.status_code != 200:
|
|
207
230
|
await r.aread()
|
|
208
231
|
raise exceptions.UnexpectedModelBehavior(f'Unexpected response from gemini {r.status_code}', r.text)
|
|
209
232
|
yield r
|
|
210
233
|
|
|
211
234
|
@staticmethod
|
|
212
|
-
def _process_response(response: _GeminiResponse) ->
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
return ModelTextResponse(content=''.join(part['text'] for part in either.right))
|
|
235
|
+
def _process_response(response: _GeminiResponse) -> ModelResponse:
|
|
236
|
+
if len(response['candidates']) != 1:
|
|
237
|
+
raise UnexpectedModelBehavior('Expected exactly one candidate in Gemini response')
|
|
238
|
+
parts = response['candidates'][0]['content']['parts']
|
|
239
|
+
return _process_response_from_parts(parts)
|
|
218
240
|
|
|
219
241
|
@staticmethod
|
|
220
242
|
async def _process_streamed_response(http_response: HTTPResponse) -> EitherStreamedResponse:
|
|
@@ -238,34 +260,37 @@ class GeminiAgentModel(AgentModel):
|
|
|
238
260
|
if start_response is None:
|
|
239
261
|
raise UnexpectedModelBehavior('Streamed response ended without content or tool calls')
|
|
240
262
|
|
|
263
|
+
# TODO: Update this once we rework stream responses to be more flexible
|
|
241
264
|
if _extract_response_parts(start_response).is_left():
|
|
242
265
|
return GeminiStreamStructuredResponse(_content=content, _stream=aiter_bytes)
|
|
243
266
|
else:
|
|
244
267
|
return GeminiStreamTextResponse(_json_content=content, _stream=aiter_bytes)
|
|
245
268
|
|
|
246
|
-
@
|
|
247
|
-
def
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
+
@classmethod
|
|
270
|
+
def _message_to_gemini_content(
|
|
271
|
+
cls, messages: list[ModelMessage]
|
|
272
|
+
) -> tuple[list[_GeminiTextPart], list[_GeminiContent]]:
|
|
273
|
+
sys_prompt_parts: list[_GeminiTextPart] = []
|
|
274
|
+
contents: list[_GeminiContent] = []
|
|
275
|
+
for m in messages:
|
|
276
|
+
if isinstance(m, ModelRequest):
|
|
277
|
+
for part in m.parts:
|
|
278
|
+
if isinstance(part, SystemPromptPart):
|
|
279
|
+
sys_prompt_parts.append(_GeminiTextPart(text=part.content))
|
|
280
|
+
elif isinstance(part, UserPromptPart):
|
|
281
|
+
contents.append(_content_user_prompt(part))
|
|
282
|
+
elif isinstance(part, ToolReturnPart):
|
|
283
|
+
contents.append(_content_tool_return(part))
|
|
284
|
+
elif isinstance(part, RetryPromptPart):
|
|
285
|
+
contents.append(_content_retry_prompt(part))
|
|
286
|
+
else:
|
|
287
|
+
assert_never(part)
|
|
288
|
+
elif isinstance(m, ModelResponse):
|
|
289
|
+
contents.append(_content_model_response(m))
|
|
290
|
+
else:
|
|
291
|
+
assert_never(m)
|
|
292
|
+
|
|
293
|
+
return sys_prompt_parts, contents
|
|
269
294
|
|
|
270
295
|
|
|
271
296
|
@dataclass
|
|
@@ -326,8 +351,8 @@ class GeminiStreamStructuredResponse(StreamStructuredResponse):
|
|
|
326
351
|
chunk = await self._stream.__anext__()
|
|
327
352
|
self._content.extend(chunk)
|
|
328
353
|
|
|
329
|
-
def get(self, *, final: bool = False) ->
|
|
330
|
-
"""Get the `
|
|
354
|
+
def get(self, *, final: bool = False) -> ModelResponse:
|
|
355
|
+
"""Get the `ModelResponse` at this point.
|
|
331
356
|
|
|
332
357
|
NOTE: It's not clear how the stream of responses should be combined because Gemini seems to always
|
|
333
358
|
reply with a single response, when returning a structured data.
|
|
@@ -339,20 +364,13 @@ class GeminiStreamStructuredResponse(StreamStructuredResponse):
|
|
|
339
364
|
self._content,
|
|
340
365
|
experimental_allow_partial='off' if final else 'trailing-strings',
|
|
341
366
|
)
|
|
342
|
-
combined_parts: list[
|
|
367
|
+
combined_parts: list[_GeminiPartUnion] = []
|
|
343
368
|
self._cost = result.Cost()
|
|
344
369
|
for r in responses:
|
|
345
370
|
self._cost += _metadata_as_cost(r)
|
|
346
371
|
candidate = r['candidates'][0]
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
combined_parts.extend(parts)
|
|
350
|
-
elif not candidate.get('finish_reason'):
|
|
351
|
-
# you can get an empty text part along with the finish_reason, so we ignore that case
|
|
352
|
-
raise UnexpectedModelBehavior(
|
|
353
|
-
'Streamed response with unexpected content, expected all parts to be function calls'
|
|
354
|
-
)
|
|
355
|
-
return _structured_response_from_parts(combined_parts, timestamp=self._timestamp)
|
|
372
|
+
combined_parts.extend(candidate['content']['parts'])
|
|
373
|
+
return _process_response_from_parts(combined_parts, timestamp=self._timestamp)
|
|
356
374
|
|
|
357
375
|
def cost(self) -> result.Cost:
|
|
358
376
|
return self._cost
|
|
@@ -366,6 +384,7 @@ class GeminiStreamStructuredResponse(StreamStructuredResponse):
|
|
|
366
384
|
# TypeAdapters take care of validation and serialization
|
|
367
385
|
|
|
368
386
|
|
|
387
|
+
@pydantic.with_config(pydantic.ConfigDict(defer_build=True))
|
|
369
388
|
class _GeminiRequest(TypedDict):
|
|
370
389
|
"""Schema for an API request to the Gemini API.
|
|
371
390
|
|
|
@@ -381,32 +400,37 @@ class _GeminiRequest(TypedDict):
|
|
|
381
400
|
Developer generated system instructions, see
|
|
382
401
|
<https://ai.google.dev/gemini-api/docs/system-instructions?lang=rest>
|
|
383
402
|
"""
|
|
403
|
+
generation_config: NotRequired[_GeminiGenerationConfig]
|
|
384
404
|
|
|
385
405
|
|
|
386
|
-
class
|
|
387
|
-
|
|
388
|
-
parts: list[_GeminiPartUnion]
|
|
406
|
+
class _GeminiGenerationConfig(TypedDict, total=False):
|
|
407
|
+
"""Schema for an API request to the Gemini API.
|
|
389
408
|
|
|
409
|
+
Note there are many additional fields available that have not been added yet.
|
|
390
410
|
|
|
391
|
-
|
|
392
|
-
|
|
411
|
+
See <https://ai.google.dev/api/generate-content#generationconfig> for API docs.
|
|
412
|
+
"""
|
|
413
|
+
|
|
414
|
+
max_output_tokens: int
|
|
415
|
+
temperature: float
|
|
416
|
+
top_p: float
|
|
393
417
|
|
|
394
418
|
|
|
395
|
-
|
|
396
|
-
|
|
419
|
+
class _GeminiContent(TypedDict):
|
|
420
|
+
role: Literal['user', 'model']
|
|
421
|
+
parts: list[_GeminiPartUnion]
|
|
397
422
|
|
|
398
423
|
|
|
399
|
-
def
|
|
400
|
-
|
|
401
|
-
return _GeminiContent(role='model', parts=parts)
|
|
424
|
+
def _content_user_prompt(m: UserPromptPart) -> _GeminiContent:
|
|
425
|
+
return _GeminiContent(role='user', parts=[_GeminiTextPart(text=m.content)])
|
|
402
426
|
|
|
403
427
|
|
|
404
|
-
def
|
|
428
|
+
def _content_tool_return(m: ToolReturnPart) -> _GeminiContent:
|
|
405
429
|
f_response = _response_part_from_response(m.tool_name, m.model_response_object())
|
|
406
430
|
return _GeminiContent(role='user', parts=[f_response])
|
|
407
431
|
|
|
408
432
|
|
|
409
|
-
def
|
|
433
|
+
def _content_retry_prompt(m: RetryPromptPart) -> _GeminiContent:
|
|
410
434
|
if m.tool_name is None:
|
|
411
435
|
part = _GeminiTextPart(text=m.model_response())
|
|
412
436
|
else:
|
|
@@ -415,26 +439,43 @@ def _content_function_retry(m: RetryPrompt) -> _GeminiContent:
|
|
|
415
439
|
return _GeminiContent(role='user', parts=[part])
|
|
416
440
|
|
|
417
441
|
|
|
442
|
+
def _content_model_response(m: ModelResponse) -> _GeminiContent:
|
|
443
|
+
parts: list[_GeminiPartUnion] = []
|
|
444
|
+
for item in m.parts:
|
|
445
|
+
if isinstance(item, ToolCallPart):
|
|
446
|
+
parts.append(_function_call_part_from_call(item))
|
|
447
|
+
elif isinstance(item, TextPart):
|
|
448
|
+
parts.append(_GeminiTextPart(text=item.content))
|
|
449
|
+
else:
|
|
450
|
+
assert_never(item)
|
|
451
|
+
return _GeminiContent(role='model', parts=parts)
|
|
452
|
+
|
|
453
|
+
|
|
418
454
|
class _GeminiTextPart(TypedDict):
|
|
419
455
|
text: str
|
|
420
456
|
|
|
421
457
|
|
|
422
458
|
class _GeminiFunctionCallPart(TypedDict):
|
|
423
|
-
function_call: Annotated[_GeminiFunctionCall, Field(alias='functionCall')]
|
|
459
|
+
function_call: Annotated[_GeminiFunctionCall, pydantic.Field(alias='functionCall')]
|
|
424
460
|
|
|
425
461
|
|
|
426
|
-
def _function_call_part_from_call(tool:
|
|
462
|
+
def _function_call_part_from_call(tool: ToolCallPart) -> _GeminiFunctionCallPart:
|
|
427
463
|
assert isinstance(tool.args, ArgsDict), f'Expected ArgsObject, got {tool.args}'
|
|
428
464
|
return _GeminiFunctionCallPart(function_call=_GeminiFunctionCall(name=tool.tool_name, args=tool.args.args_dict))
|
|
429
465
|
|
|
430
466
|
|
|
431
|
-
def
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
467
|
+
def _process_response_from_parts(parts: Sequence[_GeminiPartUnion], timestamp: datetime | None = None) -> ModelResponse:
|
|
468
|
+
items: list[ModelResponsePart] = []
|
|
469
|
+
for part in parts:
|
|
470
|
+
if 'text' in part:
|
|
471
|
+
items.append(TextPart(part['text']))
|
|
472
|
+
elif 'function_call' in part:
|
|
473
|
+
items.append(ToolCallPart.from_dict(part['function_call']['name'], part['function_call']['args']))
|
|
474
|
+
elif 'function_response' in part:
|
|
475
|
+
raise exceptions.UnexpectedModelBehavior(
|
|
476
|
+
f'Unsupported response from Gemini, expected all parts to be function calls or text, got: {part!r}'
|
|
477
|
+
)
|
|
478
|
+
return ModelResponse(items, timestamp=timestamp or _utils.now_utc())
|
|
438
479
|
|
|
439
480
|
|
|
440
481
|
class _GeminiFunctionCall(TypedDict):
|
|
@@ -445,7 +486,7 @@ class _GeminiFunctionCall(TypedDict):
|
|
|
445
486
|
|
|
446
487
|
|
|
447
488
|
class _GeminiFunctionResponsePart(TypedDict):
|
|
448
|
-
function_response: Annotated[_GeminiFunctionResponse, Field(alias='functionResponse')]
|
|
489
|
+
function_response: Annotated[_GeminiFunctionResponse, pydantic.Field(alias='functionResponse')]
|
|
449
490
|
|
|
450
491
|
|
|
451
492
|
def _response_part_from_response(name: str, response: dict[str, Any]) -> _GeminiFunctionResponsePart:
|
|
@@ -475,11 +516,11 @@ def _part_discriminator(v: Any) -> str:
|
|
|
475
516
|
# TODO discriminator
|
|
476
517
|
_GeminiPartUnion = Annotated[
|
|
477
518
|
Union[
|
|
478
|
-
Annotated[_GeminiTextPart, Tag('text')],
|
|
479
|
-
Annotated[_GeminiFunctionCallPart, Tag('function_call')],
|
|
480
|
-
Annotated[_GeminiFunctionResponsePart, Tag('function_response')],
|
|
519
|
+
Annotated[_GeminiTextPart, pydantic.Tag('text')],
|
|
520
|
+
Annotated[_GeminiFunctionCallPart, pydantic.Tag('function_call')],
|
|
521
|
+
Annotated[_GeminiFunctionResponsePart, pydantic.Tag('function_response')],
|
|
481
522
|
],
|
|
482
|
-
Discriminator(_part_discriminator),
|
|
523
|
+
pydantic.Discriminator(_part_discriminator),
|
|
483
524
|
]
|
|
484
525
|
|
|
485
526
|
|
|
@@ -489,7 +530,7 @@ class _GeminiTextContent(TypedDict):
|
|
|
489
530
|
|
|
490
531
|
|
|
491
532
|
class _GeminiTools(TypedDict):
|
|
492
|
-
function_declarations: list[Annotated[_GeminiFunction, Field(alias='functionDeclarations')]]
|
|
533
|
+
function_declarations: list[Annotated[_GeminiFunction, pydantic.Field(alias='functionDeclarations')]]
|
|
493
534
|
|
|
494
535
|
|
|
495
536
|
class _GeminiFunction(TypedDict):
|
|
@@ -504,8 +545,8 @@ class _GeminiFunction(TypedDict):
|
|
|
504
545
|
"""
|
|
505
546
|
|
|
506
547
|
|
|
507
|
-
def _function_from_abstract_tool(tool:
|
|
508
|
-
json_schema = _GeminiJsonSchema(tool.
|
|
548
|
+
def _function_from_abstract_tool(tool: ToolDefinition) -> _GeminiFunction:
|
|
549
|
+
json_schema = _GeminiJsonSchema(tool.parameters_json_schema).simplify()
|
|
509
550
|
f = _GeminiFunction(
|
|
510
551
|
name=tool.name,
|
|
511
552
|
description=tool.description,
|
|
@@ -530,6 +571,7 @@ class _GeminiFunctionCallingConfig(TypedDict):
|
|
|
530
571
|
allowed_function_names: list[str]
|
|
531
572
|
|
|
532
573
|
|
|
574
|
+
@pydantic.with_config(pydantic.ConfigDict(defer_build=True))
|
|
533
575
|
class _GeminiResponse(TypedDict):
|
|
534
576
|
"""Schema for the response from the Gemini API.
|
|
535
577
|
|
|
@@ -539,10 +581,11 @@ class _GeminiResponse(TypedDict):
|
|
|
539
581
|
|
|
540
582
|
candidates: list[_GeminiCandidates]
|
|
541
583
|
# usageMetadata appears to be required by both APIs but is omitted when streaming responses until the last response
|
|
542
|
-
usage_metadata: NotRequired[Annotated[_GeminiUsageMetaData, Field(alias='usageMetadata')]]
|
|
543
|
-
prompt_feedback: NotRequired[Annotated[_GeminiPromptFeedback, Field(alias='promptFeedback')]]
|
|
584
|
+
usage_metadata: NotRequired[Annotated[_GeminiUsageMetaData, pydantic.Field(alias='usageMetadata')]]
|
|
585
|
+
prompt_feedback: NotRequired[Annotated[_GeminiPromptFeedback, pydantic.Field(alias='promptFeedback')]]
|
|
544
586
|
|
|
545
587
|
|
|
588
|
+
# TODO: Delete the next three functions once we've reworked streams to be more flexible
|
|
546
589
|
def _extract_response_parts(
|
|
547
590
|
response: _GeminiResponse,
|
|
548
591
|
) -> _utils.Either[list[_GeminiFunctionCallPart], list[_GeminiTextPart]]:
|
|
@@ -575,14 +618,14 @@ class _GeminiCandidates(TypedDict):
|
|
|
575
618
|
"""See <https://ai.google.dev/api/generate-content#v1beta.Candidate>."""
|
|
576
619
|
|
|
577
620
|
content: _GeminiContent
|
|
578
|
-
finish_reason: NotRequired[Annotated[Literal['STOP'], Field(alias='finishReason')]]
|
|
621
|
+
finish_reason: NotRequired[Annotated[Literal['STOP', 'MAX_TOKENS'], pydantic.Field(alias='finishReason')]]
|
|
579
622
|
"""
|
|
580
623
|
See <https://ai.google.dev/api/generate-content#FinishReason>, lots of other values are possible,
|
|
581
624
|
but let's wait until we see them and know what they mean to add them here.
|
|
582
625
|
"""
|
|
583
|
-
avg_log_probs: NotRequired[Annotated[float, Field(alias='avgLogProbs')]]
|
|
626
|
+
avg_log_probs: NotRequired[Annotated[float, pydantic.Field(alias='avgLogProbs')]]
|
|
584
627
|
index: NotRequired[int]
|
|
585
|
-
safety_ratings: NotRequired[Annotated[list[_GeminiSafetyRating], Field(alias='safetyRatings')]]
|
|
628
|
+
safety_ratings: NotRequired[Annotated[list[_GeminiSafetyRating], pydantic.Field(alias='safetyRatings')]]
|
|
586
629
|
|
|
587
630
|
|
|
588
631
|
class _GeminiUsageMetaData(TypedDict, total=False):
|
|
@@ -591,10 +634,10 @@ class _GeminiUsageMetaData(TypedDict, total=False):
|
|
|
591
634
|
The docs suggest all fields are required, but some are actually not required, so we assume they are all optional.
|
|
592
635
|
"""
|
|
593
636
|
|
|
594
|
-
prompt_token_count: Annotated[int, Field(alias='promptTokenCount')]
|
|
595
|
-
candidates_token_count: NotRequired[Annotated[int, Field(alias='candidatesTokenCount')]]
|
|
596
|
-
total_token_count: Annotated[int, Field(alias='totalTokenCount')]
|
|
597
|
-
cached_content_token_count: NotRequired[Annotated[int, Field(alias='cachedContentTokenCount')]]
|
|
637
|
+
prompt_token_count: Annotated[int, pydantic.Field(alias='promptTokenCount')]
|
|
638
|
+
candidates_token_count: NotRequired[Annotated[int, pydantic.Field(alias='candidatesTokenCount')]]
|
|
639
|
+
total_token_count: Annotated[int, pydantic.Field(alias='totalTokenCount')]
|
|
640
|
+
cached_content_token_count: NotRequired[Annotated[int, pydantic.Field(alias='cachedContentTokenCount')]]
|
|
598
641
|
|
|
599
642
|
|
|
600
643
|
def _metadata_as_cost(response: _GeminiResponse) -> result.Cost:
|
|
@@ -628,15 +671,15 @@ class _GeminiSafetyRating(TypedDict):
|
|
|
628
671
|
class _GeminiPromptFeedback(TypedDict):
|
|
629
672
|
"""See <https://ai.google.dev/api/generate-content#v1beta.GenerateContentResponse>."""
|
|
630
673
|
|
|
631
|
-
block_reason: Annotated[str, Field(alias='blockReason')]
|
|
632
|
-
safety_ratings: Annotated[list[_GeminiSafetyRating], Field(alias='safetyRatings')]
|
|
674
|
+
block_reason: Annotated[str, pydantic.Field(alias='blockReason')]
|
|
675
|
+
safety_ratings: Annotated[list[_GeminiSafetyRating], pydantic.Field(alias='safetyRatings')]
|
|
633
676
|
|
|
634
677
|
|
|
635
|
-
_gemini_request_ta =
|
|
636
|
-
_gemini_response_ta =
|
|
678
|
+
_gemini_request_ta = pydantic.TypeAdapter(_GeminiRequest)
|
|
679
|
+
_gemini_response_ta = pydantic.TypeAdapter(_GeminiResponse)
|
|
637
680
|
|
|
638
681
|
# steam requests return a list of https://ai.google.dev/api/generate-content#method:-models.streamgeneratecontent
|
|
639
|
-
_gemini_streamed_response_ta =
|
|
682
|
+
_gemini_streamed_response_ta = pydantic.TypeAdapter(list[_GeminiResponse], config=pydantic.ConfigDict(defer_build=True))
|
|
640
683
|
|
|
641
684
|
|
|
642
685
|
class _GeminiJsonSchema:
|