google-genai 1.2.0__py3-none-any.whl → 1.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google/genai/_api_client.py +160 -59
- google/genai/_api_module.py +6 -1
- google/genai/_automatic_function_calling_util.py +12 -12
- google/genai/_common.py +14 -2
- google/genai/_extra_utils.py +14 -8
- google/genai/_replay_api_client.py +35 -3
- google/genai/_test_api_client.py +8 -8
- google/genai/_transformers.py +169 -48
- google/genai/batches.py +176 -127
- google/genai/caches.py +315 -214
- google/genai/chats.py +179 -35
- google/genai/client.py +16 -6
- google/genai/errors.py +19 -5
- google/genai/files.py +161 -115
- google/genai/live.py +137 -105
- google/genai/models.py +1553 -734
- google/genai/operations.py +635 -0
- google/genai/pagers.py +5 -5
- google/genai/tunings.py +166 -103
- google/genai/types.py +590 -142
- google/genai/version.py +1 -1
- {google_genai-1.2.0.dist-info → google_genai-1.4.0.dist-info}/METADATA +94 -12
- google_genai-1.4.0.dist-info/RECORD +27 -0
- {google_genai-1.2.0.dist-info → google_genai-1.4.0.dist-info}/WHEEL +1 -1
- google/genai/_operations.py +0 -365
- google_genai-1.2.0.dist-info/RECORD +0 -27
- {google_genai-1.2.0.dist-info → google_genai-1.4.0.dist-info}/LICENSE +0 -0
- {google_genai-1.2.0.dist-info → google_genai-1.4.0.dist-info}/top_level.txt +0 -0
@@ -29,7 +29,7 @@ import google.auth
|
|
29
29
|
from requests.exceptions import HTTPError
|
30
30
|
|
31
31
|
from . import errors
|
32
|
-
from ._api_client import
|
32
|
+
from ._api_client import BaseApiClient
|
33
33
|
from ._api_client import HttpOptions
|
34
34
|
from ._api_client import HttpRequest
|
35
35
|
from ._api_client import HttpResponse
|
@@ -60,6 +60,10 @@ def _redact_request_headers(headers):
|
|
60
60
|
redacted_headers[header_name] = _redact_language_label(
|
61
61
|
_redact_version_numbers(header_value)
|
62
62
|
)
|
63
|
+
elif header_name.lower() == 'x-goog-user-project':
|
64
|
+
continue
|
65
|
+
elif header_name.lower() == 'authorization':
|
66
|
+
continue
|
63
67
|
else:
|
64
68
|
redacted_headers[header_name] = header_value
|
65
69
|
return redacted_headers
|
@@ -175,7 +179,7 @@ class ReplayFile(BaseModel):
|
|
175
179
|
interactions: list[ReplayInteraction]
|
176
180
|
|
177
181
|
|
178
|
-
class ReplayApiClient(
|
182
|
+
class ReplayApiClient(BaseApiClient):
|
179
183
|
"""For integration testing, send recorded response or records a response."""
|
180
184
|
|
181
185
|
def __init__(
|
@@ -409,6 +413,34 @@ class ReplayApiClient(ApiClient):
|
|
409
413
|
else:
|
410
414
|
return self._build_response_from_replay(http_request)
|
411
415
|
|
416
|
+
async def _async_request(
|
417
|
+
self,
|
418
|
+
http_request: HttpRequest,
|
419
|
+
stream: bool = False,
|
420
|
+
) -> HttpResponse:
|
421
|
+
self._initialize_replay_session_if_not_loaded()
|
422
|
+
if self._should_call_api():
|
423
|
+
_debug_print('api mode request: %s' % http_request)
|
424
|
+
try:
|
425
|
+
result = await super()._async_request(http_request, stream)
|
426
|
+
except errors.APIError as e:
|
427
|
+
self._record_interaction(http_request, e)
|
428
|
+
raise e
|
429
|
+
if stream:
|
430
|
+
result_segments = []
|
431
|
+
async for segment in result.async_segments():
|
432
|
+
result_segments.append(json.dumps(segment))
|
433
|
+
result = HttpResponse(result.headers, result_segments)
|
434
|
+
self._record_interaction(http_request, result)
|
435
|
+
# Need to return a RecordedResponse that rebuilds the response
|
436
|
+
# segments since the stream has been consumed.
|
437
|
+
else:
|
438
|
+
self._record_interaction(http_request, result)
|
439
|
+
_debug_print('api mode result: %s' % result.json)
|
440
|
+
return result
|
441
|
+
else:
|
442
|
+
return self._build_response_from_replay(http_request)
|
443
|
+
|
412
444
|
def upload_file(self, file_path: Union[str, io.IOBase], upload_url: str, upload_size: int):
|
413
445
|
if isinstance(file_path, io.IOBase):
|
414
446
|
offset = file_path.tell()
|
@@ -425,6 +457,7 @@ class ReplayApiClient(ApiClient):
|
|
425
457
|
method='POST', url='', data={'file_path': file_path}, headers={}
|
426
458
|
)
|
427
459
|
if self._should_call_api():
|
460
|
+
result: Union[str, HttpResponse]
|
428
461
|
try:
|
429
462
|
result = super().upload_file(file_path, upload_url, upload_size)
|
430
463
|
except HTTPError as e:
|
@@ -453,4 +486,3 @@ class ReplayApiClient(ApiClient):
|
|
453
486
|
return result
|
454
487
|
else:
|
455
488
|
return self._build_response_from_replay(request)
|
456
|
-
|
google/genai/_test_api_client.py
CHANGED
@@ -17,13 +17,13 @@ import asyncio
|
|
17
17
|
import time
|
18
18
|
from unittest.mock import MagicMock, patch
|
19
19
|
import pytest
|
20
|
-
from .api_client import
|
20
|
+
from .api_client import BaseApiClient
|
21
21
|
|
22
22
|
|
23
|
-
@patch('genai.api_client.
|
24
|
-
@patch('genai.api_client.
|
23
|
+
@patch('genai.api_client.BaseApiClient._build_request')
|
24
|
+
@patch('genai.api_client.BaseApiClient._request')
|
25
25
|
def test_request_streamed_non_blocking(mock_request, mock_build_request):
|
26
|
-
api_client =
|
26
|
+
api_client = BaseApiClient(api_key='test_api_key')
|
27
27
|
http_method = 'GET'
|
28
28
|
path = 'test/path'
|
29
29
|
request_dict = {'key': 'value'}
|
@@ -56,8 +56,8 @@ def test_request_streamed_non_blocking(mock_request, mock_build_request):
|
|
56
56
|
assert end_time - start_time > 0.3
|
57
57
|
|
58
58
|
|
59
|
-
@patch('genai.api_client.
|
60
|
-
@patch('genai.api_client.
|
59
|
+
@patch('genai.api_client.BaseApiClient._build_request')
|
60
|
+
@patch('genai.api_client.BaseApiClient._async_request')
|
61
61
|
@pytest.mark.asyncio
|
62
62
|
async def test_async_request(mock_async_request, mock_build_request):
|
63
63
|
api_client = ApiClient(api_key='test_api_key')
|
@@ -99,8 +99,8 @@ async def test_async_request(mock_async_request, mock_build_request):
|
|
99
99
|
assert 0.1 <= end_time - start_time < 0.15
|
100
100
|
|
101
101
|
|
102
|
-
@patch('genai.api_client.
|
103
|
-
@patch('genai.api_client.
|
102
|
+
@patch('genai.api_client.BaseApiClient._build_request')
|
103
|
+
@patch('genai.api_client.BaseApiClient._async_request')
|
104
104
|
@pytest.mark.asyncio
|
105
105
|
async def test_async_request_streamed_non_blocking(
|
106
106
|
mock_async_request, mock_build_request
|
google/genai/_transformers.py
CHANGED
@@ -20,12 +20,16 @@ from collections.abc import Iterable, Mapping
|
|
20
20
|
from enum import Enum, EnumMeta
|
21
21
|
import inspect
|
22
22
|
import io
|
23
|
+
import logging
|
23
24
|
import re
|
24
25
|
import sys
|
25
26
|
import time
|
27
|
+
import types as builtin_types
|
26
28
|
import typing
|
27
29
|
from typing import Any, GenericAlias, Optional, Union
|
28
30
|
|
31
|
+
import types as builtin_types
|
32
|
+
|
29
33
|
if typing.TYPE_CHECKING:
|
30
34
|
import PIL.Image
|
31
35
|
|
@@ -34,16 +38,18 @@ import pydantic
|
|
34
38
|
from . import _api_client
|
35
39
|
from . import types
|
36
40
|
|
41
|
+
logger = logging.getLogger('google_genai._transformers')
|
42
|
+
|
37
43
|
if sys.version_info >= (3, 10):
|
38
|
-
VersionedUnionType =
|
39
|
-
_UNION_TYPES = (typing.Union,
|
44
|
+
VersionedUnionType = builtin_types.UnionType
|
45
|
+
_UNION_TYPES = (typing.Union, builtin_types.UnionType)
|
40
46
|
else:
|
41
47
|
VersionedUnionType = typing._UnionGenericAlias
|
42
48
|
_UNION_TYPES = (typing.Union,)
|
43
49
|
|
44
50
|
|
45
51
|
def _resource_name(
|
46
|
-
client: _api_client.
|
52
|
+
client: _api_client.BaseApiClient,
|
47
53
|
resource_name: str,
|
48
54
|
*,
|
49
55
|
collection_identifier: str,
|
@@ -135,7 +141,7 @@ def _resource_name(
|
|
135
141
|
return resource_name
|
136
142
|
|
137
143
|
|
138
|
-
def t_model(client: _api_client.
|
144
|
+
def t_model(client: _api_client.BaseApiClient, model: str):
|
139
145
|
if not model:
|
140
146
|
raise ValueError('model is required.')
|
141
147
|
if client.vertexai:
|
@@ -159,7 +165,7 @@ def t_model(client: _api_client.ApiClient, model: str):
|
|
159
165
|
return f'models/{model}'
|
160
166
|
|
161
167
|
|
162
|
-
def t_models_url(api_client: _api_client.
|
168
|
+
def t_models_url(api_client: _api_client.BaseApiClient, base_models: bool) -> str:
|
163
169
|
if api_client.vertexai:
|
164
170
|
if base_models:
|
165
171
|
return 'publishers/google/models'
|
@@ -173,7 +179,7 @@ def t_models_url(api_client: _api_client.ApiClient, base_models: bool) -> str:
|
|
173
179
|
|
174
180
|
|
175
181
|
def t_extract_models(
|
176
|
-
api_client: _api_client.
|
182
|
+
api_client: _api_client.BaseApiClient, response: dict
|
177
183
|
) -> list[types.Model]:
|
178
184
|
if not response:
|
179
185
|
return []
|
@@ -183,11 +189,18 @@ def t_extract_models(
|
|
183
189
|
return response.get('tunedModels')
|
184
190
|
elif response.get('publisherModels') is not None:
|
185
191
|
return response.get('publisherModels')
|
192
|
+
elif (
|
193
|
+
response.get('httpHeaders') is not None
|
194
|
+
and response.get('jsonPayload') is None
|
195
|
+
):
|
196
|
+
return []
|
186
197
|
else:
|
187
|
-
|
198
|
+
logger.warning('Cannot determine the models type.')
|
199
|
+
logger.debug('Cannot determine the models type for response: %s', response)
|
200
|
+
return []
|
188
201
|
|
189
202
|
|
190
|
-
def t_caches_model(api_client: _api_client.
|
203
|
+
def t_caches_model(api_client: _api_client.BaseApiClient, model: str):
|
191
204
|
model = t_model(api_client, model)
|
192
205
|
if not model:
|
193
206
|
return None
|
@@ -203,6 +216,7 @@ def t_caches_model(api_client: _api_client.ApiClient, model: str):
|
|
203
216
|
|
204
217
|
|
205
218
|
def pil_to_blob(img) -> types.Blob:
|
219
|
+
PngImagePlugin: Optional[builtin_types.ModuleType]
|
206
220
|
try:
|
207
221
|
import PIL.PngImagePlugin
|
208
222
|
|
@@ -226,10 +240,9 @@ def pil_to_blob(img) -> types.Blob:
|
|
226
240
|
return types.Blob(mime_type=mime_type, data=data)
|
227
241
|
|
228
242
|
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
def t_part(client: _api_client.ApiClient, part: PartType) -> types.Part:
|
243
|
+
def t_part(
|
244
|
+
client: _api_client.BaseApiClient, part: Optional[types.PartUnionDict]
|
245
|
+
) -> types.Part:
|
233
246
|
try:
|
234
247
|
import PIL.Image
|
235
248
|
|
@@ -237,7 +250,7 @@ def t_part(client: _api_client.ApiClient, part: PartType) -> types.Part:
|
|
237
250
|
except ImportError:
|
238
251
|
PIL_Image = None
|
239
252
|
|
240
|
-
if
|
253
|
+
if part is None:
|
241
254
|
raise ValueError('content part is required.')
|
242
255
|
if isinstance(part, str):
|
243
256
|
return types.Part(text=part)
|
@@ -247,14 +260,19 @@ def t_part(client: _api_client.ApiClient, part: PartType) -> types.Part:
|
|
247
260
|
if not part.uri or not part.mime_type:
|
248
261
|
raise ValueError('file uri and mime_type are required.')
|
249
262
|
return types.Part.from_uri(file_uri=part.uri, mime_type=part.mime_type)
|
250
|
-
|
263
|
+
if isinstance(part, dict):
|
264
|
+
return types.Part.model_validate(part)
|
265
|
+
if isinstance(part, types.Part):
|
251
266
|
return part
|
267
|
+
raise ValueError(f'Unsupported content part type: {type(part)}')
|
252
268
|
|
253
269
|
|
254
270
|
def t_parts(
|
255
|
-
client: _api_client.
|
271
|
+
client: _api_client.BaseApiClient,
|
272
|
+
parts: Optional[Union[list[types.PartUnionDict], types.PartUnionDict]],
|
256
273
|
) -> list[types.Part]:
|
257
|
-
|
274
|
+
#
|
275
|
+
if parts is None or (isinstance(parts, list) and not parts):
|
258
276
|
raise ValueError('content parts are required.')
|
259
277
|
if isinstance(parts, list):
|
260
278
|
return [t_part(client, part) for part in parts]
|
@@ -263,7 +281,7 @@ def t_parts(
|
|
263
281
|
|
264
282
|
|
265
283
|
def t_image_predictions(
|
266
|
-
client: _api_client.
|
284
|
+
client: _api_client.BaseApiClient,
|
267
285
|
predictions: Optional[Iterable[Mapping[str, Any]]],
|
268
286
|
) -> list[types.GeneratedImage]:
|
269
287
|
if not predictions:
|
@@ -282,24 +300,38 @@ def t_image_predictions(
|
|
282
300
|
return images
|
283
301
|
|
284
302
|
|
285
|
-
ContentType = Union[types.Content, types.ContentDict,
|
303
|
+
ContentType = Union[types.Content, types.ContentDict, types.PartUnionDict]
|
286
304
|
|
287
305
|
|
288
306
|
def t_content(
|
289
|
-
client: _api_client.
|
290
|
-
content: ContentType,
|
291
|
-
):
|
292
|
-
if
|
307
|
+
client: _api_client.BaseApiClient,
|
308
|
+
content: Optional[ContentType],
|
309
|
+
) -> types.Content:
|
310
|
+
if content is None:
|
293
311
|
raise ValueError('content is required.')
|
294
312
|
if isinstance(content, types.Content):
|
295
313
|
return content
|
296
314
|
if isinstance(content, dict):
|
297
|
-
|
298
|
-
|
315
|
+
try:
|
316
|
+
return types.Content.model_validate(content)
|
317
|
+
except pydantic.ValidationError:
|
318
|
+
possible_part = types.Part.model_validate(content)
|
319
|
+
return (
|
320
|
+
types.ModelContent(parts=[possible_part])
|
321
|
+
if possible_part.function_call
|
322
|
+
else types.UserContent(parts=[possible_part])
|
323
|
+
)
|
324
|
+
if isinstance(content, types.Part):
|
325
|
+
return (
|
326
|
+
types.ModelContent(parts=[content])
|
327
|
+
if content.function_call
|
328
|
+
else types.UserContent(parts=[content])
|
329
|
+
)
|
330
|
+
return types.UserContent(parts=content)
|
299
331
|
|
300
332
|
|
301
333
|
def t_contents_for_embed(
|
302
|
-
client: _api_client.
|
334
|
+
client: _api_client.BaseApiClient,
|
303
335
|
contents: Union[list[types.Content], list[types.ContentDict], ContentType],
|
304
336
|
):
|
305
337
|
if client.vertexai and isinstance(contents, list):
|
@@ -314,16 +346,105 @@ def t_contents_for_embed(
|
|
314
346
|
|
315
347
|
|
316
348
|
def t_contents(
|
317
|
-
client: _api_client.
|
318
|
-
contents:
|
319
|
-
|
320
|
-
|
349
|
+
client: _api_client.BaseApiClient,
|
350
|
+
contents: Optional[
|
351
|
+
Union[types.ContentListUnion, types.ContentListUnionDict]
|
352
|
+
],
|
353
|
+
) -> list[types.Content]:
|
354
|
+
if contents is None or (isinstance(contents, list) and not contents):
|
321
355
|
raise ValueError('contents are required.')
|
322
|
-
if isinstance(contents, list):
|
323
|
-
return [t_content(client, content) for content in contents]
|
324
|
-
else:
|
356
|
+
if not isinstance(contents, list):
|
325
357
|
return [t_content(client, contents)]
|
326
358
|
|
359
|
+
try:
|
360
|
+
import PIL.Image
|
361
|
+
|
362
|
+
PIL_Image = PIL.Image.Image
|
363
|
+
except ImportError:
|
364
|
+
PIL_Image = None
|
365
|
+
|
366
|
+
result: list[types.Content] = []
|
367
|
+
accumulated_parts: list[types.Part] = []
|
368
|
+
|
369
|
+
def _is_part(part: types.PartUnionDict) -> bool:
|
370
|
+
if (
|
371
|
+
isinstance(part, str)
|
372
|
+
or isinstance(part, types.File)
|
373
|
+
or (PIL_Image is not None and isinstance(part, PIL_Image))
|
374
|
+
or isinstance(part, types.Part)
|
375
|
+
):
|
376
|
+
return True
|
377
|
+
|
378
|
+
if isinstance(part, dict):
|
379
|
+
try:
|
380
|
+
types.Part.model_validate(part)
|
381
|
+
return True
|
382
|
+
except pydantic.ValidationError:
|
383
|
+
return False
|
384
|
+
|
385
|
+
return False
|
386
|
+
|
387
|
+
def _is_user_part(part: types.Part) -> bool:
|
388
|
+
return not part.function_call
|
389
|
+
|
390
|
+
def _are_user_parts(parts: list[types.Part]) -> bool:
|
391
|
+
return all(_is_user_part(part) for part in parts)
|
392
|
+
|
393
|
+
def _append_accumulated_parts_as_content(
|
394
|
+
result: list[types.Content],
|
395
|
+
accumulated_parts: list[types.Part],
|
396
|
+
):
|
397
|
+
if not accumulated_parts:
|
398
|
+
return
|
399
|
+
result.append(
|
400
|
+
types.UserContent(parts=accumulated_parts)
|
401
|
+
if _are_user_parts(accumulated_parts)
|
402
|
+
else types.ModelContent(parts=accumulated_parts)
|
403
|
+
)
|
404
|
+
accumulated_parts[:] = []
|
405
|
+
|
406
|
+
def _handle_current_part(
|
407
|
+
result: list[types.Content],
|
408
|
+
accumulated_parts: list[types.Part],
|
409
|
+
current_part: types.PartUnionDict,
|
410
|
+
):
|
411
|
+
current_part = t_part(client, current_part)
|
412
|
+
if _is_user_part(current_part) == _are_user_parts(accumulated_parts):
|
413
|
+
accumulated_parts.append(current_part)
|
414
|
+
else:
|
415
|
+
_append_accumulated_parts_as_content(result, accumulated_parts)
|
416
|
+
accumulated_parts[:] = [current_part]
|
417
|
+
|
418
|
+
# iterator over contents
|
419
|
+
# if content type or content dict, append to result
|
420
|
+
# if consecutive part(s),
|
421
|
+
# group consecutive user part(s) to a UserContent
|
422
|
+
# group consecutive model part(s) to a ModelContent
|
423
|
+
# append to result
|
424
|
+
# if list, we only accept a list of types.PartUnion
|
425
|
+
for content in contents:
|
426
|
+
if (
|
427
|
+
isinstance(content, types.Content)
|
428
|
+
# only allowed inner list is a list of types.PartUnion
|
429
|
+
or isinstance(content, list)
|
430
|
+
):
|
431
|
+
_append_accumulated_parts_as_content(result, accumulated_parts)
|
432
|
+
if isinstance(content, list):
|
433
|
+
result.append(types.UserContent(parts=content))
|
434
|
+
else:
|
435
|
+
result.append(content)
|
436
|
+
elif (_is_part(content)): # type: ignore
|
437
|
+
_handle_current_part(result, accumulated_parts, content) # type: ignore
|
438
|
+
elif isinstance(content, dict):
|
439
|
+
# PactDict is already handled in _is_part
|
440
|
+
result.append(types.Content.model_validate(content))
|
441
|
+
else:
|
442
|
+
raise ValueError(f'Unsupported content type: {type(content)}')
|
443
|
+
|
444
|
+
_append_accumulated_parts_as_content(result, accumulated_parts)
|
445
|
+
|
446
|
+
return result
|
447
|
+
|
327
448
|
|
328
449
|
def handle_null_fields(schema: dict[str, Any]):
|
329
450
|
"""Process null fields in the schema so it is compatible with OpenAPI.
|
@@ -386,7 +507,7 @@ def handle_null_fields(schema: dict[str, Any]):
|
|
386
507
|
|
387
508
|
def process_schema(
|
388
509
|
schema: dict[str, Any],
|
389
|
-
client: Optional[_api_client.
|
510
|
+
client: Optional[_api_client.BaseApiClient] = None,
|
390
511
|
defs: Optional[dict[str, Any]] = None,
|
391
512
|
*,
|
392
513
|
order_properties: bool = True,
|
@@ -549,9 +670,9 @@ def process_schema(
|
|
549
670
|
|
550
671
|
|
551
672
|
def _process_enum(
|
552
|
-
enum: EnumMeta, client: Optional[_api_client.
|
673
|
+
enum: EnumMeta, client: Optional[_api_client.BaseApiClient] = None # type: ignore
|
553
674
|
) -> types.Schema:
|
554
|
-
for member in enum:
|
675
|
+
for member in enum: # type: ignore
|
555
676
|
if not isinstance(member.value, str):
|
556
677
|
raise TypeError(
|
557
678
|
f'Enum member {member.name} value must be a string, got'
|
@@ -568,7 +689,7 @@ def _process_enum(
|
|
568
689
|
|
569
690
|
|
570
691
|
def t_schema(
|
571
|
-
client: _api_client.
|
692
|
+
client: _api_client.BaseApiClient, origin: Union[types.SchemaUnionDict, Any]
|
572
693
|
) -> Optional[types.Schema]:
|
573
694
|
if not origin:
|
574
695
|
return None
|
@@ -614,7 +735,7 @@ def t_schema(
|
|
614
735
|
|
615
736
|
|
616
737
|
def t_speech_config(
|
617
|
-
_: _api_client.
|
738
|
+
_: _api_client.BaseApiClient, origin: Union[types.SpeechConfigUnionDict, Any]
|
618
739
|
) -> Optional[types.SpeechConfig]:
|
619
740
|
if not origin:
|
620
741
|
return None
|
@@ -643,7 +764,7 @@ def t_speech_config(
|
|
643
764
|
raise ValueError(f'Unsupported speechConfig type: {type(origin)}')
|
644
765
|
|
645
766
|
|
646
|
-
def t_tool(client: _api_client.
|
767
|
+
def t_tool(client: _api_client.BaseApiClient, origin) -> types.Tool:
|
647
768
|
if not origin:
|
648
769
|
return None
|
649
770
|
if inspect.isfunction(origin) or inspect.ismethod(origin):
|
@@ -660,7 +781,7 @@ def t_tool(client: _api_client.ApiClient, origin) -> types.Tool:
|
|
660
781
|
|
661
782
|
# Only support functions now.
|
662
783
|
def t_tools(
|
663
|
-
client: _api_client.
|
784
|
+
client: _api_client.BaseApiClient, origin: list[Any]
|
664
785
|
) -> list[types.Tool]:
|
665
786
|
if not origin:
|
666
787
|
return []
|
@@ -680,11 +801,11 @@ def t_tools(
|
|
680
801
|
return tools
|
681
802
|
|
682
803
|
|
683
|
-
def t_cached_content_name(client: _api_client.
|
804
|
+
def t_cached_content_name(client: _api_client.BaseApiClient, name: str):
|
684
805
|
return _resource_name(client, name, collection_identifier='cachedContents')
|
685
806
|
|
686
807
|
|
687
|
-
def t_batch_job_source(client: _api_client.
|
808
|
+
def t_batch_job_source(client: _api_client.BaseApiClient, src: str):
|
688
809
|
if src.startswith('gs://'):
|
689
810
|
return types.BatchJobSource(
|
690
811
|
format='jsonl',
|
@@ -699,7 +820,7 @@ def t_batch_job_source(client: _api_client.ApiClient, src: str):
|
|
699
820
|
raise ValueError(f'Unsupported source: {src}')
|
700
821
|
|
701
822
|
|
702
|
-
def t_batch_job_destination(client: _api_client.
|
823
|
+
def t_batch_job_destination(client: _api_client.BaseApiClient, dest: str):
|
703
824
|
if dest.startswith('gs://'):
|
704
825
|
return types.BatchJobDestination(
|
705
826
|
format='jsonl',
|
@@ -714,7 +835,7 @@ def t_batch_job_destination(client: _api_client.ApiClient, dest: str):
|
|
714
835
|
raise ValueError(f'Unsupported destination: {dest}')
|
715
836
|
|
716
837
|
|
717
|
-
def t_batch_job_name(client: _api_client.
|
838
|
+
def t_batch_job_name(client: _api_client.BaseApiClient, name: str):
|
718
839
|
if not client.vertexai:
|
719
840
|
return name
|
720
841
|
|
@@ -733,7 +854,7 @@ LRO_POLLING_TIMEOUT_SECONDS = 900.0
|
|
733
854
|
LRO_POLLING_MULTIPLIER = 1.5
|
734
855
|
|
735
856
|
|
736
|
-
def t_resolve_operation(api_client: _api_client.
|
857
|
+
def t_resolve_operation(api_client: _api_client.BaseApiClient, struct: dict):
|
737
858
|
if (name := struct.get('name')) and '/operations/' in name:
|
738
859
|
operation: dict[str, Any] = struct
|
739
860
|
total_seconds = 0.0
|
@@ -742,7 +863,7 @@ def t_resolve_operation(api_client: _api_client.ApiClient, struct: dict):
|
|
742
863
|
if total_seconds > LRO_POLLING_TIMEOUT_SECONDS:
|
743
864
|
raise RuntimeError(f'Operation {name} timed out.\n{operation}')
|
744
865
|
# TODO(b/374433890): Replace with LRO module once it's available.
|
745
|
-
operation
|
866
|
+
operation = api_client.request(
|
746
867
|
http_method='GET', path=name, request_dict={}
|
747
868
|
)
|
748
869
|
time.sleep(delay_seconds)
|
@@ -762,7 +883,7 @@ def t_resolve_operation(api_client: _api_client.ApiClient, struct: dict):
|
|
762
883
|
|
763
884
|
|
764
885
|
def t_file_name(
|
765
|
-
api_client: _api_client.
|
886
|
+
api_client: _api_client.BaseApiClient, name: Optional[Union[str, types.File]]
|
766
887
|
):
|
767
888
|
# Remove the files/ prefix since it's added to the url path.
|
768
889
|
if isinstance(name, types.File):
|
@@ -784,7 +905,7 @@ def t_file_name(
|
|
784
905
|
|
785
906
|
|
786
907
|
def t_tuning_job_status(
|
787
|
-
api_client: _api_client.
|
908
|
+
api_client: _api_client.BaseApiClient, status: str
|
788
909
|
) -> types.JobState:
|
789
910
|
if status == 'STATE_UNSPECIFIED':
|
790
911
|
return 'JOB_STATE_UNSPECIFIED'
|
@@ -802,7 +923,7 @@ def t_tuning_job_status(
|
|
802
923
|
# We shouldn't use this transformer if the backend adhere to Cloud Type
|
803
924
|
# format https://cloud.google.com/docs/discovery/type-format.
|
804
925
|
# TODO(b/389133914,b/390320301): Remove the hack after backend fix the issue.
|
805
|
-
def t_bytes(api_client: _api_client.
|
926
|
+
def t_bytes(api_client: _api_client.BaseApiClient, data: bytes) -> str:
|
806
927
|
if not isinstance(data, bytes):
|
807
928
|
return data
|
808
929
|
return base64.b64encode(data).decode('ascii')
|