google-genai 1.2.0__py3-none-any.whl → 1.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -25,19 +25,23 @@ import json
25
25
  import logging
26
26
  import os
27
27
  import sys
28
- from typing import Any, Optional, Tuple, TypedDict, Union
28
+ from typing import Any, AsyncIterator, Optional, Tuple, TypedDict, Union
29
29
  from urllib.parse import urlparse, urlunparse
30
-
31
30
  import google.auth
32
31
  import google.auth.credentials
32
+ from google.auth.credentials import Credentials
33
33
  from google.auth.transport.requests import AuthorizedSession
34
+ from google.auth.transport.requests import Request
35
+ import httpx
34
36
  from pydantic import BaseModel, ConfigDict, Field, ValidationError
35
37
  import requests
36
-
38
+ from . import _common
37
39
  from . import errors
38
40
  from . import version
39
41
  from .types import HttpOptions, HttpOptionsDict, HttpOptionsOrDict
40
42
 
43
+ logger = logging.getLogger('google_genai._api_client')
44
+
41
45
 
42
46
  def _append_library_version_headers(headers: dict[str, str]) -> None:
43
47
  """Appends the telemetry header to the headers dict."""
@@ -94,6 +98,27 @@ def _join_url_path(base_url: str, path: str) -> str:
94
98
  return urlunparse(parsed_base._replace(path=base_path + '/' + path))
95
99
 
96
100
 
101
+ def _load_auth(*, project: Union[str, None]) -> tuple[Credentials, str]:
102
+ """Loads google auth credentials and project id."""
103
+ credentials, loaded_project_id = google.auth.default(
104
+ scopes=['https://www.googleapis.com/auth/cloud-platform'],
105
+ )
106
+
107
+ if not project:
108
+ project = loaded_project_id
109
+
110
+ if not project:
111
+ raise ValueError(
112
+ 'Could not resolve project using application default credentials.'
113
+ )
114
+
115
+ return credentials, project
116
+
117
+
118
+ def _refresh_auth(credentials: Credentials) -> None:
119
+ credentials.refresh(Request())
120
+
121
+
97
122
  @dataclass
98
123
  class HttpRequest:
99
124
  headers: dict[str, str]
@@ -105,15 +130,14 @@ class HttpRequest:
105
130
 
106
131
  # TODO(b/394358912): Update this class to use a SDKResponse class that can be
107
132
  # generated and used for all languages.
108
- @dataclass
109
- class BaseResponse:
110
- http_headers: dict[str, str]
133
+ class BaseResponse(_common.BaseModel):
134
+ http_headers: dict[str, str] = Field(
135
+ default=None, description='The http headers of the response.'
136
+ )
111
137
 
112
- @property
113
- def dict(self) -> dict[str, Any]:
114
- if isinstance(self, dict):
115
- return self
116
- return {'httpHeaders': self.http_headers}
138
+ json_payload: Optional[Any] = Field(
139
+ default=None, description='The json payload of the response.'
140
+ )
117
141
 
118
142
 
119
143
  class HttpResponse:
@@ -128,15 +152,15 @@ class HttpResponse:
128
152
  self.headers = headers
129
153
  self.response_stream = response_stream
130
154
  self.byte_stream = byte_stream
131
- self.segment_iterator = self.segments()
132
155
 
133
156
  # Async iterator for async streaming.
134
157
  def __aiter__(self):
158
+ self.segment_iterator = self.async_segments()
135
159
  return self
136
160
 
137
161
  async def __anext__(self):
138
162
  try:
139
- return next(self.segment_iterator)
163
+ return await self.segment_iterator.__anext__()
140
164
  except StopIteration:
141
165
  raise StopAsyncIteration
142
166
 
@@ -163,6 +187,25 @@ class HttpResponse:
163
187
  chunk = chunk[len(b'data: ') :]
164
188
  yield json.loads(str(chunk, 'utf-8'))
165
189
 
