google-genai 1.21.1__py3-none-any.whl → 1.23.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google/genai/_api_client.py +88 -11
- google/genai/_common.py +181 -1
- google/genai/_extra_utils.py +48 -10
- google/genai/_transformers.py +85 -20
- google/genai/batches.py +4717 -155
- google/genai/caches.py +10 -0
- google/genai/files.py +8 -0
- google/genai/live.py +12 -11
- google/genai/models.py +106 -2
- google/genai/operations.py +4 -0
- google/genai/tunings.py +33 -1
- google/genai/types.py +347 -78
- google/genai/version.py +1 -1
- {google_genai-1.21.1.dist-info → google_genai-1.23.0.dist-info}/METADATA +51 -2
- {google_genai-1.21.1.dist-info → google_genai-1.23.0.dist-info}/RECORD +18 -18
- {google_genai-1.21.1.dist-info → google_genai-1.23.0.dist-info}/WHEEL +0 -0
- {google_genai-1.21.1.dist-info → google_genai-1.23.0.dist-info}/licenses/LICENSE +0 -0
- {google_genai-1.21.1.dist-info → google_genai-1.23.0.dist-info}/top_level.txt +0 -0
google/genai/_api_client.py
CHANGED
@@ -60,6 +60,11 @@ from .types import HttpOptionsOrDict
|
|
60
60
|
from .types import HttpResponse as SdkHttpResponse
|
61
61
|
from .types import HttpRetryOptions
|
62
62
|
|
63
|
+
try:
|
64
|
+
from websockets.asyncio.client import connect as ws_connect
|
65
|
+
except ModuleNotFoundError:
|
66
|
+
# This try/except is for TAP, mypy complains about it which is why we have the type: ignore
|
67
|
+
from websockets.client import connect as ws_connect # type: ignore
|
63
68
|
|
64
69
|
has_aiohttp = False
|
65
70
|
try:
|
@@ -227,11 +232,13 @@ class HttpResponse:
|
|
227
232
|
headers: Union[dict[str, str], httpx.Headers, 'CIMultiDictProxy[str]'],
|
228
233
|
response_stream: Union[Any, str] = None,
|
229
234
|
byte_stream: Union[Any, bytes] = None,
|
235
|
+
session: Optional['aiohttp.ClientSession'] = None,
|
230
236
|
):
|
231
237
|
self.status_code: int = 200
|
232
238
|
self.headers = headers
|
233
239
|
self.response_stream = response_stream
|
234
240
|
self.byte_stream = byte_stream
|
241
|
+
self._session = session
|
235
242
|
|
236
243
|
# Async iterator for async streaming.
|
237
244
|
def __aiter__(self) -> 'HttpResponse':
|
@@ -291,16 +298,23 @@ class HttpResponse:
|
|
291
298
|
chunk = chunk[len('data: ') :]
|
292
299
|
yield json.loads(chunk)
|
293
300
|
elif hasattr(self.response_stream, 'content'):
|
294
|
-
|
295
|
-
|
296
|
-
|
301
|
+
# This is aiohttp.ClientResponse.
|
302
|
+
try:
|
303
|
+
while True:
|
304
|
+
chunk = await self.response_stream.content.readline()
|
305
|
+
if not chunk:
|
306
|
+
break
|
297
307
|
# In async streaming mode, the chunk of JSON is prefixed with
|
298
308
|
# "data:" which we must strip before parsing.
|
299
|
-
|
300
|
-
chunk = chunk.decode('utf-8')
|
309
|
+
chunk = chunk.decode('utf-8')
|
301
310
|
if chunk.startswith('data: '):
|
302
311
|
chunk = chunk[len('data: ') :]
|
303
|
-
|
312
|
+
chunk = chunk.strip()
|
313
|
+
if chunk:
|
314
|
+
yield json.loads(chunk)
|
315
|
+
finally:
|
316
|
+
if hasattr(self, '_session') and self._session:
|
317
|
+
await self._session.close()
|
304
318
|
else:
|
305
319
|
raise ValueError('Error parsing streaming response.')
|
306
320
|
|
@@ -538,6 +552,7 @@ class BaseApiClient:
|
|
538
552
|
# Default options for both clients.
|
539
553
|
self._http_options.headers = {'Content-Type': 'application/json'}
|
540
554
|
if self.api_key:
|
555
|
+
self.api_key = self.api_key.strip()
|
541
556
|
if self._http_options.headers is not None:
|
542
557
|
self._http_options.headers['x-goog-api-key'] = self.api_key
|
543
558
|
# Update the http options with the user provided http options.
|
@@ -558,7 +573,10 @@ class BaseApiClient:
|
|
558
573
|
# Do it once at the genai.Client level. Share among all requests.
|
559
574
|
self._async_client_session_request_args = self._ensure_aiohttp_ssl_ctx(
|
560
575
|
self._http_options
|
561
|
-
)
|
576
|
+
)
|
577
|
+
self._websocket_ssl_ctx = self._ensure_websocket_ssl_ctx(
|
578
|
+
self._http_options
|
579
|
+
)
|
562
580
|
|
563
581
|
retry_kwargs = _retry_args(self._http_options.retry_options)
|
564
582
|
self._retry = tenacity.Retrying(**retry_kwargs, reraise=True)
|
@@ -688,6 +706,63 @@ class BaseApiClient:
|
|
688
706
|
|
689
707
|
return _maybe_set(async_args, ctx)
|
690
708
|
|
709
|
+
|
710
|
+
@staticmethod
|
711
|
+
def _ensure_websocket_ssl_ctx(options: HttpOptions) -> dict[str, Any]:
|
712
|
+
"""Ensures the SSL context is present in the async client args.
|
713
|
+
|
714
|
+
Creates a default SSL context if one is not provided.
|
715
|
+
|
716
|
+
Args:
|
717
|
+
options: The http options to check for SSL context.
|
718
|
+
|
719
|
+
Returns:
|
720
|
+
An async aiohttp ClientSession._request args.
|
721
|
+
"""
|
722
|
+
|
723
|
+
verify = 'ssl' # keep it consistent with httpx.
|
724
|
+
async_args = options.async_client_args
|
725
|
+
ctx = async_args.get(verify) if async_args else None
|
726
|
+
|
727
|
+
if not ctx:
|
728
|
+
# Initialize the SSL context for the httpx client.
|
729
|
+
# Unlike requests, the aiohttp package does not automatically pull in the
|
730
|
+
# environment variables SSL_CERT_FILE or SSL_CERT_DIR. They need to be
|
731
|
+
# enabled explicitly. Instead of 'verify' at client level in httpx,
|
732
|
+
# aiohttp uses 'ssl' at request level.
|
733
|
+
ctx = ssl.create_default_context(
|
734
|
+
cafile=os.environ.get('SSL_CERT_FILE', certifi.where()),
|
735
|
+
capath=os.environ.get('SSL_CERT_DIR'),
|
736
|
+
)
|
737
|
+
|
738
|
+
def _maybe_set(
|
739
|
+
args: Optional[dict[str, Any]],
|
740
|
+
ctx: ssl.SSLContext,
|
741
|
+
) -> dict[str, Any]:
|
742
|
+
"""Sets the SSL context in the client args if not set.
|
743
|
+
|
744
|
+
Does not override the SSL context if it is already set.
|
745
|
+
|
746
|
+
Args:
|
747
|
+
args: The client args to to check for SSL context.
|
748
|
+
ctx: The SSL context to set.
|
749
|
+
|
750
|
+
Returns:
|
751
|
+
The client args with the SSL context included.
|
752
|
+
"""
|
753
|
+
if not args or not args.get(verify):
|
754
|
+
args = (args or {}).copy()
|
755
|
+
args[verify] = ctx
|
756
|
+
# Drop the args that isn't in the aiohttp RequestOptions.
|
757
|
+
copied_args = args.copy()
|
758
|
+
for key in copied_args.copy():
|
759
|
+
if key not in inspect.signature(ws_connect).parameters and key != 'ssl':
|
760
|
+
del copied_args[key]
|
761
|
+
return copied_args
|
762
|
+
|
763
|
+
return _maybe_set(async_args, ctx)
|
764
|
+
|
765
|
+
|
691
766
|
def _websocket_base_url(self) -> str:
|
692
767
|
url_parts = urlparse(self._http_options.base_url)
|
693
768
|
return url_parts._replace(scheme='wss').geturl() # type: ignore[arg-type, return-value]
|
@@ -882,6 +957,7 @@ class BaseApiClient:
|
|
882
957
|
self, http_request: HttpRequest, stream: bool = False
|
883
958
|
) -> HttpResponse:
|
884
959
|
data: Optional[Union[str, bytes]] = None
|
960
|
+
|
885
961
|
if self.vertexai and not self.api_key:
|
886
962
|
http_request.headers['Authorization'] = (
|
887
963
|
f'Bearer {await self._async_access_token()}'
|
@@ -912,8 +988,9 @@ class BaseApiClient:
|
|
912
988
|
timeout=aiohttp.ClientTimeout(connect=http_request.timeout),
|
913
989
|
**self._async_client_session_request_args,
|
914
990
|
)
|
991
|
+
|
915
992
|
await errors.APIError.raise_for_async_response(response)
|
916
|
-
return HttpResponse(response.headers, response)
|
993
|
+
return HttpResponse(response.headers, response, session=session)
|
917
994
|
else:
|
918
995
|
# aiohttp is not available. Fall back to httpx.
|
919
996
|
httpx_request = self._async_httpx_client.build_request(
|
@@ -996,14 +1073,14 @@ class BaseApiClient:
|
|
996
1073
|
path: str,
|
997
1074
|
request_dict: dict[str, object],
|
998
1075
|
http_options: Optional[HttpOptionsOrDict] = None,
|
999
|
-
) -> Generator[
|
1076
|
+
) -> Generator[SdkHttpResponse, None, None]:
|
1000
1077
|
http_request = self._build_request(
|
1001
1078
|
http_method, path, request_dict, http_options
|
1002
1079
|
)
|
1003
1080
|
|
1004
1081
|
session_response = self._request(http_request, stream=True)
|
1005
1082
|
for chunk in session_response.segments():
|
1006
|
-
yield chunk
|
1083
|
+
yield SdkHttpResponse(headers=session_response.headers, body=json.dumps(chunk))
|
1007
1084
|
|
1008
1085
|
async def async_request(
|
1009
1086
|
self,
|
@@ -1038,7 +1115,7 @@ class BaseApiClient:
|
|
1038
1115
|
|
1039
1116
|
async def async_generator(): # type: ignore[no-untyped-def]
|
1040
1117
|
async for chunk in response:
|
1041
|
-
yield chunk
|
1118
|
+
yield SdkHttpResponse(headers=response.headers, body=json.dumps(chunk))
|
1042
1119
|
|
1043
1120
|
return async_generator() # type: ignore[no-untyped-call]
|
1044
1121
|
|
google/genai/_common.py
CHANGED
@@ -16,12 +16,13 @@
|
|
16
16
|
"""Common utilities for the SDK."""
|
17
17
|
|
18
18
|
import base64
|
19
|
+
import collections.abc
|
19
20
|
import datetime
|
20
21
|
import enum
|
21
22
|
import functools
|
22
23
|
import logging
|
23
24
|
import typing
|
24
|
-
from typing import Any, Callable, Optional, Union,
|
25
|
+
from typing import Any, Callable, Optional, FrozenSet, Union, get_args, get_origin
|
25
26
|
import uuid
|
26
27
|
import warnings
|
27
28
|
|
@@ -233,6 +234,179 @@ def _remove_extra_fields(
|
|
233
234
|
T = typing.TypeVar('T', bound='BaseModel')
|
234
235
|
|
235
236
|
|
237
|
+
def _pretty_repr(
|
238
|
+
obj: Any,
|
239
|
+
*,
|
240
|
+
indent_level: int = 0,
|
241
|
+
indent_delta: int = 2,
|
242
|
+
max_len: int = 100,
|
243
|
+
max_items: int = 5,
|
244
|
+
depth: int = 6,
|
245
|
+
visited: Optional[FrozenSet[int]] = None,
|
246
|
+
) -> str:
|
247
|
+
"""Returns a representation of the given object."""
|
248
|
+
if visited is None:
|
249
|
+
visited = frozenset()
|
250
|
+
|
251
|
+
obj_id = id(obj)
|
252
|
+
if obj_id in visited:
|
253
|
+
return '<... Circular reference ...>'
|
254
|
+
|
255
|
+
if depth < 0:
|
256
|
+
return '<... Max depth ...>'
|
257
|
+
|
258
|
+
visited = frozenset(list(visited) + [obj_id])
|
259
|
+
|
260
|
+
indent = ' ' * indent_level
|
261
|
+
next_indent_str = ' ' * (indent_level + indent_delta)
|
262
|
+
|
263
|
+
if isinstance(obj, pydantic.BaseModel):
|
264
|
+
cls_name = obj.__class__.__name__
|
265
|
+
items = []
|
266
|
+
# Sort fields for consistent output
|
267
|
+
fields = sorted(type(obj).model_fields)
|
268
|
+
|
269
|
+
for field_name in fields:
|
270
|
+
field_info = type(obj).model_fields[field_name]
|
271
|
+
if not field_info.repr: # Respect Field(repr=False)
|
272
|
+
continue
|
273
|
+
|
274
|
+
try:
|
275
|
+
value = getattr(obj, field_name)
|
276
|
+
except AttributeError:
|
277
|
+
continue
|
278
|
+
|
279
|
+
if value is None:
|
280
|
+
continue
|
281
|
+
|
282
|
+
value_repr = _pretty_repr(
|
283
|
+
value,
|
284
|
+
indent_level=indent_level + indent_delta,
|
285
|
+
indent_delta=indent_delta,
|
286
|
+
max_len=max_len,
|
287
|
+
max_items=max_items,
|
288
|
+
depth=depth - 1,
|
289
|
+
visited=visited,
|
290
|
+
)
|
291
|
+
items.append(f'{next_indent_str}{field_name}={value_repr}')
|
292
|
+
|
293
|
+
if not items:
|
294
|
+
return f'{cls_name}()'
|
295
|
+
return f'{cls_name}(\n' + ',\n'.join(items) + f'\n{indent})'
|
296
|
+
elif isinstance(obj, str):
|
297
|
+
if '\n' in obj:
|
298
|
+
escaped = obj.replace('"""', '\\"\\"\\"')
|
299
|
+
# Indent the multi-line string block contents
|
300
|
+
return f'"""{escaped}"""'
|
301
|
+
return repr(obj)
|
302
|
+
elif isinstance(obj, bytes):
|
303
|
+
if len(obj) > max_len:
|
304
|
+
return f"{repr(obj[:max_len-3])[:-1]}...'"
|
305
|
+
return repr(obj)
|
306
|
+
elif isinstance(obj, collections.abc.Mapping):
|
307
|
+
if not obj:
|
308
|
+
return '{}'
|
309
|
+
if len(obj) > max_items:
|
310
|
+
return f'<dict len={len(obj)}>'
|
311
|
+
items = []
|
312
|
+
try:
|
313
|
+
sorted_keys = sorted(obj.keys(), key=str)
|
314
|
+
except TypeError:
|
315
|
+
sorted_keys = list(obj.keys())
|
316
|
+
|
317
|
+
for k in sorted_keys:
|
318
|
+
v = obj[k]
|
319
|
+
k_repr = _pretty_repr(
|
320
|
+
k,
|
321
|
+
indent_level=indent_level + indent_delta,
|
322
|
+
indent_delta=indent_delta,
|
323
|
+
max_len=max_len,
|
324
|
+
max_items=max_items,
|
325
|
+
depth=depth - 1,
|
326
|
+
visited=visited,
|
327
|
+
)
|
328
|
+
v_repr = _pretty_repr(
|
329
|
+
v,
|
330
|
+
indent_level=indent_level + indent_delta,
|
331
|
+
indent_delta=indent_delta,
|
332
|
+
max_len=max_len,
|
333
|
+
max_items=max_items,
|
334
|
+
depth=depth - 1,
|
335
|
+
visited=visited,
|
336
|
+
)
|
337
|
+
items.append(f'{next_indent_str}{k_repr}: {v_repr}')
|
338
|
+
return f'{{\n' + ',\n'.join(items) + f'\n{indent}}}'
|
339
|
+
elif isinstance(obj, (list, tuple, set)):
|
340
|
+
return _format_collection(
|
341
|
+
obj,
|
342
|
+
indent_level=indent_level,
|
343
|
+
indent_delta=indent_delta,
|
344
|
+
max_len=max_len,
|
345
|
+
max_items=max_items,
|
346
|
+
depth=depth,
|
347
|
+
visited=visited,
|
348
|
+
)
|
349
|
+
else:
|
350
|
+
# Fallback to standard repr, indenting subsequent lines only
|
351
|
+
raw_repr = repr(obj)
|
352
|
+
# Replace newlines with newline + indent
|
353
|
+
return raw_repr.replace('\n', f'\n{next_indent_str}')
|
354
|
+
|
355
|
+
|
356
|
+
|
357
|
+
def _format_collection(
|
358
|
+
obj: Any,
|
359
|
+
*,
|
360
|
+
indent_level: int,
|
361
|
+
indent_delta: int,
|
362
|
+
max_len: int,
|
363
|
+
max_items: int,
|
364
|
+
depth: int,
|
365
|
+
visited: FrozenSet[int],
|
366
|
+
) -> str:
|
367
|
+
"""Formats a collection (list, tuple, set)."""
|
368
|
+
if isinstance(obj, list):
|
369
|
+
brackets = ('[', ']')
|
370
|
+
elif isinstance(obj, tuple):
|
371
|
+
brackets = ('(', ')')
|
372
|
+
elif isinstance(obj, set):
|
373
|
+
obj = list(obj)
|
374
|
+
if obj:
|
375
|
+
brackets = ('{', '}')
|
376
|
+
else:
|
377
|
+
brackets = ('set(', ')')
|
378
|
+
else:
|
379
|
+
raise ValueError(f"Unsupported collection type: {type(obj)}")
|
380
|
+
|
381
|
+
if not obj:
|
382
|
+
return brackets[0] + brackets[1]
|
383
|
+
|
384
|
+
indent = ' ' * indent_level
|
385
|
+
next_indent_str = ' ' * (indent_level + indent_delta)
|
386
|
+
elements = []
|
387
|
+
for i, elem in enumerate(obj):
|
388
|
+
if i >= max_items:
|
389
|
+
elements.append(
|
390
|
+
f'{next_indent_str}<... {len(obj) - max_items} more items ...>'
|
391
|
+
)
|
392
|
+
break
|
393
|
+
# Each element starts on a new line, fully indented
|
394
|
+
elements.append(
|
395
|
+
next_indent_str
|
396
|
+
+ _pretty_repr(
|
397
|
+
elem,
|
398
|
+
indent_level=indent_level + indent_delta,
|
399
|
+
indent_delta=indent_delta,
|
400
|
+
max_len=max_len,
|
401
|
+
max_items=max_items,
|
402
|
+
depth=depth - 1,
|
403
|
+
visited=visited,
|
404
|
+
)
|
405
|
+
)
|
406
|
+
|
407
|
+
return f'{brackets[0]}\n' + ',\n'.join(elements) + "," + f'\n{indent}{brackets[1]}'
|
408
|
+
|
409
|
+
|
236
410
|
class BaseModel(pydantic.BaseModel):
|
237
411
|
|
238
412
|
model_config = pydantic.ConfigDict(
|
@@ -248,6 +422,12 @@ class BaseModel(pydantic.BaseModel):
|
|
248
422
|
ignored_types=(typing.TypeVar,)
|
249
423
|
)
|
250
424
|
|
425
|
+
def __repr__(self) -> str:
|
426
|
+
try:
|
427
|
+
return _pretty_repr(self)
|
428
|
+
except Exception:
|
429
|
+
return super().__repr__()
|
430
|
+
|
251
431
|
@classmethod
|
252
432
|
def _from_response(
|
253
433
|
cls: typing.Type[T], *, response: dict[str, object], kwargs: dict[str, object]
|
google/genai/_extra_utils.py
CHANGED
@@ -25,6 +25,7 @@ import pydantic
|
|
25
25
|
|
26
26
|
from . import _common
|
27
27
|
from . import _mcp_utils
|
28
|
+
from . import _transformers as t
|
28
29
|
from . import errors
|
29
30
|
from . import types
|
30
31
|
from ._adapters import McpToGenAiToolAdapter
|
@@ -62,11 +63,37 @@ def _create_generate_content_config_model(
|
|
62
63
|
return config
|
63
64
|
|
64
65
|
|
66
|
+
def _get_gcs_uri(
|
67
|
+
src: Union[str, types.BatchJobSourceOrDict]
|
68
|
+
) -> Optional[str]:
|
69
|
+
"""Extracts the first GCS URI from the source, if available."""
|
70
|
+
if isinstance(src, str) and src.startswith('gs://'):
|
71
|
+
return src
|
72
|
+
elif isinstance(src, dict) and src.get('gcs_uri'):
|
73
|
+
return src['gcs_uri'][0] if src['gcs_uri'] else None
|
74
|
+
elif isinstance(src, types.BatchJobSource) and src.gcs_uri:
|
75
|
+
return src.gcs_uri[0] if src.gcs_uri else None
|
76
|
+
return None
|
77
|
+
|
78
|
+
|
79
|
+
def _get_bigquery_uri(
|
80
|
+
src: Union[str, types.BatchJobSourceOrDict]
|
81
|
+
) -> Optional[str]:
|
82
|
+
"""Extracts the BigQuery URI from the source, if available."""
|
83
|
+
if isinstance(src, str) and src.startswith('bq://'):
|
84
|
+
return src
|
85
|
+
elif isinstance(src, dict) and src.get('bigquery_uri'):
|
86
|
+
return src['bigquery_uri']
|
87
|
+
elif isinstance(src, types.BatchJobSource) and src.bigquery_uri:
|
88
|
+
return src.bigquery_uri
|
89
|
+
return None
|
90
|
+
|
91
|
+
|
65
92
|
def format_destination(
|
66
|
-
src: str,
|
93
|
+
src: Union[str, types.BatchJobSourceOrDict],
|
67
94
|
config: Optional[types.CreateBatchJobConfigOrDict] = None,
|
68
95
|
) -> types.CreateBatchJobConfig:
|
69
|
-
"""Formats the destination uri based on the source uri."""
|
96
|
+
"""Formats the destination uri based on the source uri for Vertex AI."""
|
70
97
|
config = (
|
71
98
|
types._CreateBatchJobParameters(config=config).config
|
72
99
|
or types.CreateBatchJobConfig()
|
@@ -78,15 +105,14 @@ def format_destination(
|
|
78
105
|
config.display_name = f'genai_batch_job_{unique_name}'
|
79
106
|
|
80
107
|
if not config.dest:
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
# uri will be "bq://project.dataset.src_dest_TIMESTAMP_UUID".
|
108
|
+
gcs_source_uri = _get_gcs_uri(src)
|
109
|
+
bigquery_source_uri = _get_bigquery_uri(src)
|
110
|
+
|
111
|
+
if gcs_source_uri and gcs_source_uri.endswith('.jsonl'):
|
112
|
+
config.dest = f'{gcs_source_uri[:-6]}/dest'
|
113
|
+
elif bigquery_source_uri:
|
88
114
|
unique_name = unique_name or _common.timestamped_unique_name()
|
89
|
-
config.dest = f'{
|
115
|
+
config.dest = f'{bigquery_source_uri}_dest_{unique_name}'
|
90
116
|
else:
|
91
117
|
raise ValueError(f'Unsupported source: {src}')
|
92
118
|
return config
|
@@ -506,3 +532,15 @@ async def parse_config_for_mcp_sessions(
|
|
506
532
|
parsed_config_copy.tools.append(tool)
|
507
533
|
|
508
534
|
return parsed_config_copy, mcp_to_genai_tool_adapters
|
535
|
+
|
536
|
+
|
537
|
+
def append_chunk_contents(
|
538
|
+
contents: Union[types.ContentListUnion, types.ContentListUnionDict],
|
539
|
+
chunk: types.GenerateContentResponse,
|
540
|
+
) -> None:
|
541
|
+
"""Appends the contents of the chunk to the contents list."""
|
542
|
+
if chunk is not None and chunk.candidates is not None:
|
543
|
+
chunk_content = chunk.candidates[0].content
|
544
|
+
contents = t.t_contents(contents) # type: ignore[assignment]
|
545
|
+
if isinstance(contents, list) and chunk_content is not None:
|
546
|
+
contents.append(chunk_content) # type: ignore[arg-type]
|
google/genai/_transformers.py
CHANGED
@@ -26,7 +26,7 @@ import sys
|
|
26
26
|
import time
|
27
27
|
import types as builtin_types
|
28
28
|
import typing
|
29
|
-
from typing import Any, GenericAlias, Optional, Sequence, Union # type: ignore[attr-defined]
|
29
|
+
from typing import Any, GenericAlias, List, Optional, Sequence, Union # type: ignore[attr-defined]
|
30
30
|
from ._mcp_utils import mcp_to_gemini_tool
|
31
31
|
|
32
32
|
if typing.TYPE_CHECKING:
|
@@ -787,15 +787,25 @@ def process_schema(
|
|
787
787
|
def _process_enum(
|
788
788
|
enum: EnumMeta, client: _api_client.BaseApiClient
|
789
789
|
) -> types.Schema:
|
790
|
+
is_integer_enum = False
|
791
|
+
|
790
792
|
for member in enum: # type: ignore
|
791
|
-
if
|
793
|
+
if isinstance(member.value, int):
|
794
|
+
is_integer_enum = True
|
795
|
+
elif not isinstance(member.value, str):
|
792
796
|
raise TypeError(
|
793
|
-
f'Enum member {member.name} value must be a string, got'
|
797
|
+
f'Enum member {member.name} value must be a string or integer, got'
|
794
798
|
f' {type(member.value)}'
|
795
799
|
)
|
796
800
|
|
801
|
+
enum_to_process = enum
|
802
|
+
if is_integer_enum:
|
803
|
+
str_members = [str(member.value) for member in enum] # type: ignore
|
804
|
+
str_enum = Enum(enum.__name__, str_members, type=str) # type: ignore
|
805
|
+
enum_to_process = str_enum
|
806
|
+
|
797
807
|
class Placeholder(pydantic.BaseModel):
|
798
|
-
placeholder:
|
808
|
+
placeholder: enum_to_process # type: ignore[valid-type]
|
799
809
|
|
800
810
|
enum_schema = Placeholder.model_json_schema()
|
801
811
|
process_schema(enum_schema, client)
|
@@ -944,19 +954,54 @@ def t_cached_content_name(client: _api_client.BaseApiClient, name: str) -> str:
|
|
944
954
|
return _resource_name(client, name, collection_identifier='cachedContents')
|
945
955
|
|
946
956
|
|
947
|
-
def t_batch_job_source(
|
948
|
-
|
949
|
-
|
950
|
-
|
951
|
-
|
952
|
-
|
953
|
-
|
954
|
-
|
955
|
-
|
956
|
-
|
957
|
-
|
958
|
-
|
959
|
-
|
957
|
+
def t_batch_job_source(
|
958
|
+
client: _api_client.BaseApiClient,
|
959
|
+
src: Union[
|
960
|
+
str, List[types.InlinedRequestOrDict], types.BatchJobSourceOrDict
|
961
|
+
],
|
962
|
+
) -> types.BatchJobSource:
|
963
|
+
if isinstance(src, dict):
|
964
|
+
src = types.BatchJobSource(**src)
|
965
|
+
if isinstance(src, types.BatchJobSource):
|
966
|
+
if client.vertexai:
|
967
|
+
if src.gcs_uri and src.bigquery_uri:
|
968
|
+
raise ValueError(
|
969
|
+
'Only one of `gcs_uri` or `bigquery_uri` can be set.'
|
970
|
+
)
|
971
|
+
elif not src.gcs_uri and not src.bigquery_uri:
|
972
|
+
raise ValueError(
|
973
|
+
'One of `gcs_uri` or `bigquery_uri` must be set.'
|
974
|
+
)
|
975
|
+
else:
|
976
|
+
if src.inlined_requests and src.file_name:
|
977
|
+
raise ValueError(
|
978
|
+
'Only one of `inlined_requests` or `file_name` can be set.'
|
979
|
+
)
|
980
|
+
elif not src.inlined_requests and not src.file_name:
|
981
|
+
raise ValueError(
|
982
|
+
'One of `inlined_requests` or `file_name` must be set.'
|
983
|
+
)
|
984
|
+
return src
|
985
|
+
|
986
|
+
elif isinstance(src, list):
|
987
|
+
return types.BatchJobSource(inlined_requests=src)
|
988
|
+
elif isinstance(src, str):
|
989
|
+
if src.startswith('gs://'):
|
990
|
+
return types.BatchJobSource(
|
991
|
+
format='jsonl',
|
992
|
+
gcs_uri=[src],
|
993
|
+
)
|
994
|
+
elif src.startswith('bq://'):
|
995
|
+
return types.BatchJobSource(
|
996
|
+
format='bigquery',
|
997
|
+
bigquery_uri=src,
|
998
|
+
)
|
999
|
+
elif src.startswith('files/'):
|
1000
|
+
return types.BatchJobSource(
|
1001
|
+
file_name=src,
|
1002
|
+
)
|
1003
|
+
|
1004
|
+
raise ValueError(f'Unsupported source: {src}')
|
960
1005
|
|
961
1006
|
|
962
1007
|
def t_batch_job_destination(dest: str) -> types.BatchJobDestination:
|
@@ -976,10 +1021,15 @@ def t_batch_job_destination(dest: str) -> types.BatchJobDestination:
|
|
976
1021
|
|
977
1022
|
def t_batch_job_name(client: _api_client.BaseApiClient, name: str) -> str:
|
978
1023
|
if not client.vertexai:
|
979
|
-
|
1024
|
+
mldev_pattern = r'batches/[^/]+$'
|
1025
|
+
if re.match(mldev_pattern, name):
|
1026
|
+
return name.split('/')[-1]
|
1027
|
+
else:
|
1028
|
+
raise ValueError(f'Invalid batch job name: {name}.')
|
1029
|
+
|
1030
|
+
vertex_pattern = r'^projects/[^/]+/locations/[^/]+/batchPredictionJobs/[^/]+$'
|
980
1031
|
|
981
|
-
|
982
|
-
if re.match(pattern, name):
|
1032
|
+
if re.match(vertex_pattern, name):
|
983
1033
|
return name.split('/')[-1]
|
984
1034
|
elif name.isdigit():
|
985
1035
|
return name
|
@@ -987,6 +1037,21 @@ def t_batch_job_name(client: _api_client.BaseApiClient, name: str) -> str:
|
|
987
1037
|
raise ValueError(f'Invalid batch job name: {name}.')
|
988
1038
|
|
989
1039
|
|
1040
|
+
def t_job_state(state: str) -> str:
|
1041
|
+
if state == 'BATCH_STATE_UNSPECIFIED':
|
1042
|
+
return 'JOB_STATE_UNSPECIFIED'
|
1043
|
+
elif state == 'BATCH_STATE_PENDING':
|
1044
|
+
return 'JOB_STATE_PENDING'
|
1045
|
+
elif state == 'BATCH_STATE_SUCCEEDED':
|
1046
|
+
return 'JOB_STATE_SUCCEEDED'
|
1047
|
+
elif state == 'BATCH_STATE_FAILED':
|
1048
|
+
return 'JOB_STATE_FAILED'
|
1049
|
+
elif state == 'BATCH_STATE_CANCELLED':
|
1050
|
+
return 'JOB_STATE_CANCELLED'
|
1051
|
+
else:
|
1052
|
+
return state
|
1053
|
+
|
1054
|
+
|
990
1055
|
LRO_POLLING_INITIAL_DELAY_SECONDS = 1.0
|
991
1056
|
LRO_POLLING_MAXIMUM_DELAY_SECONDS = 20.0
|
992
1057
|
LRO_POLLING_TIMEOUT_SECONDS = 900.0
|