google-genai 1.7.0__py3-none-any.whl → 1.53.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google/genai/__init__.py +4 -2
- google/genai/_adapters.py +55 -0
- google/genai/_api_client.py +1301 -299
- google/genai/_api_module.py +1 -1
- google/genai/_automatic_function_calling_util.py +54 -33
- google/genai/_base_transformers.py +26 -0
- google/genai/_base_url.py +50 -0
- google/genai/_common.py +560 -59
- google/genai/_extra_utils.py +371 -38
- google/genai/_live_converters.py +1467 -0
- google/genai/_local_tokenizer_loader.py +214 -0
- google/genai/_mcp_utils.py +117 -0
- google/genai/_operations_converters.py +394 -0
- google/genai/_replay_api_client.py +204 -92
- google/genai/_test_api_client.py +1 -1
- google/genai/_tokens_converters.py +520 -0
- google/genai/_transformers.py +633 -233
- google/genai/batches.py +1733 -538
- google/genai/caches.py +678 -1012
- google/genai/chats.py +48 -38
- google/genai/client.py +142 -15
- google/genai/documents.py +532 -0
- google/genai/errors.py +141 -35
- google/genai/file_search_stores.py +1296 -0
- google/genai/files.py +312 -744
- google/genai/live.py +617 -367
- google/genai/live_music.py +197 -0
- google/genai/local_tokenizer.py +395 -0
- google/genai/models.py +3598 -3116
- google/genai/operations.py +201 -362
- google/genai/pagers.py +23 -7
- google/genai/py.typed +1 -0
- google/genai/tokens.py +362 -0
- google/genai/tunings.py +1274 -496
- google/genai/types.py +14535 -5454
- google/genai/version.py +2 -2
- {google_genai-1.7.0.dist-info → google_genai-1.53.0.dist-info}/METADATA +736 -234
- google_genai-1.53.0.dist-info/RECORD +41 -0
- {google_genai-1.7.0.dist-info → google_genai-1.53.0.dist-info}/WHEEL +1 -1
- google_genai-1.7.0.dist-info/RECORD +0 -27
- {google_genai-1.7.0.dist-info → google_genai-1.53.0.dist-info/licenses}/LICENSE +0 -0
- {google_genai-1.7.0.dist-info → google_genai-1.53.0.dist-info}/top_level.txt +0 -0
google/genai/_transformers.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -26,7 +26,9 @@ import sys
|
|
|
26
26
|
import time
|
|
27
27
|
import types as builtin_types
|
|
28
28
|
import typing
|
|
29
|
-
from typing import Any, GenericAlias, Optional, Union # type: ignore[attr-defined]
|
|
29
|
+
from typing import Any, GenericAlias, List, Optional, Sequence, Union # type: ignore[attr-defined]
|
|
30
|
+
from ._mcp_utils import mcp_to_gemini_tool
|
|
31
|
+
from ._common import get_value_by_path as getv
|
|
30
32
|
|
|
31
33
|
if typing.TYPE_CHECKING:
|
|
32
34
|
import PIL.Image
|
|
@@ -34,6 +36,7 @@ if typing.TYPE_CHECKING:
|
|
|
34
36
|
import pydantic
|
|
35
37
|
|
|
36
38
|
from . import _api_client
|
|
39
|
+
from . import _common
|
|
37
40
|
from . import types
|
|
38
41
|
|
|
39
42
|
logger = logging.getLogger('google_genai._transformers')
|
|
@@ -43,17 +46,70 @@ if sys.version_info >= (3, 10):
|
|
|
43
46
|
_UNION_TYPES = (typing.Union, builtin_types.UnionType)
|
|
44
47
|
from typing import TypeGuard
|
|
45
48
|
else:
|
|
46
|
-
VersionedUnionType = typing._UnionGenericAlias
|
|
49
|
+
VersionedUnionType = typing._UnionGenericAlias # type: ignore[attr-defined]
|
|
47
50
|
_UNION_TYPES = (typing.Union,)
|
|
48
51
|
from typing_extensions import TypeGuard
|
|
49
52
|
|
|
53
|
+
if typing.TYPE_CHECKING:
|
|
54
|
+
from mcp import ClientSession as McpClientSession
|
|
55
|
+
from mcp.types import Tool as McpTool
|
|
56
|
+
else:
|
|
57
|
+
McpClientSession: typing.Type = Any
|
|
58
|
+
McpTool: typing.Type = Any
|
|
59
|
+
try:
|
|
60
|
+
from mcp import ClientSession as McpClientSession
|
|
61
|
+
from mcp.types import Tool as McpTool
|
|
62
|
+
except ImportError:
|
|
63
|
+
McpClientSession = None
|
|
64
|
+
McpTool = None
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
metric_name_sdk_api_map = {
|
|
68
|
+
'exact_match': 'exactMatchSpec',
|
|
69
|
+
'bleu': 'bleuSpec',
|
|
70
|
+
'rouge_spec': 'rougeSpec',
|
|
71
|
+
}
|
|
72
|
+
metric_name_api_sdk_map = {v: k for k, v in metric_name_sdk_api_map.items()}
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def _is_duck_type_of(obj: Any, cls: type[pydantic.BaseModel]) -> bool:
|
|
76
|
+
"""Checks if an object has all of the fields of a Pydantic model.
|
|
77
|
+
|
|
78
|
+
This is a duck-typing alternative to `isinstance` to solve dual-import
|
|
79
|
+
problems. It returns False for dictionaries, which should be handled by
|
|
80
|
+
`isinstance(obj, dict)`.
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
obj: The object to check.
|
|
84
|
+
cls: The Pydantic model class to duck-type against.
|
|
85
|
+
|
|
86
|
+
Returns:
|
|
87
|
+
True if the object has all the fields defined in the Pydantic model, False
|
|
88
|
+
otherwise.
|
|
89
|
+
"""
|
|
90
|
+
if isinstance(obj, dict) or not hasattr(cls, 'model_fields'):
|
|
91
|
+
return False
|
|
92
|
+
|
|
93
|
+
# Check if the object has all of the Pydantic model's defined fields.
|
|
94
|
+
all_matched = all(hasattr(obj, field) for field in cls.model_fields)
|
|
95
|
+
if not all_matched and isinstance(obj, pydantic.BaseModel):
|
|
96
|
+
# Check the other way around if obj is a Pydantic model.
|
|
97
|
+
# Check if the Pydantic model has all of the object's defined fields.
|
|
98
|
+
try:
|
|
99
|
+
obj_private = cls()
|
|
100
|
+
all_matched = all(hasattr(obj_private, f) for f in type(obj).model_fields)
|
|
101
|
+
except ValueError:
|
|
102
|
+
return False
|
|
103
|
+
return all_matched
|
|
104
|
+
|
|
105
|
+
|
|
50
106
|
def _resource_name(
|
|
51
107
|
client: _api_client.BaseApiClient,
|
|
52
108
|
resource_name: str,
|
|
53
109
|
*,
|
|
54
110
|
collection_identifier: str,
|
|
55
111
|
collection_hierarchy_depth: int = 2,
|
|
56
|
-
):
|
|
112
|
+
) -> str:
|
|
57
113
|
# pylint: disable=line-too-long
|
|
58
114
|
"""Prepends resource name with project, location, collection_identifier if needed.
|
|
59
115
|
|
|
@@ -140,9 +196,11 @@ def _resource_name(
|
|
|
140
196
|
return resource_name
|
|
141
197
|
|
|
142
198
|
|
|
143
|
-
def t_model(client: _api_client.BaseApiClient, model: str):
|
|
199
|
+
def t_model(client: _api_client.BaseApiClient, model: str) -> str:
|
|
144
200
|
if not model:
|
|
145
201
|
raise ValueError('model is required.')
|
|
202
|
+
if '..' in model or '?' in model or '&' in model:
|
|
203
|
+
raise ValueError('invalid model parameter.')
|
|
146
204
|
if client.vertexai:
|
|
147
205
|
if (
|
|
148
206
|
model.startswith('projects/')
|
|
@@ -180,18 +238,26 @@ def t_models_url(
|
|
|
180
238
|
|
|
181
239
|
|
|
182
240
|
def t_extract_models(
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
) -> Optional[list[types.ModelDict]]:
|
|
241
|
+
response: _common.StringDict,
|
|
242
|
+
) -> list[_common.StringDict]:
|
|
186
243
|
if not response:
|
|
187
244
|
return []
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
return
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
245
|
+
|
|
246
|
+
models: Optional[list[_common.StringDict]] = response.get('models')
|
|
247
|
+
if models is not None:
|
|
248
|
+
return models
|
|
249
|
+
|
|
250
|
+
tuned_models: Optional[list[_common.StringDict]] = response.get('tunedModels')
|
|
251
|
+
if tuned_models is not None:
|
|
252
|
+
return tuned_models
|
|
253
|
+
|
|
254
|
+
publisher_models: Optional[list[_common.StringDict]] = response.get(
|
|
255
|
+
'publisherModels'
|
|
256
|
+
)
|
|
257
|
+
if publisher_models is not None:
|
|
258
|
+
return publisher_models
|
|
259
|
+
|
|
260
|
+
if (
|
|
195
261
|
response.get('httpHeaders') is not None
|
|
196
262
|
and response.get('jsonPayload') is None
|
|
197
263
|
):
|
|
@@ -202,7 +268,9 @@ def t_extract_models(
|
|
|
202
268
|
return []
|
|
203
269
|
|
|
204
270
|
|
|
205
|
-
def t_caches_model(
|
|
271
|
+
def t_caches_model(
|
|
272
|
+
api_client: _api_client.BaseApiClient, model: str
|
|
273
|
+
) -> Optional[str]:
|
|
206
274
|
model = t_model(api_client, model)
|
|
207
275
|
if not model:
|
|
208
276
|
return None
|
|
@@ -217,7 +285,7 @@ def t_caches_model(api_client: _api_client.BaseApiClient, model: str):
|
|
|
217
285
|
return model
|
|
218
286
|
|
|
219
287
|
|
|
220
|
-
def pil_to_blob(img) -> types.Blob:
|
|
288
|
+
def pil_to_blob(img: Any) -> types.Blob:
|
|
221
289
|
PngImagePlugin: Optional[builtin_types.ModuleType]
|
|
222
290
|
try:
|
|
223
291
|
import PIL.PngImagePlugin
|
|
@@ -242,33 +310,119 @@ def pil_to_blob(img) -> types.Blob:
|
|
|
242
310
|
return types.Blob(mime_type=mime_type, data=data)
|
|
243
311
|
|
|
244
312
|
|
|
245
|
-
def
|
|
246
|
-
|
|
247
|
-
|
|
313
|
+
def t_function_response(
|
|
314
|
+
function_response: types.FunctionResponseOrDict,
|
|
315
|
+
) -> types.FunctionResponse:
|
|
316
|
+
if not function_response:
|
|
317
|
+
raise ValueError('function_response is required.')
|
|
318
|
+
if isinstance(function_response, dict):
|
|
319
|
+
return types.FunctionResponse.model_validate(function_response)
|
|
320
|
+
elif _is_duck_type_of(function_response, types.FunctionResponse):
|
|
321
|
+
return function_response
|
|
322
|
+
else:
|
|
323
|
+
raise TypeError(
|
|
324
|
+
'Could not parse input as FunctionResponse. Unsupported'
|
|
325
|
+
f' function_response type: {type(function_response)}'
|
|
326
|
+
)
|
|
248
327
|
|
|
249
|
-
PIL_Image = PIL.Image.Image
|
|
250
|
-
except ImportError:
|
|
251
|
-
PIL_Image = None
|
|
252
328
|
|
|
329
|
+
def t_function_responses(
|
|
330
|
+
function_responses: Union[
|
|
331
|
+
types.FunctionResponseOrDict,
|
|
332
|
+
Sequence[types.FunctionResponseOrDict],
|
|
333
|
+
],
|
|
334
|
+
) -> list[types.FunctionResponse]:
|
|
335
|
+
if not function_responses:
|
|
336
|
+
raise ValueError('function_responses are required.')
|
|
337
|
+
if isinstance(function_responses, Sequence):
|
|
338
|
+
return [t_function_response(response) for response in function_responses]
|
|
339
|
+
else:
|
|
340
|
+
return [t_function_response(function_responses)]
|
|
341
|
+
|
|
342
|
+
|
|
343
|
+
def t_blobs(
|
|
344
|
+
blobs: Union[types.BlobImageUnionDict, list[types.BlobImageUnionDict]],
|
|
345
|
+
) -> list[types.Blob]:
|
|
346
|
+
if isinstance(blobs, list):
|
|
347
|
+
return [t_blob(blob) for blob in blobs]
|
|
348
|
+
else:
|
|
349
|
+
return [t_blob(blobs)]
|
|
350
|
+
|
|
351
|
+
|
|
352
|
+
def t_blob(blob: types.BlobImageUnionDict) -> types.Blob:
|
|
353
|
+
if not blob:
|
|
354
|
+
raise ValueError('blob is required.')
|
|
355
|
+
|
|
356
|
+
if _is_duck_type_of(blob, types.Blob):
|
|
357
|
+
return blob # type: ignore[return-value]
|
|
358
|
+
|
|
359
|
+
if isinstance(blob, dict):
|
|
360
|
+
return types.Blob.model_validate(blob)
|
|
361
|
+
|
|
362
|
+
if 'image' in blob.__class__.__name__.lower():
|
|
363
|
+
try:
|
|
364
|
+
import PIL.Image
|
|
365
|
+
|
|
366
|
+
PIL_Image = PIL.Image.Image
|
|
367
|
+
except ImportError:
|
|
368
|
+
PIL_Image = None
|
|
369
|
+
|
|
370
|
+
if PIL_Image is not None and isinstance(blob, PIL_Image):
|
|
371
|
+
return pil_to_blob(blob)
|
|
372
|
+
|
|
373
|
+
raise TypeError(
|
|
374
|
+
f'Could not parse input as Blob. Unsupported blob type: {type(blob)}'
|
|
375
|
+
)
|
|
376
|
+
|
|
377
|
+
|
|
378
|
+
def t_image_blob(blob: types.BlobImageUnionDict) -> types.Blob:
|
|
379
|
+
blob = t_blob(blob)
|
|
380
|
+
if blob.mime_type and blob.mime_type.startswith('image/'):
|
|
381
|
+
return blob
|
|
382
|
+
raise ValueError(f'Unsupported mime type: {blob.mime_type!r}')
|
|
383
|
+
|
|
384
|
+
|
|
385
|
+
def t_audio_blob(blob: types.BlobOrDict) -> types.Blob:
|
|
386
|
+
blob = t_blob(blob)
|
|
387
|
+
if blob.mime_type and blob.mime_type.startswith('audio/'):
|
|
388
|
+
return blob
|
|
389
|
+
raise ValueError(f'Unsupported mime type: {blob.mime_type!r}')
|
|
390
|
+
|
|
391
|
+
|
|
392
|
+
def t_part(part: Optional[types.PartUnionDict]) -> types.Part:
|
|
253
393
|
if part is None:
|
|
254
394
|
raise ValueError('content part is required.')
|
|
255
395
|
if isinstance(part, str):
|
|
256
396
|
return types.Part(text=part)
|
|
257
|
-
if
|
|
258
|
-
|
|
259
|
-
if isinstance(part, types.File):
|
|
260
|
-
if not part.uri or not part.mime_type:
|
|
397
|
+
if _is_duck_type_of(part, types.File):
|
|
398
|
+
if not part.uri or not part.mime_type: # type: ignore[union-attr]
|
|
261
399
|
raise ValueError('file uri and mime_type are required.')
|
|
262
|
-
return types.Part.from_uri(file_uri=part.uri, mime_type=part.mime_type)
|
|
400
|
+
return types.Part.from_uri(file_uri=part.uri, mime_type=part.mime_type) # type: ignore[union-attr]
|
|
263
401
|
if isinstance(part, dict):
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
402
|
+
try:
|
|
403
|
+
return types.Part.model_validate(part)
|
|
404
|
+
except pydantic.ValidationError:
|
|
405
|
+
return types.Part(file_data=types.FileData.model_validate(part))
|
|
406
|
+
if _is_duck_type_of(part, types.Part):
|
|
407
|
+
return part # type: ignore[return-value]
|
|
408
|
+
|
|
409
|
+
if 'image' in part.__class__.__name__.lower():
|
|
410
|
+
try:
|
|
411
|
+
import PIL.Image
|
|
412
|
+
|
|
413
|
+
PIL_Image = PIL.Image.Image
|
|
414
|
+
except ImportError:
|
|
415
|
+
PIL_Image = None
|
|
416
|
+
|
|
417
|
+
if PIL_Image is not None and isinstance(part, PIL_Image):
|
|
418
|
+
return types.Part(inline_data=pil_to_blob(part))
|
|
267
419
|
raise ValueError(f'Unsupported content part type: {type(part)}')
|
|
268
420
|
|
|
269
421
|
|
|
270
422
|
def t_parts(
|
|
271
|
-
parts: Optional[
|
|
423
|
+
parts: Optional[
|
|
424
|
+
Union[list[types.PartUnionDict], types.PartUnionDict, list[types.Part]]
|
|
425
|
+
],
|
|
272
426
|
) -> list[types.Part]:
|
|
273
427
|
#
|
|
274
428
|
if parts is None or (isinstance(parts, list) and not parts):
|
|
@@ -280,7 +434,6 @@ def t_parts(
|
|
|
280
434
|
|
|
281
435
|
|
|
282
436
|
def t_image_predictions(
|
|
283
|
-
client: _api_client.BaseApiClient,
|
|
284
437
|
predictions: Optional[Iterable[Mapping[str, Any]]],
|
|
285
438
|
) -> Optional[list[types.GeneratedImage]]:
|
|
286
439
|
if not predictions:
|
|
@@ -303,30 +456,31 @@ ContentType = Union[types.Content, types.ContentDict, types.PartUnionDict]
|
|
|
303
456
|
|
|
304
457
|
|
|
305
458
|
def t_content(
|
|
306
|
-
|
|
307
|
-
content: Optional[ContentType],
|
|
459
|
+
content: Union[ContentType, types.ContentDict, None],
|
|
308
460
|
) -> types.Content:
|
|
309
461
|
if content is None:
|
|
310
462
|
raise ValueError('content is required.')
|
|
311
|
-
if
|
|
312
|
-
return content
|
|
463
|
+
if _is_duck_type_of(content, types.Content):
|
|
464
|
+
return content # type: ignore[return-value]
|
|
313
465
|
if isinstance(content, dict):
|
|
314
466
|
try:
|
|
315
467
|
return types.Content.model_validate(content)
|
|
316
468
|
except pydantic.ValidationError:
|
|
317
|
-
possible_part =
|
|
469
|
+
possible_part = t_part(content) # type: ignore[arg-type]
|
|
318
470
|
return (
|
|
319
471
|
types.ModelContent(parts=[possible_part])
|
|
320
472
|
if possible_part.function_call
|
|
321
473
|
else types.UserContent(parts=[possible_part])
|
|
322
474
|
)
|
|
323
|
-
if
|
|
475
|
+
if _is_duck_type_of(content, types.File):
|
|
476
|
+
return types.UserContent(parts=[t_part(content)]) # type: ignore[arg-type]
|
|
477
|
+
if _is_duck_type_of(content, types.Part):
|
|
324
478
|
return (
|
|
325
|
-
types.ModelContent(parts=[content])
|
|
326
|
-
if content.function_call
|
|
327
|
-
else types.UserContent(parts=[content])
|
|
479
|
+
types.ModelContent(parts=[content]) # type: ignore[arg-type]
|
|
480
|
+
if content.function_call # type: ignore[union-attr]
|
|
481
|
+
else types.UserContent(parts=[content]) # type: ignore[arg-type]
|
|
328
482
|
)
|
|
329
|
-
return types.UserContent(parts=content)
|
|
483
|
+
return types.UserContent(parts=content) # type: ignore[arg-type]
|
|
330
484
|
|
|
331
485
|
|
|
332
486
|
def t_contents_for_embed(
|
|
@@ -334,9 +488,9 @@ def t_contents_for_embed(
|
|
|
334
488
|
contents: Union[list[types.Content], list[types.ContentDict], ContentType],
|
|
335
489
|
) -> Union[list[str], list[types.Content]]:
|
|
336
490
|
if isinstance(contents, list):
|
|
337
|
-
transformed_contents = [t_content(
|
|
491
|
+
transformed_contents = [t_content(content) for content in contents]
|
|
338
492
|
else:
|
|
339
|
-
transformed_contents = [t_content(
|
|
493
|
+
transformed_contents = [t_content(contents)]
|
|
340
494
|
|
|
341
495
|
if client.vertexai:
|
|
342
496
|
text_parts = []
|
|
@@ -349,16 +503,13 @@ def t_contents_for_embed(
|
|
|
349
503
|
if part.text:
|
|
350
504
|
text_parts.append(part.text)
|
|
351
505
|
else:
|
|
352
|
-
logger.warning(
|
|
353
|
-
f'Non-text part found, only returning text parts.'
|
|
354
|
-
)
|
|
506
|
+
logger.warning(f'Non-text part found, only returning text parts.')
|
|
355
507
|
return text_parts
|
|
356
508
|
else:
|
|
357
509
|
return transformed_contents
|
|
358
510
|
|
|
359
511
|
|
|
360
512
|
def t_contents(
|
|
361
|
-
client: _api_client.BaseApiClient,
|
|
362
513
|
contents: Optional[
|
|
363
514
|
Union[types.ContentListUnion, types.ContentListUnionDict, types.Content]
|
|
364
515
|
],
|
|
@@ -366,33 +517,45 @@ def t_contents(
|
|
|
366
517
|
if contents is None or (isinstance(contents, list) and not contents):
|
|
367
518
|
raise ValueError('contents are required.')
|
|
368
519
|
if not isinstance(contents, list):
|
|
369
|
-
return [t_content(
|
|
370
|
-
|
|
371
|
-
try:
|
|
372
|
-
import PIL.Image
|
|
373
|
-
|
|
374
|
-
PIL_Image = PIL.Image.Image
|
|
375
|
-
except ImportError:
|
|
376
|
-
PIL_Image = None
|
|
520
|
+
return [t_content(contents)]
|
|
377
521
|
|
|
378
522
|
result: list[types.Content] = []
|
|
379
523
|
accumulated_parts: list[types.Part] = []
|
|
380
524
|
|
|
381
|
-
def _is_part(
|
|
525
|
+
def _is_part(
|
|
526
|
+
part: Union[types.PartUnionDict, Any],
|
|
527
|
+
) -> TypeGuard[types.PartUnionDict]:
|
|
382
528
|
if (
|
|
383
529
|
isinstance(part, str)
|
|
384
|
-
or
|
|
385
|
-
or (
|
|
386
|
-
or isinstance(part, types.Part)
|
|
530
|
+
or _is_duck_type_of(part, types.File)
|
|
531
|
+
or _is_duck_type_of(part, types.Part)
|
|
387
532
|
):
|
|
388
533
|
return True
|
|
389
534
|
|
|
390
535
|
if isinstance(part, dict):
|
|
536
|
+
if not part:
|
|
537
|
+
# Empty dict should be considered as Content, not Part.
|
|
538
|
+
return False
|
|
391
539
|
try:
|
|
392
540
|
types.Part.model_validate(part)
|
|
393
541
|
return True
|
|
394
542
|
except pydantic.ValidationError:
|
|
395
|
-
|
|
543
|
+
try:
|
|
544
|
+
types.FileData.model_validate(part)
|
|
545
|
+
return True
|
|
546
|
+
except pydantic.ValidationError:
|
|
547
|
+
return False
|
|
548
|
+
|
|
549
|
+
if 'image' in part.__class__.__name__.lower():
|
|
550
|
+
try:
|
|
551
|
+
import PIL.Image
|
|
552
|
+
|
|
553
|
+
PIL_Image = PIL.Image.Image
|
|
554
|
+
except ImportError:
|
|
555
|
+
PIL_Image = None
|
|
556
|
+
|
|
557
|
+
if PIL_Image is not None and isinstance(part, PIL_Image):
|
|
558
|
+
return True
|
|
396
559
|
|
|
397
560
|
return False
|
|
398
561
|
|
|
@@ -405,7 +568,7 @@ def t_contents(
|
|
|
405
568
|
def _append_accumulated_parts_as_content(
|
|
406
569
|
result: list[types.Content],
|
|
407
570
|
accumulated_parts: list[types.Part],
|
|
408
|
-
):
|
|
571
|
+
) -> None:
|
|
409
572
|
if not accumulated_parts:
|
|
410
573
|
return
|
|
411
574
|
result.append(
|
|
@@ -419,7 +582,7 @@ def t_contents(
|
|
|
419
582
|
result: list[types.Content],
|
|
420
583
|
accumulated_parts: list[types.Part],
|
|
421
584
|
current_part: types.PartUnionDict,
|
|
422
|
-
):
|
|
585
|
+
) -> None:
|
|
423
586
|
current_part = t_part(current_part)
|
|
424
587
|
if _is_user_part(current_part) == _are_user_parts(accumulated_parts):
|
|
425
588
|
accumulated_parts.append(current_part)
|
|
@@ -435,17 +598,13 @@ def t_contents(
|
|
|
435
598
|
# append to result
|
|
436
599
|
# if list, we only accept a list of types.PartUnion
|
|
437
600
|
for content in contents:
|
|
438
|
-
if (
|
|
439
|
-
isinstance(content, types.Content)
|
|
440
|
-
# only allowed inner list is a list of types.PartUnion
|
|
441
|
-
or isinstance(content, list)
|
|
442
|
-
):
|
|
601
|
+
if _is_duck_type_of(content, types.Content) or isinstance(content, list):
|
|
443
602
|
_append_accumulated_parts_as_content(result, accumulated_parts)
|
|
444
603
|
if isinstance(content, list):
|
|
445
604
|
result.append(types.UserContent(parts=content)) # type: ignore[arg-type]
|
|
446
605
|
else:
|
|
447
|
-
result.append(content)
|
|
448
|
-
elif
|
|
606
|
+
result.append(content) # type: ignore[arg-type]
|
|
607
|
+
elif _is_part(content):
|
|
449
608
|
_handle_current_part(result, accumulated_parts, content)
|
|
450
609
|
elif isinstance(content, dict):
|
|
451
610
|
# PactDict is already handled in _is_part
|
|
@@ -458,7 +617,7 @@ def t_contents(
|
|
|
458
617
|
return result
|
|
459
618
|
|
|
460
619
|
|
|
461
|
-
def handle_null_fields(schema:
|
|
620
|
+
def handle_null_fields(schema: _common.StringDict) -> None:
|
|
462
621
|
"""Process null fields in the schema so it is compatible with OpenAPI.
|
|
463
622
|
|
|
464
623
|
The OpenAPI spec does not support 'type: 'null' in the schema. This function
|
|
@@ -517,16 +676,34 @@ def handle_null_fields(schema: dict[str, Any]):
|
|
|
517
676
|
del schema['anyOf']
|
|
518
677
|
|
|
519
678
|
|
|
679
|
+
def _raise_for_unsupported_schema_type(origin: Any) -> None:
|
|
680
|
+
"""Raises an error if the schema type is unsupported."""
|
|
681
|
+
raise ValueError(f'Unsupported schema type: {origin}')
|
|
682
|
+
|
|
683
|
+
|
|
684
|
+
def _raise_for_unsupported_mldev_properties(
|
|
685
|
+
schema: Any, client: Optional[_api_client.BaseApiClient]
|
|
686
|
+
) -> None:
|
|
687
|
+
if (
|
|
688
|
+
client
|
|
689
|
+
and not client.vertexai
|
|
690
|
+
and (
|
|
691
|
+
schema.get('additionalProperties')
|
|
692
|
+
or schema.get('additional_properties')
|
|
693
|
+
)
|
|
694
|
+
):
|
|
695
|
+
raise ValueError('additionalProperties is not supported in the Gemini API.')
|
|
696
|
+
|
|
697
|
+
|
|
520
698
|
def process_schema(
|
|
521
|
-
schema:
|
|
522
|
-
client: _api_client.BaseApiClient,
|
|
523
|
-
defs: Optional[
|
|
699
|
+
schema: _common.StringDict,
|
|
700
|
+
client: Optional[_api_client.BaseApiClient],
|
|
701
|
+
defs: Optional[_common.StringDict] = None,
|
|
524
702
|
*,
|
|
525
703
|
order_properties: bool = True,
|
|
526
|
-
):
|
|
704
|
+
) -> None:
|
|
527
705
|
"""Updates the schema and each sub-schema inplace to be API-compatible.
|
|
528
706
|
|
|
529
|
-
- Removes the `title` field from the schema if the client is not vertexai.
|
|
530
707
|
- Inlines the $defs.
|
|
531
708
|
|
|
532
709
|
Example of a schema before and after (with mldev):
|
|
@@ -570,73 +747,76 @@ def process_schema(
|
|
|
570
747
|
'items': {
|
|
571
748
|
'properties': {
|
|
572
749
|
'continent': {
|
|
573
|
-
|
|
750
|
+
'title': 'Continent',
|
|
751
|
+
'type': 'string'
|
|
574
752
|
},
|
|
575
753
|
'gdp': {
|
|
576
|
-
|
|
754
|
+
'title': 'Gdp',
|
|
755
|
+
'type': 'integer'
|
|
577
756
|
},
|
|
578
757
|
}
|
|
579
758
|
'required':['continent', 'gdp'],
|
|
759
|
+
'title': 'CountryInfo',
|
|
580
760
|
'type': 'object'
|
|
581
761
|
},
|
|
582
762
|
'type': 'array'
|
|
583
763
|
}
|
|
584
764
|
"""
|
|
585
|
-
if not client.vertexai:
|
|
586
|
-
schema.pop('title', None)
|
|
587
|
-
|
|
588
|
-
if schema.get('default') is not None:
|
|
589
|
-
raise ValueError(
|
|
590
|
-
'Default value is not supported in the response schema for the Gemini'
|
|
591
|
-
' API.'
|
|
592
|
-
)
|
|
593
|
-
|
|
594
765
|
if schema.get('title') == 'PlaceholderLiteralEnum':
|
|
595
|
-
schema
|
|
596
|
-
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
|
|
600
|
-
|
|
766
|
+
del schema['title']
|
|
767
|
+
|
|
768
|
+
_raise_for_unsupported_mldev_properties(schema, client)
|
|
769
|
+
|
|
770
|
+
# Standardize spelling for relevant schema fields. For example, if a dict is
|
|
771
|
+
# provided directly to response_schema, it may use `any_of` instead of `anyOf.
|
|
772
|
+
# Otherwise, model_json_schema() uses `anyOf`.
|
|
773
|
+
for from_name, to_name in [
|
|
774
|
+
('additional_properties', 'additionalProperties'),
|
|
775
|
+
('any_of', 'anyOf'),
|
|
776
|
+
('prefix_items', 'prefixItems'),
|
|
777
|
+
('property_ordering', 'propertyOrdering'),
|
|
778
|
+
]:
|
|
779
|
+
if (value := schema.pop(from_name, None)) is not None:
|
|
780
|
+
schema[to_name] = value
|
|
601
781
|
|
|
602
782
|
if defs is None:
|
|
603
783
|
defs = schema.pop('$defs', {})
|
|
604
784
|
for _, sub_schema in defs.items():
|
|
605
|
-
|
|
785
|
+
# We can skip the '$ref' check, because JSON schema forbids a '$ref' from
|
|
786
|
+
# directly referencing another '$ref':
|
|
787
|
+
# https://json-schema.org/understanding-json-schema/structuring#recursion
|
|
788
|
+
process_schema(
|
|
789
|
+
sub_schema, client, defs, order_properties=order_properties
|
|
790
|
+
)
|
|
606
791
|
|
|
607
792
|
handle_null_fields(schema)
|
|
608
793
|
|
|
609
794
|
# After removing null fields, Optional fields with only one possible type
|
|
610
795
|
# will have a $ref key that needs to be flattened
|
|
611
796
|
# For example: {'default': None, 'description': 'Name of the person', 'nullable': True, '$ref': '#/$defs/TestPerson'}
|
|
612
|
-
|
|
613
|
-
|
|
614
|
-
|
|
615
|
-
|
|
616
|
-
|
|
617
|
-
|
|
618
|
-
|
|
619
|
-
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
|
|
624
|
-
if ref_key is None:
|
|
625
|
-
process_schema(sub_schema, client, defs)
|
|
626
|
-
else:
|
|
627
|
-
ref = defs[ref_key.split('defs/')[-1]]
|
|
628
|
-
any_of.append(ref)
|
|
629
|
-
schema['anyOf'] = [item for item in any_of if '$ref' not in item]
|
|
797
|
+
if (ref := schema.pop('$ref', None)) is not None:
|
|
798
|
+
schema.update(defs[ref.split('defs/')[-1]])
|
|
799
|
+
|
|
800
|
+
def _recurse(sub_schema: _common.StringDict) -> _common.StringDict:
|
|
801
|
+
"""Returns the processed `sub_schema`, resolving its '$ref' if any."""
|
|
802
|
+
if (ref := sub_schema.pop('$ref', None)) is not None:
|
|
803
|
+
sub_schema = defs[ref.split('defs/')[-1]]
|
|
804
|
+
process_schema(sub_schema, client, defs, order_properties=order_properties)
|
|
805
|
+
return sub_schema
|
|
806
|
+
|
|
807
|
+
if (any_of := schema.get('anyOf')) is not None:
|
|
808
|
+
schema['anyOf'] = [_recurse(sub_schema) for sub_schema in any_of]
|
|
630
809
|
return
|
|
631
810
|
|
|
632
|
-
schema_type = schema.get('type'
|
|
811
|
+
schema_type = schema.get('type')
|
|
633
812
|
if isinstance(schema_type, Enum):
|
|
634
813
|
schema_type = schema_type.value
|
|
635
|
-
schema_type
|
|
814
|
+
if isinstance(schema_type, str):
|
|
815
|
+
schema_type = schema_type.upper()
|
|
636
816
|
|
|
637
817
|
# model_json_schema() returns a schema with a 'const' field when a Literal with one value is provided as a pydantic field
|
|
638
818
|
# For example `genre: Literal['action']` becomes: {'const': 'action', 'title': 'Genre', 'type': 'string'}
|
|
639
|
-
const = schema.get('const'
|
|
819
|
+
const = schema.get('const')
|
|
640
820
|
if const is not None:
|
|
641
821
|
if schema_type == 'STRING':
|
|
642
822
|
schema['enum'] = [const]
|
|
@@ -645,52 +825,49 @@ def process_schema(
|
|
|
645
825
|
raise ValueError('Literal values must be strings.')
|
|
646
826
|
|
|
647
827
|
if schema_type == 'OBJECT':
|
|
648
|
-
properties
|
|
649
|
-
|
|
650
|
-
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
|
|
654
|
-
|
|
655
|
-
|
|
656
|
-
|
|
657
|
-
|
|
658
|
-
|
|
659
|
-
|
|
660
|
-
|
|
661
|
-
|
|
662
|
-
and all(
|
|
663
|
-
ordering_key not in schema
|
|
664
|
-
for ordering_key in ['property_ordering', 'propertyOrdering']
|
|
665
|
-
)
|
|
666
|
-
):
|
|
667
|
-
property_names = list(properties.keys())
|
|
668
|
-
schema['property_ordering'] = property_names
|
|
828
|
+
if (properties := schema.get('properties')) is not None:
|
|
829
|
+
for name, sub_schema in list(properties.items()):
|
|
830
|
+
properties[name] = _recurse(sub_schema)
|
|
831
|
+
if (
|
|
832
|
+
len(properties.items()) > 1
|
|
833
|
+
and order_properties
|
|
834
|
+
and 'propertyOrdering' not in schema
|
|
835
|
+
):
|
|
836
|
+
schema['property_ordering'] = list(properties.keys())
|
|
837
|
+
if (additional := schema.get('additionalProperties')) is not None:
|
|
838
|
+
# It is legal to set 'additionalProperties' to a bool:
|
|
839
|
+
# https://json-schema.org/understanding-json-schema/reference/object#additionalproperties
|
|
840
|
+
if isinstance(additional, dict):
|
|
841
|
+
schema['additionalProperties'] = _recurse(additional)
|
|
669
842
|
elif schema_type == 'ARRAY':
|
|
670
|
-
|
|
671
|
-
|
|
672
|
-
|
|
673
|
-
|
|
674
|
-
if ref_key is None:
|
|
675
|
-
process_schema(sub_schema, client, defs)
|
|
676
|
-
else:
|
|
677
|
-
ref = defs[ref_key.split('defs/')[-1]]
|
|
678
|
-
process_schema(ref, client, defs)
|
|
679
|
-
schema['items'] = ref
|
|
843
|
+
if (items := schema.get('items')) is not None:
|
|
844
|
+
schema['items'] = _recurse(items)
|
|
845
|
+
if (prefixes := schema.get('prefixItems')) is not None:
|
|
846
|
+
schema['prefixItems'] = [_recurse(prefix) for prefix in prefixes]
|
|
680
847
|
|
|
681
848
|
|
|
682
849
|
def _process_enum(
|
|
683
|
-
enum: EnumMeta, client: _api_client.BaseApiClient
|
|
850
|
+
enum: EnumMeta, client: Optional[_api_client.BaseApiClient]
|
|
684
851
|
) -> types.Schema:
|
|
852
|
+
is_integer_enum = False
|
|
853
|
+
|
|
685
854
|
for member in enum: # type: ignore
|
|
686
|
-
if
|
|
855
|
+
if isinstance(member.value, int):
|
|
856
|
+
is_integer_enum = True
|
|
857
|
+
elif not isinstance(member.value, str):
|
|
687
858
|
raise TypeError(
|
|
688
|
-
f'Enum member {member.name} value must be a string, got'
|
|
859
|
+
f'Enum member {member.name} value must be a string or integer, got'
|
|
689
860
|
f' {type(member.value)}'
|
|
690
861
|
)
|
|
691
862
|
|
|
863
|
+
enum_to_process = enum
|
|
864
|
+
if is_integer_enum:
|
|
865
|
+
str_members = [str(member.value) for member in enum] # type: ignore
|
|
866
|
+
str_enum = Enum(enum.__name__, str_members, type=str) # type: ignore
|
|
867
|
+
enum_to_process = str_enum
|
|
868
|
+
|
|
692
869
|
class Placeholder(pydantic.BaseModel):
|
|
693
|
-
placeholder:
|
|
870
|
+
placeholder: enum_to_process # type: ignore[valid-type]
|
|
694
871
|
|
|
695
872
|
enum_schema = Placeholder.model_json_schema()
|
|
696
873
|
process_schema(enum_schema, client)
|
|
@@ -698,7 +875,9 @@ def _process_enum(
|
|
|
698
875
|
return types.Schema.model_validate(enum_schema)
|
|
699
876
|
|
|
700
877
|
|
|
701
|
-
def _is_type_dict_str_any(
|
|
878
|
+
def _is_type_dict_str_any(
|
|
879
|
+
origin: Union[types.SchemaUnionDict, Any],
|
|
880
|
+
) -> TypeGuard[_common.StringDict]:
|
|
702
881
|
"""Verifies the schema is of type dict[str, Any] for mypy type checking."""
|
|
703
882
|
return isinstance(origin, dict) and all(
|
|
704
883
|
isinstance(key, str) for key in origin
|
|
@@ -706,21 +885,23 @@ def _is_type_dict_str_any(origin: Union[types.SchemaUnionDict, Any]) -> TypeGuar
|
|
|
706
885
|
|
|
707
886
|
|
|
708
887
|
def t_schema(
|
|
709
|
-
client: _api_client.BaseApiClient,
|
|
888
|
+
client: Optional[_api_client.BaseApiClient],
|
|
889
|
+
origin: Union[types.SchemaUnionDict, Any],
|
|
710
890
|
) -> Optional[types.Schema]:
|
|
711
891
|
if not origin:
|
|
712
892
|
return None
|
|
713
893
|
if isinstance(origin, dict) and _is_type_dict_str_any(origin):
|
|
714
|
-
process_schema(origin, client
|
|
894
|
+
process_schema(origin, client)
|
|
715
895
|
return types.Schema.model_validate(origin)
|
|
716
896
|
if isinstance(origin, EnumMeta):
|
|
717
897
|
return _process_enum(origin, client)
|
|
718
|
-
if
|
|
719
|
-
if dict(origin) == dict(types.Schema()):
|
|
720
|
-
# response_schema value was coerced to an empty Schema instance because
|
|
721
|
-
|
|
722
|
-
|
|
723
|
-
|
|
898
|
+
if _is_duck_type_of(origin, types.Schema):
|
|
899
|
+
if dict(origin) == dict(types.Schema()): # type: ignore [arg-type]
|
|
900
|
+
# response_schema value was coerced to an empty Schema instance because
|
|
901
|
+
# it did not adhere to the Schema field annotation
|
|
902
|
+
_raise_for_unsupported_schema_type(origin)
|
|
903
|
+
schema = origin.model_dump(exclude_unset=True) # type: ignore[union-attr]
|
|
904
|
+
process_schema(schema, client)
|
|
724
905
|
return types.Schema.model_validate(schema)
|
|
725
906
|
|
|
726
907
|
if (
|
|
@@ -752,40 +933,43 @@ def t_schema(
|
|
|
752
933
|
|
|
753
934
|
|
|
754
935
|
def t_speech_config(
|
|
755
|
-
_: _api_client.BaseApiClient,
|
|
756
936
|
origin: Union[types.SpeechConfigUnionDict, Any],
|
|
757
937
|
) -> Optional[types.SpeechConfig]:
|
|
758
938
|
if not origin:
|
|
759
939
|
return None
|
|
760
|
-
if
|
|
761
|
-
return origin
|
|
940
|
+
if _is_duck_type_of(origin, types.SpeechConfig):
|
|
941
|
+
return origin # type: ignore[return-value]
|
|
762
942
|
if isinstance(origin, str):
|
|
763
943
|
return types.SpeechConfig(
|
|
764
944
|
voice_config=types.VoiceConfig(
|
|
765
945
|
prebuilt_voice_config=types.PrebuiltVoiceConfig(voice_name=origin)
|
|
766
946
|
)
|
|
767
947
|
)
|
|
768
|
-
if (
|
|
769
|
-
|
|
770
|
-
|
|
771
|
-
and origin['voice_config'] is not None
|
|
772
|
-
and 'prebuilt_voice_config' in origin['voice_config']
|
|
773
|
-
and origin['voice_config']['prebuilt_voice_config'] is not None
|
|
774
|
-
and 'voice_name' in origin['voice_config']['prebuilt_voice_config']
|
|
775
|
-
):
|
|
776
|
-
return types.SpeechConfig(
|
|
777
|
-
voice_config=types.VoiceConfig(
|
|
778
|
-
prebuilt_voice_config=types.PrebuiltVoiceConfig(
|
|
779
|
-
voice_name=origin['voice_config']['prebuilt_voice_config'].get(
|
|
780
|
-
'voice_name'
|
|
781
|
-
)
|
|
782
|
-
)
|
|
783
|
-
)
|
|
784
|
-
)
|
|
948
|
+
if isinstance(origin, dict):
|
|
949
|
+
return types.SpeechConfig.model_validate(origin)
|
|
950
|
+
|
|
785
951
|
raise ValueError(f'Unsupported speechConfig type: {type(origin)}')
|
|
786
952
|
|
|
787
953
|
|
|
788
|
-
def
|
|
954
|
+
def t_live_speech_config(
|
|
955
|
+
origin: types.SpeechConfigOrDict,
|
|
956
|
+
) -> Optional[types.SpeechConfig]:
|
|
957
|
+
if _is_duck_type_of(origin, types.SpeechConfig):
|
|
958
|
+
speech_config = origin
|
|
959
|
+
if isinstance(origin, dict):
|
|
960
|
+
speech_config = types.SpeechConfig.model_validate(origin)
|
|
961
|
+
|
|
962
|
+
if speech_config.multi_speaker_voice_config is not None: # type: ignore[union-attr]
|
|
963
|
+
raise ValueError(
|
|
964
|
+
'multi_speaker_voice_config is not supported in the live API.'
|
|
965
|
+
)
|
|
966
|
+
|
|
967
|
+
return speech_config # type: ignore[return-value]
|
|
968
|
+
|
|
969
|
+
|
|
970
|
+
def t_tool(
|
|
971
|
+
client: _api_client.BaseApiClient, origin: Any
|
|
972
|
+
) -> Optional[Union[types.Tool, Any]]:
|
|
789
973
|
if not origin:
|
|
790
974
|
return None
|
|
791
975
|
if inspect.isfunction(origin) or inspect.ismethod(origin):
|
|
@@ -796,11 +980,14 @@ def t_tool(client: _api_client.BaseApiClient, origin) -> Optional[types.Tool]:
|
|
|
796
980
|
)
|
|
797
981
|
]
|
|
798
982
|
)
|
|
983
|
+
elif McpTool is not None and _is_duck_type_of(origin, McpTool):
|
|
984
|
+
return mcp_to_gemini_tool(origin)
|
|
985
|
+
elif isinstance(origin, dict):
|
|
986
|
+
return types.Tool.model_validate(origin)
|
|
799
987
|
else:
|
|
800
988
|
return origin
|
|
801
989
|
|
|
802
990
|
|
|
803
|
-
# Only support functions now.
|
|
804
991
|
def t_tools(
|
|
805
992
|
client: _api_client.BaseApiClient, origin: list[Any]
|
|
806
993
|
) -> list[types.Tool]:
|
|
@@ -826,46 +1013,136 @@ def t_tools(
|
|
|
826
1013
|
return tools
|
|
827
1014
|
|
|
828
1015
|
|
|
829
|
-
def t_cached_content_name(client: _api_client.BaseApiClient, name: str):
|
|
1016
|
+
def t_cached_content_name(client: _api_client.BaseApiClient, name: str) -> str:
|
|
830
1017
|
return _resource_name(client, name, collection_identifier='cachedContents')
|
|
831
1018
|
|
|
832
1019
|
|
|
833
|
-
def t_batch_job_source(
|
|
834
|
-
|
|
835
|
-
|
|
836
|
-
|
|
837
|
-
|
|
838
|
-
)
|
|
839
|
-
|
|
840
|
-
|
|
841
|
-
|
|
842
|
-
bigquery_uri=src,
|
|
1020
|
+
def t_batch_job_source(
|
|
1021
|
+
client: _api_client.BaseApiClient,
|
|
1022
|
+
src: types.BatchJobSourceUnionDict,
|
|
1023
|
+
) -> types.BatchJobSource:
|
|
1024
|
+
if isinstance(src, dict):
|
|
1025
|
+
src = types.BatchJobSource(**src)
|
|
1026
|
+
if _is_duck_type_of(src, types.BatchJobSource):
|
|
1027
|
+
vertex_sources = sum(
|
|
1028
|
+
[src.gcs_uri is not None, src.bigquery_uri is not None] # type: ignore[union-attr]
|
|
843
1029
|
)
|
|
844
|
-
|
|
845
|
-
|
|
1030
|
+
mldev_sources = sum([
|
|
1031
|
+
src.inlined_requests is not None, # type: ignore[union-attr]
|
|
1032
|
+
src.file_name is not None, # type: ignore[union-attr]
|
|
1033
|
+
])
|
|
1034
|
+
if client.vertexai:
|
|
1035
|
+
if mldev_sources or vertex_sources != 1:
|
|
1036
|
+
raise ValueError(
|
|
1037
|
+
'Exactly one of `gcs_uri` or `bigquery_uri` must be set, other '
|
|
1038
|
+
'sources are not supported in Vertex AI.'
|
|
1039
|
+
)
|
|
1040
|
+
else:
|
|
1041
|
+
if vertex_sources or mldev_sources != 1:
|
|
1042
|
+
raise ValueError(
|
|
1043
|
+
'Exactly one of `inlined_requests`, `file_name`, '
|
|
1044
|
+
'`inlined_embed_content_requests`, or `embed_content_file_name` '
|
|
1045
|
+
'must be set, other sources are not supported in Gemini API.'
|
|
1046
|
+
)
|
|
1047
|
+
return src # type: ignore[return-value]
|
|
1048
|
+
|
|
1049
|
+
elif isinstance(src, list):
|
|
1050
|
+
return types.BatchJobSource(inlined_requests=src)
|
|
1051
|
+
elif isinstance(src, str):
|
|
1052
|
+
if src.startswith('gs://'):
|
|
1053
|
+
return types.BatchJobSource(
|
|
1054
|
+
format='jsonl',
|
|
1055
|
+
gcs_uri=[src],
|
|
1056
|
+
)
|
|
1057
|
+
elif src.startswith('bq://'):
|
|
1058
|
+
return types.BatchJobSource(
|
|
1059
|
+
format='bigquery',
|
|
1060
|
+
bigquery_uri=src,
|
|
1061
|
+
)
|
|
1062
|
+
elif src.startswith('files/'):
|
|
1063
|
+
return types.BatchJobSource(
|
|
1064
|
+
file_name=src,
|
|
1065
|
+
)
|
|
846
1066
|
|
|
1067
|
+
raise ValueError(f'Unsupported source: {src}')
|
|
847
1068
|
|
|
848
|
-
|
|
849
|
-
|
|
850
|
-
|
|
851
|
-
|
|
852
|
-
|
|
853
|
-
|
|
854
|
-
|
|
855
|
-
|
|
856
|
-
|
|
857
|
-
|
|
858
|
-
|
|
1069
|
+
|
|
1070
|
+
def t_embedding_batch_job_source(
|
|
1071
|
+
client: _api_client.BaseApiClient,
|
|
1072
|
+
src: types.EmbeddingsBatchJobSourceOrDict,
|
|
1073
|
+
) -> types.EmbeddingsBatchJobSource:
|
|
1074
|
+
if isinstance(src, dict):
|
|
1075
|
+
src = types.EmbeddingsBatchJobSource(**src)
|
|
1076
|
+
|
|
1077
|
+
if _is_duck_type_of(src, types.EmbeddingsBatchJobSource):
|
|
1078
|
+
mldev_sources = sum([
|
|
1079
|
+
src.inlined_requests is not None,
|
|
1080
|
+
src.file_name is not None,
|
|
1081
|
+
])
|
|
1082
|
+
if mldev_sources != 1:
|
|
1083
|
+
raise ValueError(
|
|
1084
|
+
'Exactly one of `inlined_requests`, `file_name`, '
|
|
1085
|
+
'`inlined_embed_content_requests`, or `embed_content_file_name` '
|
|
1086
|
+
'must be set, other sources are not supported in Gemini API.'
|
|
1087
|
+
)
|
|
1088
|
+
return src
|
|
1089
|
+
else:
|
|
1090
|
+
raise ValueError(f'Unsupported source type: {type(src)}')
|
|
1091
|
+
|
|
1092
|
+
|
|
1093
|
+
def t_batch_job_destination(
|
|
1094
|
+
dest: Union[str, types.BatchJobDestinationOrDict],
|
|
1095
|
+
) -> types.BatchJobDestination:
|
|
1096
|
+
if isinstance(dest, dict):
|
|
1097
|
+
dest = types.BatchJobDestination(**dest)
|
|
1098
|
+
return dest
|
|
1099
|
+
elif isinstance(dest, str):
|
|
1100
|
+
if dest.startswith('gs://'):
|
|
1101
|
+
return types.BatchJobDestination(
|
|
1102
|
+
format='jsonl',
|
|
1103
|
+
gcs_uri=dest,
|
|
1104
|
+
)
|
|
1105
|
+
elif dest.startswith('bq://'):
|
|
1106
|
+
return types.BatchJobDestination(
|
|
1107
|
+
format='bigquery',
|
|
1108
|
+
bigquery_uri=dest,
|
|
1109
|
+
)
|
|
1110
|
+
else:
|
|
1111
|
+
raise ValueError(f'Unsupported destination: {dest}')
|
|
1112
|
+
elif _is_duck_type_of(dest, types.BatchJobDestination):
|
|
1113
|
+
return dest
|
|
859
1114
|
else:
|
|
860
1115
|
raise ValueError(f'Unsupported destination: {dest}')
|
|
861
1116
|
|
|
862
1117
|
|
|
863
|
-
def
|
|
1118
|
+
def t_recv_batch_job_destination(dest: dict[str, Any]) -> dict[str, Any]:
|
|
1119
|
+
# Rename inlinedResponses if it looks like an embedding response.
|
|
1120
|
+
inline_responses = dest.get('inlinedResponses', {}).get(
|
|
1121
|
+
'inlinedResponses', []
|
|
1122
|
+
)
|
|
1123
|
+
if not inline_responses:
|
|
1124
|
+
return dest
|
|
1125
|
+
for response in inline_responses:
|
|
1126
|
+
inner_response = response.get('response', {})
|
|
1127
|
+
if not inner_response:
|
|
1128
|
+
continue
|
|
1129
|
+
if 'embedding' in inner_response:
|
|
1130
|
+
dest['inlinedEmbedContentResponses'] = dest.pop('inlinedResponses')
|
|
1131
|
+
break
|
|
1132
|
+
return dest
|
|
1133
|
+
|
|
1134
|
+
|
|
1135
|
+
def t_batch_job_name(client: _api_client.BaseApiClient, name: str) -> str:
|
|
864
1136
|
if not client.vertexai:
|
|
865
|
-
|
|
1137
|
+
mldev_pattern = r'batches/[^/]+$'
|
|
1138
|
+
if re.match(mldev_pattern, name):
|
|
1139
|
+
return name.split('/')[-1]
|
|
1140
|
+
else:
|
|
1141
|
+
raise ValueError(f'Invalid batch job name: {name}.')
|
|
1142
|
+
|
|
1143
|
+
vertex_pattern = r'^projects/[^/]+/locations/[^/]+/batchPredictionJobs/[^/]+$'
|
|
866
1144
|
|
|
867
|
-
|
|
868
|
-
if re.match(pattern, name):
|
|
1145
|
+
if re.match(vertex_pattern, name):
|
|
869
1146
|
return name.split('/')[-1]
|
|
870
1147
|
elif name.isdigit():
|
|
871
1148
|
return name
|
|
@@ -873,22 +1150,43 @@ def t_batch_job_name(client: _api_client.BaseApiClient, name: str):
|
|
|
873
1150
|
raise ValueError(f'Invalid batch job name: {name}.')
|
|
874
1151
|
|
|
875
1152
|
|
|
1153
|
+
def t_job_state(state: str) -> str:
|
|
1154
|
+
if state == 'BATCH_STATE_UNSPECIFIED':
|
|
1155
|
+
return 'JOB_STATE_UNSPECIFIED'
|
|
1156
|
+
elif state == 'BATCH_STATE_PENDING':
|
|
1157
|
+
return 'JOB_STATE_PENDING'
|
|
1158
|
+
elif state == 'BATCH_STATE_RUNNING':
|
|
1159
|
+
return 'JOB_STATE_RUNNING'
|
|
1160
|
+
elif state == 'BATCH_STATE_SUCCEEDED':
|
|
1161
|
+
return 'JOB_STATE_SUCCEEDED'
|
|
1162
|
+
elif state == 'BATCH_STATE_FAILED':
|
|
1163
|
+
return 'JOB_STATE_FAILED'
|
|
1164
|
+
elif state == 'BATCH_STATE_CANCELLED':
|
|
1165
|
+
return 'JOB_STATE_CANCELLED'
|
|
1166
|
+
elif state == 'BATCH_STATE_EXPIRED':
|
|
1167
|
+
return 'JOB_STATE_EXPIRED'
|
|
1168
|
+
else:
|
|
1169
|
+
return state
|
|
1170
|
+
|
|
1171
|
+
|
|
876
1172
|
LRO_POLLING_INITIAL_DELAY_SECONDS = 1.0
|
|
877
1173
|
LRO_POLLING_MAXIMUM_DELAY_SECONDS = 20.0
|
|
878
1174
|
LRO_POLLING_TIMEOUT_SECONDS = 900.0
|
|
879
1175
|
LRO_POLLING_MULTIPLIER = 1.5
|
|
880
1176
|
|
|
881
1177
|
|
|
882
|
-
def t_resolve_operation(
|
|
1178
|
+
def t_resolve_operation(
|
|
1179
|
+
api_client: _api_client.BaseApiClient, struct: _common.StringDict
|
|
1180
|
+
) -> Any:
|
|
883
1181
|
if (name := struct.get('name')) and '/operations/' in name:
|
|
884
|
-
operation:
|
|
1182
|
+
operation: _common.StringDict = struct
|
|
885
1183
|
total_seconds = 0.0
|
|
886
1184
|
delay_seconds = LRO_POLLING_INITIAL_DELAY_SECONDS
|
|
887
1185
|
while operation.get('done') != True:
|
|
888
1186
|
if total_seconds > LRO_POLLING_TIMEOUT_SECONDS:
|
|
889
1187
|
raise RuntimeError(f'Operation {name} timed out.\n{operation}')
|
|
890
1188
|
# TODO(b/374433890): Replace with LRO module once it's available.
|
|
891
|
-
operation = api_client.request(
|
|
1189
|
+
operation = api_client.request( # type: ignore[assignment]
|
|
892
1190
|
http_method='GET', path=name, request_dict={}
|
|
893
1191
|
)
|
|
894
1192
|
time.sleep(delay_seconds)
|
|
@@ -908,17 +1206,16 @@ def t_resolve_operation(api_client: _api_client.BaseApiClient, struct: dict):
|
|
|
908
1206
|
|
|
909
1207
|
|
|
910
1208
|
def t_file_name(
|
|
911
|
-
api_client: _api_client.BaseApiClient,
|
|
912
1209
|
name: Optional[Union[str, types.File, types.Video, types.GeneratedVideo]],
|
|
913
|
-
):
|
|
1210
|
+
) -> str:
|
|
914
1211
|
# Remove the files/ prefix since it's added to the url path.
|
|
915
|
-
if
|
|
916
|
-
name = name.name
|
|
917
|
-
elif
|
|
918
|
-
name = name.uri
|
|
919
|
-
elif
|
|
920
|
-
if name.video is not None:
|
|
921
|
-
name = name.video.uri
|
|
1212
|
+
if _is_duck_type_of(name, types.File):
|
|
1213
|
+
name = name.name # type: ignore[union-attr]
|
|
1214
|
+
elif _is_duck_type_of(name, types.Video):
|
|
1215
|
+
name = name.uri # type: ignore[union-attr]
|
|
1216
|
+
elif _is_duck_type_of(name, types.GeneratedVideo):
|
|
1217
|
+
if name.video is not None: # type: ignore[union-attr]
|
|
1218
|
+
name = name.video.uri # type: ignore[union-attr]
|
|
922
1219
|
else:
|
|
923
1220
|
name = None
|
|
924
1221
|
|
|
@@ -942,9 +1239,7 @@ def t_file_name(
|
|
|
942
1239
|
return name
|
|
943
1240
|
|
|
944
1241
|
|
|
945
|
-
def t_tuning_job_status(
|
|
946
|
-
api_client: _api_client.BaseApiClient, status: str
|
|
947
|
-
) -> Union[types.JobState, str]:
|
|
1242
|
+
def t_tuning_job_status(status: str) -> Union[types.JobState, str]:
|
|
948
1243
|
if status == 'STATE_UNSPECIFIED':
|
|
949
1244
|
return types.JobState.JOB_STATE_UNSPECIFIED
|
|
950
1245
|
elif status == 'CREATING':
|
|
@@ -960,11 +1255,116 @@ def t_tuning_job_status(
|
|
|
960
1255
|
return status
|
|
961
1256
|
|
|
962
1257
|
|
|
963
|
-
|
|
964
|
-
|
|
965
|
-
|
|
966
|
-
|
|
967
|
-
|
|
968
|
-
|
|
969
|
-
|
|
970
|
-
|
|
1258
|
+
def t_content_strict(content: types.ContentOrDict) -> types.Content:
|
|
1259
|
+
if isinstance(content, dict):
|
|
1260
|
+
return types.Content.model_validate(content)
|
|
1261
|
+
elif _is_duck_type_of(content, types.Content):
|
|
1262
|
+
return content
|
|
1263
|
+
else:
|
|
1264
|
+
raise ValueError(
|
|
1265
|
+
f'Could not convert input (type "{type(content)}") to `types.Content`'
|
|
1266
|
+
)
|
|
1267
|
+
|
|
1268
|
+
|
|
1269
|
+
def t_contents_strict(
|
|
1270
|
+
contents: Union[Sequence[types.ContentOrDict], types.ContentOrDict],
|
|
1271
|
+
) -> list[types.Content]:
|
|
1272
|
+
if isinstance(contents, Sequence):
|
|
1273
|
+
return [t_content_strict(content) for content in contents]
|
|
1274
|
+
else:
|
|
1275
|
+
return [t_content_strict(contents)]
|
|
1276
|
+
|
|
1277
|
+
|
|
1278
|
+
def t_client_content(
|
|
1279
|
+
turns: Optional[
|
|
1280
|
+
Union[Sequence[types.ContentOrDict], types.ContentOrDict]
|
|
1281
|
+
] = None,
|
|
1282
|
+
turn_complete: bool = True,
|
|
1283
|
+
) -> types.LiveClientContent:
|
|
1284
|
+
if turns is None:
|
|
1285
|
+
return types.LiveClientContent(turn_complete=turn_complete)
|
|
1286
|
+
|
|
1287
|
+
try:
|
|
1288
|
+
return types.LiveClientContent(
|
|
1289
|
+
turns=t_contents_strict(contents=turns),
|
|
1290
|
+
turn_complete=turn_complete,
|
|
1291
|
+
)
|
|
1292
|
+
except Exception as e:
|
|
1293
|
+
raise ValueError(
|
|
1294
|
+
f'Could not convert input (type "{type(turns)}") to '
|
|
1295
|
+
'`types.LiveClientContent`'
|
|
1296
|
+
) from e
|
|
1297
|
+
|
|
1298
|
+
|
|
1299
|
+
def t_tool_response(
|
|
1300
|
+
input: Union[
|
|
1301
|
+
types.FunctionResponseOrDict,
|
|
1302
|
+
Sequence[types.FunctionResponseOrDict],
|
|
1303
|
+
],
|
|
1304
|
+
) -> types.LiveClientToolResponse:
|
|
1305
|
+
if not input:
|
|
1306
|
+
raise ValueError(f'A tool response is required, got: \n{input}')
|
|
1307
|
+
|
|
1308
|
+
try:
|
|
1309
|
+
return types.LiveClientToolResponse(
|
|
1310
|
+
function_responses=t_function_responses(function_responses=input)
|
|
1311
|
+
)
|
|
1312
|
+
except Exception as e:
|
|
1313
|
+
raise ValueError(
|
|
1314
|
+
f'Could not convert input (type "{type(input)}") to '
|
|
1315
|
+
'`types.LiveClientToolResponse`'
|
|
1316
|
+
) from e
|
|
1317
|
+
|
|
1318
|
+
|
|
1319
|
+
def t_metrics(
|
|
1320
|
+
metrics: list[types.MetricSubclass]
|
|
1321
|
+
) -> list[dict[str, Any]]:
|
|
1322
|
+
"""Prepares the metric payload for the evaluation request.
|
|
1323
|
+
|
|
1324
|
+
Args:
|
|
1325
|
+
request_dict: The dictionary containing the request details.
|
|
1326
|
+
resolved_metrics: A list of resolved metric objects.
|
|
1327
|
+
|
|
1328
|
+
Returns:
|
|
1329
|
+
The updated request dictionary with the prepared metric payload.
|
|
1330
|
+
"""
|
|
1331
|
+
metrics_payload = []
|
|
1332
|
+
|
|
1333
|
+
for metric in metrics:
|
|
1334
|
+
metric_payload_item: dict[str, Any] = {}
|
|
1335
|
+
metric_payload_item['aggregation_metrics'] = [
|
|
1336
|
+
'AVERAGE',
|
|
1337
|
+
'STANDARD_DEVIATION',
|
|
1338
|
+
]
|
|
1339
|
+
|
|
1340
|
+
metric_name = getv(metric, ['name']).lower()
|
|
1341
|
+
|
|
1342
|
+
if metric_name == 'exact_match':
|
|
1343
|
+
metric_payload_item['exact_match_spec'] = {}
|
|
1344
|
+
elif metric_name == 'bleu':
|
|
1345
|
+
metric_payload_item['bleu_spec'] = {}
|
|
1346
|
+
elif metric_name.startswith('rouge'):
|
|
1347
|
+
rouge_type = metric_name.replace("_", "")
|
|
1348
|
+
metric_payload_item['rouge_spec'] = {'rouge_type': rouge_type}
|
|
1349
|
+
|
|
1350
|
+
elif hasattr(metric, 'prompt_template') and metric.prompt_template:
|
|
1351
|
+
pointwise_spec = {'metric_prompt_template': metric.prompt_template}
|
|
1352
|
+
system_instruction = getv(
|
|
1353
|
+
metric, ['judge_model_system_instruction']
|
|
1354
|
+
)
|
|
1355
|
+
if system_instruction:
|
|
1356
|
+
pointwise_spec['system_instruction'] = system_instruction
|
|
1357
|
+
return_raw_output = getv(
|
|
1358
|
+
metric, ['return_raw_output']
|
|
1359
|
+
)
|
|
1360
|
+
if return_raw_output:
|
|
1361
|
+
pointwise_spec['custom_output_format_config'] = { # type: ignore[assignment]
|
|
1362
|
+
'return_raw_output': return_raw_output
|
|
1363
|
+
}
|
|
1364
|
+
metric_payload_item['pointwise_metric_spec'] = pointwise_spec
|
|
1365
|
+
else:
|
|
1366
|
+
raise ValueError(
|
|
1367
|
+
'Unsupported metric type or invalid metric name:' f' {metric_name}'
|
|
1368
|
+
)
|
|
1369
|
+
metrics_payload.append(metric_payload_item)
|
|
1370
|
+
return metrics_payload
|