google-genai 1.21.1__py3-none-any.whl → 1.23.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -60,6 +60,11 @@ from .types import HttpOptionsOrDict
60
60
  from .types import HttpResponse as SdkHttpResponse
61
61
  from .types import HttpRetryOptions
62
62
 
63
+ try:
64
+ from websockets.asyncio.client import connect as ws_connect
65
+ except ModuleNotFoundError:
66
+ # This try/except is for TAP, mypy complains about it which is why we have the type: ignore
67
+ from websockets.client import connect as ws_connect # type: ignore
63
68
 
64
69
  has_aiohttp = False
65
70
  try:
@@ -227,11 +232,13 @@ class HttpResponse:
227
232
  headers: Union[dict[str, str], httpx.Headers, 'CIMultiDictProxy[str]'],
228
233
  response_stream: Union[Any, str] = None,
229
234
  byte_stream: Union[Any, bytes] = None,
235
+ session: Optional['aiohttp.ClientSession'] = None,
230
236
  ):
231
237
  self.status_code: int = 200
232
238
  self.headers = headers
233
239
  self.response_stream = response_stream
234
240
  self.byte_stream = byte_stream
241
+ self._session = session
235
242
 
236
243
  # Async iterator for async streaming.
237
244
  def __aiter__(self) -> 'HttpResponse':
@@ -291,16 +298,23 @@ class HttpResponse:
291
298
  chunk = chunk[len('data: ') :]
292
299
  yield json.loads(chunk)
293
300
  elif hasattr(self.response_stream, 'content'):
294
- async for chunk in self.response_stream.content.iter_any():
295
- # This is aiohttp.ClientResponse.
296
- if chunk:
301
+ # This is aiohttp.ClientResponse.
302
+ try:
303
+ while True:
304
+ chunk = await self.response_stream.content.readline()
305
+ if not chunk:
306
+ break
297
307
  # In async streaming mode, the chunk of JSON is prefixed with
298
308
  # "data:" which we must strip before parsing.
299
- if not isinstance(chunk, str):
300
- chunk = chunk.decode('utf-8')
309
+ chunk = chunk.decode('utf-8')
301
310
  if chunk.startswith('data: '):
302
311
  chunk = chunk[len('data: ') :]
303
- yield json.loads(chunk)
312
+ chunk = chunk.strip()
313
+ if chunk:
314
+ yield json.loads(chunk)
315
+ finally:
316
+ if hasattr(self, '_session') and self._session:
317
+ await self._session.close()
304
318
  else:
305
319
  raise ValueError('Error parsing streaming response.')
306
320
 
@@ -538,6 +552,7 @@ class BaseApiClient:
538
552
  # Default options for both clients.
539
553
  self._http_options.headers = {'Content-Type': 'application/json'}
540
554
  if self.api_key:
555
+ self.api_key = self.api_key.strip()
541
556
  if self._http_options.headers is not None:
542
557
  self._http_options.headers['x-goog-api-key'] = self.api_key
543
558
  # Update the http options with the user provided http options.
@@ -558,7 +573,10 @@ class BaseApiClient:
558
573
  # Do it once at the genai.Client level. Share among all requests.
559
574
  self._async_client_session_request_args = self._ensure_aiohttp_ssl_ctx(
560
575
  self._http_options
561
- )
576
+ )
577
+ self._websocket_ssl_ctx = self._ensure_websocket_ssl_ctx(
578
+ self._http_options
579
+ )
562
580
 
563
581
  retry_kwargs = _retry_args(self._http_options.retry_options)
564
582
  self._retry = tenacity.Retrying(**retry_kwargs, reraise=True)
@@ -688,6 +706,63 @@ class BaseApiClient:
688
706
 
689
707
  return _maybe_set(async_args, ctx)
690
708
 
