google-genai 1.3.0__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google/genai/_api_client.py +143 -69
- google/genai/_api_module.py +1 -1
- google/genai/_automatic_function_calling_util.py +15 -15
- google/genai/_common.py +6 -3
- google/genai/_extra_utils.py +62 -46
- google/genai/_replay_api_client.py +73 -4
- google/genai/_test_api_client.py +8 -8
- google/genai/_transformers.py +194 -66
- google/genai/batches.py +180 -134
- google/genai/caches.py +316 -216
- google/genai/chats.py +179 -35
- google/genai/client.py +3 -3
- google/genai/errors.py +1 -2
- google/genai/files.py +175 -119
- google/genai/live.py +73 -64
- google/genai/models.py +898 -637
- google/genai/operations.py +96 -66
- google/genai/pagers.py +16 -7
- google/genai/tunings.py +172 -112
- google/genai/types.py +228 -178
- google/genai/version.py +1 -1
- {google_genai-1.3.0.dist-info → google_genai-1.5.0.dist-info}/METADATA +8 -1
- google_genai-1.5.0.dist-info/RECORD +27 -0
- {google_genai-1.3.0.dist-info → google_genai-1.5.0.dist-info}/WHEEL +1 -1
- google_genai-1.3.0.dist-info/RECORD +0 -27
- {google_genai-1.3.0.dist-info → google_genai-1.5.0.dist-info}/LICENSE +0 -0
- {google_genai-1.3.0.dist-info → google_genai-1.5.0.dist-info}/top_level.txt +0 -0
google/genai/_transformers.py
CHANGED
@@ -24,9 +24,12 @@ import logging
|
|
24
24
|
import re
|
25
25
|
import sys
|
26
26
|
import time
|
27
|
+
import types as builtin_types
|
27
28
|
import typing
|
28
29
|
from typing import Any, GenericAlias, Optional, Union
|
29
30
|
|
31
|
+
import types as builtin_types
|
32
|
+
|
30
33
|
if typing.TYPE_CHECKING:
|
31
34
|
import PIL.Image
|
32
35
|
|
@@ -38,15 +41,15 @@ from . import types
|
|
38
41
|
logger = logging.getLogger('google_genai._transformers')
|
39
42
|
|
40
43
|
if sys.version_info >= (3, 10):
|
41
|
-
VersionedUnionType =
|
42
|
-
_UNION_TYPES = (typing.Union,
|
44
|
+
VersionedUnionType = builtin_types.UnionType
|
45
|
+
_UNION_TYPES = (typing.Union, builtin_types.UnionType)
|
43
46
|
else:
|
44
47
|
VersionedUnionType = typing._UnionGenericAlias
|
45
48
|
_UNION_TYPES = (typing.Union,)
|
46
49
|
|
47
50
|
|
48
51
|
def _resource_name(
|
49
|
-
client: _api_client.
|
52
|
+
client: _api_client.BaseApiClient,
|
50
53
|
resource_name: str,
|
51
54
|
*,
|
52
55
|
collection_identifier: str,
|
@@ -138,7 +141,7 @@ def _resource_name(
|
|
138
141
|
return resource_name
|
139
142
|
|
140
143
|
|
141
|
-
def t_model(client: _api_client.
|
144
|
+
def t_model(client: _api_client.BaseApiClient, model: str):
|
142
145
|
if not model:
|
143
146
|
raise ValueError('model is required.')
|
144
147
|
if client.vertexai:
|
@@ -162,7 +165,7 @@ def t_model(client: _api_client.ApiClient, model: str):
|
|
162
165
|
return f'models/{model}'
|
163
166
|
|
164
167
|
|
165
|
-
def t_models_url(api_client: _api_client.
|
168
|
+
def t_models_url(api_client: _api_client.BaseApiClient, base_models: bool) -> str:
|
166
169
|
if api_client.vertexai:
|
167
170
|
if base_models:
|
168
171
|
return 'publishers/google/models'
|
@@ -176,8 +179,8 @@ def t_models_url(api_client: _api_client.ApiClient, base_models: bool) -> str:
|
|
176
179
|
|
177
180
|
|
178
181
|
def t_extract_models(
|
179
|
-
api_client: _api_client.
|
180
|
-
) -> list[types.
|
182
|
+
api_client: _api_client.BaseApiClient, response: dict[str, list[types.ModelDict]]
|
183
|
+
) -> Optional[list[types.ModelDict]]:
|
181
184
|
if not response:
|
182
185
|
return []
|
183
186
|
elif response.get('models') is not None:
|
@@ -197,7 +200,7 @@ def t_extract_models(
|
|
197
200
|
return []
|
198
201
|
|
199
202
|
|
200
|
-
def t_caches_model(api_client: _api_client.
|
203
|
+
def t_caches_model(api_client: _api_client.BaseApiClient, model: str):
|
201
204
|
model = t_model(api_client, model)
|
202
205
|
if not model:
|
203
206
|
return None
|
@@ -213,6 +216,7 @@ def t_caches_model(api_client: _api_client.ApiClient, model: str):
|
|
213
216
|
|
214
217
|
|
215
218
|
def pil_to_blob(img) -> types.Blob:
|
219
|
+
PngImagePlugin: Optional[builtin_types.ModuleType]
|
216
220
|
try:
|
217
221
|
import PIL.PngImagePlugin
|
218
222
|
|
@@ -236,10 +240,9 @@ def pil_to_blob(img) -> types.Blob:
|
|
236
240
|
return types.Blob(mime_type=mime_type, data=data)
|
237
241
|
|
238
242
|
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
-
def t_part(client: _api_client.ApiClient, part: PartType) -> types.Part:
|
243
|
+
def t_part(
|
244
|
+
part: Optional[types.PartUnionDict]
|
245
|
+
) -> types.Part:
|
243
246
|
try:
|
244
247
|
import PIL.Image
|
245
248
|
|
@@ -247,7 +250,7 @@ def t_part(client: _api_client.ApiClient, part: PartType) -> types.Part:
|
|
247
250
|
except ImportError:
|
248
251
|
PIL_Image = None
|
249
252
|
|
250
|
-
if
|
253
|
+
if part is None:
|
251
254
|
raise ValueError('content part is required.')
|
252
255
|
if isinstance(part, str):
|
253
256
|
return types.Part(text=part)
|
@@ -257,25 +260,29 @@ def t_part(client: _api_client.ApiClient, part: PartType) -> types.Part:
|
|
257
260
|
if not part.uri or not part.mime_type:
|
258
261
|
raise ValueError('file uri and mime_type are required.')
|
259
262
|
return types.Part.from_uri(file_uri=part.uri, mime_type=part.mime_type)
|
260
|
-
|
263
|
+
if isinstance(part, dict):
|
264
|
+
return types.Part.model_validate(part)
|
265
|
+
if isinstance(part, types.Part):
|
261
266
|
return part
|
267
|
+
raise ValueError(f'Unsupported content part type: {type(part)}')
|
262
268
|
|
263
269
|
|
264
270
|
def t_parts(
|
265
|
-
|
271
|
+
parts: Optional[Union[list[types.PartUnionDict], types.PartUnionDict]],
|
266
272
|
) -> list[types.Part]:
|
267
|
-
|
273
|
+
#
|
274
|
+
if parts is None or (isinstance(parts, list) and not parts):
|
268
275
|
raise ValueError('content parts are required.')
|
269
276
|
if isinstance(parts, list):
|
270
|
-
return [t_part(
|
277
|
+
return [t_part(part) for part in parts]
|
271
278
|
else:
|
272
|
-
return [t_part(
|
279
|
+
return [t_part(parts)]
|
273
280
|
|
274
281
|
|
275
282
|
def t_image_predictions(
|
276
|
-
client: _api_client.
|
283
|
+
client: _api_client.BaseApiClient,
|
277
284
|
predictions: Optional[Iterable[Mapping[str, Any]]],
|
278
|
-
) -> list[types.GeneratedImage]:
|
285
|
+
) -> Optional[list[types.GeneratedImage]]:
|
279
286
|
if not predictions:
|
280
287
|
return None
|
281
288
|
images = []
|
@@ -292,24 +299,38 @@ def t_image_predictions(
|
|
292
299
|
return images
|
293
300
|
|
294
301
|
|
295
|
-
ContentType = Union[types.Content, types.ContentDict,
|
302
|
+
ContentType = Union[types.Content, types.ContentDict, types.PartUnionDict]
|
296
303
|
|
297
304
|
|
298
305
|
def t_content(
|
299
|
-
client: _api_client.
|
300
|
-
content: ContentType,
|
301
|
-
):
|
302
|
-
if
|
306
|
+
client: _api_client.BaseApiClient,
|
307
|
+
content: Optional[ContentType],
|
308
|
+
) -> types.Content:
|
309
|
+
if content is None:
|
303
310
|
raise ValueError('content is required.')
|
304
311
|
if isinstance(content, types.Content):
|
305
312
|
return content
|
306
313
|
if isinstance(content, dict):
|
307
|
-
|
308
|
-
|
314
|
+
try:
|
315
|
+
return types.Content.model_validate(content)
|
316
|
+
except pydantic.ValidationError:
|
317
|
+
possible_part = types.Part.model_validate(content)
|
318
|
+
return (
|
319
|
+
types.ModelContent(parts=[possible_part])
|
320
|
+
if possible_part.function_call
|
321
|
+
else types.UserContent(parts=[possible_part])
|
322
|
+
)
|
323
|
+
if isinstance(content, types.Part):
|
324
|
+
return (
|
325
|
+
types.ModelContent(parts=[content])
|
326
|
+
if content.function_call
|
327
|
+
else types.UserContent(parts=[content])
|
328
|
+
)
|
329
|
+
return types.UserContent(parts=content)
|
309
330
|
|
310
331
|
|
311
332
|
def t_contents_for_embed(
|
312
|
-
client: _api_client.
|
333
|
+
client: _api_client.BaseApiClient,
|
313
334
|
contents: Union[list[types.Content], list[types.ContentDict], ContentType],
|
314
335
|
):
|
315
336
|
if client.vertexai and isinstance(contents, list):
|
@@ -324,16 +345,105 @@ def t_contents_for_embed(
|
|
324
345
|
|
325
346
|
|
326
347
|
def t_contents(
|
327
|
-
client: _api_client.
|
328
|
-
contents:
|
329
|
-
|
330
|
-
|
348
|
+
client: _api_client.BaseApiClient,
|
349
|
+
contents: Optional[
|
350
|
+
Union[types.ContentListUnion, types.ContentListUnionDict]
|
351
|
+
],
|
352
|
+
) -> list[types.Content]:
|
353
|
+
if contents is None or (isinstance(contents, list) and not contents):
|
331
354
|
raise ValueError('contents are required.')
|
332
|
-
if isinstance(contents, list):
|
333
|
-
return [t_content(client, content) for content in contents]
|
334
|
-
else:
|
355
|
+
if not isinstance(contents, list):
|
335
356
|
return [t_content(client, contents)]
|
336
357
|
|
358
|
+
try:
|
359
|
+
import PIL.Image
|
360
|
+
|
361
|
+
PIL_Image = PIL.Image.Image
|
362
|
+
except ImportError:
|
363
|
+
PIL_Image = None
|
364
|
+
|
365
|
+
result: list[types.Content] = []
|
366
|
+
accumulated_parts: list[types.Part] = []
|
367
|
+
|
368
|
+
def _is_part(part: types.PartUnionDict) -> bool:
|
369
|
+
if (
|
370
|
+
isinstance(part, str)
|
371
|
+
or isinstance(part, types.File)
|
372
|
+
or (PIL_Image is not None and isinstance(part, PIL_Image))
|
373
|
+
or isinstance(part, types.Part)
|
374
|
+
):
|
375
|
+
return True
|
376
|
+
|
377
|
+
if isinstance(part, dict):
|
378
|
+
try:
|
379
|
+
types.Part.model_validate(part)
|
380
|
+
return True
|
381
|
+
except pydantic.ValidationError:
|
382
|
+
return False
|
383
|
+
|
384
|
+
return False
|
385
|
+
|
386
|
+
def _is_user_part(part: types.Part) -> bool:
|
387
|
+
return not part.function_call
|
388
|
+
|
389
|
+
def _are_user_parts(parts: list[types.Part]) -> bool:
|
390
|
+
return all(_is_user_part(part) for part in parts)
|
391
|
+
|
392
|
+
def _append_accumulated_parts_as_content(
|
393
|
+
result: list[types.Content],
|
394
|
+
accumulated_parts: list[types.Part],
|
395
|
+
):
|
396
|
+
if not accumulated_parts:
|
397
|
+
return
|
398
|
+
result.append(
|
399
|
+
types.UserContent(parts=accumulated_parts)
|
400
|
+
if _are_user_parts(accumulated_parts)
|
401
|
+
else types.ModelContent(parts=accumulated_parts)
|
402
|
+
)
|
403
|
+
accumulated_parts[:] = []
|
404
|
+
|
405
|
+
def _handle_current_part(
|
406
|
+
result: list[types.Content],
|
407
|
+
accumulated_parts: list[types.Part],
|
408
|
+
current_part: types.PartUnionDict,
|
409
|
+
):
|
410
|
+
current_part = t_part(current_part)
|
411
|
+
if _is_user_part(current_part) == _are_user_parts(accumulated_parts):
|
412
|
+
accumulated_parts.append(current_part)
|
413
|
+
else:
|
414
|
+
_append_accumulated_parts_as_content(result, accumulated_parts)
|
415
|
+
accumulated_parts[:] = [current_part]
|
416
|
+
|
417
|
+
# iterator over contents
|
418
|
+
# if content type or content dict, append to result
|
419
|
+
# if consecutive part(s),
|
420
|
+
# group consecutive user part(s) to a UserContent
|
421
|
+
# group consecutive model part(s) to a ModelContent
|
422
|
+
# append to result
|
423
|
+
# if list, we only accept a list of types.PartUnion
|
424
|
+
for content in contents:
|
425
|
+
if (
|
426
|
+
isinstance(content, types.Content)
|
427
|
+
# only allowed inner list is a list of types.PartUnion
|
428
|
+
or isinstance(content, list)
|
429
|
+
):
|
430
|
+
_append_accumulated_parts_as_content(result, accumulated_parts)
|
431
|
+
if isinstance(content, list):
|
432
|
+
result.append(types.UserContent(parts=content))
|
433
|
+
else:
|
434
|
+
result.append(content)
|
435
|
+
elif (_is_part(content)): # type: ignore
|
436
|
+
_handle_current_part(result, accumulated_parts, content) # type: ignore
|
437
|
+
elif isinstance(content, dict):
|
438
|
+
# PactDict is already handled in _is_part
|
439
|
+
result.append(types.Content.model_validate(content))
|
440
|
+
else:
|
441
|
+
raise ValueError(f'Unsupported content type: {type(content)}')
|
442
|
+
|
443
|
+
_append_accumulated_parts_as_content(result, accumulated_parts)
|
444
|
+
|
445
|
+
return result
|
446
|
+
|
337
447
|
|
338
448
|
def handle_null_fields(schema: dict[str, Any]):
|
339
449
|
"""Process null fields in the schema so it is compatible with OpenAPI.
|
@@ -396,7 +506,7 @@ def handle_null_fields(schema: dict[str, Any]):
|
|
396
506
|
|
397
507
|
def process_schema(
|
398
508
|
schema: dict[str, Any],
|
399
|
-
client:
|
509
|
+
client: _api_client.BaseApiClient,
|
400
510
|
defs: Optional[dict[str, Any]] = None,
|
401
511
|
*,
|
402
512
|
order_properties: bool = True,
|
@@ -459,7 +569,7 @@ def process_schema(
|
|
459
569
|
'type': 'array'
|
460
570
|
}
|
461
571
|
"""
|
462
|
-
if
|
572
|
+
if not client.vertexai:
|
463
573
|
schema.pop('title', None)
|
464
574
|
|
465
575
|
if schema.get('default') is not None:
|
@@ -485,15 +595,16 @@ def process_schema(
|
|
485
595
|
# After removing null fields, Optional fields with only one possible type
|
486
596
|
# will have a $ref key that needs to be flattened
|
487
597
|
# For example: {'default': None, 'description': 'Name of the person', 'nullable': True, '$ref': '#/$defs/TestPerson'}
|
488
|
-
|
489
|
-
|
598
|
+
schema_ref = schema.get('$ref', None)
|
599
|
+
if schema_ref is not None:
|
600
|
+
ref = defs[schema_ref.split('defs/')[-1]]
|
490
601
|
for schema_key in list(ref.keys()):
|
491
602
|
schema[schema_key] = ref[schema_key]
|
492
603
|
del schema['$ref']
|
493
604
|
|
494
605
|
any_of = schema.get('anyOf', None)
|
495
606
|
if any_of is not None:
|
496
|
-
if not client.vertexai:
|
607
|
+
if client and not client.vertexai:
|
497
608
|
raise ValueError(
|
498
609
|
'AnyOf is not supported in the response schema for the Gemini API.'
|
499
610
|
)
|
@@ -559,9 +670,9 @@ def process_schema(
|
|
559
670
|
|
560
671
|
|
561
672
|
def _process_enum(
|
562
|
-
enum: EnumMeta, client: Optional[_api_client.
|
673
|
+
enum: EnumMeta, client: Optional[_api_client.BaseApiClient] = None # type: ignore
|
563
674
|
) -> types.Schema:
|
564
|
-
for member in enum:
|
675
|
+
for member in enum: # type: ignore
|
565
676
|
if not isinstance(member.value, str):
|
566
677
|
raise TypeError(
|
567
678
|
f'Enum member {member.name} value must be a string, got'
|
@@ -578,7 +689,7 @@ def _process_enum(
|
|
578
689
|
|
579
690
|
|
580
691
|
def t_schema(
|
581
|
-
client: _api_client.
|
692
|
+
client: _api_client.BaseApiClient, origin: Union[types.SchemaUnionDict, Any]
|
582
693
|
) -> Optional[types.Schema]:
|
583
694
|
if not origin:
|
584
695
|
return None
|
@@ -624,7 +735,7 @@ def t_schema(
|
|
624
735
|
|
625
736
|
|
626
737
|
def t_speech_config(
|
627
|
-
_: _api_client.
|
738
|
+
_: _api_client.BaseApiClient, origin: Union[types.SpeechConfigUnionDict, Any]
|
628
739
|
) -> Optional[types.SpeechConfig]:
|
629
740
|
if not origin:
|
630
741
|
return None
|
@@ -639,7 +750,10 @@ def t_speech_config(
|
|
639
750
|
if (
|
640
751
|
isinstance(origin, dict)
|
641
752
|
and 'voice_config' in origin
|
753
|
+
and origin['voice_config'] is not None
|
642
754
|
and 'prebuilt_voice_config' in origin['voice_config']
|
755
|
+
and origin['voice_config']['prebuilt_voice_config'] is not None
|
756
|
+
and 'voice_name' in origin['voice_config']['prebuilt_voice_config']
|
643
757
|
):
|
644
758
|
return types.SpeechConfig(
|
645
759
|
voice_config=types.VoiceConfig(
|
@@ -653,7 +767,7 @@ def t_speech_config(
|
|
653
767
|
raise ValueError(f'Unsupported speechConfig type: {type(origin)}')
|
654
768
|
|
655
769
|
|
656
|
-
def t_tool(client: _api_client.
|
770
|
+
def t_tool(client: _api_client.BaseApiClient, origin) -> Optional[types.Tool]:
|
657
771
|
if not origin:
|
658
772
|
return None
|
659
773
|
if inspect.isfunction(origin) or inspect.ismethod(origin):
|
@@ -670,7 +784,7 @@ def t_tool(client: _api_client.ApiClient, origin) -> types.Tool:
|
|
670
784
|
|
671
785
|
# Only support functions now.
|
672
786
|
def t_tools(
|
673
|
-
client: _api_client.
|
787
|
+
client: _api_client.BaseApiClient, origin: list[Any]
|
674
788
|
) -> list[types.Tool]:
|
675
789
|
if not origin:
|
676
790
|
return []
|
@@ -679,22 +793,23 @@ def t_tools(
|
|
679
793
|
for tool in origin:
|
680
794
|
transformed_tool = t_tool(client, tool)
|
681
795
|
# All functions should be merged into one tool.
|
682
|
-
if transformed_tool
|
683
|
-
|
684
|
-
|
685
|
-
|
686
|
-
|
687
|
-
|
796
|
+
if transformed_tool is not None:
|
797
|
+
if transformed_tool.function_declarations:
|
798
|
+
function_tool.function_declarations += (
|
799
|
+
transformed_tool.function_declarations
|
800
|
+
)
|
801
|
+
else:
|
802
|
+
tools.append(transformed_tool)
|
688
803
|
if function_tool.function_declarations:
|
689
804
|
tools.append(function_tool)
|
690
805
|
return tools
|
691
806
|
|
692
807
|
|
693
|
-
def t_cached_content_name(client: _api_client.
|
808
|
+
def t_cached_content_name(client: _api_client.BaseApiClient, name: str):
|
694
809
|
return _resource_name(client, name, collection_identifier='cachedContents')
|
695
810
|
|
696
811
|
|
697
|
-
def t_batch_job_source(client: _api_client.
|
812
|
+
def t_batch_job_source(client: _api_client.BaseApiClient, src: str):
|
698
813
|
if src.startswith('gs://'):
|
699
814
|
return types.BatchJobSource(
|
700
815
|
format='jsonl',
|
@@ -709,7 +824,7 @@ def t_batch_job_source(client: _api_client.ApiClient, src: str):
|
|
709
824
|
raise ValueError(f'Unsupported source: {src}')
|
710
825
|
|
711
826
|
|
712
|
-
def t_batch_job_destination(client: _api_client.
|
827
|
+
def t_batch_job_destination(client: _api_client.BaseApiClient, dest: str):
|
713
828
|
if dest.startswith('gs://'):
|
714
829
|
return types.BatchJobDestination(
|
715
830
|
format='jsonl',
|
@@ -724,7 +839,7 @@ def t_batch_job_destination(client: _api_client.ApiClient, dest: str):
|
|
724
839
|
raise ValueError(f'Unsupported destination: {dest}')
|
725
840
|
|
726
841
|
|
727
|
-
def t_batch_job_name(client: _api_client.
|
842
|
+
def t_batch_job_name(client: _api_client.BaseApiClient, name: str):
|
728
843
|
if not client.vertexai:
|
729
844
|
return name
|
730
845
|
|
@@ -743,7 +858,7 @@ LRO_POLLING_TIMEOUT_SECONDS = 900.0
|
|
743
858
|
LRO_POLLING_MULTIPLIER = 1.5
|
744
859
|
|
745
860
|
|
746
|
-
def t_resolve_operation(api_client: _api_client.
|
861
|
+
def t_resolve_operation(api_client: _api_client.BaseApiClient, struct: dict):
|
747
862
|
if (name := struct.get('name')) and '/operations/' in name:
|
748
863
|
operation: dict[str, Any] = struct
|
749
864
|
total_seconds = 0.0
|
@@ -752,7 +867,7 @@ def t_resolve_operation(api_client: _api_client.ApiClient, struct: dict):
|
|
752
867
|
if total_seconds > LRO_POLLING_TIMEOUT_SECONDS:
|
753
868
|
raise RuntimeError(f'Operation {name} timed out.\n{operation}')
|
754
869
|
# TODO(b/374433890): Replace with LRO module once it's available.
|
755
|
-
operation
|
870
|
+
operation = api_client.request(
|
756
871
|
http_method='GET', path=name, request_dict={}
|
757
872
|
)
|
758
873
|
time.sleep(delay_seconds)
|
@@ -772,15 +887,25 @@ def t_resolve_operation(api_client: _api_client.ApiClient, struct: dict):
|
|
772
887
|
|
773
888
|
|
774
889
|
def t_file_name(
|
775
|
-
api_client: _api_client.
|
890
|
+
api_client: _api_client.BaseApiClient,
|
891
|
+
name: Optional[Union[str, types.File, types.Video, types.GeneratedVideo]],
|
776
892
|
):
|
777
893
|
# Remove the files/ prefix since it's added to the url path.
|
778
894
|
if isinstance(name, types.File):
|
779
895
|
name = name.name
|
896
|
+
elif isinstance(name, types.Video):
|
897
|
+
name = name.uri
|
898
|
+
elif isinstance(name, types.GeneratedVideo):
|
899
|
+
name = name.video.uri
|
780
900
|
|
781
901
|
if name is None:
|
782
902
|
raise ValueError('File name is required.')
|
783
903
|
|
904
|
+
if not isinstance(name, str):
|
905
|
+
raise ValueError(
|
906
|
+
f'Could not convert object of type `{type(name)}` to a file name.'
|
907
|
+
)
|
908
|
+
|
784
909
|
if name.startswith('https://'):
|
785
910
|
suffix = name.split('files/')[1]
|
786
911
|
match = re.match('[a-z0-9]+', suffix)
|
@@ -794,17 +919,20 @@ def t_file_name(
|
|
794
919
|
|
795
920
|
|
796
921
|
def t_tuning_job_status(
|
797
|
-
api_client: _api_client.
|
798
|
-
) -> types.JobState:
|
922
|
+
api_client: _api_client.BaseApiClient, status: str
|
923
|
+
) -> Union[types.JobState, str]:
|
799
924
|
if status == 'STATE_UNSPECIFIED':
|
800
|
-
return
|
925
|
+
return types.JobState.JOB_STATE_UNSPECIFIED
|
801
926
|
elif status == 'CREATING':
|
802
|
-
return
|
927
|
+
return types.JobState.JOB_STATE_RUNNING
|
803
928
|
elif status == 'ACTIVE':
|
804
|
-
return
|
929
|
+
return types.JobState.JOB_STATE_SUCCEEDED
|
805
930
|
elif status == 'FAILED':
|
806
|
-
return
|
931
|
+
return types.JobState.JOB_STATE_FAILED
|
807
932
|
else:
|
933
|
+
for state in types.JobState:
|
934
|
+
if str(state.value) == status:
|
935
|
+
return state
|
808
936
|
return status
|
809
937
|
|
810
938
|
|
@@ -812,7 +940,7 @@ def t_tuning_job_status(
|
|
812
940
|
# We shouldn't use this transformer if the backend adhere to Cloud Type
|
813
941
|
# format https://cloud.google.com/docs/discovery/type-format.
|
814
942
|
# TODO(b/389133914,b/390320301): Remove the hack after backend fix the issue.
|
815
|
-
def t_bytes(api_client: _api_client.
|
943
|
+
def t_bytes(api_client: _api_client.BaseApiClient, data: bytes) -> str:
|
816
944
|
if not isinstance(data, bytes):
|
817
945
|
return data
|
818
946
|
return base64.b64encode(data).decode('ascii')
|