google-genai 1.2.0__py3-none-any.whl → 1.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -29,7 +29,7 @@ import google.auth
29
29
  from requests.exceptions import HTTPError
30
30
 
31
31
  from . import errors
32
- from ._api_client import ApiClient
32
+ from ._api_client import BaseApiClient
33
33
  from ._api_client import HttpOptions
34
34
  from ._api_client import HttpRequest
35
35
  from ._api_client import HttpResponse
@@ -60,6 +60,10 @@ def _redact_request_headers(headers):
60
60
  redacted_headers[header_name] = _redact_language_label(
61
61
  _redact_version_numbers(header_value)
62
62
  )
63
+ elif header_name.lower() == 'x-goog-user-project':
64
+ continue
65
+ elif header_name.lower() == 'authorization':
66
+ continue
63
67
  else:
64
68
  redacted_headers[header_name] = header_value
65
69
  return redacted_headers
@@ -175,7 +179,7 @@ class ReplayFile(BaseModel):
175
179
  interactions: list[ReplayInteraction]
176
180
 
177
181
 
178
- class ReplayApiClient(ApiClient):
182
+ class ReplayApiClient(BaseApiClient):
179
183
  """For integration testing, send recorded response or records a response."""
180
184
 
181
185
  def __init__(
@@ -409,6 +413,34 @@ class ReplayApiClient(ApiClient):
409
413
  else:
410
414
  return self._build_response_from_replay(http_request)
411
415
 
416
+ async def _async_request(
417
+ self,
418
+ http_request: HttpRequest,
419
+ stream: bool = False,
420
+ ) -> HttpResponse:
421
+ self._initialize_replay_session_if_not_loaded()
422
+ if self._should_call_api():
423
+ _debug_print('api mode request: %s' % http_request)
424
+ try:
425
+ result = await super()._async_request(http_request, stream)
426
+ except errors.APIError as e:
427
+ self._record_interaction(http_request, e)
428
+ raise e
429
+ if stream:
430
+ result_segments = []
431
+ async for segment in result.async_segments():
432
+ result_segments.append(json.dumps(segment))
433
+ result = HttpResponse(result.headers, result_segments)
434
+ self._record_interaction(http_request, result)
435
+ # Need to return a RecordedResponse that rebuilds the response
436
+ # segments since the stream has been consumed.
437
+ else:
438
+ self._record_interaction(http_request, result)
439
+ _debug_print('api mode result: %s' % result.json)
440
+ return result
441
+ else:
442
+ return self._build_response_from_replay(http_request)
443
+
412
444
  def upload_file(self, file_path: Union[str, io.IOBase], upload_url: str, upload_size: int):
413
445
  if isinstance(file_path, io.IOBase):
414
446
  offset = file_path.tell()
@@ -425,6 +457,7 @@ class ReplayApiClient(ApiClient):
425
457
  method='POST', url='', data={'file_path': file_path}, headers={}
426
458
  )
427
459
  if self._should_call_api():
460
+ result: Union[str, HttpResponse]
428
461
  try:
429
462
  result = super().upload_file(file_path, upload_url, upload_size)
430
463
  except HTTPError as e:
@@ -453,4 +486,3 @@ class ReplayApiClient(ApiClient):
453
486
  return result
454
487
  else:
455
488
  return self._build_response_from_replay(request)
456
-
@@ -17,13 +17,13 @@ import asyncio
17
17
  import time
18
18
  from unittest.mock import MagicMock, patch
19
19
  import pytest
20
- from .api_client import ApiClient
20
+ from .api_client import BaseApiClient
21
21
 
22
22
 
23
- @patch('genai.api_client.ApiClient._build_request')
24
- @patch('genai.api_client.ApiClient._request')
23
+ @patch('genai.api_client.BaseApiClient._build_request')
24
+ @patch('genai.api_client.BaseApiClient._request')
25
25
  def test_request_streamed_non_blocking(mock_request, mock_build_request):
26
- api_client = ApiClient(api_key='test_api_key')
26
+ api_client = BaseApiClient(api_key='test_api_key')
27
27
  http_method = 'GET'
28
28
  path = 'test/path'
29
29
  request_dict = {'key': 'value'}
