google-genai 1.5.0__py3-none-any.whl → 1.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -29,16 +29,14 @@ import json
29
29
  import logging
30
30
  import os
31
31
  import sys
32
- from typing import Any, AsyncIterator, Optional, Tuple, TypedDict, Union
32
+ from typing import Any, AsyncIterator, Optional, Tuple, Union
33
33
  from urllib.parse import urlparse, urlunparse
34
34
  import google.auth
35
35
  import google.auth.credentials
36
36
  from google.auth.credentials import Credentials
37
- from google.auth.transport.requests import AuthorizedSession
38
37
  from google.auth.transport.requests import Request
39
38
  import httpx
40
- from pydantic import BaseModel, ConfigDict, Field, ValidationError
41
- import requests
39
+ from pydantic import BaseModel, Field, ValidationError
42
40
  from . import _common
43
41
  from . import errors
44
42
  from . import version
@@ -88,7 +86,8 @@ def _patch_http_options(
88
86
  copy_option[patch_key].update(patch_value)
89
87
  elif patch_value is not None: # Accept empty values.
90
88
  copy_option[patch_key] = patch_value
91
- _append_library_version_headers(copy_option['headers'])
89
+ if copy_option['headers']:
90
+ _append_library_version_headers(copy_option['headers'])
92
91
  return copy_option
93
92
 
94
93
 
@@ -103,7 +102,7 @@ def _join_url_path(base_url: str, path: str) -> str:
103
102
  return urlunparse(parsed_base._replace(path=base_path + '/' + path))
104
103
 
105
104
 
106
- def _load_auth(*, project: Union[str, None]) -> tuple[Credentials, str]:
105
+ def _load_auth(*, project: Union[str, None]) -> Tuple[Credentials, str]:
107
106
  """Loads google auth credentials and project id."""
108
107
  credentials, loaded_project_id = google.auth.default(
109
108
  scopes=['https://www.googleapis.com/auth/cloud-platform'],
@@ -273,7 +272,9 @@ class BaseApiClient:
273
272
  validated_http_options: dict[str, Any]
274
273
  if isinstance(http_options, dict):
275
274
  try:
276
- validated_http_options = HttpOptions.model_validate(http_options).model_dump()
275
+ validated_http_options = HttpOptions.model_validate(
276
+ http_options
277
+ ).model_dump()
277
278
  except ValidationError as e:
278
279
  raise ValueError(f'Invalid http_options: {e}')
279
280
  elif isinstance(http_options, HttpOptions):
@@ -359,7 +360,9 @@ class BaseApiClient:
359
360
  self._http_options['headers']['x-goog-api-key'] = self.api_key
360
361
  # Update the http options with the user provided http options.
361
362
  if http_options:
362
- self._http_options = _patch_http_options(self._http_options, validated_http_options)
363
+ self._http_options = _patch_http_options(
364
+ self._http_options, validated_http_options
365
+ )
363
366
  else:
364
367
  _append_library_version_headers(self._http_options['headers'])
365
368
 
@@ -367,8 +370,27 @@ class BaseApiClient:
367
370
  url_parts = urlparse(self._http_options['base_url'])
368
371
  return url_parts._replace(scheme='wss').geturl()
369
372
 
370
- async def _async_access_token(self) -> str:
373
+ def _access_token(self) -> str:
371
374
  """Retrieves the access token for the credentials."""
375
+ if not self._credentials:
376
+ self._credentials, project = _load_auth(project=self.project)
377
+ if not self.project:
378
+ self.project = project
379
+
380
+ if self._credentials:
381
+ if (
382
+ self._credentials.expired or not self._credentials.token
383
+ ):
384
+ # Only refresh when it needs to. Default expiration is 3600 seconds.
385
+ _refresh_auth(self._credentials)
386
+ if not self._credentials.token:
387
+ raise RuntimeError('Could not resolve API token from the environment')
388
+ return self._credentials.token
389
+ else:
390
+ raise RuntimeError('Could not resolve API token from the environment')
391
+
392
+ async def _async_access_token(self) -> str:
393
+ """Retrieves the access token for the credentials asynchronously."""
372
394
  if not self._credentials:
373
395
  async with self._auth_lock:
374
396
  # This ensures that only one coroutine can execute the auth logic at a
@@ -437,8 +459,8 @@ class BaseApiClient:
437
459
  ):
438
460
  path = f'projects/{self.project}/locations/{self.location}/' + path
439
461
  url = _join_url_path(
440
- patched_http_options['base_url'],
441
- patched_http_options['api_version'] + '/' + path,
462
+ patched_http_options.get('base_url', ''),
463
+ patched_http_options.get('api_version', '') + '/' + path,
442
464
  )
443
465
 
444
466
  timeout_in_seconds: Optional[Union[float, int]] = patched_http_options.get(
@@ -464,59 +486,56 @@ class BaseApiClient:
464
486
  http_request: HttpRequest,
465
487
  stream: bool = False,
466
488
  ) -> HttpResponse:
489
+ data: Optional[Union[str, bytes]] = None
467
490
  if self.vertexai and not self.api_key:
468
- if not self._credentials:
469
- self._credentials, _ = _load_auth(project=self.project)
470
- if self._credentials.quota_project_id:
491
+ http_request.headers['Authorization'] = (
492
+ f'Bearer {self._access_token()}'
493
+ )
494
+ if self._credentials and self._credentials.quota_project_id:
471
495
  http_request.headers['x-goog-user-project'] = (
472
496
  self._credentials.quota_project_id
473
497
  )
474
- authed_session = AuthorizedSession(self._credentials)
475
- authed_session.stream = stream
476
- response = authed_session.request(
477
- http_request.method.upper(),
478
- http_request.url,
498
+ data = json.dumps(http_request.data)
499
+ else:
500
+ if http_request.data:
501
+ if not isinstance(http_request.data, bytes):
502
+ data = json.dumps(http_request.data)
503
+ else:
504
+ data = http_request.data
505
+
506
+ if stream:
507
+ client = httpx.Client()
508
+ httpx_request = client.build_request(
509
+ method=http_request.method,
510
+ url=http_request.url,
511
+ content=data,
479
512
  headers=http_request.headers,
480
- data=json.dumps(http_request.data) if http_request.data else None,
481
513
  timeout=http_request.timeout,
482
514
  )
515
+ response = client.send(httpx_request, stream=stream)
483
516
  errors.APIError.raise_for_response(response)
484
517
  return HttpResponse(
485
518
  response.headers, response if stream else [response.text]
486
519
  )
487
520
  else:
488
- return self._request_unauthorized(http_request, stream)
489
-
490
- def _request_unauthorized(
491
- self,
492
- http_request: HttpRequest,
493
- stream: bool = False,
494
- ) -> HttpResponse:
495
- data: Optional[Union[str, bytes]] = None
496
- if http_request.data:
497
- if not isinstance(http_request.data, bytes):
498
- data = json.dumps(http_request.data)
499
- else:
500
- data = http_request.data
501
-
502
- http_session = requests.Session()
503
- response = http_session.request(
504
- method=http_request.method,
505
- url=http_request.url,
506
- headers=http_request.headers,
507
- data=data,
508
- timeout=http_request.timeout,
509
- stream=stream,
510
- )
511
- errors.APIError.raise_for_response(response)
512
- return HttpResponse(
513
- response.headers, response if stream else [response.text]
514
- )
521
+ with httpx.Client() as client:
522
+ response = client.request(
523
+ method=http_request.method,
524
+ url=http_request.url,
525
+ headers=http_request.headers,
526
+ content=data,
527
+ timeout=http_request.timeout,
528
+ )
529
+ errors.APIError.raise_for_response(response)
530
+ return HttpResponse(
531
+ response.headers, response if stream else [response.text]
532
+ )
515
533
 
516
534
  async def _async_request(
517
535
  self, http_request: HttpRequest, stream: bool = False
518
536
  ):
519
- if self.vertexai:
537
+ data: Optional[Union[str, bytes]] = None
538
+ if self.vertexai and not self.api_key:
520
539
  http_request.headers['Authorization'] = (
521
540
  f'Bearer {await self._async_access_token()}'
522
541
  )
@@ -524,12 +543,20 @@ class BaseApiClient:
524
543
  http_request.headers['x-goog-user-project'] = (
525
544
  self._credentials.quota_project_id
526
545
  )
546
+ data = json.dumps(http_request.data)
547
+ else:
548
+ if http_request.data:
549
+ if not isinstance(http_request.data, bytes):
550
+ data = json.dumps(http_request.data)
551
+ else:
552
+ data = http_request.data
553
+
527
554
  if stream:
528
555
  aclient = httpx.AsyncClient()
529
556
  httpx_request = aclient.build_request(
530
557
  method=http_request.method,
531
558
  url=http_request.url,
532
- content=json.dumps(http_request.data),
559
+ content=data,
533
560
  headers=http_request.headers,
534
561
  timeout=http_request.timeout,
535
562
  )
@@ -547,7 +574,7 @@ class BaseApiClient:
547
574
  method=http_request.method,
548
575
  url=http_request.url,
549
576
  headers=http_request.headers,
550
- content=json.dumps(http_request.data) if http_request.data else None,
577
+ content=data,
551
578
  timeout=http_request.timeout,
552
579
  )
553
580
  errors.APIError.raise_for_response(response)
@@ -633,7 +660,7 @@ class BaseApiClient:
633
660
 
634
661
  def upload_file(
635
662
  self, file_path: Union[str, io.IOBase], upload_url: str, upload_size: int
636
- ) -> str:
663
+ ) -> dict[str, str]:
637
664
  """Transfers a file to the given URL.
638
665
 
639
666
  Args:
@@ -655,7 +682,7 @@ class BaseApiClient:
655
682
 
656
683
  def _upload_fd(
657
684
  self, file: io.IOBase, upload_url: str, upload_size: int
658
- ) -> str:
685
+ ) -> dict[str, str]:
659
686
  """Transfers a file to the given URL.
