google-genai 0.7.0__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google/genai/_api_client.py +26 -25
- google/genai/_automatic_function_calling_util.py +19 -20
- google/genai/_common.py +33 -2
- google/genai/_extra_utils.py +12 -6
- google/genai/_operations.py +365 -0
- google/genai/_replay_api_client.py +7 -0
- google/genai/_transformers.py +32 -14
- google/genai/errors.py +4 -0
- google/genai/files.py +79 -71
- google/genai/live.py +5 -0
- google/genai/models.py +344 -35
- google/genai/tunings.py +288 -61
- google/genai/types.py +191 -20
- google/genai/version.py +1 -1
- {google_genai-0.7.0.dist-info → google_genai-1.0.0.dist-info}/METADATA +90 -48
- google_genai-1.0.0.dist-info/RECORD +27 -0
- google_genai-0.7.0.dist-info/RECORD +0 -26
- {google_genai-0.7.0.dist-info → google_genai-1.0.0.dist-info}/LICENSE +0 -0
- {google_genai-0.7.0.dist-info → google_genai-1.0.0.dist-info}/WHEEL +0 -0
- {google_genai-0.7.0.dist-info → google_genai-1.0.0.dist-info}/top_level.txt +0 -0
google/genai/_api_client.py
CHANGED
@@ -99,6 +99,19 @@ class HttpRequest:
|
|
99
99
|
timeout: Optional[float] = None
|
100
100
|
|
101
101
|
|
102
|
+
# TODO(b/394358912): Update this class to use a SDKResponse class that can be
|
103
|
+
# generated and used for all languages.
|
104
|
+
@dataclass
|
105
|
+
class BaseResponse:
|
106
|
+
http_headers: dict[str, str]
|
107
|
+
|
108
|
+
@property
|
109
|
+
def dict(self) -> dict[str, Any]:
|
110
|
+
if isinstance(self, dict):
|
111
|
+
return self
|
112
|
+
return {'httpHeaders': self.http_headers}
|
113
|
+
|
114
|
+
|
102
115
|
class HttpResponse:
|
103
116
|
|
104
117
|
def __init__(
|
@@ -255,7 +268,7 @@ class ApiClient:
|
|
255
268
|
'Project and location or API key must be set when using the Vertex '
|
256
269
|
'AI API.'
|
257
270
|
)
|
258
|
-
if self.api_key:
|
271
|
+
if self.api_key or self.location == 'global':
|
259
272
|
self._http_options['base_url'] = (
|
260
273
|
f'https://aiplatform.googleapis.com/'
|
261
274
|
)
|
@@ -273,7 +286,7 @@ class ApiClient:
|
|
273
286
|
self._http_options['api_version'] = 'v1beta'
|
274
287
|
# Default options for both clients.
|
275
288
|
self._http_options['headers'] = {'Content-Type': 'application/json'}
|
276
|
-
if self.api_key
|
289
|
+
if self.api_key:
|
277
290
|
self._http_options['headers']['x-goog-api-key'] = self.api_key
|
278
291
|
# Update the http options with the user provided http options.
|
279
292
|
if http_options:
|
@@ -323,8 +336,6 @@ class ApiClient:
|
|
323
336
|
and not self.api_key
|
324
337
|
):
|
325
338
|
path = f'projects/{self.project}/locations/{self.location}/' + path
|
326
|
-
elif self.vertexai and self.api_key:
|
327
|
-
path = f'{path}?key={self.api_key}'
|
328
339
|
url = _join_url_path(
|
329
340
|
patched_http_options['base_url'],
|
330
341
|
patched_http_options['api_version'] + '/' + path,
|
@@ -436,18 +447,12 @@ class ApiClient:
|
|
436
447
|
http_method, path, request_dict, http_options
|
437
448
|
)
|
438
449
|
response = self._request(http_request, stream=False)
|
439
|
-
|
440
|
-
|
441
|
-
|
442
|
-
|
443
|
-
|
444
|
-
|
445
|
-
elif (
|
446
|
-
isinstance(http_options, dict)
|
447
|
-
and 'deprecated_response_payload' in http_options
|
448
|
-
):
|
449
|
-
response._copy_to_dict(http_options['deprecated_response_payload'])
|
450
|
-
return response.json
|
450
|
+
json_response = response.json
|
451
|
+
if not json_response:
|
452
|
+
base_response = BaseResponse(response.headers).dict
|
453
|
+
return base_response
|
454
|
+
|
455
|
+
return json_response
|
451
456
|
|
452
457
|
def request_streamed(
|
453
458
|
self,
|
@@ -461,10 +466,6 @@ class ApiClient:
|
|
461
466
|
)
|
462
467
|
|
463
468
|
session_response = self._request(http_request, stream=True)
|
464
|
-
if http_options and 'deprecated_response_payload' in http_options:
|
465
|
-
session_response._copy_to_dict(
|
466
|
-
http_options['deprecated_response_payload']
|
467
|
-
)
|
468
469
|
for chunk in session_response.segments():
|
469
470
|
yield chunk
|
470
471
|
|
@@ -480,9 +481,11 @@ class ApiClient:
|
|
480
481
|
)
|
481
482
|
|
482
483
|
result = await self._async_request(http_request=http_request, stream=False)
|
483
|
-
|
484
|
-
|
485
|
-
|
484
|
+
json_response = result.json
|
485
|
+
if not json_response:
|
486
|
+
base_response = BaseResponse(result.headers).dict
|
487
|
+
return base_response
|
488
|
+
return json_response
|
486
489
|
|
487
490
|
async def async_request_streamed(
|
488
491
|
self,
|
@@ -497,8 +500,6 @@ class ApiClient:
|
|
497
500
|
|
498
501
|
response = await self._async_request(http_request=http_request, stream=True)
|
499
502
|
|
500
|
-
if http_options and 'deprecated_response_payload' in http_options:
|
501
|
-
response._copy_to_dict(http_options['deprecated_response_payload'])
|
502
503
|
async def async_generator():
|
503
504
|
async for chunk in response:
|
504
505
|
yield chunk
|
@@ -17,14 +17,18 @@ import inspect
|
|
17
17
|
import sys
|
18
18
|
import types as builtin_types
|
19
19
|
import typing
|
20
|
-
from typing import Any, Callable,
|
20
|
+
from typing import _GenericAlias, Any, Callable, get_args, get_origin, Literal, Union
|
21
|
+
|
21
22
|
import pydantic
|
23
|
+
|
24
|
+
from . import _extra_utils
|
22
25
|
from . import types
|
23
26
|
|
27
|
+
|
24
28
|
if sys.version_info >= (3, 10):
|
25
|
-
|
29
|
+
VersionedUnionType = builtin_types.UnionType
|
26
30
|
else:
|
27
|
-
|
31
|
+
VersionedUnionType = typing._UnionGenericAlias
|
28
32
|
|
29
33
|
_py_builtin_type_to_schema_type = {
|
30
34
|
str: 'STRING',
|
@@ -45,7 +49,8 @@ def _is_builtin_primitive_or_compound(
|
|
45
49
|
def _raise_for_any_of_if_mldev(schema: types.Schema):
|
46
50
|
if schema.any_of:
|
47
51
|
raise ValueError(
|
48
|
-
'AnyOf is not supported in function declaration schema for
|
52
|
+
'AnyOf is not supported in function declaration schema for'
|
53
|
+
' the Gemini API.'
|
49
54
|
)
|
50
55
|
|
51
56
|
|
@@ -53,15 +58,7 @@ def _raise_for_default_if_mldev(schema: types.Schema):
|
|
53
58
|
if schema.default is not None:
|
54
59
|
raise ValueError(
|
55
60
|
'Default value is not supported in function declaration schema for'
|
56
|
-
'
|
57
|
-
)
|
58
|
-
|
59
|
-
|
60
|
-
def _raise_for_nullable_if_mldev(schema: types.Schema):
|
61
|
-
if schema.nullable:
|
62
|
-
raise ValueError(
|
63
|
-
'Nullable is not supported in function declaration schema for'
|
64
|
-
' Google AI.'
|
61
|
+
' the Gemini API.'
|
65
62
|
)
|
66
63
|
|
67
64
|
|
@@ -69,7 +66,6 @@ def _raise_if_schema_unsupported(client, schema: types.Schema):
|
|
69
66
|
if not client.vertexai:
|
70
67
|
_raise_for_any_of_if_mldev(schema)
|
71
68
|
_raise_for_default_if_mldev(schema)
|
72
|
-
_raise_for_nullable_if_mldev(schema)
|
73
69
|
|
74
70
|
|
75
71
|
def _is_default_value_compatible(
|
@@ -82,10 +78,10 @@ def _is_default_value_compatible(
|
|
82
78
|
if (
|
83
79
|
isinstance(annotation, _GenericAlias)
|
84
80
|
or isinstance(annotation, builtin_types.GenericAlias)
|
85
|
-
or isinstance(annotation,
|
81
|
+
or isinstance(annotation, VersionedUnionType)
|
86
82
|
):
|
87
83
|
origin = get_origin(annotation)
|
88
|
-
if origin in (Union,
|
84
|
+
if origin in (Union, VersionedUnionType):
|
89
85
|
return any(
|
90
86
|
_is_default_value_compatible(default_value, arg)
|
91
87
|
for arg in get_args(annotation)
|
@@ -141,7 +137,7 @@ def _parse_schema_from_parameter(
|
|
141
137
|
_raise_if_schema_unsupported(client, schema)
|
142
138
|
return schema
|
143
139
|
if (
|
144
|
-
isinstance(param.annotation,
|
140
|
+
isinstance(param.annotation, VersionedUnionType)
|
145
141
|
# only parse simple UnionType, example int | str | float | bool
|
146
142
|
# complex UnionType will be invoked in raise branch
|
147
143
|
and all(
|
@@ -229,7 +225,11 @@ def _parse_schema_from_parameter(
|
|
229
225
|
schema.type = 'OBJECT'
|
230
226
|
unique_types = set()
|
231
227
|
for arg in args:
|
232
|
-
|
228
|
+
# The first check is for NoneType in Python 3.9, since the __name__
|
229
|
+
# attribute is not available in Python 3.9
|
230
|
+
if type(arg) is type(None) or (
|
231
|
+
hasattr(arg, '__name__') and arg.__name__ == 'NoneType'
|
232
|
+
): # Optional type
|
233
233
|
schema.nullable = True
|
234
234
|
continue
|
235
235
|
schema_in_any_of = _parse_schema_from_parameter(
|
@@ -272,9 +272,8 @@ def _parse_schema_from_parameter(
|
|
272
272
|
return schema
|
273
273
|
# all other generic alias will be invoked in raise branch
|
274
274
|
if (
|
275
|
-
inspect.isclass(param.annotation)
|
276
275
|
# for user defined class, we only support pydantic model
|
277
|
-
|
276
|
+
_extra_utils.is_annotation_pydantic_model(param.annotation)
|
278
277
|
):
|
279
278
|
if (
|
280
279
|
param.default is not inspect.Parameter.empty
|
google/genai/_common.py
CHANGED
@@ -18,14 +18,17 @@
|
|
18
18
|
import base64
|
19
19
|
import datetime
|
20
20
|
import enum
|
21
|
+
import functools
|
21
22
|
import typing
|
22
23
|
from typing import Union
|
23
24
|
import uuid
|
25
|
+
import warnings
|
24
26
|
|
25
27
|
import pydantic
|
26
28
|
from pydantic import alias_generators
|
27
29
|
|
28
30
|
from . import _api_client
|
31
|
+
from . import errors
|
29
32
|
|
30
33
|
|
31
34
|
def set_value_by_path(data, keys, value):
|
@@ -213,8 +216,16 @@ class CaseInSensitiveEnum(str, enum.Enum):
|
|
213
216
|
except KeyError:
|
214
217
|
try:
|
215
218
|
return cls[value.lower()] # Try to access directly with lowercase
|
216
|
-
except KeyError
|
217
|
-
|
219
|
+
except KeyError:
|
220
|
+
warnings.warn(f"{value} is not a valid {cls.__name__}")
|
221
|
+
try:
|
222
|
+
# Creating a enum instance based on the value
|
223
|
+
unknown_enum_val = cls._new_member_(cls) # pylint: disable=protected-access,attribute-error
|
224
|
+
unknown_enum_val._name_ = str(value) # pylint: disable=protected-access
|
225
|
+
unknown_enum_val._value_ = value # pylint: disable=protected-access
|
226
|
+
return unknown_enum_val
|
227
|
+
except:
|
228
|
+
return None
|
218
229
|
|
219
230
|
|
220
231
|
def timestamped_unique_name() -> str:
|
@@ -264,3 +275,23 @@ def encode_unserializable_types(data: dict[str, object]) -> dict[str, object]:
|
|
264
275
|
else:
|
265
276
|
processed_data[key] = value
|
266
277
|
return processed_data
|
278
|
+
|
279
|
+
|
280
|
+
def experimental_warning(message: str):
|
281
|
+
"""Experimental warning, only warns once."""
|
282
|
+
def decorator(func):
|
283
|
+
warning_done = False
|
284
|
+
@functools.wraps(func)
|
285
|
+
def wrapper(*args, **kwargs):
|
286
|
+
nonlocal warning_done
|
287
|
+
if not warning_done:
|
288
|
+
warning_done = True
|
289
|
+
warnings.warn(
|
290
|
+
message=message,
|
291
|
+
category=errors.ExperimentalWarning,
|
292
|
+
stacklevel=2,
|
293
|
+
)
|
294
|
+
return func(*args, **kwargs)
|
295
|
+
return wrapper
|
296
|
+
return decorator
|
297
|
+
|
google/genai/_extra_utils.py
CHANGED
@@ -108,16 +108,22 @@ def convert_number_values_for_function_call_args(
|
|
108
108
|
return args
|
109
109
|
|
110
110
|
|
111
|
-
def
|
112
|
-
|
113
|
-
|
114
|
-
|
111
|
+
def is_annotation_pydantic_model(annotation: Any) -> bool:
|
112
|
+
try:
|
113
|
+
return inspect.isclass(annotation) and issubclass(
|
114
|
+
annotation, pydantic.BaseModel
|
115
|
+
)
|
116
|
+
# for python 3.10 and below, inspect.isclass(annotation) has inconsistent
|
117
|
+
# results with versions above. for example, inspect.isclass(dict[str, int]) is
|
118
|
+
# True in 3.10 and below but False in 3.11 and above.
|
119
|
+
except TypeError:
|
120
|
+
return False
|
115
121
|
|
116
122
|
|
117
123
|
def convert_if_exist_pydantic_model(
|
118
124
|
value: Any, annotation: Any, param_name: str, func_name: str
|
119
125
|
) -> Any:
|
120
|
-
if isinstance(value, dict) and
|
126
|
+
if isinstance(value, dict) and is_annotation_pydantic_model(annotation):
|
121
127
|
try:
|
122
128
|
return annotation(**value)
|
123
129
|
except pydantic.ValidationError as e:
|
@@ -146,7 +152,7 @@ def convert_if_exist_pydantic_model(
|
|
146
152
|
if (
|
147
153
|
(get_args(arg) and get_origin(arg) is list)
|
148
154
|
or isinstance(value, arg)
|
149
|
-
or (isinstance(value, dict) and
|
155
|
+
or (isinstance(value, dict) and is_annotation_pydantic_model(arg))
|
150
156
|
):
|
151
157
|
try:
|
152
158
|
return convert_if_exist_pydantic_model(
|
@@ -0,0 +1,365 @@
|
|
1
|
+
# Copyright 2024 Google LLC
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
#
|
15
|
+
|
16
|
+
# Code generated by the Google Gen AI SDK generator DO NOT EDIT.
|
17
|
+
|
18
|
+
from typing import Optional, Union
|
19
|
+
from urllib.parse import urlencode
|
20
|
+
from . import _api_module
|
21
|
+
from . import _common
|
22
|
+
from . import types
|
23
|
+
from ._api_client import ApiClient
|
24
|
+
from ._common import get_value_by_path as getv
|
25
|
+
from ._common import set_value_by_path as setv
|
26
|
+
|
27
|
+
|
28
|
+
def _GetOperationParameters_to_mldev(
|
29
|
+
api_client: ApiClient,
|
30
|
+
from_object: Union[dict, object],
|
31
|
+
parent_object: dict = None,
|
32
|
+
) -> dict:
|
33
|
+
to_object = {}
|
34
|
+
if getv(from_object, ['operation_name']) is not None:
|
35
|
+
setv(
|
36
|
+
to_object,
|
37
|
+
['_url', 'operationName'],
|
38
|
+
getv(from_object, ['operation_name']),
|
39
|
+
)
|
40
|
+
|
41
|
+
if getv(from_object, ['config']) is not None:
|
42
|
+
setv(to_object, ['config'], getv(from_object, ['config']))
|
43
|
+
|
44
|
+
return to_object
|
45
|
+
|
46
|
+
|
47
|
+
def _GetOperationParameters_to_vertex(
|
48
|
+
api_client: ApiClient,
|
49
|
+
from_object: Union[dict, object],
|
50
|
+
parent_object: dict = None,
|
51
|
+
) -> dict:
|
52
|
+
to_object = {}
|
53
|
+
if getv(from_object, ['operation_name']) is not None:
|
54
|
+
setv(
|
55
|
+
to_object,
|
56
|
+
['_url', 'operationName'],
|
57
|
+
getv(from_object, ['operation_name']),
|
58
|
+
)
|
59
|
+
|
60
|
+
if getv(from_object, ['config']) is not None:
|
61
|
+
setv(to_object, ['config'], getv(from_object, ['config']))
|
62
|
+
|
63
|
+
return to_object
|
64
|
+
|
65
|
+
|
66
|
+
def _FetchPredictOperationParameters_to_mldev(
|
67
|
+
api_client: ApiClient,
|
68
|
+
from_object: Union[dict, object],
|
69
|
+
parent_object: dict = None,
|
70
|
+
) -> dict:
|
71
|
+
to_object = {}
|
72
|
+
if getv(from_object, ['operation_name']) is not None:
|
73
|
+
raise ValueError('operation_name parameter is not supported in Gemini API.')
|
74
|
+
|
75
|
+
if getv(from_object, ['resource_name']) is not None:
|
76
|
+
raise ValueError('resource_name parameter is not supported in Gemini API.')
|
77
|
+
|
78
|
+
if getv(from_object, ['config']) is not None:
|
79
|
+
raise ValueError('config parameter is not supported in Gemini API.')
|
80
|
+
|
81
|
+
return to_object
|
82
|
+
|
83
|
+
|
84
|
+
def _FetchPredictOperationParameters_to_vertex(
|
85
|
+
api_client: ApiClient,
|
86
|
+
from_object: Union[dict, object],
|
87
|
+
parent_object: dict = None,
|
88
|
+
) -> dict:
|
89
|
+
to_object = {}
|
90
|
+
if getv(from_object, ['operation_name']) is not None:
|
91
|
+
setv(to_object, ['operationName'], getv(from_object, ['operation_name']))
|
92
|
+
|
93
|
+
if getv(from_object, ['resource_name']) is not None:
|
94
|
+
setv(
|
95
|
+
to_object,
|
96
|
+
['_url', 'resourceName'],
|
97
|
+
getv(from_object, ['resource_name']),
|
98
|
+
)
|
99
|
+
|
100
|
+
if getv(from_object, ['config']) is not None:
|
101
|
+
setv(to_object, ['config'], getv(from_object, ['config']))
|
102
|
+
|
103
|
+
return to_object
|
104
|
+
|
105
|
+
|
106
|
+
def _Operation_from_mldev(
|
107
|
+
api_client: ApiClient,
|
108
|
+
from_object: Union[dict, object],
|
109
|
+
parent_object: dict = None,
|
110
|
+
) -> dict:
|
111
|
+
to_object = {}
|
112
|
+
if getv(from_object, ['name']) is not None:
|
113
|
+
setv(to_object, ['name'], getv(from_object, ['name']))
|
114
|
+
|
115
|
+
if getv(from_object, ['metadata']) is not None:
|
116
|
+
setv(to_object, ['metadata'], getv(from_object, ['metadata']))
|
117
|
+
|
118
|
+
if getv(from_object, ['done']) is not None:
|
119
|
+
setv(to_object, ['done'], getv(from_object, ['done']))
|
120
|
+
|
121
|
+
if getv(from_object, ['error']) is not None:
|
122
|
+
setv(to_object, ['error'], getv(from_object, ['error']))
|
123
|
+
|
124
|
+
if getv(from_object, ['response']) is not None:
|
125
|
+
setv(to_object, ['response'], getv(from_object, ['response']))
|
126
|
+
|
127
|
+
return to_object
|
128
|
+
|
129
|
+
|
130
|
+
def _Operation_from_vertex(
|
131
|
+
api_client: ApiClient,
|
132
|
+
from_object: Union[dict, object],
|
133
|
+
parent_object: dict = None,
|
134
|
+
) -> dict:
|
135
|
+
to_object = {}
|
136
|
+
if getv(from_object, ['name']) is not None:
|
137
|
+
setv(to_object, ['name'], getv(from_object, ['name']))
|
138
|
+
|
139
|
+
if getv(from_object, ['metadata']) is not None:
|
140
|
+
setv(to_object, ['metadata'], getv(from_object, ['metadata']))
|
141
|
+
|
142
|
+
if getv(from_object, ['done']) is not None:
|
143
|
+
setv(to_object, ['done'], getv(from_object, ['done']))
|
144
|
+
|
145
|
+
if getv(from_object, ['error']) is not None:
|
146
|
+
setv(to_object, ['error'], getv(from_object, ['error']))
|
147
|
+
|
148
|
+
if getv(from_object, ['response']) is not None:
|
149
|
+
setv(to_object, ['response'], getv(from_object, ['response']))
|
150
|
+
|
151
|
+
return to_object
|
152
|
+
|
153
|
+
|
154
|
+
class _operations(_api_module.BaseModule):
|
155
|
+
|
156
|
+
def _get_operation(
|
157
|
+
self,
|
158
|
+
*,
|
159
|
+
operation_name: str,
|
160
|
+
config: Optional[types.GetOperationConfigOrDict] = None,
|
161
|
+
) -> types.Operation:
|
162
|
+
parameter_model = types._GetOperationParameters(
|
163
|
+
operation_name=operation_name,
|
164
|
+
config=config,
|
165
|
+
)
|
166
|
+
|
167
|
+
if self._api_client.vertexai:
|
168
|
+
request_dict = _GetOperationParameters_to_vertex(
|
169
|
+
self._api_client, parameter_model
|
170
|
+
)
|
171
|
+
path = '{operationName}'.format_map(request_dict.get('_url'))
|
172
|
+
else:
|
173
|
+
request_dict = _GetOperationParameters_to_mldev(
|
174
|
+
self._api_client, parameter_model
|
175
|
+
)
|
176
|
+
path = '{operationName}'.format_map(request_dict.get('_url'))
|
177
|
+
query_params = request_dict.get('_query')
|
178
|
+
if query_params:
|
179
|
+
path = f'{path}?{urlencode(query_params)}'
|
180
|
+
# TODO: remove the hack that pops config.
|
181
|
+
request_dict.pop('config', None)
|
182
|
+
|
183
|
+
http_options = None
|
184
|
+
if isinstance(config, dict):
|
185
|
+
http_options = config.get('http_options', None)
|
186
|
+
elif hasattr(config, 'http_options'):
|
187
|
+
http_options = config.http_options
|
188
|
+
|
189
|
+
request_dict = _common.convert_to_dict(request_dict)
|
190
|
+
request_dict = _common.encode_unserializable_types(request_dict)
|
191
|
+
|
192
|
+
response_dict = self._api_client.request(
|
193
|
+
'get', path, request_dict, http_options
|
194
|
+
)
|
195
|
+
|
196
|
+
if self._api_client.vertexai:
|
197
|
+
response_dict = _Operation_from_vertex(self._api_client, response_dict)
|
198
|
+
else:
|
199
|
+
response_dict = _Operation_from_mldev(self._api_client, response_dict)
|
200
|
+
|
201
|
+
return_value = types.Operation._from_response(
|
202
|
+
response=response_dict, kwargs=parameter_model
|
203
|
+
)
|
204
|
+
self._api_client._verify_response(return_value)
|
205
|
+
return return_value
|
206
|
+
|
207
|
+
def _fetch_predict_operation(
|
208
|
+
self,
|
209
|
+
*,
|
210
|
+
operation_name: str,
|
211
|
+
resource_name: str,
|
212
|
+
config: Optional[types.FetchPredictOperationConfigOrDict] = None,
|
213
|
+
) -> types.Operation:
|
214
|
+
parameter_model = types._FetchPredictOperationParameters(
|
215
|
+
operation_name=operation_name,
|
216
|
+
resource_name=resource_name,
|
217
|
+
config=config,
|
218
|
+
)
|
219
|
+
|
220
|
+
if not self._api_client.vertexai:
|
221
|
+
raise ValueError('This method is only supported in the Vertex AI client.')
|
222
|
+
else:
|
223
|
+
request_dict = _FetchPredictOperationParameters_to_vertex(
|
224
|
+
self._api_client, parameter_model
|
225
|
+
)
|
226
|
+
path = '{resourceName}:fetchPredictOperation'.format_map(
|
227
|
+
request_dict.get('_url')
|
228
|
+
)
|
229
|
+
|
230
|
+
query_params = request_dict.get('_query')
|
231
|
+
if query_params:
|
232
|
+
path = f'{path}?{urlencode(query_params)}'
|
233
|
+
# TODO: remove the hack that pops config.
|
234
|
+
request_dict.pop('config', None)
|
235
|
+
|
236
|
+
http_options = None
|
237
|
+
if isinstance(config, dict):
|
238
|
+
http_options = config.get('http_options', None)
|
239
|
+
elif hasattr(config, 'http_options'):
|
240
|
+
http_options = config.http_options
|
241
|
+
|
242
|
+
request_dict = _common.convert_to_dict(request_dict)
|
243
|
+
request_dict = _common.encode_unserializable_types(request_dict)
|
244
|
+
|
245
|
+
response_dict = self._api_client.request(
|
246
|
+
'post', path, request_dict, http_options
|
247
|
+
)
|
248
|
+
|
249
|
+
if self._api_client.vertexai:
|
250
|
+
response_dict = _Operation_from_vertex(self._api_client, response_dict)
|
251
|
+
else:
|
252
|
+
response_dict = _Operation_from_mldev(self._api_client, response_dict)
|
253
|
+
|
254
|
+
return_value = types.Operation._from_response(
|
255
|
+
response=response_dict, kwargs=parameter_model
|
256
|
+
)
|
257
|
+
self._api_client._verify_response(return_value)
|
258
|
+
return return_value
|
259
|
+
|
260
|
+
|
261
|
+
class Async_operations(_api_module.BaseModule):
|
262
|
+
|
263
|
+
async def _get_operation(
|
264
|
+
self,
|
265
|
+
*,
|
266
|
+
operation_name: str,
|
267
|
+
config: Optional[types.GetOperationConfigOrDict] = None,
|
268
|
+
) -> types.Operation:
|
269
|
+
parameter_model = types._GetOperationParameters(
|
270
|
+
operation_name=operation_name,
|
271
|
+
config=config,
|
272
|
+
)
|
273
|
+
|
274
|
+
if self._api_client.vertexai:
|
275
|
+
request_dict = _GetOperationParameters_to_vertex(
|
276
|
+
self._api_client, parameter_model
|
277
|
+
)
|
278
|
+
path = '{operationName}'.format_map(request_dict.get('_url'))
|
279
|
+
else:
|
280
|
+
request_dict = _GetOperationParameters_to_mldev(
|
281
|
+
self._api_client, parameter_model
|
282
|
+
)
|
283
|
+
path = '{operationName}'.format_map(request_dict.get('_url'))
|
284
|
+
query_params = request_dict.get('_query')
|
285
|
+
if query_params:
|
286
|
+
path = f'{path}?{urlencode(query_params)}'
|
287
|
+
# TODO: remove the hack that pops config.
|
288
|
+
request_dict.pop('config', None)
|
289
|
+
|
290
|
+
http_options = None
|
291
|
+
if isinstance(config, dict):
|
292
|
+
http_options = config.get('http_options', None)
|
293
|
+
elif hasattr(config, 'http_options'):
|
294
|
+
http_options = config.http_options
|
295
|
+
|
296
|
+
request_dict = _common.convert_to_dict(request_dict)
|
297
|
+
request_dict = _common.encode_unserializable_types(request_dict)
|
298
|
+
|
299
|
+
response_dict = await self._api_client.async_request(
|
300
|
+
'get', path, request_dict, http_options
|
301
|
+
)
|
302
|
+
|
303
|
+
if self._api_client.vertexai:
|
304
|
+
response_dict = _Operation_from_vertex(self._api_client, response_dict)
|
305
|
+
else:
|
306
|
+
response_dict = _Operation_from_mldev(self._api_client, response_dict)
|
307
|
+
|
308
|
+
return_value = types.Operation._from_response(
|
309
|
+
response=response_dict, kwargs=parameter_model
|
310
|
+
)
|
311
|
+
self._api_client._verify_response(return_value)
|
312
|
+
return return_value
|
313
|
+
|
314
|
+
async def _fetch_predict_operation(
|
315
|
+
self,
|
316
|
+
*,
|
317
|
+
operation_name: str,
|
318
|
+
resource_name: str,
|
319
|
+
config: Optional[types.FetchPredictOperationConfigOrDict] = None,
|
320
|
+
) -> types.Operation:
|
321
|
+
parameter_model = types._FetchPredictOperationParameters(
|
322
|
+
operation_name=operation_name,
|
323
|
+
resource_name=resource_name,
|
324
|
+
config=config,
|
325
|
+
)
|
326
|
+
|
327
|
+
if not self._api_client.vertexai:
|
328
|
+
raise ValueError('This method is only supported in the Vertex AI client.')
|
329
|
+
else:
|
330
|
+
request_dict = _FetchPredictOperationParameters_to_vertex(
|
331
|
+
self._api_client, parameter_model
|
332
|
+
)
|
333
|
+
path = '{resourceName}:fetchPredictOperation'.format_map(
|
334
|
+
request_dict.get('_url')
|
335
|
+
)
|
336
|
+
|
337
|
+
query_params = request_dict.get('_query')
|
338
|
+
if query_params:
|
339
|
+
path = f'{path}?{urlencode(query_params)}'
|
340
|
+
# TODO: remove the hack that pops config.
|
341
|
+
request_dict.pop('config', None)
|
342
|
+
|
343
|
+
http_options = None
|
344
|
+
if isinstance(config, dict):
|
345
|
+
http_options = config.get('http_options', None)
|
346
|
+
elif hasattr(config, 'http_options'):
|
347
|
+
http_options = config.http_options
|
348
|
+
|
349
|
+
request_dict = _common.convert_to_dict(request_dict)
|
350
|
+
request_dict = _common.encode_unserializable_types(request_dict)
|
351
|
+
|
352
|
+
response_dict = await self._api_client.async_request(
|
353
|
+
'post', path, request_dict, http_options
|
354
|
+
)
|
355
|
+
|
356
|
+
if self._api_client.vertexai:
|
357
|
+
response_dict = _Operation_from_vertex(self._api_client, response_dict)
|
358
|
+
else:
|
359
|
+
response_dict = _Operation_from_mldev(self._api_client, response_dict)
|
360
|
+
|
361
|
+
return_value = types.Operation._from_response(
|
362
|
+
response=response_dict, kwargs=parameter_model
|
363
|
+
)
|
364
|
+
self._api_client._verify_response(return_value)
|
365
|
+
return return_value
|
@@ -78,6 +78,11 @@ def _redact_request_url(url: str) -> str:
|
|
78
78
|
'{VERTEX_URL_PREFIX}/',
|
79
79
|
result,
|
80
80
|
)
|
81
|
+
result = re.sub(
|
82
|
+
r'.*aiplatform.googleapis.com/[^/]+/',
|
83
|
+
'{VERTEX_URL_PREFIX}/',
|
84
|
+
result,
|
85
|
+
)
|
81
86
|
result = re.sub(
|
82
87
|
r'https://generativelanguage.googleapis.com/[^/]+',
|
83
88
|
'{MLDEV_URL_PREFIX}',
|
@@ -357,6 +362,8 @@ class ReplayApiClient(ApiClient):
|
|
357
362
|
if self._should_update_replay():
|
358
363
|
if isinstance(response_model, list):
|
359
364
|
response_model = response_model[0]
|
365
|
+
if response_model and 'http_headers' in response_model.model_fields:
|
366
|
+
response_model.http_headers.pop('Date', None)
|
360
367
|
interaction.response.sdk_response_segments.append(
|
361
368
|
response_model.model_dump(exclude_none=True)
|
362
369
|
)
|