google-genai 1.30.0__tar.gz → 1.31.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {google_genai-1.30.0/google_genai.egg-info → google_genai-1.31.0}/PKG-INFO +1 -1
- {google_genai-1.30.0 → google_genai-1.31.0}/google/genai/_api_client.py +32 -32
- {google_genai-1.30.0 → google_genai-1.31.0}/google/genai/_automatic_function_calling_util.py +12 -0
- {google_genai-1.30.0 → google_genai-1.31.0}/google/genai/_live_converters.py +1 -0
- {google_genai-1.30.0 → google_genai-1.31.0}/google/genai/_tokens_converters.py +1 -0
- {google_genai-1.30.0 → google_genai-1.31.0}/google/genai/batches.py +141 -0
- {google_genai-1.30.0 → google_genai-1.31.0}/google/genai/caches.py +1 -0
- {google_genai-1.30.0 → google_genai-1.31.0}/google/genai/files.py +1 -0
- {google_genai-1.30.0 → google_genai-1.31.0}/google/genai/models.py +374 -0
- {google_genai-1.30.0 → google_genai-1.31.0}/google/genai/operations.py +1 -0
- {google_genai-1.30.0 → google_genai-1.31.0}/google/genai/tunings.py +1 -0
- {google_genai-1.30.0 → google_genai-1.31.0}/google/genai/types.py +469 -180
- {google_genai-1.30.0 → google_genai-1.31.0}/google/genai/version.py +1 -1
- {google_genai-1.30.0 → google_genai-1.31.0/google_genai.egg-info}/PKG-INFO +1 -1
- {google_genai-1.30.0 → google_genai-1.31.0}/pyproject.toml +1 -1
- {google_genai-1.30.0 → google_genai-1.31.0}/LICENSE +0 -0
- {google_genai-1.30.0 → google_genai-1.31.0}/MANIFEST.in +0 -0
- {google_genai-1.30.0 → google_genai-1.31.0}/README.md +0 -0
- {google_genai-1.30.0 → google_genai-1.31.0}/google/genai/__init__.py +0 -0
- {google_genai-1.30.0 → google_genai-1.31.0}/google/genai/_adapters.py +0 -0
- {google_genai-1.30.0 → google_genai-1.31.0}/google/genai/_api_module.py +0 -0
- {google_genai-1.30.0 → google_genai-1.31.0}/google/genai/_base_url.py +0 -0
- {google_genai-1.30.0 → google_genai-1.31.0}/google/genai/_common.py +0 -0
- {google_genai-1.30.0 → google_genai-1.31.0}/google/genai/_extra_utils.py +0 -0
- {google_genai-1.30.0 → google_genai-1.31.0}/google/genai/_mcp_utils.py +0 -0
- {google_genai-1.30.0 → google_genai-1.31.0}/google/genai/_replay_api_client.py +0 -0
- {google_genai-1.30.0 → google_genai-1.31.0}/google/genai/_test_api_client.py +0 -0
- {google_genai-1.30.0 → google_genai-1.31.0}/google/genai/_transformers.py +0 -0
- {google_genai-1.30.0 → google_genai-1.31.0}/google/genai/chats.py +0 -0
- {google_genai-1.30.0 → google_genai-1.31.0}/google/genai/client.py +0 -0
- {google_genai-1.30.0 → google_genai-1.31.0}/google/genai/errors.py +0 -0
- {google_genai-1.30.0 → google_genai-1.31.0}/google/genai/live.py +0 -0
- {google_genai-1.30.0 → google_genai-1.31.0}/google/genai/live_music.py +0 -0
- {google_genai-1.30.0 → google_genai-1.31.0}/google/genai/pagers.py +0 -0
- {google_genai-1.30.0 → google_genai-1.31.0}/google/genai/py.typed +0 -0
- {google_genai-1.30.0 → google_genai-1.31.0}/google/genai/tokens.py +0 -0
- {google_genai-1.30.0 → google_genai-1.31.0}/google_genai.egg-info/SOURCES.txt +0 -0
- {google_genai-1.30.0 → google_genai-1.31.0}/google_genai.egg-info/dependency_links.txt +0 -0
- {google_genai-1.30.0 → google_genai-1.31.0}/google_genai.egg-info/requires.txt +0 -0
- {google_genai-1.30.0 → google_genai-1.31.0}/google_genai.egg-info/top_level.txt +0 -0
- {google_genai-1.30.0 → google_genai-1.31.0}/setup.cfg +0 -0
@@ -91,7 +91,7 @@ class EphemeralTokenAPIKeyError(ValueError):
|
|
91
91
|
|
92
92
|
# This method checks for the API key in the environment variables. Google API
|
93
93
|
# key is precedenced over Gemini API key.
|
94
|
-
def
|
94
|
+
def get_env_api_key() -> Optional[str]:
|
95
95
|
"""Gets the API key from environment variables, prioritizing GOOGLE_API_KEY.
|
96
96
|
|
97
97
|
Returns:
|
@@ -108,7 +108,7 @@ def _get_env_api_key() -> Optional[str]:
|
|
108
108
|
return env_google_api_key or env_gemini_api_key or None
|
109
109
|
|
110
110
|
|
111
|
-
def
|
111
|
+
def append_library_version_headers(headers: dict[str, str]) -> None:
|
112
112
|
"""Appends the telemetry header to the headers dict."""
|
113
113
|
library_label = f'google-genai-sdk/{version.__version__}'
|
114
114
|
language_label = 'gl-python/' + sys.version.split()[0]
|
@@ -131,7 +131,7 @@ def _append_library_version_headers(headers: dict[str, str]) -> None:
|
|
131
131
|
headers['x-goog-api-client'] = version_header_value
|
132
132
|
|
133
133
|
|
134
|
-
def
|
134
|
+
def patch_http_options(
|
135
135
|
options: HttpOptions, patch_options: HttpOptions
|
136
136
|
) -> HttpOptions:
|
137
137
|
copy_option = options.model_copy()
|
@@ -155,11 +155,11 @@ def _patch_http_options(
|
|
155
155
|
setattr(copy_option, key, getattr(options, key))
|
156
156
|
|
157
157
|
if copy_option.headers is not None:
|
158
|
-
|
158
|
+
append_library_version_headers(copy_option.headers)
|
159
159
|
return copy_option
|
160
160
|
|
161
161
|
|
162
|
-
def
|
162
|
+
def populate_server_timeout_header(
|
163
163
|
headers: dict[str, str], timeout_in_seconds: Optional[Union[float, int]]
|
164
164
|
) -> None:
|
165
165
|
"""Populates the server timeout header in the headers dict."""
|
@@ -167,7 +167,7 @@ def _populate_server_timeout_header(
|
|
167
167
|
headers['X-Server-Timeout'] = str(math.ceil(timeout_in_seconds))
|
168
168
|
|
169
169
|
|
170
|
-
def
|
170
|
+
def join_url_path(base_url: str, path: str) -> str:
|
171
171
|
parsed_base = urlparse(base_url)
|
172
172
|
base_path = (
|
173
173
|
parsed_base.path[:-1]
|
@@ -178,7 +178,7 @@ def _join_url_path(base_url: str, path: str) -> str:
|
|
178
178
|
return urlunparse(parsed_base._replace(path=base_path + '/' + path))
|
179
179
|
|
180
180
|
|
181
|
-
def
|
181
|
+
def load_auth(*, project: Union[str, None]) -> Tuple[Credentials, str]:
|
182
182
|
"""Loads google auth credentials and project id."""
|
183
183
|
credentials, loaded_project_id = google.auth.default( # type: ignore[no-untyped-call]
|
184
184
|
scopes=['https://www.googleapis.com/auth/cloud-platform'],
|
@@ -195,12 +195,12 @@ def _load_auth(*, project: Union[str, None]) -> Tuple[Credentials, str]:
|
|
195
195
|
return credentials, project
|
196
196
|
|
197
197
|
|
198
|
-
def
|
198
|
+
def refresh_auth(credentials: Credentials) -> Credentials:
|
199
199
|
credentials.refresh(Request()) # type: ignore[no-untyped-call]
|
200
200
|
return credentials
|
201
201
|
|
202
202
|
|
203
|
-
def
|
203
|
+
def get_timeout_in_seconds(
|
204
204
|
timeout: Optional[Union[float, int]],
|
205
205
|
) -> Optional[float]:
|
206
206
|
"""Converts the timeout to seconds."""
|
@@ -454,7 +454,7 @@ _RETRY_HTTP_STATUS_CODES = (
|
|
454
454
|
)
|
455
455
|
|
456
456
|
|
457
|
-
def
|
457
|
+
def retry_args(options: Optional[HttpRetryOptions]) -> _common.StringDict:
|
458
458
|
"""Returns the retry args for the given http retry options.
|
459
459
|
|
460
460
|
Args:
|
@@ -574,7 +574,7 @@ class BaseApiClient:
|
|
574
574
|
# Retrieve implicitly set values from the environment.
|
575
575
|
env_project = os.environ.get('GOOGLE_CLOUD_PROJECT', None)
|
576
576
|
env_location = os.environ.get('GOOGLE_CLOUD_LOCATION', None)
|
577
|
-
env_api_key =
|
577
|
+
env_api_key = get_env_api_key()
|
578
578
|
self.project = project or env_project
|
579
579
|
self.location = location or env_location
|
580
580
|
self.api_key = api_key or env_api_key
|
@@ -631,7 +631,7 @@ class BaseApiClient:
|
|
631
631
|
and not self.api_key
|
632
632
|
and not validated_http_options.base_url
|
633
633
|
):
|
634
|
-
credentials, self.project =
|
634
|
+
credentials, self.project = load_auth(project=None)
|
635
635
|
if not self._credentials:
|
636
636
|
self._credentials = credentials
|
637
637
|
|
@@ -670,12 +670,12 @@ class BaseApiClient:
|
|
670
670
|
self._http_options.headers['x-goog-api-key'] = self.api_key
|
671
671
|
# Update the http options with the user provided http options.
|
672
672
|
if http_options:
|
673
|
-
self._http_options =
|
673
|
+
self._http_options = patch_http_options(
|
674
674
|
self._http_options, validated_http_options
|
675
675
|
)
|
676
676
|
else:
|
677
677
|
if self._http_options.headers is not None:
|
678
|
-
|
678
|
+
append_library_version_headers(self._http_options.headers)
|
679
679
|
|
680
680
|
client_args, async_client_args = self._ensure_httpx_ssl_ctx(
|
681
681
|
self._http_options
|
@@ -689,7 +689,7 @@ class BaseApiClient:
|
|
689
689
|
)
|
690
690
|
self._websocket_ssl_ctx = self._ensure_websocket_ssl_ctx(self._http_options)
|
691
691
|
|
692
|
-
retry_kwargs =
|
692
|
+
retry_kwargs = retry_args(self._http_options.retry_options)
|
693
693
|
self._retry = tenacity.Retrying(**retry_kwargs)
|
694
694
|
self._async_retry = tenacity.AsyncRetrying(**retry_kwargs)
|
695
695
|
|
@@ -889,14 +889,14 @@ class BaseApiClient:
|
|
889
889
|
"""Retrieves the access token for the credentials."""
|
890
890
|
with self._sync_auth_lock:
|
891
891
|
if not self._credentials:
|
892
|
-
self._credentials, project =
|
892
|
+
self._credentials, project = load_auth(project=self.project)
|
893
893
|
if not self.project:
|
894
894
|
self.project = project
|
895
895
|
|
896
896
|
if self._credentials:
|
897
897
|
if self._credentials.expired or not self._credentials.token:
|
898
898
|
# Only refresh when it needs to. Default expiration is 3600 seconds.
|
899
|
-
|
899
|
+
refresh_auth(self._credentials)
|
900
900
|
if not self._credentials.token:
|
901
901
|
raise RuntimeError('Could not resolve API token from the environment')
|
902
902
|
return self._credentials.token # type: ignore[no-any-return]
|
@@ -912,7 +912,7 @@ class BaseApiClient:
|
|
912
912
|
if not self._credentials:
|
913
913
|
# Double check that the credentials are not set before loading them.
|
914
914
|
self._credentials, project = await asyncio.to_thread(
|
915
|
-
|
915
|
+
load_auth, project=self.project
|
916
916
|
)
|
917
917
|
if not self.project:
|
918
918
|
self.project = project
|
@@ -923,7 +923,7 @@ class BaseApiClient:
|
|
923
923
|
async with self._async_auth_lock:
|
924
924
|
if self._credentials.expired or not self._credentials.token:
|
925
925
|
# Double check that the credentials expired before refreshing.
|
926
|
-
await asyncio.to_thread(
|
926
|
+
await asyncio.to_thread(refresh_auth, self._credentials)
|
927
927
|
|
928
928
|
if not self._credentials.token:
|
929
929
|
raise RuntimeError('Could not resolve API token from the environment')
|
@@ -946,12 +946,12 @@ class BaseApiClient:
|
|
946
946
|
# patch the http options with the user provided settings.
|
947
947
|
if http_options:
|
948
948
|
if isinstance(http_options, HttpOptions):
|
949
|
-
patched_http_options =
|
949
|
+
patched_http_options = patch_http_options(
|
950
950
|
self._http_options,
|
951
951
|
http_options,
|
952
952
|
)
|
953
953
|
else:
|
954
|
-
patched_http_options =
|
954
|
+
patched_http_options = patch_http_options(
|
955
955
|
self._http_options, HttpOptions.model_validate(http_options)
|
956
956
|
)
|
957
957
|
else:
|
@@ -993,7 +993,7 @@ class BaseApiClient:
|
|
993
993
|
request_dict, patched_http_options.extra_body
|
994
994
|
)
|
995
995
|
|
996
|
-
url =
|
996
|
+
url = join_url_path(
|
997
997
|
base_url,
|
998
998
|
versioned_path,
|
999
999
|
)
|
@@ -1003,11 +1003,11 @@ class BaseApiClient:
|
|
1003
1003
|
'Ephemeral tokens can only be used with the live API.'
|
1004
1004
|
)
|
1005
1005
|
|
1006
|
-
timeout_in_seconds =
|
1006
|
+
timeout_in_seconds = get_timeout_in_seconds(patched_http_options.timeout)
|
1007
1007
|
|
1008
1008
|
if patched_http_options.headers is None:
|
1009
1009
|
raise ValueError('Request headers must be set.')
|
1010
|
-
|
1010
|
+
populate_server_timeout_header(
|
1011
1011
|
patched_http_options.headers, timeout_in_seconds
|
1012
1012
|
)
|
1013
1013
|
return HttpRequest(
|
@@ -1079,7 +1079,7 @@ class BaseApiClient:
|
|
1079
1079
|
)
|
1080
1080
|
# Support per request retry options.
|
1081
1081
|
if parameter_model.retry_options:
|
1082
|
-
retry_kwargs =
|
1082
|
+
retry_kwargs = retry_args(parameter_model.retry_options)
|
1083
1083
|
retry = tenacity.Retrying(**retry_kwargs)
|
1084
1084
|
return retry(self._request_once, http_request, stream) # type: ignore[no-any-return]
|
1085
1085
|
|
@@ -1239,7 +1239,7 @@ class BaseApiClient:
|
|
1239
1239
|
)
|
1240
1240
|
# Support per request retry options.
|
1241
1241
|
if parameter_model.retry_options:
|
1242
|
-
retry_kwargs =
|
1242
|
+
retry_kwargs = retry_args(parameter_model.retry_options)
|
1243
1243
|
retry = tenacity.AsyncRetrying(**retry_kwargs)
|
1244
1244
|
return await retry(self._async_request_once, http_request, stream) # type: ignore[no-any-return]
|
1245
1245
|
return await self._async_retry( # type: ignore[no-any-return]
|
@@ -1398,13 +1398,13 @@ class BaseApiClient:
|
|
1398
1398
|
if isinstance(self._http_options, dict)
|
1399
1399
|
else self._http_options.timeout
|
1400
1400
|
)
|
1401
|
-
timeout_in_seconds =
|
1401
|
+
timeout_in_seconds = get_timeout_in_seconds(timeout)
|
1402
1402
|
upload_headers = {
|
1403
1403
|
'X-Goog-Upload-Command': upload_command,
|
1404
1404
|
'X-Goog-Upload-Offset': str(offset),
|
1405
1405
|
'Content-Length': str(chunk_size),
|
1406
1406
|
}
|
1407
|
-
|
1407
|
+
populate_server_timeout_header(upload_headers, timeout_in_seconds)
|
1408
1408
|
retry_count = 0
|
1409
1409
|
while retry_count < MAX_RETRY_COUNT:
|
1410
1410
|
response = self._httpx_client.request(
|
@@ -1558,13 +1558,13 @@ class BaseApiClient:
|
|
1558
1558
|
if isinstance(self._http_options, dict)
|
1559
1559
|
else self._http_options.timeout
|
1560
1560
|
)
|
1561
|
-
timeout_in_seconds =
|
1561
|
+
timeout_in_seconds = get_timeout_in_seconds(timeout)
|
1562
1562
|
upload_headers = {
|
1563
1563
|
'X-Goog-Upload-Command': upload_command,
|
1564
1564
|
'X-Goog-Upload-Offset': str(offset),
|
1565
1565
|
'Content-Length': str(chunk_size),
|
1566
1566
|
}
|
1567
|
-
|
1567
|
+
populate_server_timeout_header(upload_headers, timeout_in_seconds)
|
1568
1568
|
|
1569
1569
|
retry_count = 0
|
1570
1570
|
response = None
|
@@ -1634,13 +1634,13 @@ class BaseApiClient:
|
|
1634
1634
|
if isinstance(self._http_options, dict)
|
1635
1635
|
else self._http_options.timeout
|
1636
1636
|
)
|
1637
|
-
timeout_in_seconds =
|
1637
|
+
timeout_in_seconds = get_timeout_in_seconds(timeout)
|
1638
1638
|
upload_headers = {
|
1639
1639
|
'X-Goog-Upload-Command': upload_command,
|
1640
1640
|
'X-Goog-Upload-Offset': str(offset),
|
1641
1641
|
'Content-Length': str(chunk_size),
|
1642
1642
|
}
|
1643
|
-
|
1643
|
+
populate_server_timeout_header(upload_headers, timeout_in_seconds)
|
1644
1644
|
|
1645
1645
|
retry_count = 0
|
1646
1646
|
client_response = None
|
{google_genai-1.30.0 → google_genai-1.31.0}/google/genai/_automatic_function_calling_util.py
RENAMED
@@ -30,6 +30,18 @@ if sys.version_info >= (3, 10):
|
|
30
30
|
else:
|
31
31
|
VersionedUnionType = typing._UnionGenericAlias # type: ignore[attr-defined]
|
32
32
|
|
33
|
+
|
34
|
+
__all__ = [
|
35
|
+
'_py_builtin_type_to_schema_type',
|
36
|
+
'_raise_for_unsupported_param',
|
37
|
+
'_handle_params_as_deferred_annotations',
|
38
|
+
'_add_unevaluated_items_to_fixed_len_tuple_schema',
|
39
|
+
'_is_builtin_primitive_or_compound',
|
40
|
+
'_is_default_value_compatible',
|
41
|
+
'_parse_schema_from_parameter',
|
42
|
+
'_get_required_fields',
|
43
|
+
]
|
44
|
+
|
33
45
|
_py_builtin_type_to_schema_type = {
|
34
46
|
str: types.Type.STRING,
|
35
47
|
int: types.Type.INTEGER,
|
@@ -30,6 +30,7 @@ from ._common import get_value_by_path as getv
|
|
30
30
|
from ._common import set_value_by_path as setv
|
31
31
|
from .pagers import AsyncPager, Pager
|
32
32
|
|
33
|
+
|
33
34
|
logger = logging.getLogger('google_genai.batches')
|
34
35
|
|
35
36
|
|
@@ -2257,6 +2258,17 @@ class Batches(_api_module.BaseModule):
|
|
2257
2258
|
)
|
2258
2259
|
print(batch_job.state)
|
2259
2260
|
"""
|
2261
|
+
parameter_model = types._CreateBatchJobParameters(
|
2262
|
+
model=model,
|
2263
|
+
src=src,
|
2264
|
+
config=config,
|
2265
|
+
)
|
2266
|
+
http_options: Optional[types.HttpOptions] = None
|
2267
|
+
if (
|
2268
|
+
parameter_model.config is not None
|
2269
|
+
and parameter_model.config.http_options is not None
|
2270
|
+
):
|
2271
|
+
http_options = parameter_model.config.http_options
|
2260
2272
|
if self._api_client.vertexai:
|
2261
2273
|
if isinstance(src, list):
|
2262
2274
|
raise ValueError(
|
@@ -2265,6 +2277,65 @@ class Batches(_api_module.BaseModule):
|
|
2265
2277
|
)
|
2266
2278
|
|
2267
2279
|
config = _extra_utils.format_destination(src, config)
|
2280
|
+
else:
|
2281
|
+
if isinstance(parameter_model.src, list) or (
|
2282
|
+
not isinstance(parameter_model.src, str)
|
2283
|
+
and parameter_model.src
|
2284
|
+
and parameter_model.src.inlined_requests
|
2285
|
+
):
|
2286
|
+
# Handle system instruction in InlinedRequests.
|
2287
|
+
request_url_dict: Optional[dict[str, str]]
|
2288
|
+
request_dict: dict[str, Any] = _CreateBatchJobParameters_to_mldev(
|
2289
|
+
self._api_client, parameter_model
|
2290
|
+
)
|
2291
|
+
request_url_dict = request_dict.get('_url')
|
2292
|
+
if request_url_dict:
|
2293
|
+
path = '{model}:batchGenerateContent'.format_map(request_url_dict)
|
2294
|
+
else:
|
2295
|
+
path = '{model}:batchGenerateContent'
|
2296
|
+
query_params = request_dict.get('_query')
|
2297
|
+
if query_params:
|
2298
|
+
path = f'{path}?{urlencode(query_params)}'
|
2299
|
+
request_dict.pop('config', None)
|
2300
|
+
|
2301
|
+
request_dict = _common.convert_to_dict(request_dict)
|
2302
|
+
request_dict = _common.encode_unserializable_types(request_dict)
|
2303
|
+
# Move system instruction to 'request':
|
2304
|
+
# {'systemInstruction': system_instruction}
|
2305
|
+
requests = []
|
2306
|
+
batch_dict = request_dict.get('batch')
|
2307
|
+
if batch_dict and isinstance(batch_dict, dict):
|
2308
|
+
input_config_dict = batch_dict.get('inputConfig')
|
2309
|
+
if input_config_dict and isinstance(input_config_dict, dict):
|
2310
|
+
requests_dict = input_config_dict.get('requests')
|
2311
|
+
if requests_dict and isinstance(requests_dict, dict):
|
2312
|
+
requests = requests_dict.get('requests')
|
2313
|
+
new_requests = []
|
2314
|
+
if requests:
|
2315
|
+
for req in requests:
|
2316
|
+
if req.get('systemInstruction'):
|
2317
|
+
value = req.pop('systemInstruction')
|
2318
|
+
req['request'].update({'systemInstruction': value})
|
2319
|
+
new_requests.append(req)
|
2320
|
+
request_dict['batch']['inputConfig']['requests'][ # type: ignore
|
2321
|
+
'requests'
|
2322
|
+
] = new_requests
|
2323
|
+
|
2324
|
+
response = self._api_client.request(
|
2325
|
+
'post', path, request_dict, http_options
|
2326
|
+
)
|
2327
|
+
|
2328
|
+
response_dict = '' if not response.body else json.loads(response.body)
|
2329
|
+
|
2330
|
+
response_dict = _BatchJob_from_mldev(response_dict)
|
2331
|
+
|
2332
|
+
return_value = types.BatchJob._from_response(
|
2333
|
+
response=response_dict, kwargs=parameter_model.model_dump()
|
2334
|
+
)
|
2335
|
+
|
2336
|
+
self._api_client._verify_response(return_value)
|
2337
|
+
return return_value
|
2338
|
+
|
2268
2339
|
return self._create(model=model, src=src, config=config)
|
2269
2340
|
|
2270
2341
|
def list(
|
@@ -2691,6 +2762,17 @@ class AsyncBatches(_api_module.BaseModule):
|
|
2691
2762
|
src="gs://path/to/input/data",
|
2692
2763
|
)
|
2693
2764
|
"""
|
2765
|
+
parameter_model = types._CreateBatchJobParameters(
|
2766
|
+
model=model,
|
2767
|
+
src=src,
|
2768
|
+
config=config,
|
2769
|
+
)
|
2770
|
+
http_options: Optional[types.HttpOptions] = None
|
2771
|
+
if (
|
2772
|
+
parameter_model.config is not None
|
2773
|
+
and parameter_model.config.http_options is not None
|
2774
|
+
):
|
2775
|
+
http_options = parameter_model.config.http_options
|
2694
2776
|
if self._api_client.vertexai:
|
2695
2777
|
if isinstance(src, list):
|
2696
2778
|
raise ValueError(
|
@@ -2699,6 +2781,65 @@ class AsyncBatches(_api_module.BaseModule):
|
|
2699
2781
|
)
|
2700
2782
|
|
2701
2783
|
config = _extra_utils.format_destination(src, config)
|
2784
|
+
else:
|
2785
|
+
if isinstance(parameter_model.src, list) or (
|
2786
|
+
not isinstance(parameter_model.src, str)
|
2787
|
+
and parameter_model.src
|
2788
|
+
and parameter_model.src.inlined_requests
|
2789
|
+
):
|
2790
|
+
# Handle system instruction in InlinedRequests.
|
2791
|
+
request_url_dict: Optional[dict[str, str]]
|
2792
|
+
request_dict: dict[str, Any] = _CreateBatchJobParameters_to_mldev(
|
2793
|
+
self._api_client, parameter_model
|
2794
|
+
)
|
2795
|
+
request_url_dict = request_dict.get('_url')
|
2796
|
+
if request_url_dict:
|
2797
|
+
path = '{model}:batchGenerateContent'.format_map(request_url_dict)
|
2798
|
+
else:
|
2799
|
+
path = '{model}:batchGenerateContent'
|
2800
|
+
query_params = request_dict.get('_query')
|
2801
|
+
if query_params:
|
2802
|
+
path = f'{path}?{urlencode(query_params)}'
|
2803
|
+
request_dict.pop('config', None)
|
2804
|
+
|
2805
|
+
request_dict = _common.convert_to_dict(request_dict)
|
2806
|
+
request_dict = _common.encode_unserializable_types(request_dict)
|
2807
|
+
# Move system instruction to 'request':
|
2808
|
+
# {'systemInstruction': system_instruction}
|
2809
|
+
requests = []
|
2810
|
+
batch_dict = request_dict.get('batch')
|
2811
|
+
if batch_dict and isinstance(batch_dict, dict):
|
2812
|
+
input_config_dict = batch_dict.get('inputConfig')
|
2813
|
+
if input_config_dict and isinstance(input_config_dict, dict):
|
2814
|
+
requests_dict = input_config_dict.get('requests')
|
2815
|
+
if requests_dict and isinstance(requests_dict, dict):
|
2816
|
+
requests = requests_dict.get('requests')
|
2817
|
+
new_requests = []
|
2818
|
+
if requests:
|
2819
|
+
for req in requests:
|
2820
|
+
if req.get('systemInstruction'):
|
2821
|
+
value = req.pop('systemInstruction')
|
2822
|
+
req['request'].update({'systemInstruction': value})
|
2823
|
+
new_requests.append(req)
|
2824
|
+
request_dict['batch']['inputConfig']['requests'][ # type: ignore
|
2825
|
+
'requests'
|
2826
|
+
] = new_requests
|
2827
|
+
|
2828
|
+
response = await self._api_client.async_request(
|
2829
|
+
'post', path, request_dict, http_options
|
2830
|
+
)
|
2831
|
+
|
2832
|
+
response_dict = '' if not response.body else json.loads(response.body)
|
2833
|
+
|
2834
|
+
response_dict = _BatchJob_from_mldev(response_dict)
|
2835
|
+
|
2836
|
+
return_value = types.BatchJob._from_response(
|
2837
|
+
response=response_dict, kwargs=parameter_model.model_dump()
|
2838
|
+
)
|
2839
|
+
|
2840
|
+
self._api_client._verify_response(return_value)
|
2841
|
+
return return_value
|
2842
|
+
|
2702
2843
|
return await self._create(model=model, src=src, config=config)
|
2703
2844
|
|
2704
2845
|
async def list(
|