660
687
 
661
688
  Args:
@@ -689,7 +716,7 @@ class BaseApiClient:
689
716
  data=file_chunk,
690
717
  )
691
718
 
692
- response = self._request_unauthorized(request, stream=False)
719
+ response = self._request(request, stream=False)
693
720
  offset += chunk_size
694
721
  if response.headers['X-Goog-Upload-Status'] != 'active':
695
722
  break # upload is complete or it has been interrupted.
@@ -732,25 +759,24 @@ class BaseApiClient:
732
759
  else:
733
760
  data = http_request.data
734
761
 
735
- http_session = requests.Session()
736
- response = http_session.request(
737
- method=http_request.method,
738
- url=http_request.url,
739
- headers=http_request.headers,
740
- data=data,
741
- timeout=http_request.timeout,
742
- stream=False,
743
- )
762
+ with httpx.Client(follow_redirects=True) as client:
763
+ response = client.request(
764
+ method=http_request.method,
765
+ url=http_request.url,
766
+ headers=http_request.headers,
767
+ content=data,
768
+ timeout=http_request.timeout,
769
+ )
744
770
 
745
- errors.APIError.raise_for_response(response)
746
- return HttpResponse(response.headers, byte_stream=[response.content])
771
+ errors.APIError.raise_for_response(response)
772
+ return HttpResponse(response.headers, byte_stream=[response.read()])
747
773
 