709
+
710
+ @staticmethod
711
+ def _ensure_websocket_ssl_ctx(options: HttpOptions) -> dict[str, Any]:
712
+ """Ensures the SSL context is present in the async client args.
713
+
714
+ Creates a default SSL context if one is not provided.
715
+
716
+ Args:
717
+ options: The http options to check for SSL context.
718
+
719
+ Returns:
720
+ An async aiohttp ClientSession._request args.
721
+ """
722
+
723
+ verify = 'ssl' # keep it consistent with httpx.
724
+ async_args = options.async_client_args
725
+ ctx = async_args.get(verify) if async_args else None
726
+
727
+ if not ctx:
728
+ # Initialize the SSL context for the httpx client.
729
+ # Unlike requests, the aiohttp package does not automatically pull in the
730
+ # environment variables SSL_CERT_FILE or SSL_CERT_DIR. They need to be
731
+ # enabled explicitly. Instead of 'verify' at client level in httpx,
732
+ # aiohttp uses 'ssl' at request level.
733
+ ctx = ssl.create_default_context(
734
+ cafile=os.environ.get('SSL_CERT_FILE', certifi.where()),
735
+ capath=os.environ.get('SSL_CERT_DIR'),
736
+ )
737
+
738
+ def _maybe_set(
739
+ args: Optional[dict[str, Any]],
740
+ ctx: ssl.SSLContext,
741
+ ) -> dict[str, Any]:
742
+ """Sets the SSL context in the client args if not set.
743
+
744
+ Does not override the SSL context if it is already set.
745
+
746
+ Args:
747
+ args: The client args to to check for SSL context.
748
+ ctx: The SSL context to set.
749
+
750
+ Returns:
751
+ The client args with the SSL context included.
752
+ """
753
+ if not args or not args.get(verify):
754
+ args = (args or {}).copy()
755
+ args[verify] = ctx
756
+ # Drop the args that isn't in the aiohttp RequestOptions.
757
+ copied_args = args.copy()
758
+ for key in copied_args.copy():
759
+ if key not in inspect.signature(ws_connect).parameters and key != 'ssl':
760
+ del copied_args[key]
761
+ return copied_args
762
+
763
+ return _maybe_set(async_args, ctx)
764
+
765
+
691
766
  def _websocket_base_url(self) -> str:
692
767
  url_parts = urlparse(self._http_options.base_url)
693
768
  return url_parts._replace(scheme='wss').geturl() # type: ignore[arg-type, return-value]
@@ -882,6 +957,7 @@ class BaseApiClient:
882
957
  self, http_request: HttpRequest, stream: bool = False
883
958
  ) -> HttpResponse:
884
959
  data: Optional[Union[str, bytes]] = None
960
+
885
961
  if self.vertexai and not self.api_key:
886
962
  http_request.headers['Authorization'] = (
887
963
  f'Bearer {await self._async_access_token()}'
@@ -912,8 +988,9 @@ class BaseApiClient:
912
988
  timeout=aiohttp.ClientTimeout(connect=http_request.timeout),
913
989
  **self._async_client_session_request_args,
914
990
  )
991
+
915
992
  await errors.APIError.raise_for_async_response(response)
916
- return HttpResponse(response.headers, response)
993
+ return HttpResponse(response.headers, response, session=session)
917
994
  else:
918
995
  # aiohttp is not available. Fall back to httpx.
919
996
  httpx_request = self._async_httpx_client.build_request(
@@ -996,14 +1073,14 @@ class BaseApiClient:
996
1073
  path: str,
997
1074
  request_dict: dict[str, object],
998
1075
  http_options: Optional[HttpOptionsOrDict] = None,
999
- ) -> Generator[Any, None, None]:
1076
+ ) -> Generator[SdkHttpResponse, None, None]:
1000
1077
  http_request = self._build_request(
1001
1078
  http_method, path, request_dict, http_options
1002
1079
  )
1003
1080
 
1004
1081
  session_response = self._request(http_request, stream=True)