190
+ async def async_segments(self) -> AsyncIterator[Any]:
191
+ if isinstance(self.response_stream, list):
192
+ # list of objects retrieved from replay or from non-streaming API.
193
+ for chunk in self.response_stream:
194
+ yield json.loads(chunk) if chunk else {}
195
+ elif self.response_stream is None:
196
+ async for c in []:
197
+ yield c
198
+ else:
199
+ # Iterator of objects retrieved from the API.
200
+ async for chunk in self.response_stream.aiter_lines():
201
+ # This is httpx.Response.
202
+ if chunk:
203
+ # In async streaming mode, the chunk of JSON is prefixed with "data:"
204
+ # which we must strip before parsing.
205
+ if chunk.startswith('data: '):
206
+ chunk = chunk[len('data: ') :]
207
+ yield json.loads(chunk)
208
+
166
209
  def byte_segments(self):
167
210
  if isinstance(self.byte_stream, list):
168
211
  # list of objects retrieved from replay or from non-streaming API.
@@ -234,20 +277,24 @@ class ApiClient:
234
277
 
235
278
  self._credentials = credentials
236
279
  self._http_options = HttpOptionsDict()
280
+ # Initialize the lock. This lock will be used to protect access to the
281
+ # credentials. This is crucial for thread safety when multiple coroutines
282
+ # might be accessing the credentials at the same time.
283
+ self._auth_lock = asyncio.Lock()
237
284
 
238
285
  # Handle when to use Vertex AI in express mode (api key).
239
286
  # Explicit initializer arguments are already validated above.
240
287
  if self.vertexai:
241
288
  if credentials:
242
289
  # Explicit credentials take precedence over implicit api_key.
