google-genai 1.1.0__py3-none-any.whl → 1.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google/genai/_api_client.py +152 -48
- google/genai/_api_module.py +5 -0
- google/genai/_common.py +12 -0
- google/genai/_extra_utils.py +7 -2
- google/genai/_replay_api_client.py +32 -1
- google/genai/_transformers.py +40 -10
- google/genai/batches.py +6 -3
- google/genai/caches.py +13 -10
- google/genai/client.py +22 -11
- google/genai/errors.py +18 -3
- google/genai/files.py +8 -5
- google/genai/live.py +64 -41
- google/genai/models.py +661 -87
- google/genai/{_operations.py → operations.py} +260 -20
- google/genai/tunings.py +3 -0
- google/genai/types.py +439 -7
- google/genai/version.py +1 -1
- {google_genai-1.1.0.dist-info → google_genai-1.3.0.dist-info}/METADATA +126 -15
- google_genai-1.3.0.dist-info/RECORD +27 -0
- google_genai-1.1.0.dist-info/RECORD +0 -27
- {google_genai-1.1.0.dist-info → google_genai-1.3.0.dist-info}/LICENSE +0 -0
- {google_genai-1.1.0.dist-info → google_genai-1.3.0.dist-info}/WHEEL +0 -0
- {google_genai-1.1.0.dist-info → google_genai-1.3.0.dist-info}/top_level.txt +0 -0
google/genai/_api_client.py
CHANGED
@@ -25,19 +25,23 @@ import json
|
|
25
25
|
import logging
|
26
26
|
import os
|
27
27
|
import sys
|
28
|
-
from typing import Any, Optional, Tuple, TypedDict, Union
|
28
|
+
from typing import Any, AsyncIterator, Optional, Tuple, TypedDict, Union
|
29
29
|
from urllib.parse import urlparse, urlunparse
|
30
|
-
|
31
30
|
import google.auth
|
32
31
|
import google.auth.credentials
|
32
|
+
from google.auth.credentials import Credentials
|
33
33
|
from google.auth.transport.requests import AuthorizedSession
|
34
|
+
from google.auth.transport.requests import Request
|
35
|
+
import httpx
|
34
36
|
from pydantic import BaseModel, ConfigDict, Field, ValidationError
|
35
37
|
import requests
|
36
|
-
|
38
|
+
from . import _common
|
37
39
|
from . import errors
|
38
40
|
from . import version
|
39
41
|
from .types import HttpOptions, HttpOptionsDict, HttpOptionsOrDict
|
40
42
|
|
43
|
+
logger = logging.getLogger('google_genai._api_client')
|
44
|
+
|
41
45
|
|
42
46
|
def _append_library_version_headers(headers: dict[str, str]) -> None:
|
43
47
|
"""Appends the telemetry header to the headers dict."""
|
@@ -85,11 +89,36 @@ def _patch_http_options(
|
|
85
89
|
|
86
90
|
def _join_url_path(base_url: str, path: str) -> str:
|
87
91
|
parsed_base = urlparse(base_url)
|
88
|
-
base_path =
|
92
|
+
base_path = (
|
93
|
+
parsed_base.path[:-1]
|
94
|
+
if parsed_base.path.endswith('/')
|
95
|
+
else parsed_base.path
|
96
|
+
)
|
89
97
|
path = path[1:] if path.startswith('/') else path
|
90
98
|
return urlunparse(parsed_base._replace(path=base_path + '/' + path))
|
91
99
|
|
92
100
|
|
101
|
+
def _load_auth(*, project: Union[str, None]) -> tuple[Credentials, str]:
|
102
|
+
"""Loads google auth credentials and project id."""
|
103
|
+
credentials, loaded_project_id = google.auth.default(
|
104
|
+
scopes=['https://www.googleapis.com/auth/cloud-platform'],
|
105
|
+
)
|
106
|
+
|
107
|
+
if not project:
|
108
|
+
project = loaded_project_id
|
109
|
+
|
110
|
+
if not project:
|
111
|
+
raise ValueError(
|
112
|
+
'Could not resolve project using application default credentials.'
|
113
|
+
)
|
114
|
+
|
115
|
+
return credentials, project
|
116
|
+
|
117
|
+
|
118
|
+
def _refresh_auth(credentials: Credentials) -> None:
|
119
|
+
credentials.refresh(Request())
|
120
|
+
|
121
|
+
|
93
122
|
@dataclass
|
94
123
|
class HttpRequest:
|
95
124
|
headers: dict[str, str]
|
@@ -101,15 +130,14 @@ class HttpRequest:
|
|
101
130
|
|
102
131
|
# TODO(b/394358912): Update this class to use a SDKResponse class that can be
|
103
132
|
# generated and used for all languages.
|
104
|
-
|
105
|
-
|
106
|
-
|
133
|
+
class BaseResponse(_common.BaseModel):
|
134
|
+
http_headers: dict[str, str] = Field(
|
135
|
+
default=None, description='The http headers of the response.'
|
136
|
+
)
|
107
137
|
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
return self
|
112
|
-
return {'httpHeaders': self.http_headers}
|
138
|
+
json_payload: Optional[Any] = Field(
|
139
|
+
default=None, description='The json payload of the response.'
|
140
|
+
)
|
113
141
|
|
114
142
|
|
115
143
|
class HttpResponse:
|
@@ -124,15 +152,15 @@ class HttpResponse:
|
|
124
152
|
self.headers = headers
|
125
153
|
self.response_stream = response_stream
|
126
154
|
self.byte_stream = byte_stream
|
127
|
-
self.segment_iterator = self.segments()
|
128
155
|
|
129
156
|
# Async iterator for async streaming.
|
130
157
|
def __aiter__(self):
|
158
|
+
self.segment_iterator = self.async_segments()
|
131
159
|
return self
|
132
160
|
|
133
161
|
async def __anext__(self):
|
134
162
|
try:
|
135
|
-
return
|
163
|
+
return await self.segment_iterator.__anext__()
|
136
164
|
except StopIteration:
|
137
165
|
raise StopAsyncIteration
|
138
166
|
|
@@ -159,6 +187,25 @@ class HttpResponse:
|
|
159
187
|
chunk = chunk[len(b'data: ') :]
|
160
188
|
yield json.loads(str(chunk, 'utf-8'))
|
161
189
|
|
190
|
+
async def async_segments(self) -> AsyncIterator[Any]:
|
191
|
+
if isinstance(self.response_stream, list):
|
192
|
+
# list of objects retrieved from replay or from non-streaming API.
|
193
|
+
for chunk in self.response_stream:
|
194
|
+
yield json.loads(chunk) if chunk else {}
|
195
|
+
elif self.response_stream is None:
|
196
|
+
async for c in []:
|
197
|
+
yield c
|
198
|
+
else:
|
199
|
+
# Iterator of objects retrieved from the API.
|
200
|
+
async for chunk in self.response_stream.aiter_lines():
|
201
|
+
# This is httpx.Response.
|
202
|
+
if chunk:
|
203
|
+
# In async streaming mode, the chunk of JSON is prefixed with "data:"
|
204
|
+
# which we must strip before parsing.
|
205
|
+
if chunk.startswith('data: '):
|
206
|
+
chunk = chunk[len('data: ') :]
|
207
|
+
yield json.loads(chunk)
|
208
|
+
|
162
209
|
def byte_segments(self):
|
163
210
|
if isinstance(self.byte_stream, list):
|
164
211
|
# list of objects retrieved from replay or from non-streaming API.
|
@@ -201,12 +248,14 @@ class ApiClient:
|
|
201
248
|
if (project or location) and api_key:
|
202
249
|
# API cannot consume both project/location and api_key.
|
203
250
|
raise ValueError(
|
204
|
-
'Project/location and API key are mutually exclusive in the client
|
251
|
+
'Project/location and API key are mutually exclusive in the client'
|
252
|
+
' initializer.'
|
205
253
|
)
|
206
254
|
elif credentials and api_key:
|
207
255
|
# API cannot consume both credentials and api_key.
|
208
256
|
raise ValueError(
|
209
|
-
'Credentials and API key are mutually exclusive in the client
|
257
|
+
'Credentials and API key are mutually exclusive in the client'
|
258
|
+
' initializer.'
|
210
259
|
)
|
211
260
|
|
212
261
|
# Validate http_options if a dict is provided.
|
@@ -215,7 +264,7 @@ class ApiClient:
|
|
215
264
|
HttpOptions.model_validate(http_options)
|
216
265
|
except ValidationError as e:
|
217
266
|
raise ValueError(f'Invalid http_options: {e}')
|
218
|
-
elif
|
267
|
+
elif isinstance(http_options, HttpOptions):
|
219
268
|
http_options = http_options.model_dump()
|
220
269
|
|
221
270
|
# Retrieve implicitly set values from the environment.
|
@@ -228,20 +277,24 @@ class ApiClient:
|
|
228
277
|
|
229
278
|
self._credentials = credentials
|
230
279
|
self._http_options = HttpOptionsDict()
|
280
|
+
# Initialize the lock. This lock will be used to protect access to the
|
281
|
+
# credentials. This is crucial for thread safety when multiple coroutines
|
282
|
+
# might be accessing the credentials at the same time.
|
283
|
+
self._auth_lock = asyncio.Lock()
|
231
284
|
|
232
285
|
# Handle when to use Vertex AI in express mode (api key).
|
233
286
|
# Explicit initializer arguments are already validated above.
|
234
287
|
if self.vertexai:
|
235
288
|
if credentials:
|
236
289
|
# Explicit credentials take precedence over implicit api_key.
|
237
|
-
|
290
|
+
logger.info(
|
238
291
|
'The user provided Google Cloud credentials will take precedence'
|
239
292
|
+ ' over the API key from the environment variable.'
|
240
293
|
)
|
241
294
|
self.api_key = None
|
242
295
|
elif (env_location or env_project) and api_key:
|
243
296
|
# Explicit api_key takes precedence over implicit project/location.
|
244
|
-
|
297
|
+
logger.info(
|
245
298
|
'The user provided Vertex AI API key will take precedence over the'
|
246
299
|
+ ' project/location from the environment variables.'
|
247
300
|
)
|
@@ -249,37 +302,41 @@ class ApiClient:
|
|
249
302
|
self.location = None
|
250
303
|
elif (project or location) and env_api_key:
|
251
304
|
# Explicit project/location takes precedence over implicit api_key.
|
252
|
-
|
305
|
+
logger.info(
|
253
306
|
'The user provided project/location will take precedence over the'
|
254
307
|
+ ' Vertex AI API key from the environment variable.'
|
255
308
|
)
|
256
309
|
self.api_key = None
|
257
310
|
elif (env_location or env_project) and env_api_key:
|
258
311
|
# Implicit project/location takes precedence over implicit api_key.
|
259
|
-
|
312
|
+
logger.info(
|
260
313
|
'The project/location from the environment variables will take'
|
261
314
|
+ ' precedence over the API key from the environment variables.'
|
262
315
|
)
|
263
316
|
self.api_key = None
|
264
317
|
if not self.project and not self.api_key:
|
265
|
-
self.project =
|
318
|
+
credentials, self.project = _load_auth(project=None)
|
319
|
+
if not self._credentials:
|
320
|
+
self._credentials = credentials
|
266
321
|
if not ((self.project and self.location) or self.api_key):
|
267
322
|
raise ValueError(
|
268
323
|
'Project and location or API key must be set when using the Vertex '
|
269
324
|
'AI API.'
|
270
325
|
)
|
271
326
|
if self.api_key or self.location == 'global':
|
272
|
-
self._http_options['base_url'] =
|
273
|
-
f'https://aiplatform.googleapis.com/'
|
274
|
-
)
|
327
|
+
self._http_options['base_url'] = f'https://aiplatform.googleapis.com/'
|
275
328
|
else:
|
276
329
|
self._http_options['base_url'] = (
|
277
330
|
f'https://{self.location}-aiplatform.googleapis.com/'
|
278
331
|
)
|
279
332
|
self._http_options['api_version'] = 'v1beta1'
|
280
|
-
else: #
|
333
|
+
else: # Implicit initialization or missing arguments.
|
281
334
|
if not self.api_key:
|
282
|
-
raise ValueError(
|
335
|
+
raise ValueError(
|
336
|
+
'Missing key inputs argument! To use the Google AI API,'
|
337
|
+
'provide (`api_key`) arguments. To use the Google Cloud API,'
|
338
|
+
' provide (`vertexai`, `project` & `location`) arguments.'
|
339
|
+
)
|
283
340
|
self._http_options['base_url'] = (
|
284
341
|
'https://generativelanguage.googleapis.com/'
|
285
342
|
)
|
@@ -298,6 +355,32 @@ class ApiClient:
|
|
298
355
|
url_parts = urlparse(self._http_options['base_url'])
|
299
356
|
return url_parts._replace(scheme='wss').geturl()
|
300
357
|
|
358
|
+
async def _async_access_token(self) -> str:
|
359
|
+
"""Retrieves the access token for the credentials."""
|
360
|
+
if not self._credentials:
|
361
|
+
async with self._auth_lock:
|
362
|
+
# This ensures that only one coroutine can execute the auth logic at a
|
363
|
+
# time for thread safety.
|
364
|
+
if not self._credentials:
|
365
|
+
# Double check that the credentials are not set before loading them.
|
366
|
+
self._credentials, project = await asyncio.to_thread(
|
367
|
+
_load_auth, project=self.project
|
368
|
+
)
|
369
|
+
if not self.project:
|
370
|
+
self.project = project
|
371
|
+
|
372
|
+
if self._credentials.expired or not self._credentials.token:
|
373
|
+
# Only refresh when it needs to. Default expiration is 3600 seconds.
|
374
|
+
async with self._auth_lock:
|
375
|
+
if self._credentials.expired or not self._credentials.token:
|
376
|
+
# Double check that the credentials expired before refreshing.
|
377
|
+
await asyncio.to_thread(_refresh_auth, self._credentials)
|
378
|
+
|
379
|
+
if not self._credentials.token:
|
380
|
+
raise RuntimeError('Could not resolve API token from the environment')
|
381
|
+
|
382
|
+
return self._credentials.token
|
383
|
+
|
301
384
|
def _build_request(
|
302
385
|
self,
|
303
386
|
http_method: str,
|
@@ -362,8 +445,10 @@ class ApiClient:
|
|
362
445
|
) -> HttpResponse:
|
363
446
|
if self.vertexai and not self.api_key:
|
364
447
|
if not self._credentials:
|
365
|
-
self._credentials, _ =
|
366
|
-
|
448
|
+
self._credentials, _ = _load_auth(project=self.project)
|
449
|
+
if self._credentials.quota_project_id:
|
450
|
+
http_request.headers['x-goog-user-project'] = (
|
451
|
+
self._credentials.quota_project_id
|
367
452
|
)
|
368
453
|
authed_session = AuthorizedSession(self._credentials)
|
369
454
|
authed_session.stream = stream
|
@@ -371,9 +456,7 @@ class ApiClient:
|
|
371
456
|
http_request.method.upper(),
|
372
457
|
http_request.url,
|
373
458
|
headers=http_request.headers,
|
374
|
-
data=json.dumps(http_request.data)
|
375
|
-
if http_request.data
|
376
|
-
else None,
|
459
|
+
data=json.dumps(http_request.data) if http_request.data else None,
|
377
460
|
timeout=http_request.timeout,
|
378
461
|
)
|
379
462
|
errors.APIError.raise_for_response(response)
|
@@ -413,21 +496,42 @@ class ApiClient:
|
|
413
496
|
self, http_request: HttpRequest, stream: bool = False
|
414
497
|
):
|
415
498
|
if self.vertexai:
|
416
|
-
|
417
|
-
|
418
|
-
|
499
|
+
http_request.headers['Authorization'] = (
|
500
|
+
f'Bearer {await self._async_access_token()}'
|
501
|
+
)
|
502
|
+
if self._credentials.quota_project_id:
|
503
|
+
http_request.headers['x-goog-user-project'] = (
|
504
|
+
self._credentials.quota_project_id
|
419
505
|
)
|
420
|
-
|
421
|
-
|
422
|
-
http_request,
|
423
|
-
|
506
|
+
if stream:
|
507
|
+
httpx_request = httpx.Request(
|
508
|
+
method=http_request.method,
|
509
|
+
url=http_request.url,
|
510
|
+
data=json.dumps(http_request.data),
|
511
|
+
headers=http_request.headers,
|
424
512
|
)
|
425
|
-
|
426
|
-
|
427
|
-
|
428
|
-
http_request,
|
513
|
+
aclient = httpx.AsyncClient()
|
514
|
+
response = await aclient.send(
|
515
|
+
httpx_request,
|
429
516
|
stream=stream,
|
430
517
|
)
|
518
|
+
errors.APIError.raise_for_response(response)
|
519
|
+
return HttpResponse(
|
520
|
+
response.headers, response if stream else [response.text]
|
521
|
+
)
|
522
|
+
else:
|
523
|
+
async with httpx.AsyncClient() as aclient:
|
524
|
+
response = await aclient.request(
|
525
|
+
method=http_request.method,
|
526
|
+
url=http_request.url,
|
527
|
+
headers=http_request.headers,
|
528
|
+
data=json.dumps(http_request.data) if http_request.data else None,
|
529
|
+
timeout=http_request.timeout,
|
530
|
+
)
|
531
|
+
errors.APIError.raise_for_response(response)
|
532
|
+
return HttpResponse(
|
533
|
+
response.headers, response if stream else [response.text]
|
534
|
+
)
|
431
535
|
|
432
536
|
def get_read_only_http_options(self) -> HttpOptionsDict:
|
433
537
|
copied = HttpOptionsDict()
|
@@ -449,9 +553,9 @@ class ApiClient:
|
|
449
553
|
response = self._request(http_request, stream=False)
|
450
554
|
json_response = response.json
|
451
555
|
if not json_response:
|
452
|
-
|
453
|
-
|
454
|
-
|
556
|
+
return BaseResponse(http_headers=response.headers).model_dump(
|
557
|
+
by_alias=True
|
558
|
+
)
|
455
559
|
return json_response
|
456
560
|
|
457
561
|
def request_streamed(
|
@@ -483,8 +587,7 @@ class ApiClient:
|
|
483
587
|
result = await self._async_request(http_request=http_request, stream=False)
|
484
588
|
json_response = result.json
|
485
589
|
if not json_response:
|
486
|
-
|
487
|
-
return base_response
|
590
|
+
return BaseResponse(http_headers=result.headers).model_dump(by_alias=True)
|
488
591
|
return json_response
|
489
592
|
|
490
593
|
async def async_request_streamed(
|
@@ -503,6 +606,7 @@ class ApiClient:
|
|
503
606
|
async def async_generator():
|
504
607
|
async for chunk in response:
|
505
608
|
yield chunk
|
609
|
+
|
506
610
|
return async_generator()
|
507
611
|
|
508
612
|
def upload_file(
|
google/genai/_api_module.py
CHANGED
@@ -15,6 +15,7 @@
|
|
15
15
|
|
16
16
|
"""Utilities for the API Modules of the Google Gen AI SDK."""
|
17
17
|
|
18
|
+
from typing import Optional
|
18
19
|
from . import _api_client
|
19
20
|
|
20
21
|
|
@@ -22,3 +23,7 @@ class BaseModule:
|
|
22
23
|
|
23
24
|
def __init__(self, api_client_: _api_client.ApiClient):
|
24
25
|
self._api_client = api_client_
|
26
|
+
|
27
|
+
@property
|
28
|
+
def vertexai(self) -> Optional[bool]:
|
29
|
+
return self._api_client.vertexai
|
google/genai/_common.py
CHANGED
@@ -60,6 +60,12 @@ def set_value_by_path(data, keys, value):
|
|
60
60
|
for d in data[key_name]:
|
61
61
|
set_value_by_path(d, keys[i + 1 :], value)
|
62
62
|
return
|
63
|
+
elif key.endswith('[0]'):
|
64
|
+
key_name = key[:-3]
|
65
|
+
if key_name not in data:
|
66
|
+
data[key_name] = [{}]
|
67
|
+
set_value_by_path(data[key_name][0], keys[i + 1 :], value)
|
68
|
+
return
|
63
69
|
|
64
70
|
data = data.setdefault(key, {})
|
65
71
|
|
@@ -106,6 +112,12 @@ def get_value_by_path(data: object, keys: list[str]):
|
|
106
112
|
return [get_value_by_path(d, keys[i + 1 :]) for d in data[key_name]]
|
107
113
|
else:
|
108
114
|
return None
|
115
|
+
elif key.endswith('[0]'):
|
116
|
+
key_name = key[:-3]
|
117
|
+
if key_name in data and data[key_name]:
|
118
|
+
return get_value_by_path(data[key_name][0], keys[i + 1 :])
|
119
|
+
else:
|
120
|
+
return None
|
109
121
|
else:
|
110
122
|
if key in data:
|
111
123
|
data = data[key]
|
google/genai/_extra_utils.py
CHANGED
@@ -34,6 +34,8 @@ else:
|
|
34
34
|
|
35
35
|
_DEFAULT_MAX_REMOTE_CALLS_AFC = 10
|
36
36
|
|
37
|
+
logger = logging.getLogger('google_genai.models')
|
38
|
+
|
37
39
|
|
38
40
|
def format_destination(
|
39
41
|
src: str,
|
@@ -248,7 +250,7 @@ def should_disable_afc(
|
|
248
250
|
is not None
|
249
251
|
and int(config_model.automatic_function_calling.maximum_remote_calls) <= 0
|
250
252
|
):
|
251
|
-
|
253
|
+
logger.warning(
|
252
254
|
'max_remote_calls in automatic_function_calling_config'
|
253
255
|
f' {config_model.automatic_function_calling.maximum_remote_calls} is'
|
254
256
|
' less than or equal to 0. Disabling automatic function calling.'
|
@@ -268,9 +270,12 @@ def should_disable_afc(
|
|
268
270
|
config_model.automatic_function_calling.disable
|
269
271
|
and config_model.automatic_function_calling.maximum_remote_calls
|
270
272
|
is not None
|
273
|
+
# exclude the case where max_remote_calls is set to 10 by default.
|
274
|
+
and 'maximum_remote_calls'
|
275
|
+
in config_model.automatic_function_calling.model_fields_set
|
271
276
|
and int(config_model.automatic_function_calling.maximum_remote_calls) > 0
|
272
277
|
):
|
273
|
-
|
278
|
+
logger.warning(
|
274
279
|
'`automatic_function_calling.disable` is set to `True`. And'
|
275
280
|
' `automatic_function_calling.maximum_remote_calls` is a'
|
276
281
|
' positive number'
|
@@ -60,6 +60,10 @@ def _redact_request_headers(headers):
|
|
60
60
|
redacted_headers[header_name] = _redact_language_label(
|
61
61
|
_redact_version_numbers(header_value)
|
62
62
|
)
|
63
|
+
elif header_name.lower() == 'x-goog-user-project':
|
64
|
+
continue
|
65
|
+
elif header_name.lower() == 'authorization':
|
66
|
+
continue
|
63
67
|
else:
|
64
68
|
redacted_headers[header_name] = header_value
|
65
69
|
return redacted_headers
|
@@ -409,6 +413,34 @@ class ReplayApiClient(ApiClient):
|
|
409
413
|
else:
|
410
414
|
return self._build_response_from_replay(http_request)
|
411
415
|
|
416
|
+
async def _async_request(
|
417
|
+
self,
|
418
|
+
http_request: HttpRequest,
|
419
|
+
stream: bool = False,
|
420
|
+
) -> HttpResponse:
|
421
|
+
self._initialize_replay_session_if_not_loaded()
|
422
|
+
if self._should_call_api():
|
423
|
+
_debug_print('api mode request: %s' % http_request)
|
424
|
+
try:
|
425
|
+
result = await super()._async_request(http_request, stream)
|
426
|
+
except errors.APIError as e:
|
427
|
+
self._record_interaction(http_request, e)
|
428
|
+
raise e
|
429
|
+
if stream:
|
430
|
+
result_segments = []
|
431
|
+
async for segment in result.async_segments():
|
432
|
+
result_segments.append(json.dumps(segment))
|
433
|
+
result = HttpResponse(result.headers, result_segments)
|
434
|
+
self._record_interaction(http_request, result)
|
435
|
+
# Need to return a RecordedResponse that rebuilds the response
|
436
|
+
# segments since the stream has been consumed.
|
437
|
+
else:
|
438
|
+
self._record_interaction(http_request, result)
|
439
|
+
_debug_print('api mode result: %s' % result.json)
|
440
|
+
return result
|
441
|
+
else:
|
442
|
+
return self._build_response_from_replay(http_request)
|
443
|
+
|
412
444
|
def upload_file(self, file_path: Union[str, io.IOBase], upload_url: str, upload_size: int):
|
413
445
|
if isinstance(file_path, io.IOBase):
|
414
446
|
offset = file_path.tell()
|
@@ -453,4 +485,3 @@ class ReplayApiClient(ApiClient):
|
|
453
485
|
return result
|
454
486
|
else:
|
455
487
|
return self._build_response_from_replay(request)
|
456
|
-
|
google/genai/_transformers.py
CHANGED
@@ -20,11 +20,12 @@ from collections.abc import Iterable, Mapping
|
|
20
20
|
from enum import Enum, EnumMeta
|
21
21
|
import inspect
|
22
22
|
import io
|
23
|
+
import logging
|
23
24
|
import re
|
25
|
+
import sys
|
24
26
|
import time
|
25
27
|
import typing
|
26
28
|
from typing import Any, GenericAlias, Optional, Union
|
27
|
-
import sys
|
28
29
|
|
29
30
|
if typing.TYPE_CHECKING:
|
30
31
|
import PIL.Image
|
@@ -34,6 +35,8 @@ import pydantic
|
|
34
35
|
from . import _api_client
|
35
36
|
from . import types
|
36
37
|
|
38
|
+
logger = logging.getLogger('google_genai._transformers')
|
39
|
+
|
37
40
|
if sys.version_info >= (3, 10):
|
38
41
|
VersionedUnionType = typing.types.UnionType
|
39
42
|
_UNION_TYPES = (typing.Union, typing.types.UnionType)
|
@@ -183,8 +186,15 @@ def t_extract_models(
|
|
183
186
|
return response.get('tunedModels')
|
184
187
|
elif response.get('publisherModels') is not None:
|
185
188
|
return response.get('publisherModels')
|
189
|
+
elif (
|
190
|
+
response.get('httpHeaders') is not None
|
191
|
+
and response.get('jsonPayload') is None
|
192
|
+
):
|
193
|
+
return []
|
186
194
|
else:
|
187
|
-
|
195
|
+
logger.warning('Cannot determine the models type.')
|
196
|
+
logger.debug('Cannot determine the models type for response: %s', response)
|
197
|
+
return []
|
188
198
|
|
189
199
|
|
190
200
|
def t_caches_model(api_client: _api_client.ApiClient, model: str):
|
@@ -205,12 +215,17 @@ def t_caches_model(api_client: _api_client.ApiClient, model: str):
|
|
205
215
|
def pil_to_blob(img) -> types.Blob:
|
206
216
|
try:
|
207
217
|
import PIL.PngImagePlugin
|
218
|
+
|
208
219
|
PngImagePlugin = PIL.PngImagePlugin
|
209
220
|
except ImportError:
|
210
221
|
PngImagePlugin = None
|
211
222
|
|
212
223
|
bytesio = io.BytesIO()
|
213
|
-
if
|
224
|
+
if (
|
225
|
+
PngImagePlugin is not None
|
226
|
+
and isinstance(img, PngImagePlugin.PngImageFile)
|
227
|
+
or img.mode == 'RGBA'
|
228
|
+
):
|
214
229
|
img.save(bytesio, format='PNG')
|
215
230
|
mime_type = 'image/png'
|
216
231
|
else:
|
@@ -249,7 +264,7 @@ def t_part(client: _api_client.ApiClient, part: PartType) -> types.Part:
|
|
249
264
|
def t_parts(
|
250
265
|
client: _api_client.ApiClient, parts: Union[list, PartType]
|
251
266
|
) -> list[types.Part]:
|
252
|
-
if parts
|
267
|
+
if not parts:
|
253
268
|
raise ValueError('content parts are required.')
|
254
269
|
if isinstance(parts, list):
|
255
270
|
return [t_part(client, part) for part in parts]
|
@@ -382,7 +397,10 @@ def handle_null_fields(schema: dict[str, Any]):
|
|
382
397
|
def process_schema(
|
383
398
|
schema: dict[str, Any],
|
384
399
|
client: Optional[_api_client.ApiClient] = None,
|
385
|
-
defs: Optional[dict[str, Any]]=None
|
400
|
+
defs: Optional[dict[str, Any]] = None,
|
401
|
+
*,
|
402
|
+
order_properties: bool = True,
|
403
|
+
):
|
386
404
|
"""Updates the schema and each sub-schema inplace to be API-compatible.
|
387
405
|
|
388
406
|
- Removes the `title` field from the schema if the client is not vertexai.
|
@@ -517,6 +535,16 @@ def process_schema(
|
|
517
535
|
ref = defs[ref_key.split('defs/')[-1]]
|
518
536
|
process_schema(ref, client, defs)
|
519
537
|
properties[name] = ref
|
538
|
+
if (
|
539
|
+
len(properties.items()) > 1
|
540
|
+
and order_properties
|
541
|
+
and all(
|
542
|
+
ordering_key not in schema
|
543
|
+
for ordering_key in ['property_ordering', 'propertyOrdering']
|
544
|
+
)
|
545
|
+
):
|
546
|
+
property_names = list(properties.keys())
|
547
|
+
schema['property_ordering'] = property_names
|
520
548
|
elif schema_type == 'ARRAY':
|
521
549
|
sub_schema = schema.get('items', None)
|
522
550
|
if sub_schema is None:
|
@@ -539,6 +567,7 @@ def _process_enum(
|
|
539
567
|
f'Enum member {member.name} value must be a string, got'
|
540
568
|
f' {type(member.value)}'
|
541
569
|
)
|
570
|
+
|
542
571
|
class Placeholder(pydantic.BaseModel):
|
543
572
|
placeholder: enum
|
544
573
|
|
@@ -554,7 +583,7 @@ def t_schema(
|
|
554
583
|
if not origin:
|
555
584
|
return None
|
556
585
|
if isinstance(origin, dict):
|
557
|
-
process_schema(origin, client)
|
586
|
+
process_schema(origin, client, order_properties=False)
|
558
587
|
return types.Schema.model_validate(origin)
|
559
588
|
if isinstance(origin, EnumMeta):
|
560
589
|
return _process_enum(origin, client)
|
@@ -563,15 +592,15 @@ def t_schema(
|
|
563
592
|
# response_schema value was coerced to an empty Schema instance because it did not adhere to the Schema field annotation
|
564
593
|
raise ValueError(f'Unsupported schema type.')
|
565
594
|
schema = origin.model_dump(exclude_unset=True)
|
566
|
-
process_schema(schema, client)
|
595
|
+
process_schema(schema, client, order_properties=False)
|
567
596
|
return types.Schema.model_validate(schema)
|
568
597
|
|
569
598
|
if (
|
570
599
|
# in Python 3.9 Generic alias list[int] counts as a type,
|
571
600
|
# and breaks issubclass because it's not a class.
|
572
|
-
not isinstance(origin, GenericAlias)
|
573
|
-
isinstance(origin, type)
|
574
|
-
issubclass(origin, pydantic.BaseModel)
|
601
|
+
not isinstance(origin, GenericAlias)
|
602
|
+
and isinstance(origin, type)
|
603
|
+
and issubclass(origin, pydantic.BaseModel)
|
575
604
|
):
|
576
605
|
schema = origin.model_json_schema()
|
577
606
|
process_schema(schema, client)
|
@@ -582,6 +611,7 @@ def t_schema(
|
|
582
611
|
or isinstance(origin, VersionedUnionType)
|
583
612
|
or typing.get_origin(origin) in _UNION_TYPES
|
584
613
|
):
|
614
|
+
|
585
615
|
class Placeholder(pydantic.BaseModel):
|
586
616
|
placeholder: origin
|
587
617
|
|
google/genai/batches.py
CHANGED
@@ -15,6 +15,7 @@
|
|
15
15
|
|
16
16
|
# Code generated by the Google Gen AI SDK generator DO NOT EDIT.
|
17
17
|
|
18
|
+
import logging
|
18
19
|
from typing import Optional, Union
|
19
20
|
from urllib.parse import urlencode
|
20
21
|
from . import _api_module
|
@@ -27,6 +28,8 @@ from ._common import get_value_by_path as getv
|
|
27
28
|
from ._common import set_value_by_path as setv
|
28
29
|
from .pagers import AsyncPager, Pager
|
29
30
|
|
31
|
+
logger = logging.getLogger('google_genai.batches')
|
32
|
+
|
30
33
|
|
31
34
|
def _BatchJobSource_to_mldev(
|
32
35
|
api_client: ApiClient,
|
@@ -1050,7 +1053,7 @@ class AsyncBatches(_api_module.BaseModule):
|
|
1050
1053
|
|
1051
1054
|
.. code-block:: python
|
1052
1055
|
|
1053
|
-
batch_job = client.batches.get(name='123456789')
|
1056
|
+
batch_job = await client.aio.batches.get(name='123456789')
|
1054
1057
|
print(f"Batch job: {batch_job.name}, state {batch_job.state}")
|
1055
1058
|
"""
|
1056
1059
|
|
@@ -1116,7 +1119,7 @@ class AsyncBatches(_api_module.BaseModule):
|
|
1116
1119
|
|
1117
1120
|
.. code-block:: python
|
1118
1121
|
|
1119
|
-
client.batches.cancel(name='123456789')
|
1122
|
+
await client.aio.batches.cancel(name='123456789')
|
1120
1123
|
"""
|
1121
1124
|
|
1122
1125
|
parameter_model = types._CancelBatchJobParameters(
|
@@ -1222,7 +1225,7 @@ class AsyncBatches(_api_module.BaseModule):
|
|
1222
1225
|
|
1223
1226
|
.. code-block:: python
|
1224
1227
|
|
1225
|
-
client.batches.delete(name='123456789')
|
1228
|
+
await client.aio.batches.delete(name='123456789')
|
1226
1229
|
"""
|
1227
1230
|
|
1228
1231
|
parameter_model = types._DeleteBatchJobParameters(
|