google-genai 1.3.0__py3-none-any.whl → 1.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -24,9 +24,12 @@ import logging
24
24
  import re
25
25
  import sys
26
26
  import time
27
+ import types as builtin_types
27
28
  import typing
28
29
  from typing import Any, GenericAlias, Optional, Union
29
30
 
31
+ import types as builtin_types
32
+
30
33
  if typing.TYPE_CHECKING:
31
34
  import PIL.Image
32
35
 
@@ -38,15 +41,15 @@ from . import types
38
41
  logger = logging.getLogger('google_genai._transformers')
39
42
 
40
43
  if sys.version_info >= (3, 10):
41
- VersionedUnionType = typing.types.UnionType
42
- _UNION_TYPES = (typing.Union, typing.types.UnionType)
44
+ VersionedUnionType = builtin_types.UnionType
45
+ _UNION_TYPES = (typing.Union, builtin_types.UnionType)
43
46
  else:
44
47
  VersionedUnionType = typing._UnionGenericAlias
45
48
  _UNION_TYPES = (typing.Union,)
46
49
 
47
50
 
48
51
  def _resource_name(
49
- client: _api_client.ApiClient,
52
+ client: _api_client.BaseApiClient,
50
53
  resource_name: str,
51
54
  *,
52
55
  collection_identifier: str,
@@ -138,7 +141,7 @@ def _resource_name(
138
141
  return resource_name
139
142
 
140
143
 
141
- def t_model(client: _api_client.ApiClient, model: str):
144
+ def t_model(client: _api_client.BaseApiClient, model: str):
142
145
  if not model:
143
146
  raise ValueError('model is required.')
144
147
  if client.vertexai:
@@ -162,7 +165,7 @@ def t_model(client: _api_client.ApiClient, model: str):
162
165
  return f'models/{model}'
163
166
 
164
167
 
165
- def t_models_url(api_client: _api_client.ApiClient, base_models: bool) -> str:
168
+ def t_models_url(api_client: _api_client.BaseApiClient, base_models: bool) -> str:
166
169
  if api_client.vertexai:
167
170
  if base_models:
168
171
  return 'publishers/google/models'
@@ -176,7 +179,7 @@ def t_models_url(api_client: _api_client.ApiClient, base_models: bool) -> str:
176
179
 
177
180
 
178
181
  def t_extract_models(
179
- api_client: _api_client.ApiClient, response: dict
182
+ api_client: _api_client.BaseApiClient, response: dict
180
183
  ) -> list[types.Model]:
181
184
  if not response:
182
185
  return []
@@ -197,7 +200,7 @@ def t_extract_models(
197
200
  return []
198
201
 
199
202
 
200
- def t_caches_model(api_client: _api_client.ApiClient, model: str):
203
+ def t_caches_model(api_client: _api_client.BaseApiClient, model: str):
201
204
  model = t_model(api_client, model)
202
205
  if not model:
203
206
  return None
@@ -213,6 +216,7 @@ def t_caches_model(api_client: _api_client.ApiClient, model: str):
213
216
 
214
217
 
215
218
  def pil_to_blob(img) -> types.Blob:
219
+ PngImagePlugin: Optional[builtin_types.ModuleType]
216
220
  try:
217
221
  import PIL.PngImagePlugin
218
222
 
@@ -236,10 +240,9 @@ def pil_to_blob(img) -> types.Blob:
236
240
  return types.Blob(mime_type=mime_type, data=data)
237
241
 
238
242
 
239
- PartType = Union[types.Part, types.PartDict, str, 'PIL.Image.Image']
240
-
241
-
242
- def t_part(client: _api_client.ApiClient, part: PartType) -> types.Part:
243
+ def t_part(
244
+ client: _api_client.BaseApiClient, part: Optional[types.PartUnionDict]
245
+ ) -> types.Part:
243
246
  try:
244
247
  import PIL.Image
245
248
 
@@ -247,7 +250,7 @@ def t_part(client: _api_client.ApiClient, part: PartType) -> types.Part:
247
250
  except ImportError:
248
251
  PIL_Image = None
249
252
 
250
- if not part:
253
+ if part is None:
251
254
  raise ValueError('content part is required.')
252
255
  if isinstance(part, str):
253
256
  return types.Part(text=part)
@@ -257,14 +260,19 @@ def t_part(client: _api_client.ApiClient, part: PartType) -> types.Part:
257
260
  if not part.uri or not part.mime_type:
258
261
  raise ValueError('file uri and mime_type are required.')
259
262
  return types.Part.from_uri(file_uri=part.uri, mime_type=part.mime_type)
260
- else:
263
+ if isinstance(part, dict):
264
+ return types.Part.model_validate(part)
265
+ if isinstance(part, types.Part):
261
266
  return part
267
+ raise ValueError(f'Unsupported content part type: {type(part)}')
262
268
 
263
269
 
264
270
  def t_parts(
265
- client: _api_client.ApiClient, parts: Union[list, PartType]
271
+ client: _api_client.BaseApiClient,
272
+ parts: Optional[Union[list[types.PartUnionDict], types.PartUnionDict]],
266
273
  ) -> list[types.Part]:
267
- if not parts:
274
+ #
275
+ if parts is None or (isinstance(parts, list) and not parts):
268
276
  raise ValueError('content parts are required.')
269
277
  if isinstance(parts, list):
270
278
  return [t_part(client, part) for part in parts]
@@ -273,7 +281,7 @@ def t_parts(
273
281
 
274
282
 
275
283
  def t_image_predictions(
276
- client: _api_client.ApiClient,
284
+ client: _api_client.BaseApiClient,
277
285
  predictions: Optional[Iterable[Mapping[str, Any]]],
278
286
  ) -> list[types.GeneratedImage]:
279
287
  if not predictions:
@@ -292,24 +300,38 @@ def t_image_predictions(
292
300
  return images
293
301
 
294
302
 
295
- ContentType = Union[types.Content, types.ContentDict, PartType]
303
+ ContentType = Union[types.Content, types.ContentDict, types.PartUnionDict]
296
304
 
297
305
 
298
306
  def t_content(
299
- client: _api_client.ApiClient,
300
- content: ContentType,
301
- ):
302
- if not content:
307
+ client: _api_client.BaseApiClient,
308
+ content: Optional[ContentType],
309
+ ) -> types.Content:
310
+ if content is None:
303
311
  raise ValueError('content is required.')
304
312
  if isinstance(content, types.Content):
305
313
  return content
306
314
  if isinstance(content, dict):
307
- return types.Content.model_validate(content)
308
- return types.Content(role='user', parts=t_parts(client, content))
315
+ try:
316
+ return types.Content.model_validate(content)
317
+ except pydantic.ValidationError:
318
+ possible_part = types.Part.model_validate(content)
319
+ return (
320
+ types.ModelContent(parts=[possible_part])
321
+ if possible_part.function_call
322
+ else types.UserContent(parts=[possible_part])
323
+ )
324
+ if isinstance(content, types.Part):
325
+ return (
326
+ types.ModelContent(parts=[content])
327
+ if content.function_call
328
+ else types.UserContent(parts=[content])
329
+ )
330
+ return types.UserContent(parts=content)
309
331
 
310
332
 
311
333
  def t_contents_for_embed(
312
- client: _api_client.ApiClient,
334
+ client: _api_client.BaseApiClient,
313
335
  contents: Union[list[types.Content], list[types.ContentDict], ContentType],
314
336
  ):
315
337
  if client.vertexai and isinstance(contents, list):
@@ -324,16 +346,105 @@ def t_contents_for_embed(
324
346
 
325
347
 
326
348
  def t_contents(
327
- client: _api_client.ApiClient,
328
- contents: Union[list[types.Content], list[types.ContentDict], ContentType],
329
- ):
330
- if not contents:
349
+ client: _api_client.BaseApiClient,
350
+ contents: Optional[
351
+ Union[types.ContentListUnion, types.ContentListUnionDict]
352
+ ],
353
+ ) -> list[types.Content]:
354
+ if contents is None or (isinstance(contents, list) and not contents):
331
355
  raise ValueError('contents are required.')
332
- if isinstance(contents, list):
333
- return [t_content(client, content) for content in contents]
334
- else:
356
+ if not isinstance(contents, list):
335
357
  return [t_content(client, contents)]
336
358
 
359
+ try:
360
+ import PIL.Image
361
+
362
+ PIL_Image = PIL.Image.Image
363
+ except ImportError:
364
+ PIL_Image = None
365
+
366
+ result: list[types.Content] = []
367
+ accumulated_parts: list[types.Part] = []
368
+
369
+ def _is_part(part: types.PartUnionDict) -> bool:
370
+ if (
371
+ isinstance(part, str)
372
+ or isinstance(part, types.File)
373
+ or (PIL_Image is not None and isinstance(part, PIL_Image))
374
+ or isinstance(part, types.Part)
375
+ ):
376
+ return True
377
+
378
+ if isinstance(part, dict):
379
+ try:
380
+ types.Part.model_validate(part)
381
+ return True
382
+ except pydantic.ValidationError:
383
+ return False
384
+
385
+ return False
386
+
387
+ def _is_user_part(part: types.Part) -> bool:
388
+ return not part.function_call
389
+
390
+ def _are_user_parts(parts: list[types.Part]) -> bool:
391
+ return all(_is_user_part(part) for part in parts)
392
+
393
+ def _append_accumulated_parts_as_content(
394
+ result: list[types.Content],
395
+ accumulated_parts: list[types.Part],
396
+ ):
397
+ if not accumulated_parts:
398
+ return
399
+ result.append(
400
+ types.UserContent(parts=accumulated_parts)
401
+ if _are_user_parts(accumulated_parts)
402
+ else types.ModelContent(parts=accumulated_parts)
403
+ )
404
+ accumulated_parts[:] = []
405
+
406
+ def _handle_current_part(
407
+ result: list[types.Content],
408
+ accumulated_parts: list[types.Part],
409
+ current_part: types.PartUnionDict,
410
+ ):
411
+ current_part = t_part(client, current_part)
412
+ if _is_user_part(current_part) == _are_user_parts(accumulated_parts):
413
+ accumulated_parts.append(current_part)
414
+ else:
415
+ _append_accumulated_parts_as_content(result, accumulated_parts)
416
+ accumulated_parts[:] = [current_part]
417
+
418
+ # iterator over contents
419
+ # if content type or content dict, append to result
420
+ # if consecutive part(s),
421
+ # group consecutive user part(s) to a UserContent
422
+ # group consecutive model part(s) to a ModelContent
423
+ # append to result
424
+ # if list, we only accept a list of types.PartUnion
425
+ for content in contents:
426
+ if (
427
+ isinstance(content, types.Content)
428
+ # only allowed inner list is a list of types.PartUnion
429
+ or isinstance(content, list)
430
+ ):
431
+ _append_accumulated_parts_as_content(result, accumulated_parts)
432
+ if isinstance(content, list):
433
+ result.append(types.UserContent(parts=content))
434
+ else:
435
+ result.append(content)
436
+ elif (_is_part(content)): # type: ignore
437
+ _handle_current_part(result, accumulated_parts, content) # type: ignore
438
+ elif isinstance(content, dict):
439
+ # PactDict is already handled in _is_part
440
+ result.append(types.Content.model_validate(content))
441
+ else:
442
+ raise ValueError(f'Unsupported content type: {type(content)}')
443
+
444
+ _append_accumulated_parts_as_content(result, accumulated_parts)
445
+
446
+ return result
447
+
337
448
 
338
449
  def handle_null_fields(schema: dict[str, Any]):
339
450
  """Process null fields in the schema so it is compatible with OpenAPI.
@@ -396,7 +507,7 @@ def handle_null_fields(schema: dict[str, Any]):
396
507
 
397
508
  def process_schema(
398
509
  schema: dict[str, Any],
399
- client: Optional[_api_client.ApiClient] = None,
510
+ client: Optional[_api_client.BaseApiClient] = None,
400
511
  defs: Optional[dict[str, Any]] = None,
401
512
  *,
402
513
  order_properties: bool = True,
@@ -559,9 +670,9 @@ def process_schema(
559
670
 
560
671
 
561
672
  def _process_enum(
562
- enum: EnumMeta, client: Optional[_api_client.ApiClient] = None
673
+ enum: EnumMeta, client: Optional[_api_client.BaseApiClient] = None # type: ignore
563
674
  ) -> types.Schema:
564
- for member in enum:
675
+ for member in enum: # type: ignore
565
676
  if not isinstance(member.value, str):
566
677
  raise TypeError(
567
678
  f'Enum member {member.name} value must be a string, got'
@@ -578,7 +689,7 @@ def _process_enum(
578
689
 
579
690
 
580
691
  def t_schema(
581
- client: _api_client.ApiClient, origin: Union[types.SchemaUnionDict, Any]
692
+ client: _api_client.BaseApiClient, origin: Union[types.SchemaUnionDict, Any]
582
693
  ) -> Optional[types.Schema]:
583
694
  if not origin:
584
695
  return None
@@ -624,7 +735,7 @@ def t_schema(
624
735
 
625
736
 
626
737
  def t_speech_config(
627
- _: _api_client.ApiClient, origin: Union[types.SpeechConfigUnionDict, Any]
738
+ _: _api_client.BaseApiClient, origin: Union[types.SpeechConfigUnionDict, Any]
628
739
  ) -> Optional[types.SpeechConfig]:
629
740
  if not origin:
630
741
  return None
@@ -653,7 +764,7 @@ def t_speech_config(
653
764
  raise ValueError(f'Unsupported speechConfig type: {type(origin)}')
654
765
 
655
766
 
656
- def t_tool(client: _api_client.ApiClient, origin) -> types.Tool:
767
+ def t_tool(client: _api_client.BaseApiClient, origin) -> types.Tool:
657
768
  if not origin:
658
769
  return None
659
770
  if inspect.isfunction(origin) or inspect.ismethod(origin):
@@ -670,7 +781,7 @@ def t_tool(client: _api_client.ApiClient, origin) -> types.Tool:
670
781
 
671
782
  # Only support functions now.
672
783
  def t_tools(
673
- client: _api_client.ApiClient, origin: list[Any]
784
+ client: _api_client.BaseApiClient, origin: list[Any]
674
785
  ) -> list[types.Tool]:
675
786
  if not origin:
676
787
  return []
@@ -690,11 +801,11 @@ def t_tools(
690
801
  return tools
691
802
 
692
803
 
693
- def t_cached_content_name(client: _api_client.ApiClient, name: str):
804
+ def t_cached_content_name(client: _api_client.BaseApiClient, name: str):
694
805
  return _resource_name(client, name, collection_identifier='cachedContents')
695
806
 
696
807
 
697
- def t_batch_job_source(client: _api_client.ApiClient, src: str):
808
+ def t_batch_job_source(client: _api_client.BaseApiClient, src: str):
698
809
  if src.startswith('gs://'):
699
810
  return types.BatchJobSource(
700
811
  format='jsonl',
@@ -709,7 +820,7 @@ def t_batch_job_source(client: _api_client.ApiClient, src: str):
709
820
  raise ValueError(f'Unsupported source: {src}')
710
821
 
711
822
 
712
- def t_batch_job_destination(client: _api_client.ApiClient, dest: str):
823
+ def t_batch_job_destination(client: _api_client.BaseApiClient, dest: str):
713
824
  if dest.startswith('gs://'):
714
825
  return types.BatchJobDestination(
715
826
  format='jsonl',
@@ -724,7 +835,7 @@ def t_batch_job_destination(client: _api_client.ApiClient, dest: str):
724
835
  raise ValueError(f'Unsupported destination: {dest}')
725
836
 
726
837
 
727
- def t_batch_job_name(client: _api_client.ApiClient, name: str):
838
+ def t_batch_job_name(client: _api_client.BaseApiClient, name: str):
728
839
  if not client.vertexai:
729
840
  return name
730
841
 
@@ -743,7 +854,7 @@ LRO_POLLING_TIMEOUT_SECONDS = 900.0
743
854
  LRO_POLLING_MULTIPLIER = 1.5
744
855
 
745
856
 
746
- def t_resolve_operation(api_client: _api_client.ApiClient, struct: dict):
857
+ def t_resolve_operation(api_client: _api_client.BaseApiClient, struct: dict):
747
858
  if (name := struct.get('name')) and '/operations/' in name:
748
859
  operation: dict[str, Any] = struct
749
860
  total_seconds = 0.0
@@ -752,7 +863,7 @@ def t_resolve_operation(api_client: _api_client.ApiClient, struct: dict):
752
863
  if total_seconds > LRO_POLLING_TIMEOUT_SECONDS:
753
864
  raise RuntimeError(f'Operation {name} timed out.\n{operation}')
754
865
  # TODO(b/374433890): Replace with LRO module once it's available.
755
- operation: dict[str, Any] = api_client.request(
866
+ operation = api_client.request(
756
867
  http_method='GET', path=name, request_dict={}
757
868
  )
758
869
  time.sleep(delay_seconds)
@@ -772,7 +883,7 @@ def t_resolve_operation(api_client: _api_client.ApiClient, struct: dict):
772
883
 
773
884
 
774
885
  def t_file_name(
775
- api_client: _api_client.ApiClient, name: Union[str, types.File]
886
+ api_client: _api_client.BaseApiClient, name: Optional[Union[str, types.File]]
776
887
  ):
777
888
  # Remove the files/ prefix since it's added to the url path.
778
889
  if isinstance(name, types.File):
@@ -794,7 +905,7 @@ def t_file_name(
794
905
 
795
906
 
796
907
  def t_tuning_job_status(
797
- api_client: _api_client.ApiClient, status: str
908
+ api_client: _api_client.BaseApiClient, status: str
798
909
  ) -> types.JobState:
799
910
  if status == 'STATE_UNSPECIFIED':
800
911
  return 'JOB_STATE_UNSPECIFIED'
@@ -812,7 +923,7 @@ def t_tuning_job_status(
812
923
  # We shouldn't use this transformer if the backend adhere to Cloud Type
813
924
  # format https://cloud.google.com/docs/discovery/type-format.
814
925
  # TODO(b/389133914,b/390320301): Remove the hack after backend fix the issue.
815
- def t_bytes(api_client: _api_client.ApiClient, data: bytes) -> str:
926
+ def t_bytes(api_client: _api_client.BaseApiClient, data: bytes) -> str:
816
927
  if not isinstance(data, bytes):
817
928
  return data
818
929
  return base64.b64encode(data).decode('ascii')