243
- logging.info(
290
+ logger.info(
244
291
  'The user provided Google Cloud credentials will take precedence'
245
292
  + ' over the API key from the environment variable.'
246
293
  )
247
294
  self.api_key = None
248
295
  elif (env_location or env_project) and api_key:
249
296
  # Explicit api_key takes precedence over implicit project/location.
250
- logging.info(
297
+ logger.info(
251
298
  'The user provided Vertex AI API key will take precedence over the'
252
299
  + ' project/location from the environment variables.'
253
300
  )
@@ -255,20 +302,22 @@ class ApiClient:
255
302
  self.location = None
256
303
  elif (project or location) and env_api_key:
257
304
  # Explicit project/location takes precedence over implicit api_key.
258
- logging.info(
305
+ logger.info(
259
306
  'The user provided project/location will take precedence over the'
260
307
  + ' Vertex AI API key from the environment variable.'
261
308
  )
262
309
  self.api_key = None
263
310
  elif (env_location or env_project) and env_api_key:
264
311
  # Implicit project/location takes precedence over implicit api_key.
265
- logging.info(
312
+ logger.info(
266
313
  'The project/location from the environment variables will take'
267
314
  + ' precedence over the API key from the environment variables.'
268
315
  )
269
316
  self.api_key = None
270
317
  if not self.project and not self.api_key:
271
- self.project = google.auth.default()[1]
318
+ credentials, self.project = _load_auth(project=None)
319
+ if not self._credentials:
320
+ self._credentials = credentials
272
321
  if not ((self.project and self.location) or self.api_key):
273
322
  raise ValueError(
274
323
  'Project and location or API key must be set when using the Vertex '
@@ -306,6 +355,32 @@ class ApiClient:
306
355
  url_parts = urlparse(self._http_options['base_url'])
307
356
  return url_parts._replace(scheme='wss').geturl()
308
357
 
358
+ async def _async_access_token(self) -> str:
359
+ """Retrieves the access token for the credentials."""
360
+ if not self._credentials:
361
+ async with self._auth_lock:
362
+ # This ensures that only one coroutine can execute the auth logic at a
363
+ # time for thread safety.
364
+ if not self._credentials:
365
+ # Double check that the credentials are not set before loading them.
366
+ self._credentials, project = await asyncio.to_thread(
367
+ _load_auth, project=self.project
368
+ )
369
+ if not self.project:
370
+ self.project = project
371
+
372
+ if self._credentials.expired or not self._credentials.token:
373
+ # Only refresh when it needs to. Default expiration is 3600 seconds.
374
+ async with self._auth_lock:
375
+ if self._credentials.expired or not self._credentials.token:
376
+ # Double check that the credentials expired before refreshing.
377
+ await asyncio.to_thread(_refresh_auth, self._credentials)
378
+
379
+ if not self._credentials.token:
380
+ raise RuntimeError('Could not resolve API token from the environment')
381
+
382
+ return self._credentials.token
383
+
309
384
  def _build_request(
310
385
  self,
311
386
  http_method: str,
@@ -370,8 +445,10 @@ class ApiClient:
370
445
  ) -> HttpResponse:
371
446
  if self.vertexai and not self.api_key:
372
447
  if not self._credentials:
373
- self._credentials, _ = google.auth.default(
374
- scopes=['https://www.googleapis.com/auth/cloud-platform'],
448
+ self._credentials, _ = _load_auth(project=self.project)
449
+ if self._credentials.quota_project_id:
450
+ http_request.headers['x-goog-user-project'] = (
451
+ self._credentials.quota_project_id
375
452
  )
376
453
  authed_session = AuthorizedSession(self._credentials)
377
454
  authed_session.stream = stream
@@ -419,21 +496,42 @@ class ApiClient:
419
496
  self, http_request: HttpRequest, stream: bool = False
420
497
  ):
421
498
  if self.vertexai:
422
- if not self._credentials:
423
- self._credentials, _ = google.auth.default(
424
- scopes=['https://www.googleapis.com/auth/cloud-platform'],
499
+ http_request.headers['Authorization'] = (
500
+ f'Bearer {await self._async_access_token()}'
501
+ )
502
+ if self._credentials.quota_project_id:
503
+ http_request.headers['x-goog-user-project'] = (
504
+ self._credentials.quota_project_id
425
505
  )
426
- return await asyncio.to_thread(
427
- self._request,
428
- http_request,
429
- stream=stream,
506
+ if stream:
507
+ httpx_request = httpx.Request(
508
+ method=http_request.method,
509
+ url=http_request.url,
510
+ data=json.dumps(http_request.data),
511
+ headers=http_request.headers,
430
512
  )
431
- else:
432
- return await asyncio.to_thread(
433
- self._request,
434
- http_request,
513
+ aclient = httpx.AsyncClient()
514
+ response = await aclient.send(
515
+ httpx_request,
435
516
  stream=stream,
436
517
  )
518
+ errors.APIError.raise_for_response(response)
519
+ return HttpResponse(
520
+ response.headers, response if stream else [response.text]
521
+ )
522
+ else:
523
+ async with httpx.AsyncClient() as aclient:
524
+ response = await aclient.request(
525
+ method=http_request.method,
526
+ url=http_request.url,
527
+ headers=http_request.headers,
528
+ data=json.dumps(http_request.data) if http_request.data else None,
529
+ timeout=http_request.timeout,
530
+ )
531
+ errors.APIError.raise_for_response(response)
532
+ return HttpResponse(
533
+ response.headers, response if stream else [response.text]
534
+ )
437
535
 
438
536
  def get_read_only_http_options(self) -> HttpOptionsDict:
439
537
  copied = HttpOptionsDict()
@@ -455,9 +553,9 @@ class ApiClient:
455
553
  response = self._request(http_request, stream=False)
456
554
  json_response = response.json
457
555
  if not json_response:
458
- base_response = BaseResponse(response.headers).dict
459
- return base_response
460
-
556
+ return BaseResponse(http_headers=response.headers).model_dump(
557
+ by_alias=True
558
+ )
461
559
  return json_response
462
560
 
463
561
  def request_streamed(
@@ -489,8 +587,7 @@ class ApiClient:
489
587
  result = await self._async_request(http_request=http_request, stream=False)
490
588
  json_response = result.json
491
589
  if not json_response:
492
- base_response = BaseResponse(result.headers).dict
493
- return base_response
590
+ return BaseResponse(http_headers=result.headers).model_dump(by_alias=True)
494
591
  return json_response
495
592
 
496
593
  async def async_request_streamed(
@@ -15,6 +15,7 @@
15
15
 
16
16
  """Utilities for the API Modules of the Google Gen AI SDK."""
17
17
 
18
+ from typing import Optional
18
19
  from . import _api_client
19
20
 
20
21
 
@@ -22,3 +23,7 @@ class BaseModule:
22
23
 
23
24
  def __init__(self, api_client_: _api_client.ApiClient):
24
25
  self._api_client = api_client_
26
+
27
+ @property
28
+ def vertexai(self) -> Optional[bool]:
29
+ return self._api_client.vertexai
google/genai/_common.py CHANGED
@@ -60,6 +60,12 @@ def set_value_by_path(data, keys, value):
60
60
  for d in data[key_name]:
61
61
  set_value_by_path(d, keys[i + 1 :], value)
62
62
  return
63
+ elif key.endswith('[0]'):
64
+ key_name = key[:-3]
65
+ if key_name not in data:
66
+ data[key_name] = [{}]
67
+ set_value_by_path(data[key_name][0], keys[i + 1 :], value)
68
+ return
63
69
 
64
70
  data = data.setdefault(key, {})
65
71
 
@@ -106,6 +112,12 @@ def get_value_by_path(data: object, keys: list[str]):
106
112
  return [get_value_by_path(d, keys[i + 1 :]) for d in data[key_name]]
107
113
  else:
108
114
  return None
115
+ elif key.endswith('[0]'):
116
+ key_name = key[:-3]
117
+ if key_name in data and data[key_name]:
118
+ return get_value_by_path(data[key_name][0], keys[i + 1 :])
119
+ else:
120
+ return None
109
121
  else:
110
122
  if key in data:
111
123
  data = data[key]
@@ -34,6 +34,8 @@ else:
34
34
 
35
35
  _DEFAULT_MAX_REMOTE_CALLS_AFC = 10
36
36
 
37
+ logger = logging.getLogger('google_genai.models')
38
+
37
39
 
38
40
  def format_destination(
39
41
  src: str,
@@ -248,7 +250,7 @@ def should_disable_afc(
248
250
  is not None
249
251
  and int(config_model.automatic_function_calling.maximum_remote_calls) <= 0
250
252
  ):
251
- logging.warning(
253
+ logger.warning(
252
254
  'max_remote_calls in automatic_function_calling_config'
253
255
  f' {config_model.automatic_function_calling.maximum_remote_calls} is'
254
256
  ' less than or equal to 0. Disabling automatic function calling.'
@@ -268,9 +270,12 @@ def should_disable_afc(
268
270
  config_model.automatic_function_calling.disable
269
271
  and config_model.automatic_function_calling.maximum_remote_calls
270
272
  is not None
273
+ # exclude the case where max_remote_calls is set to 10 by default.
274
+ and 'maximum_remote_calls'
275
+ in config_model.automatic_function_calling.model_fields_set
271
276
  and int(config_model.automatic_function_calling.maximum_remote_calls) > 0
272
277
  ):
273
- logging.warning(
278
+ logger.warning(
274
279
  '`automatic_function_calling.disable` is set to `True`. And'
275
280
  ' `automatic_function_calling.maximum_remote_calls` is a'
276
281
  ' positive number'
@@ -60,6 +60,10 @@ def _redact_request_headers(headers):
60
60
  redacted_headers[header_name] = _redact_language_label(
61
61
  _redact_version_numbers(header_value)
62
62
  )
63
+ elif header_name.lower() == 'x-goog-user-project':
64
+ continue
65
+ elif header_name.lower() == 'authorization':
66
+ continue
63
67
  else:
64
68
  redacted_headers[header_name] = header_value
65
69
  return redacted_headers
@@ -409,6 +413,34 @@ class ReplayApiClient(ApiClient):
409
413
  else:
410
414
  return self._build_response_from_replay(http_request)
411
415
 
416
+ async def _async_request(
417
+ self,
418
+ http_request: HttpRequest,
419
+ stream: bool = False,
420
+ ) -> HttpResponse:
421
+ self._initialize_replay_session_if_not_loaded()
422
+ if self._should_call_api():
423
+ _debug_print('api mode request: %s' % http_request)
424
+ try:
425
+ result = await super()._async_request(http_request, stream)
426
+ except errors.APIError as e:
427
+ self._record_interaction(http_request, e)
428
+ raise e
429
+ if stream:
430
+ result_segments = []
431
+ async for segment in result.async_segments():
432
+ result_segments.append(json.dumps(segment))
433
+ result = HttpResponse(result.headers, result_segments)
434
+ self._record_interaction(http_request, result)
435
+ # Need to return a RecordedResponse that rebuilds the response
436
+ # segments since the stream has been consumed.
437
+ else:
438
+ self._record_interaction(http_request, result)
439
+ _debug_print('api mode result: %s' % result.json)
440
+ return result
441
+ else:
442
+ return self._build_response_from_replay(http_request)
443
+
412
444
  def upload_file(self, file_path: Union[str, io.IOBase], upload_url: str, upload_size: int):
413
445
  if isinstance(file_path, io.IOBase):
414
446
  offset = file_path.tell()
@@ -453,4 +485,3 @@ class ReplayApiClient(ApiClient):
453
485
  return result
454
486
  else:
455
487
  return self._build_response_from_replay(request)
456
-
@@ -20,6 +20,7 @@ from collections.abc import Iterable, Mapping
20
20
  from enum import Enum, EnumMeta
21
21
  import inspect
22
22
  import io
23
+ import logging
23
24
  import re
24
25
  import sys
25
26
  import time
@@ -34,6 +35,8 @@ import pydantic
34
35
  from . import _api_client
35
36
  from . import types
36
37
 
38
+ logger = logging.getLogger('google_genai._transformers')
39
+
37
40
  if sys.version_info >= (3, 10):
38
41
  VersionedUnionType = typing.types.UnionType
39
42
  _UNION_TYPES = (typing.Union, typing.types.UnionType)
@@ -183,8 +186,15 @@ def t_extract_models(
183
186
  return response.get('tunedModels')
184
187
  elif response.get('publisherModels') is not None:
185
188
  return response.get('publisherModels')
189
+ elif (
190
+ response.get('httpHeaders') is not None
191
+ and response.get('jsonPayload') is None
192
+ ):
193
+ return []
186
194
  else:
187
- raise ValueError('Cannot determine the models type.')
195
+ logger.warning('Cannot determine the models type.')
196
+ logger.debug('Cannot determine the models type for response: %s', response)
197
+ return []
188
198
 
189
199
 
190
200
  def t_caches_model(api_client: _api_client.ApiClient, model: str):
@@ -254,7 +264,7 @@ def t_part(client: _api_client.ApiClient, part: PartType) -> types.Part:
254
264
  def t_parts(
255
265
  client: _api_client.ApiClient, parts: Union[list, PartType]
256
266
  ) -> list[types.Part]:
257
- if parts is None:
267
+ if not parts:
258
268
  raise ValueError('content parts are required.')
259
269
  if isinstance(parts, list):
260
270
  return [t_part(client, part) for part in parts]
google/genai/batches.py CHANGED
@@ -15,6 +15,7 @@
15
15
 
16
16
  # Code generated by the Google Gen AI SDK generator DO NOT EDIT.
17
17
 
18
+ import logging
18
19
  from typing import Optional, Union
19
20
  from urllib.parse import urlencode
20
21
  from . import _api_module
@@ -27,6 +28,8 @@ from ._common import get_value_by_path as getv
27
28
  from ._common import set_value_by_path as setv
28
29
  from .pagers import AsyncPager, Pager
29
30
 
31
+ logger = logging.getLogger('google_genai.batches')
32
+
30
33
 
31
34
  def _BatchJobSource_to_mldev(
32
35
  api_client: ApiClient,
@@ -1050,7 +1053,7 @@ class AsyncBatches(_api_module.BaseModule):
1050
1053
 
1051
1054
  .. code-block:: python
1052
1055
 
1053
- batch_job = client.batches.get(name='123456789')
1056
+ batch_job = await client.aio.batches.get(name='123456789')
1054
1057
  print(f"Batch job: {batch_job.name}, state {batch_job.state}")
1055
1058
  """
1056
1059
 
@@ -1116,7 +1119,7 @@ class AsyncBatches(_api_module.BaseModule):
1116
1119
 
1117
1120
  .. code-block:: python
1118
1121
 
1119
- client.batches.cancel(name='123456789')
1122
+ await client.aio.batches.cancel(name='123456789')
1120
1123
  """
1121
1124
 
1122
1125
  parameter_model = types._CancelBatchJobParameters(
@@ -1222,7 +1225,7 @@ class AsyncBatches(_api_module.BaseModule):
1222
1225
 
1223
1226
  .. code-block:: python
1224
1227
 
1225
- client.batches.delete(name='123456789')
1228
+ await client.aio.batches.delete(name='123456789')
1226
1229
  """
1227
1230
 
1228
1231
  parameter_model = types._DeleteBatchJobParameters(
google/genai/caches.py CHANGED
@@ -15,6 +15,7 @@
15
15
 
16
16
  # Code generated by the Google Gen AI SDK generator DO NOT EDIT.
17
17
 
18
+ import logging
18
19
  from typing import Optional, Union
19
20
  from urllib.parse import urlencode
20
21
  from . import _api_module
@@ -26,6 +27,8 @@ from ._common import get_value_by_path as getv
26
27
  from ._common import set_value_by_path as setv
27
28
  from .pagers import AsyncPager, Pager
28
29
 
30
+ logger = logging.getLogger('google_genai.caches')
31
+
29
32
 
30
33
  def _Part_to_mldev(
31
34
  api_client: ApiClient,
@@ -1183,7 +1186,7 @@ class Caches(_api_module.BaseModule):
1183
1186
  .. code-block:: python
1184
1187
 
1185
1188
  contents = ... // Initialize the content to cache.
1186
- response = await client.aio.caches.create(
1189
+ response = client.caches.create(
1187
1190
  model= ... // The publisher model id
1188
1191
  contents=contents,
1189
1192
  config={
@@ -1251,8 +1254,7 @@ class Caches(_api_module.BaseModule):
1251
1254
 
1252
1255
  .. code-block:: python
1253
1256
 
1254
- await client.aio.caches.get(name= ... ) // The server-generated resource
1255
- name.
1257
+ client.caches.get(name= ... ) // The server-generated resource name.
1256
1258
  """
1257
1259
 
1258
1260
  parameter_model = types._GetCachedContentParameters(
@@ -1314,8 +1316,7 @@ class Caches(_api_module.BaseModule):
1314
1316
 
1315
1317
  .. code-block:: python
1316
1318
 
1317
- await client.aio.caches.delete(name= ... ) // The server-generated
1318
- resource name.
1319
+ client.caches.delete(name= ... ) // The server-generated resource name.
1319
1320
  """
1320
1321
 
1321
1322
  parameter_model = types._DeleteCachedContentParameters(
@@ -1377,7 +1378,7 @@ class Caches(_api_module.BaseModule):
1377
1378
 
1378
1379
  .. code-block:: python
1379
1380
 
1380
- response = await client.aio.caches.update(
1381
+ response = client.caches.update(
1381
1382
  name= ... // The server-generated resource name.
1382
1383
  config={
1383
1384
  'ttl': '7600s',
@@ -1439,8 +1440,8 @@ class Caches(_api_module.BaseModule):
1439
1440
 
1440
1441
  .. code-block:: python
1441
1442
 
1442
- cached_contents = await client.aio.caches.list(config={'page_size': 2})
1443
- async for cached_content in cached_contents:
1443
+ cached_contents = client.caches.list(config={'page_size': 2})
1444
+ for cached_content in cached_contents:
1444
1445
  print(cached_content)
1445
1446
  """
1446
1447
 
google/genai/client.py CHANGED
@@ -27,6 +27,7 @@ from .chats import AsyncChats, Chats
27
27
  from .files import AsyncFiles, Files
28
28
  from .live import AsyncLive
29
29
  from .models import AsyncModels, Models
30
+ from .operations import AsyncOperations, Operations
30
31
  from .tunings import AsyncTunings, Tunings
31
32
 
32
33
 
@@ -42,6 +43,7 @@ class AsyncClient:
42
43
  self._batches = AsyncBatches(self._api_client)
43
44
  self._files = AsyncFiles(self._api_client)
44
45
  self._live = AsyncLive(self._api_client)
46
+ self._operations = AsyncOperations(self._api_client)
45
47
 
46
48
  @property
47
49
  def models(self) -> AsyncModels:
@@ -71,6 +73,9 @@ class AsyncClient:
71
73
  def live(self) -> AsyncLive:
72
74
  return self._live
73
75
 
76
+ @property
77
+ def operations(self) -> AsyncOperations:
78
+ return self._operations
74
79
 
75
80
  class DebugConfig(pydantic.BaseModel):
76
81
  """Configuration options that change client network behavior when testing."""
@@ -100,9 +105,9 @@ class Client:
100
105
  `api_key="your-api-key"` or by defining `GOOGLE_API_KEY="your-api-key"` as an
101
106
  environment variable
102
107
 
103
- Vertex AI API users can provide inputs argument as `vertexai=false,
108
+ Vertex AI API users can provide inputs argument as `vertexai=True,
104
109
  project="your-project-id", location="us-central1"` or by defining
105
- `GOOGLE_GENAI_USE_VERTEXAI=false`, `GOOGLE_CLOUD_PROJECT` and
110
+ `GOOGLE_GENAI_USE_VERTEXAI=true`, `GOOGLE_CLOUD_PROJECT` and
106
111
  `GOOGLE_CLOUD_LOCATION` environment variables.
107
112
 
108
113
  Attributes:
@@ -205,6 +210,7 @@ class Client:
205
210
  self._caches = Caches(self._api_client)
206
211
  self._batches = Batches(self._api_client)
207
212
  self._files = Files(self._api_client)
213
+ self._operations = Operations(self._api_client)
208
214
 
209
215
  @staticmethod
210
216
  def _get_api_client(
@@ -270,7 +276,11 @@ class Client:
270
276
  def files(self) -> Files:
271
277
  return self._files
272
278
 
279
+ @property
280
+ def operations(self) -> Operations:
281
+ return self._operations
282
+
273
283
  @property
274
284
  def vertexai(self) -> bool:
275
285
  """Returns whether the client is using the Vertex AI API."""
276
- return self._api_client.vertexai or False
286
+ return self._api_client.vertexai or False
google/genai/errors.py CHANGED
@@ -16,7 +16,8 @@
16
16
  """Error classes for the GenAI SDK."""
17
17
 
18
18
  from typing import Any, Optional, TYPE_CHECKING, Union
19
-
19
+ import httpx
20
+ import json
20
21
  import requests
21
22
 
22
23
 
@@ -34,7 +35,9 @@ class APIError(Exception):
34
35
  response: Optional[Any] = None
35
36
 
36
37
  def __init__(
37
- self, code: int, response: Union[requests.Response, 'ReplayResponse']
38
+ self,
39
+ code: int,
40
+ response: Union[requests.Response, 'ReplayResponse', httpx.Response],
38
41
  ):
39
42
  self.response = response
40
43
 
@@ -48,6 +51,18 @@ class APIError(Exception):
48
51
  'message': response.text,
49
52
  'status': response.reason,
50
53
  }
54
+ elif isinstance(response, httpx.Response):
55
+ try:
56
+ response_json = response.json()
57
+ except (json.decoder.JSONDecodeError, httpx.ResponseNotRead):
58
+ try:
59
+ message = response.text
60
+ except httpx.ResponseNotRead:
61
+ message = None
62
+ response_json = {
63
+ 'message': message,
64
+ 'status': response.reason_phrase,
65
+ }
51
66
  else:
52
67
  response_json = response.body_segments[0].get('error', {})
53
68
 
@@ -89,7 +104,7 @@ class APIError(Exception):
89
104
 
90
105
  @classmethod
91
106
  def raise_for_response(
92
- cls, response: Union[requests.Response, 'ReplayResponse']
107
+ cls, response: Union[requests.Response, 'ReplayResponse', httpx.Response]
93
108
  ):
94
109
  """Raises an error with detailed error message if the response has an error status."""
95
110
  if response.status_code == 200: