google-genai 0.2.2__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google/genai/__init__.py +2 -1
- google/genai/_api_client.py +91 -38
- google/genai/_automatic_function_calling_util.py +19 -22
- google/genai/_replay_api_client.py +22 -28
- google/genai/_transformers.py +15 -0
- google/genai/batches.py +16 -16
- google/genai/caches.py +48 -46
- google/genai/chats.py +88 -15
- google/genai/client.py +6 -3
- google/genai/files.py +22 -22
- google/genai/live.py +28 -5
- google/genai/models.py +109 -77
- google/genai/tunings.py +17 -17
- google/genai/types.py +173 -90
- google/genai/version.py +16 -0
- {google_genai-0.2.2.dist-info → google_genai-0.4.0.dist-info}/METADATA +66 -18
- google_genai-0.4.0.dist-info/RECORD +25 -0
- {google_genai-0.2.2.dist-info → google_genai-0.4.0.dist-info}/WHEEL +1 -1
- google_genai-0.2.2.dist-info/RECORD +0 -24
- {google_genai-0.2.2.dist-info → google_genai-0.4.0.dist-info}/LICENSE +0 -0
- {google_genai-0.2.2.dist-info → google_genai-0.4.0.dist-info}/top_level.txt +0 -0
google/genai/__init__.py
CHANGED
google/genai/_api_client.py
CHANGED
@@ -23,35 +23,66 @@ import datetime
|
|
23
23
|
import json
|
24
24
|
import os
|
25
25
|
import sys
|
26
|
-
from typing import Any, Optional, TypedDict, Union
|
26
|
+
from typing import Any, Optional, Tuple, TypedDict, Union
|
27
27
|
from urllib.parse import urlparse, urlunparse
|
28
28
|
|
29
29
|
import google.auth
|
30
30
|
import google.auth.credentials
|
31
31
|
from google.auth.transport.requests import AuthorizedSession
|
32
|
-
from pydantic import BaseModel
|
32
|
+
from pydantic import BaseModel, ConfigDict, Field, ValidationError
|
33
33
|
import requests
|
34
34
|
|
35
35
|
from . import errors
|
36
|
+
from . import version
|
36
37
|
|
37
38
|
|
38
|
-
class HttpOptions(
|
39
|
+
class HttpOptions(BaseModel):
|
40
|
+
"""HTTP options for the api client."""
|
41
|
+
model_config = ConfigDict(extra='forbid')
|
42
|
+
|
43
|
+
base_url: Optional[str] = Field(
|
44
|
+
default=None,
|
45
|
+
description="""The base URL for the AI platform service endpoint.""",
|
46
|
+
)
|
47
|
+
api_version: Optional[str] = Field(
|
48
|
+
default=None,
|
49
|
+
description="""Specifies the version of the API to use.""",
|
50
|
+
)
|
51
|
+
headers: Optional[dict[str, str]] = Field(
|
52
|
+
default=None,
|
53
|
+
description="""Additional HTTP headers to be sent with the request.""",
|
54
|
+
)
|
55
|
+
response_payload: Optional[dict] = Field(
|
56
|
+
default=None,
|
57
|
+
description="""If set, the response payload will be returned int the supplied dict.""",
|
58
|
+
)
|
59
|
+
timeout: Optional[Union[float, Tuple[float, float]]] = Field(
|
60
|
+
default=None,
|
61
|
+
description="""Timeout for the request in seconds.""",
|
62
|
+
)
|
63
|
+
|
64
|
+
|
65
|
+
class HttpOptionsDict(TypedDict):
|
39
66
|
"""HTTP options for the api client."""
|
40
67
|
|
41
|
-
base_url: str = None
|
68
|
+
base_url: Optional[str] = None
|
42
69
|
"""The base URL for the AI platform service endpoint."""
|
43
|
-
api_version: str = None
|
70
|
+
api_version: Optional[str] = None
|
44
71
|
"""Specifies the version of the API to use."""
|
45
|
-
headers: dict[str,
|
72
|
+
headers: Optional[dict[str, Union[str, list[str]]]] = None
|
46
73
|
"""Additional HTTP headers to be sent with the request."""
|
47
|
-
response_payload: dict = None
|
74
|
+
response_payload: Optional[dict] = None
|
48
75
|
"""If set, the response payload will be returned int the supplied dict."""
|
76
|
+
timeout: Optional[Union[float, Tuple[float, float]]] = None
|
77
|
+
"""Timeout for the request in seconds."""
|
78
|
+
|
79
|
+
|
80
|
+
HttpOptionsOrDict = Union[HttpOptions, HttpOptionsDict]
|
49
81
|
|
50
82
|
|
51
83
|
def _append_library_version_headers(headers: dict[str, str]) -> None:
|
52
84
|
"""Appends the telemetry header to the headers dict."""
|
53
|
-
|
54
|
-
library_label = f'google-genai-sdk/0.2.2'
|
85
|
+
library_label = f'google-genai-sdk/{version.__version__}'
|
55
86
|
language_label = 'gl-python/' + sys.version.split()[0]
|
56
87
|
version_header_value = f'{library_label} {language_label}'
|
57
88
|
if (
|
@@ -71,20 +102,24 @@ def _append_library_version_headers(headers: dict[str, str]) -> None:
|
|
71
102
|
|
72
103
|
|
73
104
|
def _patch_http_options(
|
74
|
-
options:
|
75
|
-
) ->
|
105
|
+
options: HttpOptionsDict, patch_options: HttpOptionsDict
|
106
|
+
) -> HttpOptionsDict:
|
76
107
|
# use shallow copy so we don't override the original objects.
|
77
|
-
copy_option =
|
108
|
+
copy_option = HttpOptionsDict()
|
78
109
|
copy_option.update(options)
|
79
|
-
for
|
110
|
+
for patch_key, patch_value in patch_options.items():
|
80
111
|
# if both are dicts, update the copy.
|
81
112
|
# This is to handle cases like merging headers.
|
82
|
-
if isinstance(
|
83
|
-
|
84
|
-
|
85
|
-
copy_option[
|
86
|
-
|
87
|
-
|
113
|
+
if isinstance(patch_value, dict) and isinstance(
|
114
|
+
copy_option.get(patch_key, None), dict
|
115
|
+
):
|
116
|
+
copy_option[patch_key] = {}
|
117
|
+
copy_option[patch_key].update(
|
118
|
+
options[patch_key]
|
119
|
+
) # shallow copy from original options.
|
120
|
+
copy_option[patch_key].update(patch_value)
|
121
|
+
elif patch_value is not None: # Accept empty values.
|
122
|
+
copy_option[patch_key] = patch_value
|
88
123
|
_append_library_version_headers(copy_option['headers'])
|
89
124
|
return copy_option
|
90
125
|
|
@@ -98,10 +133,11 @@ def _join_url_path(base_url: str, path: str) -> str:
|
|
98
133
|
|
99
134
|
@dataclass
|
100
135
|
class HttpRequest:
|
101
|
-
headers: dict[str, str]
|
136
|
+
headers: dict[str, Union[str, list[str]]]
|
102
137
|
url: str
|
103
138
|
method: str
|
104
139
|
data: Union[dict[str, object], bytes]
|
140
|
+
timeout: Optional[Union[float, Tuple[float, float]]] = None
|
105
141
|
|
106
142
|
|
107
143
|
class HttpResponse:
|
@@ -147,7 +183,7 @@ class ApiClient:
|
|
147
183
|
credentials: google.auth.credentials.Credentials = None,
|
148
184
|
project: Union[str, None] = None,
|
149
185
|
location: Union[str, None] = None,
|
150
|
-
http_options:
|
186
|
+
http_options: HttpOptionsOrDict = None,
|
151
187
|
):
|
152
188
|
self.vertexai = vertexai
|
153
189
|
if self.vertexai is None:
|
@@ -163,11 +199,20 @@ class ApiClient:
|
|
163
199
|
'Project/location and API key are mutually exclusive in the client initializer.'
|
164
200
|
)
|
165
201
|
|
202
|
+
# Validate http_options if a dict is provided.
|
203
|
+
if isinstance(http_options, dict):
|
204
|
+
try:
|
205
|
+
HttpOptions.model_validate(http_options)
|
206
|
+
except ValidationError as e:
|
207
|
+
raise ValueError(f'Invalid http_options: {e}')
|
208
|
+
elif(isinstance(http_options, HttpOptions)):
|
209
|
+
http_options = http_options.model_dump()
|
210
|
+
|
166
211
|
self.api_key: Optional[str] = None
|
167
212
|
self.project = project or os.environ.get('GOOGLE_CLOUD_PROJECT', None)
|
168
213
|
self.location = location or os.environ.get('GOOGLE_CLOUD_LOCATION', None)
|
169
214
|
self._credentials = credentials
|
170
|
-
self._http_options =
|
215
|
+
self._http_options = HttpOptionsDict()
|
171
216
|
|
172
217
|
if self.vertexai:
|
173
218
|
if not self.project:
|
@@ -208,7 +253,7 @@ class ApiClient:
|
|
208
253
|
http_method: str,
|
209
254
|
path: str,
|
210
255
|
request_dict: dict[str, object],
|
211
|
-
http_options:
|
256
|
+
http_options: HttpOptionsDict = None,
|
212
257
|
) -> HttpRequest:
|
213
258
|
# Remove all special dict keys such as _url and _query.
|
214
259
|
keys_to_delete = [key for key in request_dict.keys() if key.startswith('_')]
|
@@ -232,6 +277,7 @@ class ApiClient:
|
|
232
277
|
url=url,
|
233
278
|
headers=patched_http_options['headers'],
|
234
279
|
data=request_dict,
|
280
|
+
timeout=patched_http_options.get('timeout', None),
|
235
281
|
)
|
236
282
|
|
237
283
|
def _request(
|
@@ -241,17 +287,19 @@ class ApiClient:
|
|
241
287
|
) -> HttpResponse:
|
242
288
|
if self.vertexai:
|
243
289
|
if not self._credentials:
|
244
|
-
self._credentials, _ = google.auth.default(
|
290
|
+
self._credentials, _ = google.auth.default(
|
291
|
+
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
292
|
+
)
|
245
293
|
authed_session = AuthorizedSession(self._credentials)
|
246
294
|
authed_session.stream = stream
|
247
295
|
response = authed_session.request(
|
248
296
|
http_request.method.upper(),
|
249
297
|
http_request.url,
|
250
298
|
headers=http_request.headers,
|
251
|
-
data=json.dumps(http_request.data, cls=RequestJsonEncoder)
|
252
|
-
|
253
|
-
|
254
|
-
timeout=
|
299
|
+
data=json.dumps(http_request.data, cls=RequestJsonEncoder)
|
300
|
+
if http_request.data
|
301
|
+
else None,
|
302
|
+
timeout=http_request.timeout,
|
255
303
|
)
|
256
304
|
errors.APIError.raise_for_response(response)
|
257
305
|
return HttpResponse(
|
@@ -273,13 +321,14 @@ class ApiClient:
|
|
273
321
|
data = http_request.data
|
274
322
|
|
275
323
|
http_session = requests.Session()
|
276
|
-
|
324
|
+
response = http_session.request(
|
277
325
|
method=http_request.method,
|
278
326
|
url=http_request.url,
|
279
327
|
headers=http_request.headers,
|
280
328
|
data=data,
|
281
|
-
|
282
|
-
|
329
|
+
timeout=http_request.timeout,
|
330
|
+
stream=stream,
|
331
|
+
)
|
283
332
|
errors.APIError.raise_for_response(response)
|
284
333
|
return HttpResponse(
|
285
334
|
response.headers, response if stream else [response.text]
|
@@ -290,7 +339,9 @@ class ApiClient:
|
|
290
339
|
):
|
291
340
|
if self.vertexai:
|
292
341
|
if not self._credentials:
|
293
|
-
self._credentials, _ = google.auth.default(
|
342
|
+
self._credentials, _ = google.auth.default(
|
343
|
+
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
344
|
+
)
|
294
345
|
return await asyncio.to_thread(
|
295
346
|
self._request,
|
296
347
|
http_request,
|
@@ -303,8 +354,10 @@ class ApiClient:
|
|
303
354
|
stream=stream,
|
304
355
|
)
|
305
356
|
|
306
|
-
def get_read_only_http_options(self) ->
|
307
|
-
copied =
|
357
|
+
def get_read_only_http_options(self) -> HttpOptionsDict:
|
358
|
+
copied = HttpOptionsDict()
|
359
|
+
if isinstance(self._http_options, BaseModel):
|
360
|
+
self._http_options = self._http_options.model_dump()
|
308
361
|
copied.update(self._http_options)
|
309
362
|
return copied
|
310
363
|
|
@@ -313,7 +366,7 @@ class ApiClient:
|
|
313
366
|
http_method: str,
|
314
367
|
path: str,
|
315
368
|
request_dict: dict[str, object],
|
316
|
-
http_options:
|
369
|
+
http_options: HttpOptionsDict = None,
|
317
370
|
):
|
318
371
|
http_request = self._build_request(
|
319
372
|
http_method, path, request_dict, http_options
|
@@ -328,7 +381,7 @@ class ApiClient:
|
|
328
381
|
http_method: str,
|
329
382
|
path: str,
|
330
383
|
request_dict: dict[str, object],
|
331
|
-
http_options:
|
384
|
+
http_options: HttpOptionsDict = None,
|
332
385
|
):
|
333
386
|
http_request = self._build_request(
|
334
387
|
http_method, path, request_dict, http_options
|
@@ -345,7 +398,7 @@ class ApiClient:
|
|
345
398
|
http_method: str,
|
346
399
|
path: str,
|
347
400
|
request_dict: dict[str, object],
|
348
|
-
http_options:
|
401
|
+
http_options: HttpOptionsDict = None,
|
349
402
|
) -> dict[str, object]:
|
350
403
|
http_request = self._build_request(
|
351
404
|
http_method, path, request_dict, http_options
|
@@ -361,7 +414,7 @@ class ApiClient:
|
|
361
414
|
http_method: str,
|
362
415
|
path: str,
|
363
416
|
request_dict: dict[str, object],
|
364
|
-
http_options:
|
417
|
+
http_options: HttpOptionsDict = None,
|
365
418
|
):
|
366
419
|
http_request = self._build_request(
|
367
420
|
http_method, path, request_dict, http_options
|
@@ -58,8 +58,8 @@ def _raise_for_nullable_if_mldev(schema: types.Schema):
|
|
58
58
|
)
|
59
59
|
|
60
60
|
|
61
|
-
def _raise_if_schema_unsupported(
|
62
|
-
if not
|
61
|
+
def _raise_if_schema_unsupported(variant: str, schema: types.Schema):
|
62
|
+
if not variant == 'VERTEX_AI':
|
63
63
|
_raise_for_any_of_if_mldev(schema)
|
64
64
|
_raise_for_default_if_mldev(schema)
|
65
65
|
_raise_for_nullable_if_mldev(schema)
|
@@ -112,7 +112,7 @@ def _is_default_value_compatible(
|
|
112
112
|
|
113
113
|
|
114
114
|
def _parse_schema_from_parameter(
|
115
|
-
|
115
|
+
variant: str, param: inspect.Parameter, func_name: str
|
116
116
|
) -> types.Schema:
|
117
117
|
"""parse schema from parameter.
|
118
118
|
|
@@ -130,7 +130,7 @@ def _parse_schema_from_parameter(
|
|
130
130
|
raise ValueError(default_value_error_msg)
|
131
131
|
schema.default = param.default
|
132
132
|
schema.type = _py_builtin_type_to_schema_type[param.annotation]
|
133
|
-
_raise_if_schema_unsupported(
|
133
|
+
_raise_if_schema_unsupported(variant, schema)
|
134
134
|
return schema
|
135
135
|
if (
|
136
136
|
isinstance(param.annotation, typing_types.UnionType)
|
@@ -149,7 +149,7 @@ def _parse_schema_from_parameter(
|
|
149
149
|
schema.nullable = True
|
150
150
|
continue
|
151
151
|
schema_in_any_of = _parse_schema_from_parameter(
|
152
|
-
|
152
|
+
variant,
|
153
153
|
inspect.Parameter(
|
154
154
|
'item', inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=arg
|
155
155
|
),
|
@@ -170,9 +170,8 @@ def _parse_schema_from_parameter(
|
|
170
170
|
):
|
171
171
|
if not _is_default_value_compatible(param.default, param.annotation):
|
172
172
|
raise ValueError(default_value_error_msg)
|
173
|
-
# TODO: b/379715133 - handle pydantic model default value
|
174
173
|
schema.default = param.default
|
175
|
-
_raise_if_schema_unsupported(
|
174
|
+
_raise_if_schema_unsupported(variant, schema)
|
176
175
|
return schema
|
177
176
|
if isinstance(param.annotation, _GenericAlias) or isinstance(
|
178
177
|
param.annotation, typing_types.GenericAlias
|
@@ -185,7 +184,7 @@ def _parse_schema_from_parameter(
|
|
185
184
|
if not _is_default_value_compatible(param.default, param.annotation):
|
186
185
|
raise ValueError(default_value_error_msg)
|
187
186
|
schema.default = param.default
|
188
|
-
_raise_if_schema_unsupported(
|
187
|
+
_raise_if_schema_unsupported(variant, schema)
|
189
188
|
return schema
|
190
189
|
if origin is Literal:
|
191
190
|
if not all(isinstance(arg, str) for arg in args):
|
@@ -198,12 +197,12 @@ def _parse_schema_from_parameter(
|
|
198
197
|
if not _is_default_value_compatible(param.default, param.annotation):
|
199
198
|
raise ValueError(default_value_error_msg)
|
200
199
|
schema.default = param.default
|
201
|
-
_raise_if_schema_unsupported(
|
200
|
+
_raise_if_schema_unsupported(variant, schema)
|
202
201
|
return schema
|
203
202
|
if origin is list:
|
204
203
|
schema.type = 'ARRAY'
|
205
204
|
schema.items = _parse_schema_from_parameter(
|
206
|
-
|
205
|
+
variant,
|
207
206
|
inspect.Parameter(
|
208
207
|
'item',
|
209
208
|
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
@@ -215,7 +214,7 @@ def _parse_schema_from_parameter(
|
|
215
214
|
if not _is_default_value_compatible(param.default, param.annotation):
|
216
215
|
raise ValueError(default_value_error_msg)
|
217
216
|
schema.default = param.default
|
218
|
-
_raise_if_schema_unsupported(
|
217
|
+
_raise_if_schema_unsupported(variant, schema)
|
219
218
|
return schema
|
220
219
|
if origin is Union:
|
221
220
|
schema.any_of = []
|
@@ -226,7 +225,7 @@ def _parse_schema_from_parameter(
|
|
226
225
|
schema.nullable = True
|
227
226
|
continue
|
228
227
|
schema_in_any_of = _parse_schema_from_parameter(
|
229
|
-
|
228
|
+
variant,
|
230
229
|
inspect.Parameter(
|
231
230
|
'item',
|
232
231
|
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
@@ -250,7 +249,7 @@ def _parse_schema_from_parameter(
|
|
250
249
|
if not _is_default_value_compatible(param.default, param.annotation):
|
251
250
|
raise ValueError(default_value_error_msg)
|
252
251
|
schema.default = param.default
|
253
|
-
_raise_if_schema_unsupported(
|
252
|
+
_raise_if_schema_unsupported(variant, schema)
|
254
253
|
return schema
|
255
254
|
# all other generic alias will be invoked in raise branch
|
256
255
|
if (
|
@@ -258,17 +257,16 @@ def _parse_schema_from_parameter(
|
|
258
257
|
# for user defined class, we only support pydantic model
|
259
258
|
and issubclass(param.annotation, pydantic.BaseModel)
|
260
259
|
):
|
261
|
-
if
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
)
|
260
|
+
if (
|
261
|
+
param.default is not inspect.Parameter.empty
|
262
|
+
and param.default is not None
|
263
|
+
):
|
264
|
+
schema.default = param.default
|
267
265
|
schema.type = 'OBJECT'
|
268
266
|
schema.properties = {}
|
269
267
|
for field_name, field_info in param.annotation.model_fields.items():
|
270
268
|
schema.properties[field_name] = _parse_schema_from_parameter(
|
271
|
-
|
269
|
+
variant,
|
272
270
|
inspect.Parameter(
|
273
271
|
field_name,
|
274
272
|
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
@@ -276,7 +274,7 @@ def _parse_schema_from_parameter(
|
|
276
274
|
),
|
277
275
|
func_name,
|
278
276
|
)
|
279
|
-
_raise_if_schema_unsupported(
|
277
|
+
_raise_if_schema_unsupported(variant, schema)
|
280
278
|
return schema
|
281
279
|
raise ValueError(
|
282
280
|
f'Failed to parse the parameter {param} of function {func_name} for'
|
@@ -294,4 +292,3 @@ def _get_required_fields(schema: types.Schema) -> list[str]:
|
|
294
292
|
for field_name, field_schema in schema.properties.items()
|
295
293
|
if not field_schema.nullable and field_schema.default is None
|
296
294
|
]
|
297
|
-
|
@@ -15,6 +15,7 @@
|
|
15
15
|
|
16
16
|
"""Replay API client."""
|
17
17
|
|
18
|
+
import base64
|
18
19
|
import copy
|
19
20
|
import inspect
|
20
21
|
import json
|
@@ -105,28 +106,6 @@ def redact_http_request(http_request: HttpRequest):
|
|
105
106
|
_redact_request_body(http_request.data)
|
106
107
|
|
107
108
|
|
108
|
-
def process_bytes_fields(data: dict[str, object]):
|
109
|
-
"""Converts bytes fields to strings.
|
110
|
-
|
111
|
-
This function doesn't modify the content of data dict.
|
112
|
-
"""
|
113
|
-
if not isinstance(data, dict):
|
114
|
-
return data
|
115
|
-
for key, value in data.items():
|
116
|
-
if isinstance(value, bytes):
|
117
|
-
data[key] = value.decode()
|
118
|
-
elif isinstance(value, dict):
|
119
|
-
process_bytes_fields(value)
|
120
|
-
elif isinstance(value, list):
|
121
|
-
if all(isinstance(v, bytes) for v in value):
|
122
|
-
data[key] = [v.decode() for v in value]
|
123
|
-
else:
|
124
|
-
data[key] = [process_bytes_fields(v) for v in value]
|
125
|
-
else:
|
126
|
-
data[key] = value
|
127
|
-
return data
|
128
|
-
|
129
|
-
|
130
109
|
def _current_file_path_and_line():
|
131
110
|
"""Prints the current file path and line number."""
|
132
111
|
frame = inspect.currentframe().f_back.f_back
|
@@ -185,7 +164,7 @@ class ReplayFile(BaseModel):
|
|
185
164
|
|
186
165
|
|
187
166
|
class ReplayApiClient(ApiClient):
|
188
|
-
"""For integration testing, send recorded
|
167
|
+
"""For integration testing, send recorded response or records a response."""
|
189
168
|
|
190
169
|
def __init__(
|
191
170
|
self,
|
@@ -280,9 +259,18 @@ class ReplayApiClient(ApiClient):
|
|
280
259
|
replay_file_path = self._get_replay_file_path()
|
281
260
|
os.makedirs(os.path.dirname(replay_file_path), exist_ok=True)
|
282
261
|
with open(replay_file_path, 'w') as f:
|
262
|
+
replay_session_dict = self.replay_session.model_dump()
|
263
|
+
# Use for non-utf-8 bytes in image/video... output.
|
264
|
+
for interaction in replay_session_dict['interactions']:
|
265
|
+
segments = []
|
266
|
+
for response in interaction['response']['sdk_response_segments']:
|
267
|
+
segments.append(json.loads(json.dumps(
|
268
|
+
response, cls=ResponseJsonEncoder
|
269
|
+
)))
|
270
|
+
interaction['response']['sdk_response_segments'] = segments
|
283
271
|
f.write(
|
284
272
|
json.dumps(
|
285
|
-
|
273
|
+
replay_session_dict, indent=2, cls=RequestJsonEncoder
|
286
274
|
)
|
287
275
|
)
|
288
276
|
self.replay_session = None
|
@@ -463,10 +451,16 @@ class ResponseJsonEncoder(json.JSONEncoder):
|
|
463
451
|
"""
|
464
452
|
def default(self, o):
|
465
453
|
if isinstance(o, bytes):
|
466
|
-
#
|
467
|
-
#
|
468
|
-
#
|
469
|
-
|
454
|
+
# Use base64.b64encode() to encode bytes to string so that the media bytes
|
455
|
+
# fields are serializable.
|
456
|
+
# o.decode(encoding='utf-8', errors='replace') doesn't work because it
|
457
|
+
# uses a fixed error string `\ufffd` for all non-utf-8 characters,
|
458
|
+
# which cannot be converted back to original bytes. And other languages
|
459
|
+
# only have the original bytes to compare with.
|
460
|
+
# Since we use base64.b64encoding() in replay test, a change that breaks
|
461
|
+
# native bytes can be captured by
|
462
|
+
# test_compute_tokens.py::test_token_bytes_deserialization.
|
463
|
+
return base64.b64encode(o).decode(encoding='utf-8')
|
470
464
|
elif isinstance(o, datetime.datetime):
|
471
465
|
# dt.isoformat() prints "2024-11-15T23:27:45.624657+00:00"
|
472
466
|
# but replay files want "2024-11-15T23:27:45.624657Z"
|
google/genai/_transformers.py
CHANGED
@@ -24,6 +24,7 @@ import time
|
|
24
24
|
from typing import Any, Optional, Union
|
25
25
|
|
26
26
|
import PIL.Image
|
27
|
+
import PIL.PngImagePlugin
|
27
28
|
|
28
29
|
from . import _api_client
|
29
30
|
from . import types
|
@@ -298,6 +299,20 @@ def t_speech_config(
|
|
298
299
|
prebuilt_voice_config=types.PrebuiltVoiceConfig(voice_name=origin)
|
299
300
|
)
|
300
301
|
)
|
302
|
+
if (
|
303
|
+
isinstance(origin, dict)
|
304
|
+
and 'voice_config' in origin
|
305
|
+
and 'prebuilt_voice_config' in origin['voice_config']
|
306
|
+
):
|
307
|
+
return types.SpeechConfig(
|
308
|
+
voice_config=types.VoiceConfig(
|
309
|
+
prebuilt_voice_config=types.PrebuiltVoiceConfig(
|
310
|
+
voice_name=origin['voice_config']['prebuilt_voice_config'].get(
|
311
|
+
'voice_name'
|
312
|
+
)
|
313
|
+
)
|
314
|
+
)
|
315
|
+
)
|
301
316
|
raise ValueError(f'Unsupported speechConfig type: {type(origin)}')
|
302
317
|
|
303
318
|
|
google/genai/batches.py
CHANGED
@@ -31,13 +31,13 @@ def _BatchJobSource_to_mldev(
|
|
31
31
|
parent_object: dict = None,
|
32
32
|
) -> dict:
|
33
33
|
to_object = {}
|
34
|
-
if getv(from_object, ['format']):
|
34
|
+
if getv(from_object, ['format']) is not None:
|
35
35
|
raise ValueError('format parameter is not supported in Google AI.')
|
36
36
|
|
37
|
-
if getv(from_object, ['gcs_uri']):
|
37
|
+
if getv(from_object, ['gcs_uri']) is not None:
|
38
38
|
raise ValueError('gcs_uri parameter is not supported in Google AI.')
|
39
39
|
|
40
|
-
if getv(from_object, ['bigquery_uri']):
|
40
|
+
if getv(from_object, ['bigquery_uri']) is not None:
|
41
41
|
raise ValueError('bigquery_uri parameter is not supported in Google AI.')
|
42
42
|
|
43
43
|
return to_object
|
@@ -71,13 +71,13 @@ def _BatchJobDestination_to_mldev(
|
|
71
71
|
parent_object: dict = None,
|
72
72
|
) -> dict:
|
73
73
|
to_object = {}
|
74
|
-
if getv(from_object, ['format']):
|
74
|
+
if getv(from_object, ['format']) is not None:
|
75
75
|
raise ValueError('format parameter is not supported in Google AI.')
|
76
76
|
|
77
|
-
if getv(from_object, ['gcs_uri']):
|
77
|
+
if getv(from_object, ['gcs_uri']) is not None:
|
78
78
|
raise ValueError('gcs_uri parameter is not supported in Google AI.')
|
79
79
|
|
80
|
-
if getv(from_object, ['bigquery_uri']):
|
80
|
+
if getv(from_object, ['bigquery_uri']) is not None:
|
81
81
|
raise ValueError('bigquery_uri parameter is not supported in Google AI.')
|
82
82
|
|
83
83
|
return to_object
|
@@ -121,7 +121,7 @@ def _CreateBatchJobConfig_to_mldev(
|
|
121
121
|
if getv(from_object, ['display_name']) is not None:
|
122
122
|
setv(parent_object, ['displayName'], getv(from_object, ['display_name']))
|
123
123
|
|
124
|
-
if getv(from_object, ['dest']):
|
124
|
+
if getv(from_object, ['dest']) is not None:
|
125
125
|
raise ValueError('dest parameter is not supported in Google AI.')
|
126
126
|
|
127
127
|
return to_object
|
@@ -159,10 +159,10 @@ def _CreateBatchJobParameters_to_mldev(
|
|
159
159
|
parent_object: dict = None,
|
160
160
|
) -> dict:
|
161
161
|
to_object = {}
|
162
|
-
if getv(from_object, ['model']):
|
162
|
+
if getv(from_object, ['model']) is not None:
|
163
163
|
raise ValueError('model parameter is not supported in Google AI.')
|
164
164
|
|
165
|
-
if getv(from_object, ['src']):
|
165
|
+
if getv(from_object, ['src']) is not None:
|
166
166
|
raise ValueError('src parameter is not supported in Google AI.')
|
167
167
|
|
168
168
|
if getv(from_object, ['config']) is not None:
|
@@ -243,7 +243,7 @@ def _GetBatchJobParameters_to_mldev(
|
|
243
243
|
parent_object: dict = None,
|
244
244
|
) -> dict:
|
245
245
|
to_object = {}
|
246
|
-
if getv(from_object, ['name']):
|
246
|
+
if getv(from_object, ['name']) is not None:
|
247
247
|
raise ValueError('name parameter is not supported in Google AI.')
|
248
248
|
|
249
249
|
if getv(from_object, ['config']) is not None:
|
@@ -313,7 +313,7 @@ def _CancelBatchJobParameters_to_mldev(
|
|
313
313
|
parent_object: dict = None,
|
314
314
|
) -> dict:
|
315
315
|
to_object = {}
|
316
|
-
if getv(from_object, ['name']):
|
316
|
+
if getv(from_object, ['name']) is not None:
|
317
317
|
raise ValueError('name parameter is not supported in Google AI.')
|
318
318
|
|
319
319
|
if getv(from_object, ['config']) is not None:
|
@@ -374,7 +374,7 @@ def _ListBatchJobConfig_to_mldev(
|
|
374
374
|
getv(from_object, ['page_token']),
|
375
375
|
)
|
376
376
|
|
377
|
-
if getv(from_object, ['filter']):
|
377
|
+
if getv(from_object, ['filter']) is not None:
|
378
378
|
raise ValueError('filter parameter is not supported in Google AI.')
|
379
379
|
|
380
380
|
return to_object
|
@@ -413,7 +413,7 @@ def _ListBatchJobParameters_to_mldev(
|
|
413
413
|
parent_object: dict = None,
|
414
414
|
) -> dict:
|
415
415
|
to_object = {}
|
416
|
-
if getv(from_object, ['config']):
|
416
|
+
if getv(from_object, ['config']) is not None:
|
417
417
|
raise ValueError('config parameter is not supported in Google AI.')
|
418
418
|
|
419
419
|
return to_object
|
@@ -443,7 +443,7 @@ def _DeleteBatchJobParameters_to_mldev(
|
|
443
443
|
parent_object: dict = None,
|
444
444
|
) -> dict:
|
445
445
|
to_object = {}
|
446
|
-
if getv(from_object, ['name']):
|
446
|
+
if getv(from_object, ['name']) is not None:
|
447
447
|
raise ValueError('name parameter is not supported in Google AI.')
|
448
448
|
|
449
449
|
return to_object
|
@@ -947,7 +947,7 @@ class Batches(_common.BaseModule):
|
|
947
947
|
Args:
|
948
948
|
model (str): The model to use for the batch job.
|
949
949
|
src (str): The source of the batch job. Currently supports GCS URI(-s) or
|
950
|
-
|
950
|
+
BigQuery URI. Example: "gs://path/to/input/data" or
|
951
951
|
"bq://projectId.bqDatasetId.bqTableId".
|
952
952
|
config (CreateBatchJobConfig): Optional configuration for the batch job.
|
953
953
|
|
@@ -1243,7 +1243,7 @@ class AsyncBatches(_common.BaseModule):
|
|
1243
1243
|
Args:
|
1244
1244
|
model (str): The model to use for the batch job.
|
1245
1245
|
src (str): The source of the batch job. Currently supports GCS URI(-s) or
|
1246
|
-
|
1246
|
+
BigQuery URI. Example: "gs://path/to/input/data" or
|
1247
1247
|
"bq://projectId.bqDatasetId.bqTableId".
|
1248
1248
|
config (CreateBatchJobConfig): Optional configuration for the batch job.
|
1249
1249
|
|