google-genai 1.5.0__py3-none-any.whl → 1.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google/genai/_api_client.py +93 -67
- google/genai/_automatic_function_calling_util.py +4 -14
- google/genai/_transformers.py +61 -37
- google/genai/batches.py +4 -0
- google/genai/caches.py +20 -26
- google/genai/client.py +3 -2
- google/genai/errors.py +11 -19
- google/genai/files.py +7 -7
- google/genai/live.py +276 -93
- google/genai/models.py +131 -66
- google/genai/operations.py +30 -2
- google/genai/pagers.py +3 -5
- google/genai/tunings.py +31 -21
- google/genai/types.py +88 -33
- google/genai/version.py +1 -1
- {google_genai-1.5.0.dist-info → google_genai-1.6.0.dist-info}/METADATA +194 -25
- google_genai-1.6.0.dist-info/RECORD +27 -0
- {google_genai-1.5.0.dist-info → google_genai-1.6.0.dist-info}/WHEEL +1 -1
- google_genai-1.5.0.dist-info/RECORD +0 -27
- {google_genai-1.5.0.dist-info → google_genai-1.6.0.dist-info}/LICENSE +0 -0
- {google_genai-1.5.0.dist-info → google_genai-1.6.0.dist-info}/top_level.txt +0 -0
google/genai/_api_client.py
CHANGED
@@ -29,16 +29,14 @@ import json
|
|
29
29
|
import logging
|
30
30
|
import os
|
31
31
|
import sys
|
32
|
-
from typing import Any, AsyncIterator, Optional, Tuple,
|
32
|
+
from typing import Any, AsyncIterator, Optional, Tuple, Union
|
33
33
|
from urllib.parse import urlparse, urlunparse
|
34
34
|
import google.auth
|
35
35
|
import google.auth.credentials
|
36
36
|
from google.auth.credentials import Credentials
|
37
|
-
from google.auth.transport.requests import AuthorizedSession
|
38
37
|
from google.auth.transport.requests import Request
|
39
38
|
import httpx
|
40
|
-
from pydantic import BaseModel,
|
41
|
-
import requests
|
39
|
+
from pydantic import BaseModel, Field, ValidationError
|
42
40
|
from . import _common
|
43
41
|
from . import errors
|
44
42
|
from . import version
|
@@ -88,7 +86,8 @@ def _patch_http_options(
|
|
88
86
|
copy_option[patch_key].update(patch_value)
|
89
87
|
elif patch_value is not None: # Accept empty values.
|
90
88
|
copy_option[patch_key] = patch_value
|
91
|
-
|
89
|
+
if copy_option['headers']:
|
90
|
+
_append_library_version_headers(copy_option['headers'])
|
92
91
|
return copy_option
|
93
92
|
|
94
93
|
|
@@ -103,7 +102,7 @@ def _join_url_path(base_url: str, path: str) -> str:
|
|
103
102
|
return urlunparse(parsed_base._replace(path=base_path + '/' + path))
|
104
103
|
|
105
104
|
|
106
|
-
def _load_auth(*, project: Union[str, None]) ->
|
105
|
+
def _load_auth(*, project: Union[str, None]) -> Tuple[Credentials, str]:
|
107
106
|
"""Loads google auth credentials and project id."""
|
108
107
|
credentials, loaded_project_id = google.auth.default(
|
109
108
|
scopes=['https://www.googleapis.com/auth/cloud-platform'],
|
@@ -273,7 +272,9 @@ class BaseApiClient:
|
|
273
272
|
validated_http_options: dict[str, Any]
|
274
273
|
if isinstance(http_options, dict):
|
275
274
|
try:
|
276
|
-
validated_http_options = HttpOptions.model_validate(
|
275
|
+
validated_http_options = HttpOptions.model_validate(
|
276
|
+
http_options
|
277
|
+
).model_dump()
|
277
278
|
except ValidationError as e:
|
278
279
|
raise ValueError(f'Invalid http_options: {e}')
|
279
280
|
elif isinstance(http_options, HttpOptions):
|
@@ -359,7 +360,9 @@ class BaseApiClient:
|
|
359
360
|
self._http_options['headers']['x-goog-api-key'] = self.api_key
|
360
361
|
# Update the http options with the user provided http options.
|
361
362
|
if http_options:
|
362
|
-
self._http_options = _patch_http_options(
|
363
|
+
self._http_options = _patch_http_options(
|
364
|
+
self._http_options, validated_http_options
|
365
|
+
)
|
363
366
|
else:
|
364
367
|
_append_library_version_headers(self._http_options['headers'])
|
365
368
|
|
@@ -367,8 +370,27 @@ class BaseApiClient:
|
|
367
370
|
url_parts = urlparse(self._http_options['base_url'])
|
368
371
|
return url_parts._replace(scheme='wss').geturl()
|
369
372
|
|
370
|
-
|
373
|
+
def _access_token(self) -> str:
|
371
374
|
"""Retrieves the access token for the credentials."""
|
375
|
+
if not self._credentials:
|
376
|
+
self._credentials, project = _load_auth(project=self.project)
|
377
|
+
if not self.project:
|
378
|
+
self.project = project
|
379
|
+
|
380
|
+
if self._credentials:
|
381
|
+
if (
|
382
|
+
self._credentials.expired or not self._credentials.token
|
383
|
+
):
|
384
|
+
# Only refresh when it needs to. Default expiration is 3600 seconds.
|
385
|
+
_refresh_auth(self._credentials)
|
386
|
+
if not self._credentials.token:
|
387
|
+
raise RuntimeError('Could not resolve API token from the environment')
|
388
|
+
return self._credentials.token
|
389
|
+
else:
|
390
|
+
raise RuntimeError('Could not resolve API token from the environment')
|
391
|
+
|
392
|
+
async def _async_access_token(self) -> str:
|
393
|
+
"""Retrieves the access token for the credentials asynchronously."""
|
372
394
|
if not self._credentials:
|
373
395
|
async with self._auth_lock:
|
374
396
|
# This ensures that only one coroutine can execute the auth logic at a
|
@@ -437,8 +459,8 @@ class BaseApiClient:
|
|
437
459
|
):
|
438
460
|
path = f'projects/{self.project}/locations/{self.location}/' + path
|
439
461
|
url = _join_url_path(
|
440
|
-
patched_http_options
|
441
|
-
patched_http_options
|
462
|
+
patched_http_options.get('base_url', ''),
|
463
|
+
patched_http_options.get('api_version', '') + '/' + path,
|
442
464
|
)
|
443
465
|
|
444
466
|
timeout_in_seconds: Optional[Union[float, int]] = patched_http_options.get(
|
@@ -464,59 +486,56 @@ class BaseApiClient:
|
|
464
486
|
http_request: HttpRequest,
|
465
487
|
stream: bool = False,
|
466
488
|
) -> HttpResponse:
|
489
|
+
data: Optional[Union[str, bytes]] = None
|
467
490
|
if self.vertexai and not self.api_key:
|
468
|
-
|
469
|
-
|
470
|
-
|
491
|
+
http_request.headers['Authorization'] = (
|
492
|
+
f'Bearer {self._access_token()}'
|
493
|
+
)
|
494
|
+
if self._credentials and self._credentials.quota_project_id:
|
471
495
|
http_request.headers['x-goog-user-project'] = (
|
472
496
|
self._credentials.quota_project_id
|
473
497
|
)
|
474
|
-
|
475
|
-
|
476
|
-
|
477
|
-
|
478
|
-
http_request.
|
498
|
+
data = json.dumps(http_request.data)
|
499
|
+
else:
|
500
|
+
if http_request.data:
|
501
|
+
if not isinstance(http_request.data, bytes):
|
502
|
+
data = json.dumps(http_request.data)
|
503
|
+
else:
|
504
|
+
data = http_request.data
|
505
|
+
|
506
|
+
if stream:
|
507
|
+
client = httpx.Client()
|
508
|
+
httpx_request = client.build_request(
|
509
|
+
method=http_request.method,
|
510
|
+
url=http_request.url,
|
511
|
+
content=data,
|
479
512
|
headers=http_request.headers,
|
480
|
-
data=json.dumps(http_request.data) if http_request.data else None,
|
481
513
|
timeout=http_request.timeout,
|
482
514
|
)
|
515
|
+
response = client.send(httpx_request, stream=stream)
|
483
516
|
errors.APIError.raise_for_response(response)
|
484
517
|
return HttpResponse(
|
485
518
|
response.headers, response if stream else [response.text]
|
486
519
|
)
|
487
520
|
else:
|
488
|
-
|
489
|
-
|
490
|
-
|
491
|
-
|
492
|
-
|
493
|
-
|
494
|
-
|
495
|
-
|
496
|
-
|
497
|
-
|
498
|
-
|
499
|
-
|
500
|
-
data = http_request.data
|
501
|
-
|
502
|
-
http_session = requests.Session()
|
503
|
-
response = http_session.request(
|
504
|
-
method=http_request.method,
|
505
|
-
url=http_request.url,
|
506
|
-
headers=http_request.headers,
|
507
|
-
data=data,
|
508
|
-
timeout=http_request.timeout,
|
509
|
-
stream=stream,
|
510
|
-
)
|
511
|
-
errors.APIError.raise_for_response(response)
|
512
|
-
return HttpResponse(
|
513
|
-
response.headers, response if stream else [response.text]
|
514
|
-
)
|
521
|
+
with httpx.Client() as client:
|
522
|
+
response = client.request(
|
523
|
+
method=http_request.method,
|
524
|
+
url=http_request.url,
|
525
|
+
headers=http_request.headers,
|
526
|
+
content=data,
|
527
|
+
timeout=http_request.timeout,
|
528
|
+
)
|
529
|
+
errors.APIError.raise_for_response(response)
|
530
|
+
return HttpResponse(
|
531
|
+
response.headers, response if stream else [response.text]
|
532
|
+
)
|
515
533
|
|
516
534
|
async def _async_request(
|
517
535
|
self, http_request: HttpRequest, stream: bool = False
|
518
536
|
):
|
519
|
-
|
537
|
+
data: Optional[Union[str, bytes]] = None
|
538
|
+
if self.vertexai and not self.api_key:
|
520
539
|
http_request.headers['Authorization'] = (
|
521
540
|
f'Bearer {await self._async_access_token()}'
|
522
541
|
)
|
@@ -524,12 +543,20 @@ class BaseApiClient:
|
|
524
543
|
http_request.headers['x-goog-user-project'] = (
|
525
544
|
self._credentials.quota_project_id
|
526
545
|
)
|
546
|
+
data = json.dumps(http_request.data)
|
547
|
+
else:
|
548
|
+
if http_request.data:
|
549
|
+
if not isinstance(http_request.data, bytes):
|
550
|
+
data = json.dumps(http_request.data)
|
551
|
+
else:
|
552
|
+
data = http_request.data
|
553
|
+
|
527
554
|
if stream:
|
528
555
|
aclient = httpx.AsyncClient()
|
529
556
|
httpx_request = aclient.build_request(
|
530
557
|
method=http_request.method,
|
531
558
|
url=http_request.url,
|
532
|
-
content=
|
559
|
+
content=data,
|
533
560
|
headers=http_request.headers,
|
534
561
|
timeout=http_request.timeout,
|
535
562
|
)
|
@@ -547,7 +574,7 @@ class BaseApiClient:
|
|
547
574
|
method=http_request.method,
|
548
575
|
url=http_request.url,
|
549
576
|
headers=http_request.headers,
|
550
|
-
content=
|
577
|
+
content=data,
|
551
578
|
timeout=http_request.timeout,
|
552
579
|
)
|
553
580
|
errors.APIError.raise_for_response(response)
|
@@ -633,7 +660,7 @@ class BaseApiClient:
|
|
633
660
|
|
634
661
|
def upload_file(
|
635
662
|
self, file_path: Union[str, io.IOBase], upload_url: str, upload_size: int
|
636
|
-
) -> str:
|
663
|
+
) -> dict[str, str]:
|
637
664
|
"""Transfers a file to the given URL.
|
638
665
|
|
639
666
|
Args:
|
@@ -655,7 +682,7 @@ class BaseApiClient:
|
|
655
682
|
|
656
683
|
def _upload_fd(
|
657
684
|
self, file: io.IOBase, upload_url: str, upload_size: int
|
658
|
-
) -> str:
|
685
|
+
) -> dict[str, str]:
|
659
686
|
"""Transfers a file to the given URL.
|
660
687
|
|
661
688
|
Args:
|
@@ -689,7 +716,7 @@ class BaseApiClient:
|
|
689
716
|
data=file_chunk,
|
690
717
|
)
|
691
718
|
|
692
|
-
response = self.
|
719
|
+
response = self._request(request, stream=False)
|
693
720
|
offset += chunk_size
|
694
721
|
if response.headers['X-Goog-Upload-Status'] != 'active':
|
695
722
|
break # upload is complete or it has been interrupted.
|
@@ -732,25 +759,24 @@ class BaseApiClient:
|
|
732
759
|
else:
|
733
760
|
data = http_request.data
|
734
761
|
|
735
|
-
|
736
|
-
|
737
|
-
|
738
|
-
|
739
|
-
|
740
|
-
|
741
|
-
|
742
|
-
|
743
|
-
)
|
762
|
+
with httpx.Client(follow_redirects=True) as client:
|
763
|
+
response = client.request(
|
764
|
+
method=http_request.method,
|
765
|
+
url=http_request.url,
|
766
|
+
headers=http_request.headers,
|
767
|
+
content=data,
|
768
|
+
timeout=http_request.timeout,
|
769
|
+
)
|
744
770
|
|
745
|
-
|
746
|
-
|
771
|
+
errors.APIError.raise_for_response(response)
|
772
|
+
return HttpResponse(response.headers, byte_stream=[response.read()])
|
747
773
|
|
748
774
|
async def async_upload_file(
|
749
775
|
self,
|
750
776
|
file_path: Union[str, io.IOBase],
|
751
777
|
upload_url: str,
|
752
778
|
upload_size: int,
|
753
|
-
) -> str:
|
779
|
+
) -> dict[str, str]:
|
754
780
|
"""Transfers a file asynchronously to the given URL.
|
755
781
|
|
756
782
|
Args:
|
@@ -776,7 +802,7 @@ class BaseApiClient:
|
|
776
802
|
file: Union[io.IOBase, anyio.AsyncFile],
|
777
803
|
upload_url: str,
|
778
804
|
upload_size: int,
|
779
|
-
) -> str:
|
805
|
+
) -> dict[str, str]:
|
780
806
|
"""Transfers a file asynchronously to the given URL.
|
781
807
|
|
782
808
|
Args:
|
@@ -842,7 +868,7 @@ class BaseApiClient:
|
|
842
868
|
'get', path=path, request_dict={}, http_options=http_options
|
843
869
|
)
|
844
870
|
|
845
|
-
data: Optional[Union[str, bytes]]
|
871
|
+
data: Optional[Union[str, bytes]] = None
|
846
872
|
if http_request.data:
|
847
873
|
if not isinstance(http_request.data, bytes):
|
848
874
|
data = json.dumps(http_request.data)
|
@@ -17,7 +17,7 @@ import inspect
|
|
17
17
|
import sys
|
18
18
|
import types as builtin_types
|
19
19
|
import typing
|
20
|
-
from typing import _GenericAlias, Any, Callable, get_args, get_origin, Literal, Optional, Union
|
20
|
+
from typing import _GenericAlias, Any, Callable, get_args, get_origin, Literal, Optional, Union # type: ignore[attr-defined]
|
21
21
|
|
22
22
|
import pydantic
|
23
23
|
|
@@ -41,19 +41,11 @@ _py_builtin_type_to_schema_type = {
|
|
41
41
|
|
42
42
|
|
43
43
|
def _is_builtin_primitive_or_compound(
|
44
|
-
annotation: inspect.Parameter.annotation,
|
44
|
+
annotation: inspect.Parameter.annotation, # type: ignore[valid-type]
|
45
45
|
) -> bool:
|
46
46
|
return annotation in _py_builtin_type_to_schema_type.keys()
|
47
47
|
|
48
48
|
|
49
|
-
def _raise_for_any_of_if_mldev(schema: types.Schema):
|
50
|
-
if schema.any_of:
|
51
|
-
raise ValueError(
|
52
|
-
'AnyOf is not supported in function declaration schema for'
|
53
|
-
' the Gemini API.'
|
54
|
-
)
|
55
|
-
|
56
|
-
|
57
49
|
def _raise_for_default_if_mldev(schema: types.Schema):
|
58
50
|
if schema.default is not None:
|
59
51
|
raise ValueError(
|
@@ -64,12 +56,11 @@ def _raise_for_default_if_mldev(schema: types.Schema):
|
|
64
56
|
|
65
57
|
def _raise_if_schema_unsupported(api_option: Literal['VERTEX_AI', 'GEMINI_API'], schema: types.Schema):
|
66
58
|
if api_option == 'GEMINI_API':
|
67
|
-
_raise_for_any_of_if_mldev(schema)
|
68
59
|
_raise_for_default_if_mldev(schema)
|
69
60
|
|
70
61
|
|
71
62
|
def _is_default_value_compatible(
|
72
|
-
default_value: Any, annotation: inspect.Parameter.annotation
|
63
|
+
default_value: Any, annotation: inspect.Parameter.annotation # type: ignore[valid-type]
|
73
64
|
) -> bool:
|
74
65
|
# None type is expected to be handled external to this function
|
75
66
|
if _is_builtin_primitive_or_compound(annotation):
|
@@ -292,8 +283,7 @@ def _parse_schema_from_parameter(
|
|
292
283
|
),
|
293
284
|
func_name,
|
294
285
|
)
|
295
|
-
|
296
|
-
schema.required = _get_required_fields(schema)
|
286
|
+
schema.required = _get_required_fields(schema)
|
297
287
|
_raise_if_schema_unsupported(api_option, schema)
|
298
288
|
return schema
|
299
289
|
raise ValueError(
|
google/genai/_transformers.py
CHANGED
@@ -26,9 +26,7 @@ import sys
|
|
26
26
|
import time
|
27
27
|
import types as builtin_types
|
28
28
|
import typing
|
29
|
-
from typing import Any, GenericAlias, Optional, Union
|
30
|
-
|
31
|
-
import types as builtin_types
|
29
|
+
from typing import Any, GenericAlias, Optional, Union # type: ignore[attr-defined]
|
32
30
|
|
33
31
|
if typing.TYPE_CHECKING:
|
34
32
|
import PIL.Image
|
@@ -43,10 +41,11 @@ logger = logging.getLogger('google_genai._transformers')
|
|
43
41
|
if sys.version_info >= (3, 10):
|
44
42
|
VersionedUnionType = builtin_types.UnionType
|
45
43
|
_UNION_TYPES = (typing.Union, builtin_types.UnionType)
|
44
|
+
from typing import TypeGuard
|
46
45
|
else:
|
47
46
|
VersionedUnionType = typing._UnionGenericAlias
|
48
47
|
_UNION_TYPES = (typing.Union,)
|
49
|
-
|
48
|
+
from typing_extensions import TypeGuard
|
50
49
|
|
51
50
|
def _resource_name(
|
52
51
|
client: _api_client.BaseApiClient,
|
@@ -165,7 +164,9 @@ def t_model(client: _api_client.BaseApiClient, model: str):
|
|
165
164
|
return f'models/{model}'
|
166
165
|
|
167
166
|
|
168
|
-
def t_models_url(
|
167
|
+
def t_models_url(
|
168
|
+
api_client: _api_client.BaseApiClient, base_models: bool
|
169
|
+
) -> str:
|
169
170
|
if api_client.vertexai:
|
170
171
|
if base_models:
|
171
172
|
return 'publishers/google/models'
|
@@ -179,7 +180,8 @@ def t_models_url(api_client: _api_client.BaseApiClient, base_models: bool) -> st
|
|
179
180
|
|
180
181
|
|
181
182
|
def t_extract_models(
|
182
|
-
api_client: _api_client.BaseApiClient,
|
183
|
+
api_client: _api_client.BaseApiClient,
|
184
|
+
response: dict[str, list[types.ModelDict]],
|
183
185
|
) -> Optional[list[types.ModelDict]]:
|
184
186
|
if not response:
|
185
187
|
return []
|
@@ -240,9 +242,7 @@ def pil_to_blob(img) -> types.Blob:
|
|
240
242
|
return types.Blob(mime_type=mime_type, data=data)
|
241
243
|
|
242
244
|
|
243
|
-
def t_part(
|
244
|
-
part: Optional[types.PartUnionDict]
|
245
|
-
) -> types.Part:
|
245
|
+
def t_part(part: Optional[types.PartUnionDict]) -> types.Part:
|
246
246
|
try:
|
247
247
|
import PIL.Image
|
248
248
|
|
@@ -268,7 +268,7 @@ def t_part(
|
|
268
268
|
|
269
269
|
|
270
270
|
def t_parts(
|
271
|
-
parts: Optional[Union[list[types.PartUnionDict], types.PartUnionDict]],
|
271
|
+
parts: Optional[Union[list[types.PartUnionDict], types.PartUnionDict, list[types.Part]]],
|
272
272
|
) -> list[types.Part]:
|
273
273
|
#
|
274
274
|
if parts is None or (isinstance(parts, list) and not parts):
|
@@ -332,22 +332,35 @@ def t_content(
|
|
332
332
|
def t_contents_for_embed(
|
333
333
|
client: _api_client.BaseApiClient,
|
334
334
|
contents: Union[list[types.Content], list[types.ContentDict], ContentType],
|
335
|
-
):
|
336
|
-
if
|
337
|
-
|
338
|
-
return [t_content(client, content).parts[0].text for content in contents]
|
339
|
-
elif client.vertexai:
|
340
|
-
return [t_content(client, contents).parts[0].text]
|
341
|
-
elif isinstance(contents, list):
|
342
|
-
return [t_content(client, content) for content in contents]
|
335
|
+
) -> Union[list[str], list[types.Content]]:
|
336
|
+
if isinstance(contents, list):
|
337
|
+
transformed_contents = [t_content(client, content) for content in contents]
|
343
338
|
else:
|
344
|
-
|
339
|
+
transformed_contents = [t_content(client, contents)]
|
340
|
+
|
341
|
+
if client.vertexai:
|
342
|
+
text_parts = []
|
343
|
+
for content in transformed_contents:
|
344
|
+
if content is not None:
|
345
|
+
if isinstance(content, dict):
|
346
|
+
content = types.Content.model_validate(content)
|
347
|
+
if content.parts is not None:
|
348
|
+
for part in content.parts:
|
349
|
+
if part.text:
|
350
|
+
text_parts.append(part.text)
|
351
|
+
else:
|
352
|
+
logger.warning(
|
353
|
+
f'Non-text part found, only returning text parts.'
|
354
|
+
)
|
355
|
+
return text_parts
|
356
|
+
else:
|
357
|
+
return transformed_contents
|
345
358
|
|
346
359
|
|
347
360
|
def t_contents(
|
348
361
|
client: _api_client.BaseApiClient,
|
349
362
|
contents: Optional[
|
350
|
-
Union[types.ContentListUnion, types.ContentListUnionDict]
|
363
|
+
Union[types.ContentListUnion, types.ContentListUnionDict, types.Content]
|
351
364
|
],
|
352
365
|
) -> list[types.Content]:
|
353
366
|
if contents is None or (isinstance(contents, list) and not contents):
|
@@ -365,7 +378,7 @@ def t_contents(
|
|
365
378
|
result: list[types.Content] = []
|
366
379
|
accumulated_parts: list[types.Part] = []
|
367
380
|
|
368
|
-
def _is_part(part: types.PartUnionDict) ->
|
381
|
+
def _is_part(part: Union[types.PartUnionDict, Any]) -> TypeGuard[types.PartUnionDict]:
|
369
382
|
if (
|
370
383
|
isinstance(part, str)
|
371
384
|
or isinstance(part, types.File)
|
@@ -429,11 +442,11 @@ def t_contents(
|
|
429
442
|
):
|
430
443
|
_append_accumulated_parts_as_content(result, accumulated_parts)
|
431
444
|
if isinstance(content, list):
|
432
|
-
result.append(types.UserContent(parts=content))
|
445
|
+
result.append(types.UserContent(parts=content)) # type: ignore[arg-type]
|
433
446
|
else:
|
434
447
|
result.append(content)
|
435
|
-
elif (_is_part(content)):
|
436
|
-
_handle_current_part(result, accumulated_parts, content)
|
448
|
+
elif (_is_part(content)):
|
449
|
+
_handle_current_part(result, accumulated_parts, content)
|
437
450
|
elif isinstance(content, dict):
|
438
451
|
# PactDict is already handled in _is_part
|
439
452
|
result.append(types.Content.model_validate(content))
|
@@ -499,7 +512,7 @@ def handle_null_fields(schema: dict[str, Any]):
|
|
499
512
|
schema['anyOf'].remove({'type': 'null'})
|
500
513
|
if len(schema['anyOf']) == 1:
|
501
514
|
# If there is only one type left after removing null, remove the anyOf field.
|
502
|
-
for key,val in schema['anyOf'][0].items():
|
515
|
+
for key, val in schema['anyOf'][0].items():
|
503
516
|
schema[key] = val
|
504
517
|
del schema['anyOf']
|
505
518
|
|
@@ -574,7 +587,8 @@ def process_schema(
|
|
574
587
|
|
575
588
|
if schema.get('default') is not None:
|
576
589
|
raise ValueError(
|
577
|
-
'Default value is not supported in the response schema for the Gemini
|
590
|
+
'Default value is not supported in the response schema for the Gemini'
|
591
|
+
' API.'
|
578
592
|
)
|
579
593
|
|
580
594
|
if schema.get('title') == 'PlaceholderLiteralEnum':
|
@@ -604,10 +618,6 @@ def process_schema(
|
|
604
618
|
|
605
619
|
any_of = schema.get('anyOf', None)
|
606
620
|
if any_of is not None:
|
607
|
-
if client and not client.vertexai:
|
608
|
-
raise ValueError(
|
609
|
-
'AnyOf is not supported in the response schema for the Gemini API.'
|
610
|
-
)
|
611
621
|
for sub_schema in any_of:
|
612
622
|
# $ref is present in any_of if the schema is a union of Pydantic classes
|
613
623
|
ref_key = sub_schema.get('$ref', None)
|
@@ -670,7 +680,7 @@ def process_schema(
|
|
670
680
|
|
671
681
|
|
672
682
|
def _process_enum(
|
673
|
-
enum: EnumMeta, client:
|
683
|
+
enum: EnumMeta, client: _api_client.BaseApiClient
|
674
684
|
) -> types.Schema:
|
675
685
|
for member in enum: # type: ignore
|
676
686
|
if not isinstance(member.value, str):
|
@@ -680,7 +690,7 @@ def _process_enum(
|
|
680
690
|
)
|
681
691
|
|
682
692
|
class Placeholder(pydantic.BaseModel):
|
683
|
-
placeholder: enum
|
693
|
+
placeholder: enum # type: ignore[valid-type]
|
684
694
|
|
685
695
|
enum_schema = Placeholder.model_json_schema()
|
686
696
|
process_schema(enum_schema, client)
|
@@ -688,12 +698,19 @@ def _process_enum(
|
|
688
698
|
return types.Schema.model_validate(enum_schema)
|
689
699
|
|
690
700
|
|
701
|
+
def _is_type_dict_str_any(origin: Union[types.SchemaUnionDict, Any]) -> TypeGuard[dict[str, Any]]:
|
702
|
+
"""Verifies the schema is of type dict[str, Any] for mypy type checking."""
|
703
|
+
return isinstance(origin, dict) and all(
|
704
|
+
isinstance(key, str) for key in origin
|
705
|
+
)
|
706
|
+
|
707
|
+
|
691
708
|
def t_schema(
|
692
709
|
client: _api_client.BaseApiClient, origin: Union[types.SchemaUnionDict, Any]
|
693
710
|
) -> Optional[types.Schema]:
|
694
711
|
if not origin:
|
695
712
|
return None
|
696
|
-
if isinstance(origin, dict):
|
713
|
+
if isinstance(origin, dict) and _is_type_dict_str_any(origin):
|
697
714
|
process_schema(origin, client, order_properties=False)
|
698
715
|
return types.Schema.model_validate(origin)
|
699
716
|
if isinstance(origin, EnumMeta):
|
@@ -724,7 +741,7 @@ def t_schema(
|
|
724
741
|
):
|
725
742
|
|
726
743
|
class Placeholder(pydantic.BaseModel):
|
727
|
-
placeholder: origin
|
744
|
+
placeholder: origin # type: ignore[valid-type]
|
728
745
|
|
729
746
|
schema = Placeholder.model_json_schema()
|
730
747
|
process_schema(schema, client)
|
@@ -735,7 +752,8 @@ def t_schema(
|
|
735
752
|
|
736
753
|
|
737
754
|
def t_speech_config(
|
738
|
-
_: _api_client.BaseApiClient,
|
755
|
+
_: _api_client.BaseApiClient,
|
756
|
+
origin: Union[types.SpeechConfigUnionDict, Any],
|
739
757
|
) -> Optional[types.SpeechConfig]:
|
740
758
|
if not origin:
|
741
759
|
return None
|
@@ -794,7 +812,10 @@ def t_tools(
|
|
794
812
|
transformed_tool = t_tool(client, tool)
|
795
813
|
# All functions should be merged into one tool.
|
796
814
|
if transformed_tool is not None:
|
797
|
-
if
|
815
|
+
if (
|
816
|
+
transformed_tool.function_declarations
|
817
|
+
and function_tool.function_declarations is not None
|
818
|
+
):
|
798
819
|
function_tool.function_declarations += (
|
799
820
|
transformed_tool.function_declarations
|
800
821
|
)
|
@@ -896,7 +917,10 @@ def t_file_name(
|
|
896
917
|
elif isinstance(name, types.Video):
|
897
918
|
name = name.uri
|
898
919
|
elif isinstance(name, types.GeneratedVideo):
|
899
|
-
|
920
|
+
if name.video is not None:
|
921
|
+
name = name.video.uri
|
922
|
+
else:
|
923
|
+
name = None
|
900
924
|
|
901
925
|
if name is None:
|
902
926
|
raise ValueError('File name is required.')
|
google/genai/batches.py
CHANGED
@@ -998,6 +998,8 @@ class Batches(_api_module.BaseModule):
|
|
998
998
|
for batch_job in batch_jobs:
|
999
999
|
print(f"Batch job: {batch_job.name}, state {batch_job.state}")
|
1000
1000
|
"""
|
1001
|
+
if config is None:
|
1002
|
+
config = types.ListBatchJobsConfig()
|
1001
1003
|
return Pager(
|
1002
1004
|
'batch_jobs',
|
1003
1005
|
self._list,
|
@@ -1373,6 +1375,8 @@ class AsyncBatches(_api_module.BaseModule):
|
|
1373
1375
|
await batch_jobs_pager.next_page()
|
1374
1376
|
print(f"next page: {batch_jobs_pager.page}")
|
1375
1377
|
"""
|
1378
|
+
if config is None:
|
1379
|
+
config = types.ListBatchJobsConfig()
|
1376
1380
|
return AsyncPager(
|
1377
1381
|
'batch_jobs',
|
1378
1382
|
self._list,
|