@@ -56,8 +56,8 @@ def test_request_streamed_non_blocking(mock_request, mock_build_request):
56
56
  assert end_time - start_time > 0.3
57
57
 
58
58
 
59
- @patch('genai.api_client.ApiClient._build_request')
60
- @patch('genai.api_client.ApiClient._async_request')
59
+ @patch('genai.api_client.BaseApiClient._build_request')
60
+ @patch('genai.api_client.BaseApiClient._async_request')
61
61
  @pytest.mark.asyncio
62
62
  async def test_async_request(mock_async_request, mock_build_request):
63
63
  api_client = ApiClient(api_key='test_api_key')
@@ -99,8 +99,8 @@ async def test_async_request(mock_async_request, mock_build_request):
99
99
  assert 0.1 <= end_time - start_time < 0.15
100
100
 
101
101
 
102
- @patch('genai.api_client.ApiClient._build_request')
103
- @patch('genai.api_client.ApiClient._async_request')
102
+ @patch('genai.api_client.BaseApiClient._build_request')
103
+ @patch('genai.api_client.BaseApiClient._async_request')
104
104
  @pytest.mark.asyncio
105
105
  async def test_async_request_streamed_non_blocking(
106
106
  mock_async_request, mock_build_request
@@ -20,12 +20,16 @@ from collections.abc import Iterable, Mapping
20
20
  from enum import Enum, EnumMeta
21
21
  import inspect
22
22
  import io
23
+ import logging
23
24
  import re
24
25
  import sys
25
26
  import time
27
+ import types as builtin_types
26
28
  import typing
27
29
  from typing import Any, GenericAlias, Optional, Union
28
30
 
31
+ import types as builtin_types
32
+
29
33
  if typing.TYPE_CHECKING:
30
34
  import PIL.Image
31
35
 
@@ -34,16 +38,18 @@ import pydantic
34
38
  from . import _api_client
35
39
  from . import types
36
40
 
41
+ logger = logging.getLogger('google_genai._transformers')
42
+
37
43
  if sys.version_info >= (3, 10):
38
- VersionedUnionType = typing.types.UnionType
39
- _UNION_TYPES = (typing.Union, typing.types.UnionType)
44
+ VersionedUnionType = builtin_types.UnionType
45
+ _UNION_TYPES = (typing.Union, builtin_types.UnionType)
40
46
  else:
41
47
  VersionedUnionType = typing._UnionGenericAlias
42
48
  _UNION_TYPES = (typing.Union,)
43
49
 
44
50
 
45
51
  def _resource_name(
46
- client: _api_client.ApiClient,
52
+ client: _api_client.BaseApiClient,
47
53
  resource_name: str,
48
54
  *,
49
55
  collection_identifier: str,
@@ -135,7 +141,7 @@ def _resource_name(
135
141
  return resource_name
136
142
 
137
143
 
138
- def t_model(client: _api_client.ApiClient, model: str):
144
+ def t_model(client: _api_client.BaseApiClient, model: str):
139
145
  if not model:
140
146
  raise ValueError('model is required.')
141
147
  if client.vertexai:
@@ -159,7 +165,7 @@ def t_model(client: _api_client.ApiClient, model: str):
159
165
  return f'models/{model}'
160
166
 
161
167
 
162
- def t_models_url(api_client: _api_client.ApiClient, base_models: bool) -> str:
168
+ def t_models_url(api_client: _api_client.BaseApiClient, base_models: bool) -> str:
163
169
  if api_client.vertexai:
164
170
  if base_models:
165
171
  return 'publishers/google/models'
@@ -173,7 +179,7 @@ def t_models_url(api_client: _api_client.ApiClient, base_models: bool) -> str:
173
179
 
174
180
 
175
181
  def t_extract_models(
176
- api_client: _api_client.ApiClient, response: dict
182
+ api_client: _api_client.BaseApiClient, response: dict
177
183
  ) -> list[types.Model]:
178
184
  if not response:
179
185
  return []
@@ -183,11 +189,18 @@ def t_extract_models(
183
189
  return response.get('tunedModels')
184
190
  elif response.get('publisherModels') is not None:
185
191
  return response.get('publisherModels')
192
+ elif (
193
+ response.get('httpHeaders') is not None
194
+ and response.get('jsonPayload') is None
195
+ ):
196
+ return []
186
197
  else:
187
- raise ValueError('Cannot determine the models type.')
198
+ logger.warning('Cannot determine the models type.')
199
+ logger.debug('Cannot determine the models type for response: %s', response)
200
+ return []
188
201
 
189
202
 
190
- def t_caches_model(api_client: _api_client.ApiClient, model: str):
203
+ def t_caches_model(api_client: _api_client.BaseApiClient, model: str):
191
204
  model = t_model(api_client, model)
192
205
  if not model:
193
206
  return None
@@ -203,6 +216,7 @@ def t_caches_model(api_client: _api_client.ApiClient, model: str):
203
216
 
204
217
 
205
218
  def pil_to_blob(img) -> types.Blob:
219
+ PngImagePlugin: Optional[builtin_types.ModuleType]
206
220
  try:
207
221
  import PIL.PngImagePlugin
208
222
 
@@ -226,10 +240,9 @@ def pil_to_blob(img) -> types.Blob:
226
240
  return types.Blob(mime_type=mime_type, data=data)
227
241
 
228
242
 
229
- PartType = Union[types.Part, types.PartDict, str, 'PIL.Image.Image']
230
-
231
-
232
- def t_part(client: _api_client.ApiClient, part: PartType) -> types.Part:
243
+ def t_part(
244
+ client: _api_client.BaseApiClient, part: Optional[types.PartUnionDict]
245
+ ) -> types.Part:
233
246
  try:
234
247
  import PIL.Image
235
248
 
@@ -237,7 +250,7 @@ def t_part(client: _api_client.ApiClient, part: PartType) -> types.Part:
237
250
  except ImportError:
238
251
  PIL_Image = None
239
252
 
240
- if not part:
253
+ if part is None:
241
254
  raise ValueError('content part is required.')
242
255
  if isinstance(part, str):
243
256
  return types.Part(text=part)
@@ -247,14 +260,19 @@ def t_part(client: _api_client.ApiClient, part: PartType) -> types.Part:
247
260
  if not part.uri or not part.mime_type:
248
261
  raise ValueError('file uri and mime_type are required.')
249
262
  return types.Part.from_uri(file_uri=part.uri, mime_type=part.mime_type)
250
- else:
263
+ if isinstance(part, dict):
264
+ return types.Part.model_validate(part)
265
+ if isinstance(part, types.Part):
251
266
  return part
267
+ raise ValueError(f'Unsupported content part type: {type(part)}')
252
268
 
253
269
 
254
270
  def t_parts(
255
- client: _api_client.ApiClient, parts: Union[list, PartType]
271
+ client: _api_client.BaseApiClient,
272
+ parts: Optional[Union[list[types.PartUnionDict], types.PartUnionDict]],
256
273
  ) -> list[types.Part]:
257
- if parts is None:
274
+ #
275
+ if parts is None or (isinstance(parts, list) and not parts):
258
276
  raise ValueError('content parts are required.')
259
277
  if isinstance(parts, list):
260
278
  return [t_part(client, part) for part in parts]
@@ -263,7 +281,7 @@ def t_parts(
263
281
 
264
282
 
265
283
  def t_image_predictions(
266
- client: _api_client.ApiClient,
284
+ client: _api_client.BaseApiClient,
267
285
  predictions: Optional[Iterable[Mapping[str, Any]]],
268
286
  ) -> list[types.GeneratedImage]:
269
287
  if not predictions:
@@ -282,24 +300,38 @@ def t_image_predictions(
282
300
  return images
283
301
 
284
302
 
285
- ContentType = Union[types.Content, types.ContentDict, PartType]
303
+ ContentType = Union[types.Content, types.ContentDict, types.PartUnionDict]
286
304
 
287
305
 
288
306
  def t_content(
289
- client: _api_client.ApiClient,
290
- content: ContentType,
291
- ):
292
- if not content:
307
+ client: _api_client.BaseApiClient,
308
+ content: Optional[ContentType],
309
+ ) -> types.Content:
310
+ if content is None:
293
311
  raise ValueError('content is required.')
294
312
  if isinstance(content, types.Content):
295
313
  return content
296
314
  if isinstance(content, dict):
297
- return types.Content.model_validate(content)
298
- return types.Content(role='user', parts=t_parts(client, content))
315
+ try:
316
+ return types.Content.model_validate(content)
317
+ except pydantic.ValidationError:
318
+ possible_part = types.Part.model_validate(content)
319
+ return (
320
+ types.ModelContent(parts=[possible_part])
321
+ if possible_part.function_call
322
+ else types.UserContent(parts=[possible_part])
323
+ )
324
+ if isinstance(content, types.Part):
325
+ return (
326
+ types.ModelContent(parts=[content])
327
+ if content.function_call
328
+ else types.UserContent(parts=[content])
329
+ )
330
+ return types.UserContent(parts=content)
299
331
 
300
332
 
301
333
  def t_contents_for_embed(
302
- client: _api_client.ApiClient,
334
+ client: _api_client.BaseApiClient,
303
335
  contents: Union[list[types.Content], list[types.ContentDict], ContentType],
304
336
  ):
305
337
  if client.vertexai and isinstance(contents, list):
@@ -314,16 +346,105 @@ def t_contents_for_embed(
314
346
 
315
347
 
316
348
  def t_contents(
317
- client: _api_client.ApiClient,
318
- contents: Union[list[types.Content], list[types.ContentDict], ContentType],
319
- ):
320
- if not contents:
349
+ client: _api_client.BaseApiClient,
350
+ contents: Optional[
351
+ Union[types.ContentListUnion, types.ContentListUnionDict]
352
+ ],
353
+ ) -> list[types.Content]:
354
+ if contents is None or (isinstance(contents, list) and not contents):
321
355
  raise ValueError('contents are required.')
322
- if isinstance(contents, list):
323
- return [t_content(client, content) for content in contents]
324
- else:
356
+ if not isinstance(contents, list):
325
357
  return [t_content(client, contents)]
326
358
 
359
+ try:
360
+ import PIL.Image
361
+
362
+ PIL_Image = PIL.Image.Image
363
+ except ImportError:
364
+ PIL_Image = None
365
+
366
+ result: list[types.Content] = []
367
+ accumulated_parts: list[types.Part] = []
368
+
369
+ def _is_part(part: types.PartUnionDict) -> bool:
370
+ if (
371
+ isinstance(part, str)
372
+ or isinstance(part, types.File)
373
+ or (PIL_Image is not None and isinstance(part, PIL_Image))
374
+ or isinstance(part, types.Part)
375
+ ):
376
+ return True
377
+
378
+ if isinstance(part, dict):
379
+ try:
380
+ types.Part.model_validate(part)
381
+ return True
382
+ except pydantic.ValidationError:
383
+ return False
384
+
385
+ return False
386
+
387
+ def _is_user_part(part: types.Part) -> bool:
388
+ return not part.function_call
389
+
390
+ def _are_user_parts(parts: list[types.Part]) -> bool:
391
+ return all(_is_user_part(part) for part in parts)
392
+
393
+ def _append_accumulated_parts_as_content(
394
+ result: list[types.Content],
395
+ accumulated_parts: list[types.Part],
396
+ ):
397
+ if not accumulated_parts:
398
+ return
399
+ result.append(
400
+ types.UserContent(parts=accumulated_parts)
401
+ if _are_user_parts(accumulated_parts)
402
+ else types.ModelContent(parts=accumulated_parts)
403
+ )
404
+ accumulated_parts[:] = []
405
+
406
+ def _handle_current_part(
407
+ result: list[types.Content],
408
+ accumulated_parts: list[types.Part],
409
+ current_part: types.PartUnionDict,
410
+ ):
411
+ current_part = t_part(client, current_part)
412
+ if _is_user_part(current_part) == _are_user_parts(accumulated_parts):
413
+ accumulated_parts.append(current_part)
414
+ else:
415
+ _append_accumulated_parts_as_content(result, accumulated_parts)
416
+ accumulated_parts[:] = [current_part]
417
+
418
+ # iterator over contents
419
+ # if content type or content dict, append to result
420
+ # if consecutive part(s),
421
+ # group consecutive user part(s) to a UserContent
422
+ # group consecutive model part(s) to a ModelContent
423
+ # append to result
424
+ # if list, we only accept a list of types.PartUnion
425
+ for content in contents:
426
+ if (
427
+ isinstance(content, types.Content)
428
+ # only allowed inner list is a list of types.PartUnion
429
+ or isinstance(content, list)
430
+ ):
431
+ _append_accumulated_parts_as_content(result, accumulated_parts)
432
+ if isinstance(content, list):
433
+ result.append(types.UserContent(parts=content))
434
+ else:
435
+ result.append(content)
436
+ elif (_is_part(content)): # type: ignore
437
+ _handle_current_part(result, accumulated_parts, content) # type: ignore
438
+ elif isinstance(content, dict):
439
+ # PactDict is already handled in _is_part
440
+ result.append(types.Content.model_validate(content))
441
+ else:
442
+ raise ValueError(f'Unsupported content type: {type(content)}')
443
+
444
+ _append_accumulated_parts_as_content(result, accumulated_parts)
445
+
446
+ return result
447
+
327
448
 
328
449
  def handle_null_fields(schema: dict[str, Any]):
329
450
  """Process null fields in the schema so it is compatible with OpenAPI.
@@ -386,7 +507,7 @@ def handle_null_fields(schema: dict[str, Any]):
386
507
 
387
508
  def process_schema(
388
509
  schema: dict[str, Any],
389
- client: Optional[_api_client.ApiClient] = None,
510
+ client: Optional[_api_client.BaseApiClient] = None,
390
511
  defs: Optional[dict[str, Any]] = None,
391
512
  *,
392
513
  order_properties: bool = True,
@@ -549,9 +670,9 @@ def process_schema(
549
670
 
550
671
 
551
672
  def _process_enum(
552
- enum: EnumMeta, client: Optional[_api_client.ApiClient] = None
673
+ enum: EnumMeta, client: Optional[_api_client.BaseApiClient] = None # type: ignore
553
674
  ) -> types.Schema:
554
- for member in enum:
675
+ for member in enum: # type: ignore
555
676
  if not isinstance(member.value, str):
556
677
  raise TypeError(
557
678
  f'Enum member {member.name} value must be a string, got'
@@ -568,7 +689,7 @@ def _process_enum(
568
689
 
569
690
 
570
691
  def t_schema(
571
- client: _api_client.ApiClient, origin: Union[types.SchemaUnionDict, Any]
692
+ client: _api_client.BaseApiClient, origin: Union[types.SchemaUnionDict, Any]
572
693
  ) -> Optional[types.Schema]:
573
694
  if not origin:
574
695
  return None
@@ -614,7 +735,7 @@ def t_schema(
614
735
 
615
736
 
616
737
  def t_speech_config(
617
- _: _api_client.ApiClient, origin: Union[types.SpeechConfigUnionDict, Any]
738
+ _: _api_client.BaseApiClient, origin: Union[types.SpeechConfigUnionDict, Any]
618
739
  ) -> Optional[types.SpeechConfig]:
619
740
  if not origin:
620
741
  return None
@@ -643,7 +764,7 @@ def t_speech_config(
643
764
  raise ValueError(f'Unsupported speechConfig type: {type(origin)}')
644
765
 
645
766
 
646
- def t_tool(client: _api_client.ApiClient, origin) -> types.Tool:
767
+ def t_tool(client: _api_client.BaseApiClient, origin) -> types.Tool:
647
768
  if not origin:
648
769
  return None
649
770
  if inspect.isfunction(origin) or inspect.ismethod(origin):
@@ -660,7 +781,7 @@ def t_tool(client: _api_client.ApiClient, origin) -> types.Tool:
660
781
 
661
782
  # Only support functions now.
662
783
  def t_tools(
663
- client: _api_client.ApiClient, origin: list[Any]
784
+ client: _api_client.BaseApiClient, origin: list[Any]
664
785
  ) -> list[types.Tool]:
665
786
  if not origin:
666
787
  return []
@@ -680,11 +801,11 @@ def t_tools(
680
801
  return tools
681
802
 
682
803
 
683
- def t_cached_content_name(client: _api_client.ApiClient, name: str):
804
+ def t_cached_content_name(client: _api_client.BaseApiClient, name: str):
684
805
  return _resource_name(client, name, collection_identifier='cachedContents')
685
806
 
686
807
 
687
- def t_batch_job_source(client: _api_client.ApiClient, src: str):
808
+ def t_batch_job_source(client: _api_client.BaseApiClient, src: str):
688
809
  if src.startswith('gs://'):
689
810
  return types.BatchJobSource(
690
811
  format='jsonl',
@@ -699,7 +820,7 @@ def t_batch_job_source(client: _api_client.ApiClient, src: str):
699
820
  raise ValueError(f'Unsupported source: {src}')
700
821
 
701
822
 
702
- def t_batch_job_destination(client: _api_client.ApiClient, dest: str):
823
+ def t_batch_job_destination(client: _api_client.BaseApiClient, dest: str):
703
824
  if dest.startswith('gs://'):
704
825
  return types.BatchJobDestination(
705
826
  format='jsonl',
@@ -714,7 +835,7 @@ def t_batch_job_destination(client: _api_client.ApiClient, dest: str):
714
835
  raise ValueError(f'Unsupported destination: {dest}')
715
836
 
716
837
 
717
- def t_batch_job_name(client: _api_client.ApiClient, name: str):
838
+ def t_batch_job_name(client: _api_client.BaseApiClient, name: str):
718
839
  if not client.vertexai:
719
840
  return name
720
841
 
@@ -733,7 +854,7 @@ LRO_POLLING_TIMEOUT_SECONDS = 900.0
733
854
  LRO_POLLING_MULTIPLIER = 1.5
734
855
 
735
856
 
736
- def t_resolve_operation(api_client: _api_client.ApiClient, struct: dict):
857
+ def t_resolve_operation(api_client: _api_client.BaseApiClient, struct: dict):
737
858
  if (name := struct.get('name')) and '/operations/' in name:
738
859
  operation: dict[str, Any] = struct
739
860
  total_seconds = 0.0
@@ -742,7 +863,7 @@ def t_resolve_operation(api_client: _api_client.ApiClient, struct: dict):
742
863
  if total_seconds > LRO_POLLING_TIMEOUT_SECONDS:
743
864
  raise RuntimeError(f'Operation {name} timed out.\n{operation}')
744
865
  # TODO(b/374433890): Replace with LRO module once it's available.
745
- operation: dict[str, Any] = api_client.request(
866
+ operation = api_client.request(
746
867
  http_method='GET', path=name, request_dict={}
747
868
  )
748
869
  time.sleep(delay_seconds)
@@ -762,7 +883,7 @@ def t_resolve_operation(api_client: _api_client.ApiClient, struct: dict):
762
883
 
763
884
 
764
885
  def t_file_name(
765
- api_client: _api_client.ApiClient, name: Union[str, types.File]
886
+ api_client: _api_client.BaseApiClient, name: Optional[Union[str, types.File]]
766
887
  ):
767
888
  # Remove the files/ prefix since it's added to the url path.
768
889
  if isinstance(name, types.File):
@@ -784,7 +905,7 @@ def t_file_name(
784
905
 
785
906
 
786
907
  def t_tuning_job_status(
787
- api_client: _api_client.ApiClient, status: str
908
+ api_client: _api_client.BaseApiClient, status: str
788
909
  ) -> types.JobState:
789
910
  if status == 'STATE_UNSPECIFIED':
790
911
  return 'JOB_STATE_UNSPECIFIED'
@@ -802,7 +923,7 @@ def t_tuning_job_status(
802
923
  # We shouldn't use this transformer if the backend adhere to Cloud Type
803
924
  # format https://cloud.google.com/docs/discovery/type-format.
804
925
  # TODO(b/389133914,b/390320301): Remove the hack after backend fix the issue.
805
- def t_bytes(api_client: _api_client.ApiClient, data: bytes) -> str:
926
+ def t_bytes(api_client: _api_client.BaseApiClient, data: bytes) -> str:
806
927
  if not isinstance(data, bytes):
807
928
  return data
808
929
  return base64.b64encode(data).decode('ascii')