google-genai 1.21.1__py3-none-any.whl → 1.22.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -996,14 +996,14 @@ class BaseApiClient:
996
996
  path: str,
997
997
  request_dict: dict[str, object],
998
998
  http_options: Optional[HttpOptionsOrDict] = None,
999
- ) -> Generator[Any, None, None]:
999
+ ) -> Generator[SdkHttpResponse, None, None]:
1000
1000
  http_request = self._build_request(
1001
1001
  http_method, path, request_dict, http_options
1002
1002
  )
1003
1003
 
1004
1004
  session_response = self._request(http_request, stream=True)
1005
1005
  for chunk in session_response.segments():
1006
- yield chunk
1006
+ yield SdkHttpResponse(headers=session_response.headers, body=json.dumps(chunk))
1007
1007
 
1008
1008
  async def async_request(
1009
1009
  self,
@@ -1038,7 +1038,7 @@ class BaseApiClient:
1038
1038
 
1039
1039
  async def async_generator(): # type: ignore[no-untyped-def]
1040
1040
  async for chunk in response:
1041
- yield chunk
1041
+ yield SdkHttpResponse(headers=response.headers, body=json.dumps(chunk))
1042
1042
 
1043
1043
  return async_generator() # type: ignore[no-untyped-call]
1044
1044
 
google/genai/_common.py CHANGED
@@ -16,12 +16,13 @@
16
16
  """Common utilities for the SDK."""
17
17
 
18
18
  import base64
19
+ import collections.abc
19
20
  import datetime
20
21
  import enum
21
22
  import functools
22
23
  import logging
23
24
  import typing
24
- from typing import Any, Callable, Optional, Union, get_origin, get_args
25
+ from typing import Any, Callable, Optional, FrozenSet, Union, get_args, get_origin
25
26
  import uuid
26
27
  import warnings
27
28
 
@@ -233,6 +234,179 @@ def _remove_extra_fields(
233
234
  T = typing.TypeVar('T', bound='BaseModel')
234
235
 
235
236
 
237
+ def _pretty_repr(
238
+ obj: Any,
239
+ *,
240
+ indent_level: int = 0,
241
+ indent_delta: int = 2,
242
+ max_len: int = 100,
243
+ max_items: int = 5,
244
+ depth: int = 6,
245
+ visited: Optional[FrozenSet[int]] = None,
246
+ ) -> str:
247
+ """Returns a representation of the given object."""
248
+ if visited is None:
249
+ visited = frozenset()
250
+
251
+ obj_id = id(obj)
252
+ if obj_id in visited:
253
+ return '<... Circular reference ...>'
254
+
255
+ if depth < 0:
256
+ return '<... Max depth ...>'
257
+
258
+ visited = frozenset(list(visited) + [obj_id])
259
+
260
+ indent = ' ' * indent_level
261
+ next_indent_str = ' ' * (indent_level + indent_delta)
262
+
263
+ if isinstance(obj, pydantic.BaseModel):
264
+ cls_name = obj.__class__.__name__
265
+ items = []
266
+ # Sort fields for consistent output
267
+ fields = sorted(type(obj).model_fields)
268
+
269
+ for field_name in fields:
270
+ field_info = type(obj).model_fields[field_name]
271
+ if not field_info.repr: # Respect Field(repr=False)
272
+ continue
273
+
274
+ try:
275
+ value = getattr(obj, field_name)
276
+ except AttributeError:
277
+ continue
278
+
279
+ if value is None:
280
+ continue
281
+
282
+ value_repr = _pretty_repr(
283
+ value,
284
+ indent_level=indent_level + indent_delta,
285
+ indent_delta=indent_delta,
286
+ max_len=max_len,
287
+ max_items=max_items,
288
+ depth=depth - 1,
289
+ visited=visited,
290
+ )
291
+ items.append(f'{next_indent_str}{field_name}={value_repr}')
292
+
293
+ if not items:
294
+ return f'{cls_name}()'
295
+ return f'{cls_name}(\n' + ',\n'.join(items) + f'\n{indent})'
296
+ elif isinstance(obj, str):
297
+ if '\n' in obj:
298
+ escaped = obj.replace('"""', '\\"\\"\\"')
299
+ # Indent the multi-line string block contents
300
+ return f'"""{escaped}"""'
301
+ return repr(obj)
302
+ elif isinstance(obj, bytes):
303
+ if len(obj) > max_len:
304
+ return f"{repr(obj[:max_len-3])[:-1]}...'"
305
+ return repr(obj)
306
+ elif isinstance(obj, collections.abc.Mapping):
307
+ if not obj:
308
+ return '{}'
309
+ if len(obj) > max_items:
310
+ return f'<dict len={len(obj)}>'
311
+ items = []
312
+ try:
313
+ sorted_keys = sorted(obj.keys(), key=str)
314
+ except TypeError:
315
+ sorted_keys = list(obj.keys())
316
+
317
+ for k in sorted_keys:
318
+ v = obj[k]
319
+ k_repr = _pretty_repr(
320
+ k,
321
+ indent_level=indent_level + indent_delta,
322
+ indent_delta=indent_delta,
323
+ max_len=max_len,
324
+ max_items=max_items,
325
+ depth=depth - 1,
326
+ visited=visited,
327
+ )
328
+ v_repr = _pretty_repr(
329
+ v,
330
+ indent_level=indent_level + indent_delta,
331
+ indent_delta=indent_delta,
332
+ max_len=max_len,
333
+ max_items=max_items,
334
+ depth=depth - 1,
335
+ visited=visited,
336
+ )
337
+ items.append(f'{next_indent_str}{k_repr}: {v_repr}')
338
+ return f'{{\n' + ',\n'.join(items) + f'\n{indent}}}'
339
+ elif isinstance(obj, (list, tuple, set)):
340
+ return _format_collection(
341
+ obj,
342
+ indent_level=indent_level,
343
+ indent_delta=indent_delta,
344
+ max_len=max_len,
345
+ max_items=max_items,
346
+ depth=depth,
347
+ visited=visited,
348
+ )
349
+ else:
350
+ # Fallback to standard repr, indenting subsequent lines only
351
+ raw_repr = repr(obj)
352
+ # Replace newlines with newline + indent
353
+ return raw_repr.replace('\n', f'\n{next_indent_str}')
354
+
355
+
356
+
357
+ def _format_collection(
358
+ obj: Any,
359
+ *,
360
+ indent_level: int,
361
+ indent_delta: int,
362
+ max_len: int,
363
+ max_items: int,
364
+ depth: int,
365
+ visited: FrozenSet[int],
366
+ ) -> str:
367
+ """Formats a collection (list, tuple, set)."""
368
+ if isinstance(obj, list):
369
+ brackets = ('[', ']')
370
+ elif isinstance(obj, tuple):
371
+ brackets = ('(', ')')
372
+ elif isinstance(obj, set):
373
+ obj = list(obj)
374
+ if obj:
375
+ brackets = ('{', '}')
376
+ else:
377
+ brackets = ('set(', ')')
378
+ else:
379
+ raise ValueError(f"Unsupported collection type: {type(obj)}")
380
+
381
+ if not obj:
382
+ return brackets[0] + brackets[1]
383
+
384
+ indent = ' ' * indent_level
385
+ next_indent_str = ' ' * (indent_level + indent_delta)
386
+ elements = []
387
+ for i, elem in enumerate(obj):
388
+ if i >= max_items:
389
+ elements.append(
390
+ f'{next_indent_str}<... {len(obj) - max_items} more items ...>'
391
+ )
392
+ break
393
+ # Each element starts on a new line, fully indented
394
+ elements.append(
395
+ next_indent_str
396
+ + _pretty_repr(
397
+ elem,
398
+ indent_level=indent_level + indent_delta,
399
+ indent_delta=indent_delta,
400
+ max_len=max_len,
401
+ max_items=max_items,
402
+ depth=depth - 1,
403
+ visited=visited,
404
+ )
405
+ )
406
+
407
+ return f'{brackets[0]}\n' + ',\n'.join(elements) + "," + f'\n{indent}{brackets[1]}'
408
+
409
+
236
410
  class BaseModel(pydantic.BaseModel):
237
411
 
238
412
  model_config = pydantic.ConfigDict(
@@ -248,6 +422,12 @@ class BaseModel(pydantic.BaseModel):
248
422
  ignored_types=(typing.TypeVar,)
249
423
  )
250
424
 
425
+ def __repr__(self) -> str:
426
+ try:
427
+ return _pretty_repr(self)
428
+ except Exception:
429
+ return super().__repr__()
430
+
251
431
  @classmethod
252
432
  def _from_response(
253
433
  cls: typing.Type[T], *, response: dict[str, object], kwargs: dict[str, object]
@@ -25,6 +25,7 @@ import pydantic
25
25
 
26
26
  from . import _common
27
27
  from . import _mcp_utils
28
+ from . import _transformers as t
28
29
  from . import errors
29
30
  from . import types
30
31
  from ._adapters import McpToGenAiToolAdapter
@@ -62,11 +63,37 @@ def _create_generate_content_config_model(
62
63
  return config
63
64
 
64
65
 
66
+ def _get_gcs_uri(
67
+ src: Union[str, types.BatchJobSourceOrDict]
68
+ ) -> Optional[str]:
69
+ """Extracts the first GCS URI from the source, if available."""
70
+ if isinstance(src, str) and src.startswith('gs://'):
71
+ return src
72
+ elif isinstance(src, dict) and src.get('gcs_uri'):
73
+ return src['gcs_uri'][0] if src['gcs_uri'] else None
74
+ elif isinstance(src, types.BatchJobSource) and src.gcs_uri:
75
+ return src.gcs_uri[0] if src.gcs_uri else None
76
+ return None
77
+
78
+
79
+ def _get_bigquery_uri(
80
+ src: Union[str, types.BatchJobSourceOrDict]
81
+ ) -> Optional[str]:
82
+ """Extracts the BigQuery URI from the source, if available."""
83
+ if isinstance(src, str) and src.startswith('bq://'):
84
+ return src
85
+ elif isinstance(src, dict) and src.get('bigquery_uri'):
86
+ return src['bigquery_uri']
87
+ elif isinstance(src, types.BatchJobSource) and src.bigquery_uri:
88
+ return src.bigquery_uri
89
+ return None
90
+
91
+
65
92
  def format_destination(
66
- src: str,
93
+ src: Union[str, types.BatchJobSourceOrDict],
67
94
  config: Optional[types.CreateBatchJobConfigOrDict] = None,
68
95
  ) -> types.CreateBatchJobConfig:
69
- """Formats the destination uri based on the source uri."""
96
+ """Formats the destination uri based on the source uri for Vertex AI."""
70
97
  config = (
71
98
  types._CreateBatchJobParameters(config=config).config
72
99
  or types.CreateBatchJobConfig()
@@ -78,15 +105,14 @@ def format_destination(
78
105
  config.display_name = f'genai_batch_job_{unique_name}'
79
106
 
80
107
  if not config.dest:
81
- if src.startswith('gs://') and src.endswith('.jsonl'):
82
- # If source uri is "gs://bucket/path/to/src.jsonl", then the destination
83
- # uri prefix will be "gs://bucket/path/to/src/dest".
84
- config.dest = f'{src[:-6]}/dest'
85
- elif src.startswith('bq://'):
86
- # If source uri is "bq://project.dataset.src", then the destination
87
- # uri will be "bq://project.dataset.src_dest_TIMESTAMP_UUID".
108
+ gcs_source_uri = _get_gcs_uri(src)
109
+ bigquery_source_uri = _get_bigquery_uri(src)
110
+
111
+ if gcs_source_uri and gcs_source_uri.endswith('.jsonl'):
112
+ config.dest = f'{gcs_source_uri[:-6]}/dest'
113
+ elif bigquery_source_uri:
88
114
  unique_name = unique_name or _common.timestamped_unique_name()
89
- config.dest = f'{src}_dest_{unique_name}'
115
+ config.dest = f'{bigquery_source_uri}_dest_{unique_name}'
90
116
  else:
91
117
  raise ValueError(f'Unsupported source: {src}')
92
118
  return config
@@ -506,3 +532,15 @@ async def parse_config_for_mcp_sessions(
506
532
  parsed_config_copy.tools.append(tool)
507
533
 
508
534
  return parsed_config_copy, mcp_to_genai_tool_adapters
535
+
536
+
537
+ def append_chunk_contents(
538
+ contents: Union[types.ContentListUnion, types.ContentListUnionDict],
539
+ chunk: types.GenerateContentResponse,
540
+ ) -> None:
541
+ """Appends the contents of the chunk to the contents list."""
542
+ if chunk is not None and chunk.candidates is not None:
543
+ chunk_content = chunk.candidates[0].content
544
+ contents = t.t_contents(contents) # type: ignore[assignment]
545
+ if isinstance(contents, list) and chunk_content is not None:
546
+ contents.append(chunk_content) # type: ignore[arg-type]
@@ -26,7 +26,7 @@ import sys
26
26
  import time
27
27
  import types as builtin_types
28
28
  import typing
29
- from typing import Any, GenericAlias, Optional, Sequence, Union # type: ignore[attr-defined]
29
+ from typing import Any, GenericAlias, List, Optional, Sequence, Union # type: ignore[attr-defined]
30
30
  from ._mcp_utils import mcp_to_gemini_tool
31
31
 
32
32
  if typing.TYPE_CHECKING:
@@ -787,15 +787,25 @@ def process_schema(
787
787
  def _process_enum(
788
788
  enum: EnumMeta, client: _api_client.BaseApiClient
789
789
  ) -> types.Schema:
790
+ is_integer_enum = False
791
+
790
792
  for member in enum: # type: ignore
791
- if not isinstance(member.value, str):
793
+ if isinstance(member.value, int):
794
+ is_integer_enum = True
795
+ elif not isinstance(member.value, str):
792
796
  raise TypeError(
793
- f'Enum member {member.name} value must be a string, got'
797
+ f'Enum member {member.name} value must be a string or integer, got'
794
798
  f' {type(member.value)}'
795
799
  )
796
800
 
801
+ enum_to_process = enum
802
+ if is_integer_enum:
803
+ str_members = [str(member.value) for member in enum] # type: ignore
804
+ str_enum = Enum(enum.__name__, str_members, type=str) # type: ignore
805
+ enum_to_process = str_enum
806
+
797
807
  class Placeholder(pydantic.BaseModel):
798
- placeholder: enum # type: ignore[valid-type]
808
+ placeholder: enum_to_process # type: ignore[valid-type]
799
809
 
800
810
  enum_schema = Placeholder.model_json_schema()
801
811
  process_schema(enum_schema, client)
@@ -944,19 +954,54 @@ def t_cached_content_name(client: _api_client.BaseApiClient, name: str) -> str:
944
954
  return _resource_name(client, name, collection_identifier='cachedContents')
945
955
 
946
956
 
947
- def t_batch_job_source(src: str) -> types.BatchJobSource:
948
- if src.startswith('gs://'):
949
- return types.BatchJobSource(
950
- format='jsonl',
951
- gcs_uri=[src],
952
- )
953
- elif src.startswith('bq://'):
954
- return types.BatchJobSource(
955
- format='bigquery',
956
- bigquery_uri=src,
957
- )
958
- else:
959
- raise ValueError(f'Unsupported source: {src}')
957
+ def t_batch_job_source(
958
+ client: _api_client.BaseApiClient,
959
+ src: Union[
960
+ str, List[types.InlinedRequestOrDict], types.BatchJobSourceOrDict
961
+ ],
962
+ ) -> types.BatchJobSource:
963
+ if isinstance(src, dict):
964
+ src = types.BatchJobSource(**src)
965
+ if isinstance(src, types.BatchJobSource):
966
+ if client.vertexai:
967
+ if src.gcs_uri and src.bigquery_uri:
968
+ raise ValueError(
969
+ 'Only one of `gcs_uri` or `bigquery_uri` can be set.'
970
+ )
971
+ elif not src.gcs_uri and not src.bigquery_uri:
972
+ raise ValueError(
973
+ 'One of `gcs_uri` or `bigquery_uri` must be set.'
974
+ )
975
+ else:
976
+ if src.inlined_requests and src.file_name:
977
+ raise ValueError(
978
+ 'Only one of `inlined_requests` or `file_name` can be set.'
979
+ )
980
+ elif not src.inlined_requests and not src.file_name:
981
+ raise ValueError(
982
+ 'One of `inlined_requests` or `file_name` must be set.'
983
+ )
984
+ return src
985
+
986
+ elif isinstance(src, list):
987
+ return types.BatchJobSource(inlined_requests=src)
988
+ elif isinstance(src, str):
989
+ if src.startswith('gs://'):
990
+ return types.BatchJobSource(
991
+ format='jsonl',
992
+ gcs_uri=[src],
993
+ )
994
+ elif src.startswith('bq://'):
995
+ return types.BatchJobSource(
996
+ format='bigquery',
997
+ bigquery_uri=src,
998
+ )
999
+ elif src.startswith('files/'):
1000
+ return types.BatchJobSource(
1001
+ file_name=src,
1002
+ )
1003
+
1004
+ raise ValueError(f'Unsupported source: {src}')
960
1005
 
961
1006
 
962
1007
  def t_batch_job_destination(dest: str) -> types.BatchJobDestination:
@@ -976,10 +1021,15 @@ def t_batch_job_destination(dest: str) -> types.BatchJobDestination:
976
1021
 
977
1022
  def t_batch_job_name(client: _api_client.BaseApiClient, name: str) -> str:
978
1023
  if not client.vertexai:
979
- return name
1024
+ mldev_pattern = r'batches/[^/]+$'
1025
+ if re.match(mldev_pattern, name):
1026
+ return name.split('/')[-1]
1027
+ else:
1028
+ raise ValueError(f'Invalid batch job name: {name}.')
1029
+
1030
+ vertex_pattern = r'^projects/[^/]+/locations/[^/]+/batchPredictionJobs/[^/]+$'
980
1031
 
981
- pattern = r'^projects/[^/]+/locations/[^/]+/batchPredictionJobs/[^/]+$'
982
- if re.match(pattern, name):
1032
+ if re.match(vertex_pattern, name):
983
1033
  return name.split('/')[-1]
984
1034
  elif name.isdigit():
985
1035
  return name
@@ -987,6 +1037,21 @@ def t_batch_job_name(client: _api_client.BaseApiClient, name: str) -> str:
987
1037
  raise ValueError(f'Invalid batch job name: {name}.')
988
1038
 
989
1039
 
1040
+ def t_job_state(state: str) -> str:
1041
+ if state == 'BATCH_STATE_UNSPECIFIED':
1042
+ return 'JOB_STATE_UNSPECIFIED'
1043
+ elif state == 'BATCH_STATE_PENDING':
1044
+ return 'JOB_STATE_PENDING'
1045
+ elif state == 'BATCH_STATE_SUCCEEDED':
1046
+ return 'JOB_STATE_SUCCEEDED'
1047
+ elif state == 'BATCH_STATE_FAILED':
1048
+ return 'JOB_STATE_FAILED'
1049
+ elif state == 'BATCH_STATE_CANCELLED':
1050
+ return 'JOB_STATE_CANCELLED'
1051
+ else:
1052
+ return state
1053
+
1054
+
990
1055
  LRO_POLLING_INITIAL_DELAY_SECONDS = 1.0
991
1056
  LRO_POLLING_MAXIMUM_DELAY_SECONDS = 20.0
992
1057
  LRO_POLLING_TIMEOUT_SECONDS = 900.0