1005
1082
  for chunk in session_response.segments():
1006
- yield chunk
1083
+ yield SdkHttpResponse(headers=session_response.headers, body=json.dumps(chunk))
1007
1084
 
1008
1085
  async def async_request(
1009
1086
  self,
@@ -1038,7 +1115,7 @@ class BaseApiClient:
1038
1115
 
1039
1116
  async def async_generator(): # type: ignore[no-untyped-def]
1040
1117
  async for chunk in response:
1041
- yield chunk
1118
+ yield SdkHttpResponse(headers=response.headers, body=json.dumps(chunk))
1042
1119
 
1043
1120
  return async_generator() # type: ignore[no-untyped-call]
1044
1121
 
google/genai/_common.py CHANGED
@@ -16,12 +16,13 @@
16
16
  """Common utilities for the SDK."""
17
17
 
18
18
  import base64
19
+ import collections.abc
19
20
  import datetime
20
21
  import enum
21
22
  import functools
22
23
  import logging
23
24
  import typing
24
- from typing import Any, Callable, Optional, Union, get_origin, get_args
25
+ from typing import Any, Callable, Optional, FrozenSet, Union, get_args, get_origin
25
26
  import uuid
26
27
  import warnings
27
28
 
@@ -233,6 +234,179 @@ def _remove_extra_fields(
233
234
  T = typing.TypeVar('T', bound='BaseModel')
234
235
 
235
236
 
237
+ def _pretty_repr(
238
+ obj: Any,
239
+ *,
240
+ indent_level: int = 0,
241
+ indent_delta: int = 2,
242
+ max_len: int = 100,
243
+ max_items: int = 5,
244
+ depth: int = 6,
245
+ visited: Optional[FrozenSet[int]] = None,
246
+ ) -> str:
247
+ """Returns a representation of the given object."""
248
+ if visited is None:
249
+ visited = frozenset()
250
+
251
+ obj_id = id(obj)
252
+ if obj_id in visited:
253
+ return '<... Circular reference ...>'
254
+
255
+ if depth < 0:
256
+ return '<... Max depth ...>'
257
+
258
+ visited = frozenset(list(visited) + [obj_id])
259
+
260
+ indent = ' ' * indent_level
261
+ next_indent_str = ' ' * (indent_level + indent_delta)
262
+
263
+ if isinstance(obj, pydantic.BaseModel):
264
+ cls_name = obj.__class__.__name__
265
+ items = []
266
+ # Sort fields for consistent output
267
+ fields = sorted(type(obj).model_fields)
268
+
269
+ for field_name in fields:
270
+ field_info = type(obj).model_fields[field_name]
271
+ if not field_info.repr: # Respect Field(repr=False)
272
+ continue
273
+
274
+ try:
275
+ value = getattr(obj, field_name)
276
+ except AttributeError:
277
+ continue
278
+
279
+ if value is None:
280
+ continue
281
+
282
+ value_repr = _pretty_repr(
283
+ value,
284
+ indent_level=indent_level + indent_delta,
285
+ indent_delta=indent_delta,
286
+ max_len=max_len,
287
+ max_items=max_items,
288
+ depth=depth - 1,
289
+ visited=visited,
290
+ )
291
+ items.append(f'{next_indent_str}{field_name}={value_repr}')
292
+
293
+ if not items:
294
+ return f'{cls_name}()'
295
+ return f'{cls_name}(\n' + ',\n'.join(items) + f'\n{indent})'
296
+ elif isinstance(obj, str):
297
+ if '\n' in obj:
298
+ escaped = obj.replace('"""', '\\"\\"\\"')
299
+ # Indent the multi-line string block contents
300
+ return f'"""{escaped}"""'
301
+ return repr(obj)
302
+ elif isinstance(obj, bytes):
303
+ if len(obj) > max_len:
304
+ return f"{repr(obj[:max_len-3])[:-1]}...'"
305
+ return repr(obj)
306
+ elif isinstance(obj, collections.abc.Mapping):
307
+ if not obj:
308
+ return '{}'
309
+ if len(obj) > max_items:
310
+ return f'<dict len={len(obj)}>'
311
+ items = []
312
+ try:
313
+ sorted_keys = sorted(obj.keys(), key=str)
314
+ except TypeError:
315
+ sorted_keys = list(obj.keys())
316
+
317
+ for k in sorted_keys:
318
+ v = obj[k]
319
+ k_repr = _pretty_repr(
320
+ k,
321
+ indent_level=indent_level + indent_delta,
322
+ indent_delta=indent_delta,
323
+ max_len=max_len,
324
+ max_items=max_items,
325
+ depth=depth - 1,
326
+ visited=visited,
327
+ )
328
+ v_repr = _pretty_repr(
329
+ v,
330
+ indent_level=indent_level + indent_delta,
331
+ indent_delta=indent_delta,
332
+ max_len=max_len,
333
+ max_items=max_items,
334
+ depth=depth - 1,
335
+ visited=visited,
336
+ )
337
+ items.append(f'{next_indent_str}{k_repr}: {v_repr}')
338
+ return f'{{\n' + ',\n'.join(items) + f'\n{indent}}}'
339
+ elif isinstance(obj, (list, tuple, set)):
340
+ return _format_collection(
341
+ obj,
342
+ indent_level=indent_level,
343
+ indent_delta=indent_delta,
344
+ max_len=max_len,
345
+ max_items=max_items,
346
+ depth=depth,
347
+ visited=visited,
348
+ )
349
+ else:
350
+ # Fallback to standard repr, indenting subsequent lines only
351
+ raw_repr = repr(obj)
352
+ # Replace newlines with newline + indent
353
+ return raw_repr.replace('\n', f'\n{next_indent_str}')
354
+
355
+
356
+
357
+ def _format_collection(
358
+ obj: Any,
359
+ *,
360
+ indent_level: int,
361
+ indent_delta: int,
362
+ max_len: int,
363
+ max_items: int,
364
+ depth: int,
365
+ visited: FrozenSet[int],
366
+ ) -> str:
367
+ """Formats a collection (list, tuple, set)."""
368
+ if isinstance(obj, list):
369
+ brackets = ('[', ']')
370
+ elif isinstance(obj, tuple):
371
+ brackets = ('(', ')')
372
+ elif isinstance(obj, set):
373
+ obj = list(obj)
374
+ if obj:
375
+ brackets = ('{', '}')
376
+ else:
377
+ brackets = ('set(', ')')
378
+ else:
379
+ raise ValueError(f"Unsupported collection type: {type(obj)}")
380
+
381
+ if not obj:
382
+ return brackets[0] + brackets[1]
383
+
384
+ indent = ' ' * indent_level
385
+ next_indent_str = ' ' * (indent_level + indent_delta)
386
+ elements = []
387
+ for i, elem in enumerate(obj):
388
+ if i >= max_items:
389
+ elements.append(
390
+ f'{next_indent_str}<... {len(obj) - max_items} more items ...>'
391
+ )
392
+ break
393
+ # Each element starts on a new line, fully indented
394
+ elements.append(
395
+ next_indent_str
396
+ + _pretty_repr(
397
+ elem,
398
+ indent_level=indent_level + indent_delta,
399
+ indent_delta=indent_delta,
400
+ max_len=max_len,
401
+ max_items=max_items,
402
+ depth=depth - 1,
403
+ visited=visited,
404
+ )
405
+ )
406
+
407
+ return f'{brackets[0]}\n' + ',\n'.join(elements) + "," + f'\n{indent}{brackets[1]}'
408
+
409
+
236
410
  class BaseModel(pydantic.BaseModel):
237
411
 
238
412
  model_config = pydantic.ConfigDict(
@@ -248,6 +422,12 @@ class BaseModel(pydantic.BaseModel):
248
422
  ignored_types=(typing.TypeVar,)
249
423
  )
250
424
 
425
+ def __repr__(self) -> str:
426
+ try:
427
+ return _pretty_repr(self)
428
+ except Exception:
429
+ return super().__repr__()
430
+
251
431
  @classmethod
252
432
  def _from_response(
253
433
  cls: typing.Type[T], *, response: dict[str, object], kwargs: dict[str, object]
@@ -25,6 +25,7 @@ import pydantic
25
25
 
26
26
  from . import _common
27
27
  from . import _mcp_utils
28
+ from . import _transformers as t
28
29
  from . import errors
29
30
  from . import types
30
31
  from ._adapters import McpToGenAiToolAdapter
@@ -62,11 +63,37 @@ def _create_generate_content_config_model(
62
63
  return config
63
64
 
64
65
 
66
+ def _get_gcs_uri(
67
+ src: Union[str, types.BatchJobSourceOrDict]
68
+ ) -> Optional[str]:
69
+ """Extracts the first GCS URI from the source, if available."""
70
+ if isinstance(src, str) and src.startswith('gs://'):
71
+ return src
72
+ elif isinstance(src, dict) and src.get('gcs_uri'):
73
+ return src['gcs_uri'][0] if src['gcs_uri'] else None
74
+ elif isinstance(src, types.BatchJobSource) and src.gcs_uri:
75
+ return src.gcs_uri[0] if src.gcs_uri else None
76
+ return None
77
+
78
+
79
+ def _get_bigquery_uri(
80
+ src: Union[str, types.BatchJobSourceOrDict]
81
+ ) -> Optional[str]:
82
+ """Extracts the BigQuery URI from the source, if available."""
83
+ if isinstance(src, str) and src.startswith('bq://'):
84
+ return src
85
+ elif isinstance(src, dict) and src.get('bigquery_uri'):
86
+ return src['bigquery_uri']
87
+ elif isinstance(src, types.BatchJobSource) and src.bigquery_uri:
88
+ return src.bigquery_uri
89
+ return None
90
+
91
+
65
92
  def format_destination(
66
- src: str,
93
+ src: Union[str, types.BatchJobSourceOrDict],
67
94
  config: Optional[types.CreateBatchJobConfigOrDict] = None,
68
95
  ) -> types.CreateBatchJobConfig:
69
- """Formats the destination uri based on the source uri."""
96
+ """Formats the destination uri based on the source uri for Vertex AI."""
70
97
  config = (
71
98
  types._CreateBatchJobParameters(config=config).config
72
99
  or types.CreateBatchJobConfig()
@@ -78,15 +105,14 @@ def format_destination(
78
105
  config.display_name = f'genai_batch_job_{unique_name}'
79
106
 
80
107
  if not config.dest:
81
- if src.startswith('gs://') and src.endswith('.jsonl'):
82
- # If source uri is "gs://bucket/path/to/src.jsonl", then the destination
83
- # uri prefix will be "gs://bucket/path/to/src/dest".
84
- config.dest = f'{src[:-6]}/dest'
85
- elif src.startswith('bq://'):
86
- # If source uri is "bq://project.dataset.src", then the destination
87
- # uri will be "bq://project.dataset.src_dest_TIMESTAMP_UUID".
108
+ gcs_source_uri = _get_gcs_uri(src)
109
+ bigquery_source_uri = _get_bigquery_uri(src)
110
+
111
+ if gcs_source_uri and gcs_source_uri.endswith('.jsonl'):
112
+ config.dest = f'{gcs_source_uri[:-6]}/dest'
113
+ elif bigquery_source_uri:
88
114
  unique_name = unique_name or _common.timestamped_unique_name()
89
- config.dest = f'{src}_dest_{unique_name}'
115
+ config.dest = f'{bigquery_source_uri}_dest_{unique_name}'
90
116
  else:
91
117
  raise ValueError(f'Unsupported source: {src}')
92
118
  return config
@@ -506,3 +532,15 @@ async def parse_config_for_mcp_sessions(
506
532
  parsed_config_copy.tools.append(tool)
507
533
 
508
534
  return parsed_config_copy, mcp_to_genai_tool_adapters
535
+
536
+
537
+ def append_chunk_contents(
538
+ contents: Union[types.ContentListUnion, types.ContentListUnionDict],
539
+ chunk: types.GenerateContentResponse,
540
+ ) -> None:
541
+ """Appends the contents of the chunk to the contents list."""
542
+ if chunk is not None and chunk.candidates is not None:
543
+ chunk_content = chunk.candidates[0].content
544
+ contents = t.t_contents(contents) # type: ignore[assignment]
545
+ if isinstance(contents, list) and chunk_content is not None:
546
+ contents.append(chunk_content) # type: ignore[arg-type]
@@ -26,7 +26,7 @@ import sys
26
26
  import time
27
27
  import types as builtin_types
28
28
  import typing
29
- from typing import Any, GenericAlias, Optional, Sequence, Union # type: ignore[attr-defined]
29
+ from typing import Any, GenericAlias, List, Optional, Sequence, Union # type: ignore[attr-defined]
30
30
  from ._mcp_utils import mcp_to_gemini_tool
31
31
 
32
32
  if typing.TYPE_CHECKING:
@@ -787,15 +787,25 @@ def process_schema(
787
787
  def _process_enum(
788
788
  enum: EnumMeta, client: _api_client.BaseApiClient
789
789
  ) -> types.Schema:
790
+ is_integer_enum = False
791
+
790
792
  for member in enum: # type: ignore
791
- if not isinstance(member.value, str):
793
+ if isinstance(member.value, int):
794
+ is_integer_enum = True
795
+ elif not isinstance(member.value, str):
792
796
  raise TypeError(
793
- f'Enum member {member.name} value must be a string, got'
797
+ f'Enum member {member.name} value must be a string or integer, got'
794
798
  f' {type(member.value)}'
795
799
  )
796
800
 
801
+ enum_to_process = enum
802
+ if is_integer_enum:
803
+ str_members = [str(member.value) for member in enum] # type: ignore
804
+ str_enum = Enum(enum.__name__, str_members, type=str) # type: ignore
805
+ enum_to_process = str_enum
806
+
797
807
  class Placeholder(pydantic.BaseModel):
798
- placeholder: enum # type: ignore[valid-type]
808
+ placeholder: enum_to_process # type: ignore[valid-type]
799
809
 
800
810
  enum_schema = Placeholder.model_json_schema()
801
811
  process_schema(enum_schema, client)
@@ -944,19 +954,54 @@ def t_cached_content_name(client: _api_client.BaseApiClient, name: str) -> str:
944
954
  return _resource_name(client, name, collection_identifier='cachedContents')
945
955
 
946
956
 
947
- def t_batch_job_source(src: str) -> types.BatchJobSource:
948
- if src.startswith('gs://'):
949
- return types.BatchJobSource(
950
- format='jsonl',
951
- gcs_uri=[src],
952
- )
953
- elif src.startswith('bq://'):
954
- return types.BatchJobSource(
955
- format='bigquery',
956
- bigquery_uri=src,
957
- )
958
- else:
959
- raise ValueError(f'Unsupported source: {src}')
957
+ def t_batch_job_source(
958
+ client: _api_client.BaseApiClient,
959
+ src: Union[
960
+ str, List[types.InlinedRequestOrDict], types.BatchJobSourceOrDict
961
+ ],
962
+ ) -> types.BatchJobSource:
963
+ if isinstance(src, dict):
964
+ src = types.BatchJobSource(**src)
965
+ if isinstance(src, types.BatchJobSource):
966
+ if client.vertexai:
967
+ if src.gcs_uri and src.bigquery_uri:
968
+ raise ValueError(
969
+ 'Only one of `gcs_uri` or `bigquery_uri` can be set.'
970
+ )
971
+ elif not src.gcs_uri and not src.bigquery_uri:
972
+ raise ValueError(
973
+ 'One of `gcs_uri` or `bigquery_uri` must be set.'
974
+ )
975
+ else:
976
+ if src.inlined_requests and src.file_name:
977
+ raise ValueError(
978
+ 'Only one of `inlined_requests` or `file_name` can be set.'
979
+ )
980
+ elif not src.inlined_requests and not src.file_name:
981
+ raise ValueError(
982
+ 'One of `inlined_requests` or `file_name` must be set.'
983
+ )
984
+ return src
985
+
986
+ elif isinstance(src, list):
987
+ return types.BatchJobSource(inlined_requests=src)
988
+ elif isinstance(src, str):
989
+ if src.startswith('gs://'):
990
+ return types.BatchJobSource(
991
+ format='jsonl',
992
+ gcs_uri=[src],
993
+ )
994
+ elif src.startswith('bq://'):
995
+ return types.BatchJobSource(
996
+ format='bigquery',
997
+ bigquery_uri=src,
998
+ )
999
+ elif src.startswith('files/'):
1000
+ return types.BatchJobSource(
1001
+ file_name=src,
1002
+ )
1003
+
1004
+ raise ValueError(f'Unsupported source: {src}')
960
1005
 
961
1006
 
962
1007
  def t_batch_job_destination(dest: str) -> types.BatchJobDestination:
@@ -976,10 +1021,15 @@ def t_batch_job_destination(dest: str) -> types.BatchJobDestination:
976
1021
 
977
1022
  def t_batch_job_name(client: _api_client.BaseApiClient, name: str) -> str:
978
1023
  if not client.vertexai:
979
- return name
1024
+ mldev_pattern = r'batches/[^/]+$'
1025
+ if re.match(mldev_pattern, name):
1026
+ return name.split('/')[-1]
1027
+ else:
1028
+ raise ValueError(f'Invalid batch job name: {name}.')
1029
+
1030
+ vertex_pattern = r'^projects/[^/]+/locations/[^/]+/batchPredictionJobs/[^/]+$'
980
1031
 
981
- pattern = r'^projects/[^/]+/locations/[^/]+/batchPredictionJobs/[^/]+$'
982
- if re.match(pattern, name):
1032
+ if re.match(vertex_pattern, name):
983
1033
  return name.split('/')[-1]
984
1034
  elif name.isdigit():
985
1035
  return name
@@ -987,6 +1037,21 @@ def t_batch_job_name(client: _api_client.BaseApiClient, name: str) -> str:
987
1037
  raise ValueError(f'Invalid batch job name: {name}.')
988
1038
 
989
1039
 
1040
+ def t_job_state(state: str) -> str:
1041
+ if state == 'BATCH_STATE_UNSPECIFIED':
1042
+ return 'JOB_STATE_UNSPECIFIED'
1043
+ elif state == 'BATCH_STATE_PENDING':
1044
+ return 'JOB_STATE_PENDING'
1045
+ elif state == 'BATCH_STATE_SUCCEEDED':
1046
+ return 'JOB_STATE_SUCCEEDED'
1047
+ elif state == 'BATCH_STATE_FAILED':
1048
+ return 'JOB_STATE_FAILED'
1049
+ elif state == 'BATCH_STATE_CANCELLED':
1050
+ return 'JOB_STATE_CANCELLED'
1051
+ else:
1052
+ return state
1053
+
1054
+
990
1055
  LRO_POLLING_INITIAL_DELAY_SECONDS = 1.0
991
1056
  LRO_POLLING_MAXIMUM_DELAY_SECONDS = 20.0
992
1057
  LRO_POLLING_TIMEOUT_SECONDS = 900.0