748
774
  async def async_upload_file(
749
775
  self,
750
776
  file_path: Union[str, io.IOBase],
751
777
  upload_url: str,
752
778
  upload_size: int,
753
- ) -> str:
779
+ ) -> dict[str, str]:
754
780
  """Transfers a file asynchronously to the given URL.
755
781
 
756
782
  Args:
@@ -776,7 +802,7 @@ class BaseApiClient:
776
802
  file: Union[io.IOBase, anyio.AsyncFile],
777
803
  upload_url: str,
778
804
  upload_size: int,
779
- ) -> str:
805
+ ) -> dict[str, str]:
780
806
  """Transfers a file asynchronously to the given URL.
781
807
 
782
808
  Args:
@@ -842,7 +868,7 @@ class BaseApiClient:
842
868
  'get', path=path, request_dict={}, http_options=http_options
843
869
  )
844
870
 
845
- data: Optional[Union[str, bytes]]
871
+ data: Optional[Union[str, bytes]] = None
846
872
  if http_request.data:
847
873
  if not isinstance(http_request.data, bytes):
848
874
  data = json.dumps(http_request.data)
@@ -17,7 +17,7 @@ import inspect
17
17
  import sys
18
18
  import types as builtin_types
19
19
  import typing
20
- from typing import _GenericAlias, Any, Callable, get_args, get_origin, Literal, Optional, Union
20
+ from typing import _GenericAlias, Any, Callable, get_args, get_origin, Literal, Optional, Union # type: ignore[attr-defined]
21
21
 
