google-genai 0.5.0__py3-none-any.whl → 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google/genai/_api_client.py +234 -131
- google/genai/_api_module.py +24 -0
- google/genai/_automatic_function_calling_util.py +43 -22
- google/genai/_common.py +37 -12
- google/genai/_extra_utils.py +25 -19
- google/genai/_replay_api_client.py +47 -35
- google/genai/_test_api_client.py +1 -1
- google/genai/_transformers.py +301 -51
- google/genai/batches.py +204 -165
- google/genai/caches.py +127 -144
- google/genai/chats.py +22 -18
- google/genai/client.py +32 -37
- google/genai/errors.py +1 -1
- google/genai/files.py +333 -165
- google/genai/live.py +16 -6
- google/genai/models.py +601 -283
- google/genai/tunings.py +91 -428
- google/genai/types.py +1190 -955
- google/genai/version.py +1 -1
- google_genai-0.7.0.dist-info/METADATA +1021 -0
- google_genai-0.7.0.dist-info/RECORD +26 -0
- google_genai-0.5.0.dist-info/METADATA +0 -888
- google_genai-0.5.0.dist-info/RECORD +0 -25
- {google_genai-0.5.0.dist-info → google_genai-0.7.0.dist-info}/LICENSE +0 -0
- {google_genai-0.5.0.dist-info → google_genai-0.7.0.dist-info}/WHEEL +0 -0
- {google_genai-0.5.0.dist-info → google_genai-0.7.0.dist-info}/top_level.txt +0 -0
@@ -14,11 +14,18 @@
|
|
14
14
|
#
|
15
15
|
|
16
16
|
import inspect
|
17
|
-
import
|
17
|
+
import sys
|
18
|
+
import types as builtin_types
|
19
|
+
import typing
|
18
20
|
from typing import Any, Callable, Literal, Union, _GenericAlias, get_args, get_origin
|
19
21
|
import pydantic
|
20
22
|
from . import types
|
21
23
|
|
24
|
+
if sys.version_info >= (3, 10):
|
25
|
+
UnionType = builtin_types.UnionType
|
26
|
+
else:
|
27
|
+
UnionType = typing._UnionGenericAlias
|
28
|
+
|
22
29
|
_py_builtin_type_to_schema_type = {
|
23
30
|
str: 'STRING',
|
24
31
|
int: 'INTEGER',
|
@@ -58,8 +65,8 @@ def _raise_for_nullable_if_mldev(schema: types.Schema):
|
|
58
65
|
)
|
59
66
|
|
60
67
|
|
61
|
-
def _raise_if_schema_unsupported(
|
62
|
-
if not
|
68
|
+
def _raise_if_schema_unsupported(client, schema: types.Schema):
|
69
|
+
if not client.vertexai:
|
63
70
|
_raise_for_any_of_if_mldev(schema)
|
64
71
|
_raise_for_default_if_mldev(schema)
|
65
72
|
_raise_for_nullable_if_mldev(schema)
|
@@ -74,11 +81,11 @@ def _is_default_value_compatible(
|
|
74
81
|
|
75
82
|
if (
|
76
83
|
isinstance(annotation, _GenericAlias)
|
77
|
-
or isinstance(annotation,
|
78
|
-
or isinstance(annotation,
|
84
|
+
or isinstance(annotation, builtin_types.GenericAlias)
|
85
|
+
or isinstance(annotation, UnionType)
|
79
86
|
):
|
80
87
|
origin = get_origin(annotation)
|
81
|
-
if origin in (Union,
|
88
|
+
if origin in (Union, UnionType):
|
82
89
|
return any(
|
83
90
|
_is_default_value_compatible(default_value, arg)
|
84
91
|
for arg in get_args(annotation)
|
@@ -107,12 +114,13 @@ def _is_default_value_compatible(
|
|
107
114
|
return default_value in get_args(annotation)
|
108
115
|
|
109
116
|
# return False for any other unrecognized annotation
|
110
|
-
# let caller handle the raise
|
111
117
|
return False
|
112
118
|
|
113
119
|
|
114
120
|
def _parse_schema_from_parameter(
|
115
|
-
|
121
|
+
client,
|
122
|
+
param: inspect.Parameter,
|
123
|
+
func_name: str,
|
116
124
|
) -> types.Schema:
|
117
125
|
"""parse schema from parameter.
|
118
126
|
|
@@ -130,12 +138,12 @@ def _parse_schema_from_parameter(
|
|
130
138
|
raise ValueError(default_value_error_msg)
|
131
139
|
schema.default = param.default
|
132
140
|
schema.type = _py_builtin_type_to_schema_type[param.annotation]
|
133
|
-
_raise_if_schema_unsupported(
|
141
|
+
_raise_if_schema_unsupported(client, schema)
|
134
142
|
return schema
|
135
143
|
if (
|
136
|
-
isinstance(param.annotation,
|
144
|
+
isinstance(param.annotation, UnionType)
|
137
145
|
# only parse simple UnionType, example int | str | float | bool
|
138
|
-
# complex
|
146
|
+
# complex UnionType will be invoked in raise branch
|
139
147
|
and all(
|
140
148
|
(_is_builtin_primitive_or_compound(arg) or arg is type(None))
|
141
149
|
for arg in get_args(param.annotation)
|
@@ -149,7 +157,7 @@ def _parse_schema_from_parameter(
|
|
149
157
|
schema.nullable = True
|
150
158
|
continue
|
151
159
|
schema_in_any_of = _parse_schema_from_parameter(
|
152
|
-
|
160
|
+
client,
|
153
161
|
inspect.Parameter(
|
154
162
|
'item', inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=arg
|
155
163
|
),
|
@@ -171,10 +179,10 @@ def _parse_schema_from_parameter(
|
|
171
179
|
if not _is_default_value_compatible(param.default, param.annotation):
|
172
180
|
raise ValueError(default_value_error_msg)
|
173
181
|
schema.default = param.default
|
174
|
-
_raise_if_schema_unsupported(
|
182
|
+
_raise_if_schema_unsupported(client, schema)
|
175
183
|
return schema
|
176
184
|
if isinstance(param.annotation, _GenericAlias) or isinstance(
|
177
|
-
param.annotation,
|
185
|
+
param.annotation, builtin_types.GenericAlias
|
178
186
|
):
|
179
187
|
origin = get_origin(param.annotation)
|
180
188
|
args = get_args(param.annotation)
|
@@ -184,7 +192,7 @@ def _parse_schema_from_parameter(
|
|
184
192
|
if not _is_default_value_compatible(param.default, param.annotation):
|
185
193
|
raise ValueError(default_value_error_msg)
|
186
194
|
schema.default = param.default
|
187
|
-
_raise_if_schema_unsupported(
|
195
|
+
_raise_if_schema_unsupported(client, schema)
|
188
196
|
return schema
|
189
197
|
if origin is Literal:
|
190
198
|
if not all(isinstance(arg, str) for arg in args):
|
@@ -197,12 +205,12 @@ def _parse_schema_from_parameter(
|
|
197
205
|
if not _is_default_value_compatible(param.default, param.annotation):
|
198
206
|
raise ValueError(default_value_error_msg)
|
199
207
|
schema.default = param.default
|
200
|
-
_raise_if_schema_unsupported(
|
208
|
+
_raise_if_schema_unsupported(client, schema)
|
201
209
|
return schema
|
202
210
|
if origin is list:
|
203
211
|
schema.type = 'ARRAY'
|
204
212
|
schema.items = _parse_schema_from_parameter(
|
205
|
-
|
213
|
+
client,
|
206
214
|
inspect.Parameter(
|
207
215
|
'item',
|
208
216
|
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
@@ -214,7 +222,7 @@ def _parse_schema_from_parameter(
|
|
214
222
|
if not _is_default_value_compatible(param.default, param.annotation):
|
215
223
|
raise ValueError(default_value_error_msg)
|
216
224
|
schema.default = param.default
|
217
|
-
_raise_if_schema_unsupported(
|
225
|
+
_raise_if_schema_unsupported(client, schema)
|
218
226
|
return schema
|
219
227
|
if origin is Union:
|
220
228
|
schema.any_of = []
|
@@ -225,7 +233,7 @@ def _parse_schema_from_parameter(
|
|
225
233
|
schema.nullable = True
|
226
234
|
continue
|
227
235
|
schema_in_any_of = _parse_schema_from_parameter(
|
228
|
-
|
236
|
+
client,
|
229
237
|
inspect.Parameter(
|
230
238
|
'item',
|
231
239
|
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
@@ -233,6 +241,17 @@ def _parse_schema_from_parameter(
|
|
233
241
|
),
|
234
242
|
func_name,
|
235
243
|
)
|
244
|
+
if (
|
245
|
+
len(param.annotation.__args__) == 2
|
246
|
+
and type(None) in param.annotation.__args__
|
247
|
+
): # Optional type
|
248
|
+
for optional_arg in param.annotation.__args__:
|
249
|
+
if (
|
250
|
+
hasattr(optional_arg, '__origin__')
|
251
|
+
and optional_arg.__origin__ is list
|
252
|
+
):
|
253
|
+
# Optional type with list, for example Optional[list[str]]
|
254
|
+
schema.items = schema_in_any_of.items
|
236
255
|
if (
|
237
256
|
schema_in_any_of.model_dump_json(exclude_none=True)
|
238
257
|
not in unique_types
|
@@ -249,7 +268,7 @@ def _parse_schema_from_parameter(
|
|
249
268
|
if not _is_default_value_compatible(param.default, param.annotation):
|
250
269
|
raise ValueError(default_value_error_msg)
|
251
270
|
schema.default = param.default
|
252
|
-
_raise_if_schema_unsupported(
|
271
|
+
_raise_if_schema_unsupported(client, schema)
|
253
272
|
return schema
|
254
273
|
# all other generic alias will be invoked in raise branch
|
255
274
|
if (
|
@@ -266,7 +285,7 @@ def _parse_schema_from_parameter(
|
|
266
285
|
schema.properties = {}
|
267
286
|
for field_name, field_info in param.annotation.model_fields.items():
|
268
287
|
schema.properties[field_name] = _parse_schema_from_parameter(
|
269
|
-
|
288
|
+
client,
|
270
289
|
inspect.Parameter(
|
271
290
|
field_name,
|
272
291
|
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
@@ -274,7 +293,9 @@ def _parse_schema_from_parameter(
|
|
274
293
|
),
|
275
294
|
func_name,
|
276
295
|
)
|
277
|
-
|
296
|
+
if client.vertexai:
|
297
|
+
schema.required = _get_required_fields(schema)
|
298
|
+
_raise_if_schema_unsupported(client, schema)
|
278
299
|
return schema
|
279
300
|
raise ValueError(
|
280
301
|
f'Failed to parse the parameter {param} of function {func_name} for'
|
google/genai/_common.py
CHANGED
@@ -17,6 +17,7 @@
|
|
17
17
|
|
18
18
|
import base64
|
19
19
|
import datetime
|
20
|
+
import enum
|
20
21
|
import typing
|
21
22
|
from typing import Union
|
22
23
|
import uuid
|
@@ -112,12 +113,6 @@ def get_value_by_path(data: object, keys: list[str]):
|
|
112
113
|
return data
|
113
114
|
|
114
115
|
|
115
|
-
class BaseModule:
|
116
|
-
|
117
|
-
def __init__(self, api_client_: _api_client.ApiClient):
|
118
|
-
self._api_client = api_client_
|
119
|
-
|
120
|
-
|
121
116
|
def convert_to_dict(obj: dict[str, object]) -> dict[str, object]:
|
122
117
|
"""Recursively converts a given object to a dictionary.
|
123
118
|
|
@@ -144,7 +139,7 @@ def _remove_extra_fields(
|
|
144
139
|
) -> None:
|
145
140
|
"""Removes extra fields from the response that are not in the model.
|
146
141
|
|
147
|
-
|
142
|
+
Mutates the response in place.
|
148
143
|
"""
|
149
144
|
|
150
145
|
key_values = list(response.items())
|
@@ -185,7 +180,7 @@ class BaseModel(pydantic.BaseModel):
|
|
185
180
|
alias_generator=alias_generators.to_camel,
|
186
181
|
populate_by_name=True,
|
187
182
|
from_attributes=True,
|
188
|
-
protected_namespaces=
|
183
|
+
protected_namespaces=(),
|
189
184
|
extra='forbid',
|
190
185
|
# This allows us to use arbitrary types in the model. E.g. PIL.Image.
|
191
186
|
arbitrary_types_allowed=True,
|
@@ -208,6 +203,20 @@ class BaseModel(pydantic.BaseModel):
|
|
208
203
|
return self.model_dump(exclude_none=True, mode='json')
|
209
204
|
|
210
205
|
|
206
|
+
class CaseInSensitiveEnum(str, enum.Enum):
|
207
|
+
"""Case insensitive enum."""
|
208
|
+
|
209
|
+
@classmethod
|
210
|
+
def _missing_(cls, value):
|
211
|
+
try:
|
212
|
+
return cls[value.upper()] # Try to access directly with uppercase
|
213
|
+
except KeyError:
|
214
|
+
try:
|
215
|
+
return cls[value.lower()] # Try to access directly with lowercase
|
216
|
+
except KeyError as e:
|
217
|
+
raise ValueError(f"{value} is not a valid {cls.__name__}") from e
|
218
|
+
|
219
|
+
|
211
220
|
def timestamped_unique_name() -> str:
|
212
221
|
"""Composes a timestamped unique name.
|
213
222
|
|
@@ -219,23 +228,39 @@ def timestamped_unique_name() -> str:
|
|
219
228
|
return f'{timestamp}_{unique_id}'
|
220
229
|
|
221
230
|
|
222
|
-
def
|
223
|
-
"""
|
231
|
+
def encode_unserializable_types(data: dict[str, object]) -> dict[str, object]:
|
232
|
+
"""Converts unserializable types in dict to json.dumps() compatible types.
|
233
|
+
|
234
|
+
This function is called in models.py after calling convert_to_dict(). The
|
235
|
+
convert_to_dict() can convert pydantic object to dict. However, the input to
|
236
|
+
convert_to_dict() is dict mixed of pydantic object and nested dict(the output
|
237
|
+
of converters). So they may be bytes in the dict and they are out of
|
238
|
+
`ser_json_bytes` control in model_dump(mode='json') called in
|
239
|
+
`convert_to_dict`, as well as datetime deserialization in Pydantic json mode.
|
240
|
+
|
241
|
+
Returns:
|
242
|
+
A dictionary with json.dumps() incompatible type (e.g. bytes datetime)
|
243
|
+
to compatible type (e.g. base64 encoded string, isoformat date string).
|
244
|
+
"""
|
224
245
|
processed_data = {}
|
225
246
|
if not isinstance(data, dict):
|
226
247
|
return data
|
227
248
|
for key, value in data.items():
|
228
249
|
if isinstance(value, bytes):
|
229
250
|
processed_data[key] = base64.urlsafe_b64encode(value).decode('ascii')
|
251
|
+
elif isinstance(value, datetime.datetime):
|
252
|
+
processed_data[key] = value.isoformat()
|
230
253
|
elif isinstance(value, dict):
|
231
|
-
processed_data[key] =
|
254
|
+
processed_data[key] = encode_unserializable_types(value)
|
232
255
|
elif isinstance(value, list):
|
233
256
|
if all(isinstance(v, bytes) for v in value):
|
234
257
|
processed_data[key] = [
|
235
258
|
base64.urlsafe_b64encode(v).decode('ascii') for v in value
|
236
259
|
]
|
260
|
+
if all(isinstance(v, datetime.datetime) for v in value):
|
261
|
+
processed_data[key] = [v.isoformat() for v in value]
|
237
262
|
else:
|
238
|
-
processed_data[key] = [
|
263
|
+
processed_data[key] = [encode_unserializable_types(v) for v in value]
|
239
264
|
else:
|
240
265
|
processed_data[key] = value
|
241
266
|
return processed_data
|
google/genai/_extra_utils.py
CHANGED
@@ -13,12 +13,13 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
#
|
15
15
|
|
16
|
-
"""Extra utils depending on types that are shared between sync and async modules.
|
17
|
-
"""
|
16
|
+
"""Extra utils depending on types that are shared between sync and async modules."""
|
18
17
|
|
19
18
|
import inspect
|
20
19
|
import logging
|
21
|
-
|
20
|
+
import typing
|
21
|
+
from typing import Any, Callable, Dict, Optional, Union, get_args, get_origin
|
22
|
+
import sys
|
22
23
|
|
23
24
|
import pydantic
|
24
25
|
|
@@ -26,6 +27,10 @@ from . import _common
|
|
26
27
|
from . import errors
|
27
28
|
from . import types
|
28
29
|
|
30
|
+
if sys.version_info >= (3, 10):
|
31
|
+
from types import UnionType
|
32
|
+
else:
|
33
|
+
UnionType = typing._UnionGenericAlias
|
29
34
|
|
30
35
|
_DEFAULT_MAX_REMOTE_CALLS_AFC = 10
|
31
36
|
|
@@ -78,8 +83,8 @@ def get_function_map(
|
|
78
83
|
if inspect.iscoroutinefunction(tool):
|
79
84
|
raise errors.UnsupportedFunctionError(
|
80
85
|
f'Function {tool.__name__} is a coroutine function, which is not'
|
81
|
-
' supported for automatic function calling. Please manually
|
82
|
-
f' {tool.__name__} to get the function response.'
|
86
|
+
' supported for automatic function calling. Please manually'
|
87
|
+
f' invoke {tool.__name__} to get the function response.'
|
83
88
|
)
|
84
89
|
function_map[tool.__name__] = tool
|
85
90
|
return function_map
|
@@ -116,7 +121,7 @@ def convert_if_exist_pydantic_model(
|
|
116
121
|
try:
|
117
122
|
return annotation(**value)
|
118
123
|
except pydantic.ValidationError as e:
|
119
|
-
raise errors.
|
124
|
+
raise errors.UnknownFunctionCallArgumentError(
|
120
125
|
f'Failed to parse parameter {param_name} for function'
|
121
126
|
f' {func_name} from function call part because function call argument'
|
122
127
|
f' value {value} is not compatible with parameter annotation'
|
@@ -135,11 +140,13 @@ def convert_if_exist_pydantic_model(
|
|
135
140
|
for k, v in value.items()
|
136
141
|
}
|
137
142
|
# example 1: typing.Union[int, float]
|
138
|
-
# example 2: int | float equivalent to
|
139
|
-
if get_origin(annotation) in (Union,
|
143
|
+
# example 2: int | float equivalent to UnionType[int, float]
|
144
|
+
if get_origin(annotation) in (Union, UnionType):
|
140
145
|
for arg in get_args(annotation):
|
141
|
-
if
|
142
|
-
|
146
|
+
if (
|
147
|
+
(get_args(arg) and get_origin(arg) is list)
|
148
|
+
or isinstance(value, arg)
|
149
|
+
or (isinstance(value, dict) and _is_annotation_pydantic_model(arg))
|
143
150
|
):
|
144
151
|
try:
|
145
152
|
return convert_if_exist_pydantic_model(
|
@@ -150,7 +157,7 @@ def convert_if_exist_pydantic_model(
|
|
150
157
|
except pydantic.ValidationError:
|
151
158
|
continue
|
152
159
|
# if none of the union type is matched, raise error
|
153
|
-
raise errors.
|
160
|
+
raise errors.UnknownFunctionCallArgumentError(
|
154
161
|
f'Failed to parse parameter {param_name} for function'
|
155
162
|
f' {func_name} from function call part because function call argument'
|
156
163
|
f' value {value} cannot be converted to parameter annotation'
|
@@ -161,7 +168,7 @@ def convert_if_exist_pydantic_model(
|
|
161
168
|
if isinstance(value, int) and annotation is float:
|
162
169
|
return value
|
163
170
|
if not isinstance(value, annotation):
|
164
|
-
raise errors.
|
171
|
+
raise errors.UnknownFunctionCallArgumentError(
|
165
172
|
f'Failed to parse parameter {param_name} for function {func_name} from'
|
166
173
|
f' function call part because function call argument value {value} is'
|
167
174
|
f' not compatible with parameter annotation {annotation}.'
|
@@ -209,7 +216,9 @@ def get_function_response_parts(
|
|
209
216
|
response = {'result': invoke_function_from_dict_args(args, func)}
|
210
217
|
except Exception as e: # pylint: disable=broad-except
|
211
218
|
response = {'error': str(e)}
|
212
|
-
func_response = types.Part.from_function_response(
|
219
|
+
func_response = types.Part.from_function_response(
|
220
|
+
name=func_name, response=response
|
221
|
+
)
|
213
222
|
|
214
223
|
func_response_parts.append(func_response)
|
215
224
|
return func_response_parts
|
@@ -231,8 +240,7 @@ def should_disable_afc(
|
|
231
240
|
and config_model.automatic_function_calling
|
232
241
|
and config_model.automatic_function_calling.maximum_remote_calls
|
233
242
|
is not None
|
234
|
-
and int(config_model.automatic_function_calling.maximum_remote_calls)
|
235
|
-
<= 0
|
243
|
+
and int(config_model.automatic_function_calling.maximum_remote_calls) <= 0
|
236
244
|
):
|
237
245
|
logging.warning(
|
238
246
|
'max_remote_calls in automatic_function_calling_config'
|
@@ -294,6 +302,7 @@ def get_max_remote_calls_afc(
|
|
294
302
|
return _DEFAULT_MAX_REMOTE_CALLS_AFC
|
295
303
|
return int(config_model.automatic_function_calling.maximum_remote_calls)
|
296
304
|
|
305
|
+
|
297
306
|
def should_append_afc_history(
|
298
307
|
config: Optional[types.GenerateContentConfigOrDict] = None,
|
299
308
|
) -> bool:
|
@@ -302,9 +311,6 @@ def should_append_afc_history(
|
|
302
311
|
if config and isinstance(config, dict)
|
303
312
|
else config
|
304
313
|
)
|
305
|
-
if
|
306
|
-
not config_model
|
307
|
-
or not config_model.automatic_function_calling
|
308
|
-
):
|
314
|
+
if not config_model or not config_model.automatic_function_calling:
|
309
315
|
return True
|
310
316
|
return not config_model.automatic_function_calling.ignore_call_history
|
@@ -17,11 +17,12 @@
|
|
17
17
|
|
18
18
|
import base64
|
19
19
|
import copy
|
20
|
+
import datetime
|
20
21
|
import inspect
|
22
|
+
import io
|
21
23
|
import json
|
22
24
|
import os
|
23
25
|
import re
|
24
|
-
import datetime
|
25
26
|
from typing import Any, Literal, Optional, Union
|
26
27
|
|
27
28
|
import google.auth
|
@@ -32,9 +33,9 @@ from ._api_client import ApiClient
|
|
32
33
|
from ._api_client import HttpOptions
|
33
34
|
from ._api_client import HttpRequest
|
34
35
|
from ._api_client import HttpResponse
|
35
|
-
from ._api_client import RequestJsonEncoder
|
36
36
|
from ._common import BaseModel
|
37
37
|
|
38
|
+
|
38
39
|
def _redact_version_numbers(version_string: str) -> str:
|
39
40
|
"""Redacts version numbers in the form x.y.z from a string."""
|
40
41
|
return re.sub(r'\d+\.\d+\.\d+', '{VERSION_NUMBER}', version_string)
|
@@ -145,6 +146,7 @@ class ReplayResponse(BaseModel):
|
|
145
146
|
status_code: int = 200
|
146
147
|
headers: dict[str, str]
|
147
148
|
body_segments: list[dict[str, object]]
|
149
|
+
byte_segments: Optional[list[bytes]] = None
|
148
150
|
sdk_response_segments: list[dict[str, object]]
|
149
151
|
|
150
152
|
def model_post_init(self, __context: Any) -> None:
|
@@ -264,17 +266,13 @@ class ReplayApiClient(ApiClient):
|
|
264
266
|
replay_file_path = self._get_replay_file_path()
|
265
267
|
os.makedirs(os.path.dirname(replay_file_path), exist_ok=True)
|
266
268
|
with open(replay_file_path, 'w') as f:
|
267
|
-
f.write(
|
268
|
-
json.dumps(
|
269
|
-
self.replay_session.model_dump(mode='json'), indent=2, cls=ResponseJsonEncoder
|
270
|
-
)
|
271
|
-
)
|
269
|
+
f.write(self.replay_session.model_dump_json(exclude_unset=True, indent=2))
|
272
270
|
self.replay_session = None
|
273
271
|
|
274
272
|
def _record_interaction(
|
275
273
|
self,
|
276
274
|
http_request: HttpRequest,
|
277
|
-
http_response: Union[HttpResponse, errors.APIError],
|
275
|
+
http_response: Union[HttpResponse, errors.APIError, bytes],
|
278
276
|
):
|
279
277
|
if not self._should_update_replay():
|
280
278
|
return
|
@@ -289,6 +287,9 @@ class ReplayApiClient(ApiClient):
|
|
289
287
|
response = ReplayResponse(
|
290
288
|
headers=dict(http_response.headers),
|
291
289
|
body_segments=list(http_response.segments()),
|
290
|
+
byte_segments=[
|
291
|
+
seg[:100] + b'...' for seg in http_response.byte_segments()
|
292
|
+
],
|
292
293
|
status_code=http_response.status_code,
|
293
294
|
sdk_response_segments=[],
|
294
295
|
)
|
@@ -322,11 +323,7 @@ class ReplayApiClient(ApiClient):
|
|
322
323
|
# so that the comparison is fair.
|
323
324
|
_redact_request_body(request_data_copy)
|
324
325
|
|
325
|
-
|
326
|
-
# Because the expected_request_body dict never contains bytes values.
|
327
|
-
actual_request_body = [
|
328
|
-
json.loads(json.dumps(request_data_copy, cls=RequestJsonEncoder))
|
329
|
-
]
|
326
|
+
actual_request_body = [request_data_copy]
|
330
327
|
expected_request_body = interaction.request.body_segments
|
331
328
|
assert actual_request_body == expected_request_body, (
|
332
329
|
'Request body mismatch:\n'
|
@@ -349,6 +346,7 @@ class ReplayApiClient(ApiClient):
|
|
349
346
|
json.dumps(segment)
|
350
347
|
for segment in interaction.response.body_segments
|
351
348
|
],
|
349
|
+
byte_stream=interaction.response.byte_segments,
|
352
350
|
)
|
353
351
|
|
354
352
|
def _verify_response(self, response_model: BaseModel):
|
@@ -368,7 +366,9 @@ class ReplayApiClient(ApiClient):
|
|
368
366
|
response_model = response_model[0]
|
369
367
|
print('response_model: ', response_model.model_dump(exclude_none=True))
|
370
368
|
actual = response_model.model_dump(exclude_none=True, mode='json')
|
371
|
-
expected = interaction.response.sdk_response_segments[
|
369
|
+
expected = interaction.response.sdk_response_segments[
|
370
|
+
self._sdk_response_index
|
371
|
+
]
|
372
372
|
assert (
|
373
373
|
actual == expected
|
374
374
|
), f'SDK response mismatch:\nActual: {actual}\nExpected: {expected}'
|
@@ -397,15 +397,26 @@ class ReplayApiClient(ApiClient):
|
|
397
397
|
# segments since the stream has been consumed.
|
398
398
|
else:
|
399
399
|
self._record_interaction(http_request, result)
|
400
|
-
_debug_print('api mode result: %s' % result.
|
400
|
+
_debug_print('api mode result: %s' % result.json)
|
401
401
|
return result
|
402
402
|
else:
|
403
403
|
return self._build_response_from_replay(http_request)
|
404
404
|
|
405
|
-
def upload_file(self, file_path: str, upload_url: str, upload_size: int):
|
406
|
-
|
407
|
-
|
408
|
-
|
405
|
+
def upload_file(self, file_path: Union[str, io.IOBase], upload_url: str, upload_size: int):
|
406
|
+
if isinstance(file_path, io.IOBase):
|
407
|
+
offset = file_path.tell()
|
408
|
+
content = file_path.read()
|
409
|
+
file_path.seek(offset, os.SEEK_SET)
|
410
|
+
request = HttpRequest(
|
411
|
+
method='POST',
|
412
|
+
url='',
|
413
|
+
data={'bytes': base64.b64encode(content).decode('utf-8')},
|
414
|
+
headers={}
|
415
|
+
)
|
416
|
+
else:
|
417
|
+
request = HttpRequest(
|
418
|
+
method='POST', url='', data={'file_path': file_path}, headers={}
|
419
|
+
)
|
409
420
|
if self._should_call_api():
|
410
421
|
try:
|
411
422
|
result = super().upload_file(file_path, upload_url, upload_size)
|
@@ -418,20 +429,21 @@ class ReplayApiClient(ApiClient):
|
|
418
429
|
self._record_interaction(request, HttpResponse({}, [json.dumps(result)]))
|
419
430
|
return result
|
420
431
|
else:
|
421
|
-
return self._build_response_from_replay(request).
|
422
|
-
|
423
|
-
|
424
|
-
|
425
|
-
|
426
|
-
|
427
|
-
|
428
|
-
|
429
|
-
|
430
|
-
|
431
|
-
|
432
|
-
|
433
|
-
|
434
|
-
|
435
|
-
|
432
|
+
return self._build_response_from_replay(request).json
|
433
|
+
|
434
|
+
def _download_file_request(self, request):
|
435
|
+
self._initialize_replay_session_if_not_loaded()
|
436
|
+
if self._should_call_api():
|
437
|
+
try:
|
438
|
+
result = super()._download_file_request(request)
|
439
|
+
except HTTPError as e:
|
440
|
+
result = HttpResponse(
|
441
|
+
e.response.headers, [json.dumps({'reason': e.response.reason})]
|
442
|
+
)
|
443
|
+
result.status_code = e.response.status_code
|
444
|
+
raise e
|
445
|
+
self._record_interaction(request, result)
|
446
|
+
return result
|
436
447
|
else:
|
437
|
-
return
|
448
|
+
return self._build_response_from_replay(request)
|
449
|
+
|
google/genai/_test_api_client.py
CHANGED
@@ -132,7 +132,7 @@ async def test_async_request_streamed_non_blocking(
|
|
132
132
|
|
133
133
|
chunks = []
|
134
134
|
start_time = time.time()
|
135
|
-
async for chunk in api_client.async_request_streamed(
|
135
|
+
async for chunk in await api_client.async_request_streamed(
|
136
136
|
http_method, path, request_dict
|
137
137
|
):
|
138
138
|
chunks.append(chunk)
|