google-genai 1.3.0__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -24,9 +24,12 @@ import logging
24
24
  import re
25
25
  import sys
26
26
  import time
27
+ import types as builtin_types
27
28
  import typing
28
29
  from typing import Any, GenericAlias, Optional, Union
29
30
 
31
+ import types as builtin_types
32
+
30
33
  if typing.TYPE_CHECKING:
31
34
  import PIL.Image
32
35
 
@@ -38,15 +41,15 @@ from . import types
38
41
  logger = logging.getLogger('google_genai._transformers')
39
42
 
40
43
  if sys.version_info >= (3, 10):
41
- VersionedUnionType = typing.types.UnionType
42
- _UNION_TYPES = (typing.Union, typing.types.UnionType)
44
+ VersionedUnionType = builtin_types.UnionType
45
+ _UNION_TYPES = (typing.Union, builtin_types.UnionType)
43
46
  else:
44
47
  VersionedUnionType = typing._UnionGenericAlias
45
48
  _UNION_TYPES = (typing.Union,)
46
49
 
47
50
 
48
51
  def _resource_name(
49
- client: _api_client.ApiClient,
52
+ client: _api_client.BaseApiClient,
50
53
  resource_name: str,
51
54
  *,
52
55
  collection_identifier: str,
@@ -138,7 +141,7 @@ def _resource_name(
138
141
  return resource_name
139
142
 
140
143
 
141
- def t_model(client: _api_client.ApiClient, model: str):
144
+ def t_model(client: _api_client.BaseApiClient, model: str):
142
145
  if not model:
143
146
  raise ValueError('model is required.')
144
147
  if client.vertexai:
@@ -162,7 +165,7 @@ def t_model(client: _api_client.ApiClient, model: str):
162
165
  return f'models/{model}'
163
166
 
164
167
 
165
- def t_models_url(api_client: _api_client.ApiClient, base_models: bool) -> str:
168
+ def t_models_url(api_client: _api_client.BaseApiClient, base_models: bool) -> str:
166
169
  if api_client.vertexai:
167
170
  if base_models:
168
171
  return 'publishers/google/models'
@@ -176,8 +179,8 @@ def t_models_url(api_client: _api_client.ApiClient, base_models: bool) -> str:
176
179
 
177
180
 
178
181
  def t_extract_models(
179
- api_client: _api_client.ApiClient, response: dict
180
- ) -> list[types.Model]:
182
+ api_client: _api_client.BaseApiClient, response: dict[str, list[types.ModelDict]]
183
+ ) -> Optional[list[types.ModelDict]]:
181
184
  if not response:
182
185
  return []
183
186
  elif response.get('models') is not None:
@@ -197,7 +200,7 @@ def t_extract_models(
197
200
  return []
198
201
 
199
202
 
200
- def t_caches_model(api_client: _api_client.ApiClient, model: str):
203
+ def t_caches_model(api_client: _api_client.BaseApiClient, model: str):
201
204
  model = t_model(api_client, model)
202
205
  if not model:
203
206
  return None
@@ -213,6 +216,7 @@ def t_caches_model(api_client: _api_client.ApiClient, model: str):
213
216
 
214
217
 
215
218
  def pil_to_blob(img) -> types.Blob:
219
+ PngImagePlugin: Optional[builtin_types.ModuleType]
216
220
  try:
217
221
  import PIL.PngImagePlugin
218
222
 
@@ -236,10 +240,9 @@ def pil_to_blob(img) -> types.Blob:
236
240
  return types.Blob(mime_type=mime_type, data=data)
237
241
 
238
242
 
239
- PartType = Union[types.Part, types.PartDict, str, 'PIL.Image.Image']
240
-
241
-
242
- def t_part(client: _api_client.ApiClient, part: PartType) -> types.Part:
243
+ def t_part(
244
+ part: Optional[types.PartUnionDict]
245
+ ) -> types.Part:
243
246
  try:
244
247
  import PIL.Image
245
248
 
@@ -247,7 +250,7 @@ def t_part(client: _api_client.ApiClient, part: PartType) -> types.Part:
247
250
  except ImportError:
248
251
  PIL_Image = None
249
252
 
250
- if not part:
253
+ if part is None:
251
254
  raise ValueError('content part is required.')
252
255
  if isinstance(part, str):
253
256
  return types.Part(text=part)
@@ -257,25 +260,29 @@ def t_part(client: _api_client.ApiClient, part: PartType) -> types.Part:
257
260
  if not part.uri or not part.mime_type:
258
261
  raise ValueError('file uri and mime_type are required.')
259
262
  return types.Part.from_uri(file_uri=part.uri, mime_type=part.mime_type)
260
- else:
263
+ if isinstance(part, dict):
264
+ return types.Part.model_validate(part)
265
+ if isinstance(part, types.Part):
261
266
  return part
267
+ raise ValueError(f'Unsupported content part type: {type(part)}')
262
268
 
263
269
 
264
270
  def t_parts(
265
- client: _api_client.ApiClient, parts: Union[list, PartType]
271
+ parts: Optional[Union[list[types.PartUnionDict], types.PartUnionDict]],
266
272
  ) -> list[types.Part]:
267
- if not parts:
273
+ #
274
+ if parts is None or (isinstance(parts, list) and not parts):
268
275
  raise ValueError('content parts are required.')
269
276
  if isinstance(parts, list):
270
- return [t_part(client, part) for part in parts]
277
+ return [t_part(part) for part in parts]
271
278
  else:
272
- return [t_part(client, parts)]
279
+ return [t_part(parts)]
273
280
 
274
281
 
275
282
  def t_image_predictions(
276
- client: _api_client.ApiClient,
283
+ client: _api_client.BaseApiClient,
277
284
  predictions: Optional[Iterable[Mapping[str, Any]]],
278
- ) -> list[types.GeneratedImage]:
285
+ ) -> Optional[list[types.GeneratedImage]]:
279
286
  if not predictions:
280
287
  return None
281
288
  images = []
@@ -292,24 +299,38 @@ def t_image_predictions(
292
299
  return images
293
300
 
294
301
 
295
- ContentType = Union[types.Content, types.ContentDict, PartType]
302
+ ContentType = Union[types.Content, types.ContentDict, types.PartUnionDict]
296
303
 
297
304
 
298
305
  def t_content(
299
- client: _api_client.ApiClient,
300
- content: ContentType,
301
- ):
302
- if not content:
306
+ client: _api_client.BaseApiClient,
307
+ content: Optional[ContentType],
308
+ ) -> types.Content:
309
+ if content is None:
303
310
  raise ValueError('content is required.')
304
311
  if isinstance(content, types.Content):
305
312
  return content
306
313
  if isinstance(content, dict):
307
- return types.Content.model_validate(content)
308
- return types.Content(role='user', parts=t_parts(client, content))
314
+ try:
315
+ return types.Content.model_validate(content)
316
+ except pydantic.ValidationError:
317
+ possible_part = types.Part.model_validate(content)
318
+ return (
319
+ types.ModelContent(parts=[possible_part])
320
+ if possible_part.function_call
321
+ else types.UserContent(parts=[possible_part])
322
+ )
323
+ if isinstance(content, types.Part):
324
+ return (
325
+ types.ModelContent(parts=[content])
326
+ if content.function_call
327
+ else types.UserContent(parts=[content])
328
+ )
329
+ return types.UserContent(parts=content)
309
330
 
310
331
 
311
332
  def t_contents_for_embed(
312
- client: _api_client.ApiClient,
333
+ client: _api_client.BaseApiClient,
313
334
  contents: Union[list[types.Content], list[types.ContentDict], ContentType],
314
335
  ):
315
336
  if client.vertexai and isinstance(contents, list):
@@ -324,16 +345,105 @@ def t_contents_for_embed(
324
345
 
325
346
 
326
347
  def t_contents(
327
- client: _api_client.ApiClient,
328
- contents: Union[list[types.Content], list[types.ContentDict], ContentType],
329
- ):
330
- if not contents:
348
+ client: _api_client.BaseApiClient,
349
+ contents: Optional[
350
+ Union[types.ContentListUnion, types.ContentListUnionDict]
351
+ ],
352
+ ) -> list[types.Content]:
353
+ if contents is None or (isinstance(contents, list) and not contents):
331
354
  raise ValueError('contents are required.')
332
- if isinstance(contents, list):
333
- return [t_content(client, content) for content in contents]
334
- else:
355
+ if not isinstance(contents, list):
335
356
  return [t_content(client, contents)]
336
357
 
358
+ try:
359
+ import PIL.Image
360
+
361
+ PIL_Image = PIL.Image.Image
362
+ except ImportError:
363
+ PIL_Image = None
364
+
365
+ result: list[types.Content] = []
366
+ accumulated_parts: list[types.Part] = []
367
+
368
+ def _is_part(part: types.PartUnionDict) -> bool:
369
+ if (
370
+ isinstance(part, str)
371
+ or isinstance(part, types.File)
372
+ or (PIL_Image is not None and isinstance(part, PIL_Image))
373
+ or isinstance(part, types.Part)
374
+ ):
375
+ return True
376
+
377
+ if isinstance(part, dict):
378
+ try:
379
+ types.Part.model_validate(part)
380
+ return True
381
+ except pydantic.ValidationError:
382
+ return False
383
+
384
+ return False
385
+
386
+ def _is_user_part(part: types.Part) -> bool:
387
+ return not part.function_call
388
+
389
+ def _are_user_parts(parts: list[types.Part]) -> bool:
390
+ return all(_is_user_part(part) for part in parts)
391
+
392
+ def _append_accumulated_parts_as_content(
393
+ result: list[types.Content],
394
+ accumulated_parts: list[types.Part],
395
+ ):
396
+ if not accumulated_parts:
397
+ return
398
+ result.append(
399
+ types.UserContent(parts=accumulated_parts)
400
+ if _are_user_parts(accumulated_parts)
401
+ else types.ModelContent(parts=accumulated_parts)
402
+ )
403
+ accumulated_parts[:] = []
404
+
405
+ def _handle_current_part(
406
+ result: list[types.Content],
407
+ accumulated_parts: list[types.Part],
408
+ current_part: types.PartUnionDict,
409
+ ):
410
+ current_part = t_part(current_part)
411
+ if _is_user_part(current_part) == _are_user_parts(accumulated_parts):
412
+ accumulated_parts.append(current_part)
413
+ else:
414
+ _append_accumulated_parts_as_content(result, accumulated_parts)
415
+ accumulated_parts[:] = [current_part]
416
+
417
+ # iterator over contents
418
+ # if content type or content dict, append to result
419
+ # if consecutive part(s),
420
+ # group consecutive user part(s) to a UserContent
421
+ # group consecutive model part(s) to a ModelContent
422
+ # append to result
423
+ # if list, we only accept a list of types.PartUnion
424
+ for content in contents:
425
+ if (
426
+ isinstance(content, types.Content)
427
+ # only allowed inner list is a list of types.PartUnion
428
+ or isinstance(content, list)
429
+ ):
430
+ _append_accumulated_parts_as_content(result, accumulated_parts)
431
+ if isinstance(content, list):
432
+ result.append(types.UserContent(parts=content))
433
+ else:
434
+ result.append(content)
435
+ elif (_is_part(content)): # type: ignore
436
+ _handle_current_part(result, accumulated_parts, content) # type: ignore
437
+ elif isinstance(content, dict):
438
+ # PactDict is already handled in _is_part
439
+ result.append(types.Content.model_validate(content))
440
+ else:
441
+ raise ValueError(f'Unsupported content type: {type(content)}')
442
+
443
+ _append_accumulated_parts_as_content(result, accumulated_parts)
444
+
445
+ return result
446
+
337
447
 
338
448
  def handle_null_fields(schema: dict[str, Any]):
339
449
  """Process null fields in the schema so it is compatible with OpenAPI.
@@ -396,7 +506,7 @@ def handle_null_fields(schema: dict[str, Any]):
396
506
 
397
507
  def process_schema(
398
508
  schema: dict[str, Any],
399
- client: Optional[_api_client.ApiClient] = None,
509
+ client: _api_client.BaseApiClient,
400
510
  defs: Optional[dict[str, Any]] = None,
401
511
  *,
402
512
  order_properties: bool = True,
@@ -459,7 +569,7 @@ def process_schema(
459
569
  'type': 'array'
460
570
  }
461
571
  """
462
- if client and not client.vertexai:
572
+ if not client.vertexai:
463
573
  schema.pop('title', None)
464
574
 
465
575
  if schema.get('default') is not None:
@@ -485,15 +595,16 @@ def process_schema(
485
595
  # After removing null fields, Optional fields with only one possible type
486
596
  # will have a $ref key that needs to be flattened
487
597
  # For example: {'default': None, 'description': 'Name of the person', 'nullable': True, '$ref': '#/$defs/TestPerson'}
488
- if schema.get('$ref', None):
489
- ref = defs[schema.get('$ref').split('defs/')[-1]]
598
+ schema_ref = schema.get('$ref', None)
599
+ if schema_ref is not None:
600
+ ref = defs[schema_ref.split('defs/')[-1]]
490
601
  for schema_key in list(ref.keys()):
491
602
  schema[schema_key] = ref[schema_key]
492
603
  del schema['$ref']
493
604
 
494
605
  any_of = schema.get('anyOf', None)
495
606
  if any_of is not None:
496
- if not client.vertexai:
607
+ if client and not client.vertexai:
497
608
  raise ValueError(
498
609
  'AnyOf is not supported in the response schema for the Gemini API.'
499
610
  )
@@ -559,9 +670,9 @@ def process_schema(
559
670
 
560
671
 
561
672
  def _process_enum(
562
- enum: EnumMeta, client: Optional[_api_client.ApiClient] = None
673
+ enum: EnumMeta, client: Optional[_api_client.BaseApiClient] = None # type: ignore
563
674
  ) -> types.Schema:
564
- for member in enum:
675
+ for member in enum: # type: ignore
565
676
  if not isinstance(member.value, str):
566
677
  raise TypeError(
567
678
  f'Enum member {member.name} value must be a string, got'
@@ -578,7 +689,7 @@ def _process_enum(
578
689
 
579
690
 
580
691
  def t_schema(
581
- client: _api_client.ApiClient, origin: Union[types.SchemaUnionDict, Any]
692
+ client: _api_client.BaseApiClient, origin: Union[types.SchemaUnionDict, Any]
582
693
  ) -> Optional[types.Schema]:
583
694
  if not origin:
584
695
  return None
@@ -624,7 +735,7 @@ def t_schema(
624
735
 
625
736
 
626
737
  def t_speech_config(
627
- _: _api_client.ApiClient, origin: Union[types.SpeechConfigUnionDict, Any]
738
+ _: _api_client.BaseApiClient, origin: Union[types.SpeechConfigUnionDict, Any]
628
739
  ) -> Optional[types.SpeechConfig]:
629
740
  if not origin:
630
741
  return None
@@ -639,7 +750,10 @@ def t_speech_config(
639
750
  if (
640
751
  isinstance(origin, dict)
641
752
  and 'voice_config' in origin
753
+ and origin['voice_config'] is not None
642
754
  and 'prebuilt_voice_config' in origin['voice_config']
755
+ and origin['voice_config']['prebuilt_voice_config'] is not None
756
+ and 'voice_name' in origin['voice_config']['prebuilt_voice_config']
643
757
  ):
644
758
  return types.SpeechConfig(
645
759
  voice_config=types.VoiceConfig(
@@ -653,7 +767,7 @@ def t_speech_config(
653
767
  raise ValueError(f'Unsupported speechConfig type: {type(origin)}')
654
768
 
655
769
 
656
- def t_tool(client: _api_client.ApiClient, origin) -> types.Tool:
770
+ def t_tool(client: _api_client.BaseApiClient, origin) -> Optional[types.Tool]:
657
771
  if not origin:
658
772
  return None
659
773
  if inspect.isfunction(origin) or inspect.ismethod(origin):
@@ -670,7 +784,7 @@ def t_tool(client: _api_client.ApiClient, origin) -> types.Tool:
670
784
 
671
785
  # Only support functions now.
672
786
  def t_tools(
673
- client: _api_client.ApiClient, origin: list[Any]
787
+ client: _api_client.BaseApiClient, origin: list[Any]
674
788
  ) -> list[types.Tool]:
675
789
  if not origin:
676
790
  return []
@@ -679,22 +793,23 @@ def t_tools(
679
793
  for tool in origin:
680
794
  transformed_tool = t_tool(client, tool)
681
795
  # All functions should be merged into one tool.
682
- if transformed_tool.function_declarations:
683
- function_tool.function_declarations += (
684
- transformed_tool.function_declarations
685
- )
686
- else:
687
- tools.append(transformed_tool)
796
+ if transformed_tool is not None:
797
+ if transformed_tool.function_declarations:
798
+ function_tool.function_declarations += (
799
+ transformed_tool.function_declarations
800
+ )
801
+ else:
802
+ tools.append(transformed_tool)
688
803
  if function_tool.function_declarations:
689
804
  tools.append(function_tool)
690
805
  return tools
691
806
 
692
807
 
693
- def t_cached_content_name(client: _api_client.ApiClient, name: str):
808
+ def t_cached_content_name(client: _api_client.BaseApiClient, name: str):
694
809
  return _resource_name(client, name, collection_identifier='cachedContents')
695
810
 
696
811
 
697
- def t_batch_job_source(client: _api_client.ApiClient, src: str):
812
+ def t_batch_job_source(client: _api_client.BaseApiClient, src: str):
698
813
  if src.startswith('gs://'):
699
814
  return types.BatchJobSource(
700
815
  format='jsonl',
@@ -709,7 +824,7 @@ def t_batch_job_source(client: _api_client.ApiClient, src: str):
709
824
  raise ValueError(f'Unsupported source: {src}')
710
825
 
711
826
 
712
- def t_batch_job_destination(client: _api_client.ApiClient, dest: str):
827
+ def t_batch_job_destination(client: _api_client.BaseApiClient, dest: str):
713
828
  if dest.startswith('gs://'):
714
829
  return types.BatchJobDestination(
715
830
  format='jsonl',
@@ -724,7 +839,7 @@ def t_batch_job_destination(client: _api_client.ApiClient, dest: str):
724
839
  raise ValueError(f'Unsupported destination: {dest}')
725
840
 
726
841
 
727
- def t_batch_job_name(client: _api_client.ApiClient, name: str):
842
+ def t_batch_job_name(client: _api_client.BaseApiClient, name: str):
728
843
  if not client.vertexai:
729
844
  return name
730
845
 
@@ -743,7 +858,7 @@ LRO_POLLING_TIMEOUT_SECONDS = 900.0
743
858
  LRO_POLLING_MULTIPLIER = 1.5
744
859
 
745
860
 
746
- def t_resolve_operation(api_client: _api_client.ApiClient, struct: dict):
861
+ def t_resolve_operation(api_client: _api_client.BaseApiClient, struct: dict):
747
862
  if (name := struct.get('name')) and '/operations/' in name:
748
863
  operation: dict[str, Any] = struct
749
864
  total_seconds = 0.0
@@ -752,7 +867,7 @@ def t_resolve_operation(api_client: _api_client.ApiClient, struct: dict):
752
867
  if total_seconds > LRO_POLLING_TIMEOUT_SECONDS:
753
868
  raise RuntimeError(f'Operation {name} timed out.\n{operation}')
754
869
  # TODO(b/374433890): Replace with LRO module once it's available.
755
- operation: dict[str, Any] = api_client.request(
870
+ operation = api_client.request(
756
871
  http_method='GET', path=name, request_dict={}
757
872
  )
758
873
  time.sleep(delay_seconds)
@@ -772,15 +887,25 @@ def t_resolve_operation(api_client: _api_client.ApiClient, struct: dict):
772
887
 
773
888
 
774
889
  def t_file_name(
775
- api_client: _api_client.ApiClient, name: Union[str, types.File]
890
+ api_client: _api_client.BaseApiClient,
891
+ name: Optional[Union[str, types.File, types.Video, types.GeneratedVideo]],
776
892
  ):
777
893
  # Remove the files/ prefix since it's added to the url path.
778
894
  if isinstance(name, types.File):
779
895
  name = name.name
896
+ elif isinstance(name, types.Video):
897
+ name = name.uri
898
+ elif isinstance(name, types.GeneratedVideo):
899
+ name = name.video.uri
780
900
 
781
901
  if name is None:
782
902
  raise ValueError('File name is required.')
783
903
 
904
+ if not isinstance(name, str):
905
+ raise ValueError(
906
+ f'Could not convert object of type `{type(name)}` to a file name.'
907
+ )
908
+
784
909
  if name.startswith('https://'):
785
910
  suffix = name.split('files/')[1]
786
911
  match = re.match('[a-z0-9]+', suffix)
@@ -794,17 +919,20 @@ def t_file_name(
794
919
 
795
920
 
796
921
  def t_tuning_job_status(
797
- api_client: _api_client.ApiClient, status: str
798
- ) -> types.JobState:
922
+ api_client: _api_client.BaseApiClient, status: str
923
+ ) -> Union[types.JobState, str]:
799
924
  if status == 'STATE_UNSPECIFIED':
800
- return 'JOB_STATE_UNSPECIFIED'
925
+ return types.JobState.JOB_STATE_UNSPECIFIED
801
926
  elif status == 'CREATING':
802
- return 'JOB_STATE_RUNNING'
927
+ return types.JobState.JOB_STATE_RUNNING
803
928
  elif status == 'ACTIVE':
804
- return 'JOB_STATE_SUCCEEDED'
929
+ return types.JobState.JOB_STATE_SUCCEEDED
805
930
  elif status == 'FAILED':
806
- return 'JOB_STATE_FAILED'
931
+ return types.JobState.JOB_STATE_FAILED
807
932
  else:
933
+ for state in types.JobState:
934
+ if str(state.value) == status:
935
+ return state
808
936
  return status
809
937
 
810
938
 
@@ -812,7 +940,7 @@ def t_tuning_job_status(
812
940
  # We shouldn't use this transformer if the backend adhere to Cloud Type
813
941
  # format https://cloud.google.com/docs/discovery/type-format.
814
942
  # TODO(b/389133914,b/390320301): Remove the hack after backend fix the issue.
815
- def t_bytes(api_client: _api_client.ApiClient, data: bytes) -> str:
943
+ def t_bytes(api_client: _api_client.BaseApiClient, data: bytes) -> str:
816
944
  if not isinstance(data, bytes):
817
945
  return data
818
946
  return base64.b64encode(data).decode('ascii')