22
22
  import pydantic
23
23
 
@@ -41,19 +41,11 @@ _py_builtin_type_to_schema_type = {
41
41
 
42
42
 
43
43
  def _is_builtin_primitive_or_compound(
44
- annotation: inspect.Parameter.annotation,
44
+ annotation: inspect.Parameter.annotation, # type: ignore[valid-type]
45
45
  ) -> bool:
46
46
  return annotation in _py_builtin_type_to_schema_type.keys()
47
47
 
48
48
 
49
- def _raise_for_any_of_if_mldev(schema: types.Schema):
50
- if schema.any_of:
51
- raise ValueError(
52
- 'AnyOf is not supported in function declaration schema for'
53
- ' the Gemini API.'
54
- )
55
-
56
-
57
49
  def _raise_for_default_if_mldev(schema: types.Schema):
58
50
  if schema.default is not None:
59
51
  raise ValueError(
@@ -64,12 +56,11 @@ def _raise_for_default_if_mldev(schema: types.Schema):
64
56
 
65
57
  def _raise_if_schema_unsupported(api_option: Literal['VERTEX_AI', 'GEMINI_API'], schema: types.Schema):
66
58
  if api_option == 'GEMINI_API':
67
- _raise_for_any_of_if_mldev(schema)
68
59
  _raise_for_default_if_mldev(schema)
69
60
 
70
61
 
71
62
  def _is_default_value_compatible(
72
- default_value: Any, annotation: inspect.Parameter.annotation
63
+ default_value: Any, annotation: inspect.Parameter.annotation # type: ignore[valid-type]
73
64
  ) -> bool:
74
65
  # None type is expected to be handled external to this function
75
66
  if _is_builtin_primitive_or_compound(annotation):
@@ -292,8 +283,7 @@ def _parse_schema_from_parameter(
292
283
  ),
293
284
  func_name,
294
285
  )
295
- if api_option == 'VERTEX_AI':
296
- schema.required = _get_required_fields(schema)
286
+ schema.required = _get_required_fields(schema)
297
287
  _raise_if_schema_unsupported(api_option, schema)
298
288
  return schema
299
289
  raise ValueError(
@@ -26,9 +26,7 @@ import sys
26
26
  import time
27
27
  import types as builtin_types
28
28
  import typing
29
- from typing import Any, GenericAlias, Optional, Union
30
-
31
- import types as builtin_types
29
+ from typing import Any, GenericAlias, Optional, Union # type: ignore[attr-defined]
32
30
 
33
31
  if typing.TYPE_CHECKING:
34
32
  import PIL.Image
@@ -43,10 +41,11 @@ logger = logging.getLogger('google_genai._transformers')
43
41
  if sys.version_info >= (3, 10):
44
42
  VersionedUnionType = builtin_types.UnionType
45
43
  _UNION_TYPES = (typing.Union, builtin_types.UnionType)
44
+ from typing import TypeGuard
46
45
  else:
47
46
  VersionedUnionType = typing._UnionGenericAlias
48
47
  _UNION_TYPES = (typing.Union,)
49
-
48
+ from typing_extensions import TypeGuard
50
49
 
51
50
  def _resource_name(
52
51
  client: _api_client.BaseApiClient,
@@ -165,7 +164,9 @@ def t_model(client: _api_client.BaseApiClient, model: str):
165
164
  return f'models/{model}'
166
165
 
167
166
 
168
- def t_models_url(api_client: _api_client.BaseApiClient, base_models: bool) -> str:
167
+ def t_models_url(
168
+ api_client: _api_client.BaseApiClient, base_models: bool
169
+ ) -> str:
169
170
  if api_client.vertexai:
170
171
  if base_models:
171
172
  return 'publishers/google/models'
@@ -179,7 +180,8 @@ def t_models_url(api_client: _api_client.BaseApiClient, base_models: bool) -> st
179
180
 
180
181
 
181
182
  def t_extract_models(
182
- api_client: _api_client.BaseApiClient, response: dict[str, list[types.ModelDict]]
183
+ api_client: _api_client.BaseApiClient,
184
+ response: dict[str, list[types.ModelDict]],
183
185
  ) -> Optional[list[types.ModelDict]]:
184
186
  if not response:
185
187
  return []
@@ -240,9 +242,7 @@ def pil_to_blob(img) -> types.Blob:
240
242
  return types.Blob(mime_type=mime_type, data=data)
241
243
 
242
244
 
243
- def t_part(
244
- part: Optional[types.PartUnionDict]
245
- ) -> types.Part:
245
+ def t_part(part: Optional[types.PartUnionDict]) -> types.Part:
246
246
  try:
247
247
  import PIL.Image
248
248
 
@@ -268,7 +268,7 @@ def t_part(
268
268
 
269
269
 
270
270
  def t_parts(
271
- parts: Optional[Union[list[types.PartUnionDict], types.PartUnionDict]],
271
+ parts: Optional[Union[list[types.PartUnionDict], types.PartUnionDict, list[types.Part]]],
272
272
  ) -> list[types.Part]:
273
273
  #
274
274
  if parts is None or (isinstance(parts, list) and not parts):
@@ -332,22 +332,35 @@ def t_content(
332
332
  def t_contents_for_embed(
333
333
  client: _api_client.BaseApiClient,
334
334
  contents: Union[list[types.Content], list[types.ContentDict], ContentType],
335
- ):
336
- if client.vertexai and isinstance(contents, list):
337
- # TODO: Assert that only text is supported.
338
- return [t_content(client, content).parts[0].text for content in contents]
339
- elif client.vertexai:
340
- return [t_content(client, contents).parts[0].text]
341
- elif isinstance(contents, list):
342
- return [t_content(client, content) for content in contents]
335
+ ) -> Union[list[str], list[types.Content]]:
336
+ if isinstance(contents, list):
337
+ transformed_contents = [t_content(client, content) for content in contents]
343
338
  else:
344
- return [t_content(client, contents)]
339
+ transformed_contents = [t_content(client, contents)]
340
+
341
+ if client.vertexai:
342
+ text_parts = []
343
+ for content in transformed_contents:
344
+ if content is not None:
345
+ if isinstance(content, dict):
346
+ content = types.Content.model_validate(content)
347
+ if content.parts is not None:
348
+ for part in content.parts:
349
+ if part.text:
350
+ text_parts.append(part.text)
351
+ else:
352
+ logger.warning(
353
+ f'Non-text part found, only returning text parts.'
354
+ )
355
+ return text_parts
356
+ else:
357
+ return transformed_contents
345
358
 
346
359
 
347
360
  def t_contents(
348
361
  client: _api_client.BaseApiClient,
349
362
  contents: Optional[
350
- Union[types.ContentListUnion, types.ContentListUnionDict]
363
+ Union[types.ContentListUnion, types.ContentListUnionDict, types.Content]
351
364
  ],
352
365
  ) -> list[types.Content]:
353
366
  if contents is None or (isinstance(contents, list) and not contents):
@@ -365,7 +378,7 @@ def t_contents(
365
378
  result: list[types.Content] = []
366
379
  accumulated_parts: list[types.Part] = []
367
380
 
368
- def _is_part(part: types.PartUnionDict) -> bool:
381
+ def _is_part(part: Union[types.PartUnionDict, Any]) -> TypeGuard[types.PartUnionDict]:
369
382
  if (
370
383
  isinstance(part, str)
371
384
  or isinstance(part, types.File)
@@ -429,11 +442,11 @@ def t_contents(
429
442
  ):
430
443
  _append_accumulated_parts_as_content(result, accumulated_parts)
431
444
  if isinstance(content, list):
432
- result.append(types.UserContent(parts=content))
445
+ result.append(types.UserContent(parts=content)) # type: ignore[arg-type]
433
446
  else:
434
447
  result.append(content)
435
- elif (_is_part(content)): # type: ignore
436
- _handle_current_part(result, accumulated_parts, content) # type: ignore
448
+ elif (_is_part(content)):
449
+ _handle_current_part(result, accumulated_parts, content)
437
450
  elif isinstance(content, dict):
438
451
  # PactDict is already handled in _is_part
439
452
  result.append(types.Content.model_validate(content))
@@ -499,7 +512,7 @@ def handle_null_fields(schema: dict[str, Any]):
499
512
  schema['anyOf'].remove({'type': 'null'})
500
513
  if len(schema['anyOf']) == 1:
501
514
  # If there is only one type left after removing null, remove the anyOf field.
502
- for key,val in schema['anyOf'][0].items():
515
+ for key, val in schema['anyOf'][0].items():
503
516
  schema[key] = val
504
517
  del schema['anyOf']
505
518
 
@@ -574,7 +587,8 @@ def process_schema(
574
587
 
575
588
  if schema.get('default') is not None:
576
589
  raise ValueError(
577
- 'Default value is not supported in the response schema for the Gemini API.'
590
+ 'Default value is not supported in the response schema for the Gemini'
591
+ ' API.'
578
592
  )
579
593
 
580
594
  if schema.get('title') == 'PlaceholderLiteralEnum':
@@ -604,10 +618,6 @@ def process_schema(
604
618
 
605
619
  any_of = schema.get('anyOf', None)
606
620
  if any_of is not None:
607
- if client and not client.vertexai:
608
- raise ValueError(
609
- 'AnyOf is not supported in the response schema for the Gemini API.'
610
- )
611
621
  for sub_schema in any_of:
612
622
  # $ref is present in any_of if the schema is a union of Pydantic classes
613
623
  ref_key = sub_schema.get('$ref', None)
@@ -670,7 +680,7 @@ def process_schema(
670
680
 
671
681
 
672
682
  def _process_enum(
673
- enum: EnumMeta, client: Optional[_api_client.BaseApiClient] = None # type: ignore
683
+ enum: EnumMeta, client: _api_client.BaseApiClient
674
684
  ) -> types.Schema:
675
685
  for member in enum: # type: ignore
676
686
  if not isinstance(member.value, str):
@@ -680,7 +690,7 @@ def _process_enum(
680
690
  )
681
691
 
682
692
  class Placeholder(pydantic.BaseModel):
683
- placeholder: enum
693
+ placeholder: enum # type: ignore[valid-type]
684
694
 
685
695
  enum_schema = Placeholder.model_json_schema()
686
696
  process_schema(enum_schema, client)
@@ -688,12 +698,19 @@ def _process_enum(
688
698
  return types.Schema.model_validate(enum_schema)
689
699
 
690
700
 
701
+ def _is_type_dict_str_any(origin: Union[types.SchemaUnionDict, Any]) -> TypeGuard[dict[str, Any]]:
702
+ """Verifies the schema is of type dict[str, Any] for mypy type checking."""
703
+ return isinstance(origin, dict) and all(
704
+ isinstance(key, str) for key in origin
705
+ )
706
+
707
+
691
708
  def t_schema(
692
709
  client: _api_client.BaseApiClient, origin: Union[types.SchemaUnionDict, Any]
693
710
  ) -> Optional[types.Schema]:
694
711
  if not origin:
695
712
  return None
696
- if isinstance(origin, dict):
713
+ if isinstance(origin, dict) and _is_type_dict_str_any(origin):
697
714
  process_schema(origin, client, order_properties=False)
698
715
  return types.Schema.model_validate(origin)
699
716
  if isinstance(origin, EnumMeta):
@@ -724,7 +741,7 @@ def t_schema(
724
741
  ):
725
742
 
726
743
  class Placeholder(pydantic.BaseModel):
727
- placeholder: origin
744
+ placeholder: origin # type: ignore[valid-type]
728
745
 
729
746
  schema = Placeholder.model_json_schema()
730
747
  process_schema(schema, client)
@@ -735,7 +752,8 @@ def t_schema(
735
752
 
736
753
 
737
754
  def t_speech_config(
738
- _: _api_client.BaseApiClient, origin: Union[types.SpeechConfigUnionDict, Any]
755
+ _: _api_client.BaseApiClient,
756
+ origin: Union[types.SpeechConfigUnionDict, Any],
739
757
  ) -> Optional[types.SpeechConfig]:
740
758
  if not origin:
741
759
  return None
@@ -794,7 +812,10 @@ def t_tools(
794
812
  transformed_tool = t_tool(client, tool)
795
813
  # All functions should be merged into one tool.
796
814
  if transformed_tool is not None:
797
- if transformed_tool.function_declarations:
815
+ if (
816
+ transformed_tool.function_declarations
817
+ and function_tool.function_declarations is not None
818
+ ):
798
819
  function_tool.function_declarations += (
799
820
  transformed_tool.function_declarations
800
821
  )
@@ -896,7 +917,10 @@ def t_file_name(
896
917
  elif isinstance(name, types.Video):
897
918
  name = name.uri
898
919
  elif isinstance(name, types.GeneratedVideo):
899
- name = name.video.uri
920
+ if name.video is not None:
921
+ name = name.video.uri
922
+ else:
923
+ name = None
900
924
 
901
925
  if name is None:
902
926
  raise ValueError('File name is required.')
google/genai/batches.py CHANGED
@@ -998,6 +998,8 @@ class Batches(_api_module.BaseModule):
998
998
  for batch_job in batch_jobs:
999
999
  print(f"Batch job: {batch_job.name}, state {batch_job.state}")
1000
1000
  """
1001
+ if config is None:
1002
+ config = types.ListBatchJobsConfig()
1001
1003
  return Pager(
1002
1004
  'batch_jobs',
1003
1005
  self._list,
@@ -1373,6 +1375,8 @@ class AsyncBatches(_api_module.BaseModule):
1373
1375
  await batch_jobs_pager.next_page()
1374
1376
  print(f"next page: {batch_jobs_pager.page}")
1375
1377
  """
1378
+ if config is None:
1379
+ config = types.ListBatchJobsConfig()
1376
1380
  return AsyncPager(
1377
1381
  'batch_jobs',
1378
1382
  self._list,