google-genai 0.4.0__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -21,6 +21,7 @@ import copy
21
21
  from dataclasses import dataclass
22
22
  import datetime
23
23
  import json
24
+ import logging
24
25
  import os
25
26
  import sys
26
27
  from typing import Any, Optional, Tuple, TypedDict, Union
@@ -60,6 +61,10 @@ class HttpOptions(BaseModel):
60
61
  default=None,
61
62
  description="""Timeout for the request in seconds.""",
62
63
  )
64
+ skip_project_and_location_in_path: bool = Field(
65
+ default=False,
66
+ description="""If set to True, the project and location will not be appended to the path.""",
67
+ )
63
68
 
64
69
 
65
70
  class HttpOptionsDict(TypedDict):
@@ -69,13 +74,14 @@ class HttpOptionsDict(TypedDict):
69
74
  """The base URL for the AI platform service endpoint."""
70
75
  api_version: Optional[str] = None
71
76
  """Specifies the version of the API to use."""
72
- headers: Optional[dict[str, Union[str, list[str]]]] = None
77
+ headers: Optional[dict[str, str]] = None
73
78
  """Additional HTTP headers to be sent with the request."""
74
79
  response_payload: Optional[dict] = None
75
80
  """If set, the response payload will be returned int the supplied dict."""
76
81
  timeout: Optional[Union[float, Tuple[float, float]]] = None
77
82
  """Timeout for the request in seconds."""
78
-
83
+ skip_project_and_location_in_path: bool = False
84
+ """If set to True, the project and location will not be appended to the path."""
79
85
 
80
86
  HttpOptionsOrDict = Union[HttpOptions, HttpOptionsDict]
81
87
 
@@ -133,7 +139,7 @@ def _join_url_path(base_url: str, path: str) -> str:
133
139
 
134
140
  @dataclass
135
141
  class HttpRequest:
136
- headers: dict[str, Union[str, list[str]]]
142
+ headers: dict[str, str]
137
143
  url: str
138
144
  method: str
139
145
  data: Union[dict[str, object], bytes]
@@ -195,9 +201,15 @@ class ApiClient:
195
201
 
196
202
  # Validate explicitly set intializer values.
197
203
  if (project or location) and api_key:
204
+ # API cannot consume both project/location and api_key.
198
205
  raise ValueError(
199
206
  'Project/location and API key are mutually exclusive in the client initializer.'
200
207
  )
208
+ elif credentials and api_key:
209
+ # API cannot consume both credentials and api_key.
210
+ raise ValueError(
211
+ 'Credentials and API key are mutually exclusive in the client initializer.'
212
+ )
201
213
 
202
214
  # Validate http_options if a dict is provided.
203
215
  if isinstance(http_options, dict):
@@ -208,26 +220,65 @@ class ApiClient:
208
220
  elif(isinstance(http_options, HttpOptions)):
209
221
  http_options = http_options.model_dump()
210
222
 
211
- self.api_key: Optional[str] = None
212
- self.project = project or os.environ.get('GOOGLE_CLOUD_PROJECT', None)
213
- self.location = location or os.environ.get('GOOGLE_CLOUD_LOCATION', None)
223
+ # Retrieve implicitly set values from the environment.
224
+ env_project = os.environ.get('GOOGLE_CLOUD_PROJECT', None)
225
+ env_location = os.environ.get('GOOGLE_CLOUD_LOCATION', None)
226
+ env_api_key = os.environ.get('GOOGLE_API_KEY', None)
227
+ self.project = project or env_project
228
+ self.location = location or env_location
229
+ self.api_key = api_key or env_api_key
230
+
214
231
  self._credentials = credentials
215
232
  self._http_options = HttpOptionsDict()
216
233
 
234
+ # Handle when to use Vertex AI in express mode (api key).
235
+ # Explicit initializer arguments are already validated above.
217
236
  if self.vertexai:
218
- if not self.project:
237
+ if credentials:
238
+ # Explicit credentials take precedence over implicit api_key.
239
+ logging.info(
240
+ 'The user provided Google Cloud credentials will take precedence'
241
+ + ' over the API key from the environment variable.'
242
+ )
243
+ self.api_key = None
244
+ elif (env_location or env_project) and api_key:
245
+ # Explicit api_key takes precedence over implicit project/location.
246
+ logging.info(
247
+ 'The user provided Vertex AI API key will take precedence over the'
248
+ + ' project/location from the environment variables.'
249
+ )
250
+ self.project = None
251
+ self.location = None
252
+ elif (project or location) and env_api_key:
253
+ # Explicit project/location takes precedence over implicit api_key.
254
+ logging.info(
255
+ 'The user provided project/location will take precedence over the'
256
+ + ' Vertex AI API key from the environment variable.'
257
+ )
258
+ self.api_key = None
259
+ elif (env_location or env_project) and env_api_key:
260
+ # Implicit project/location takes precedence over implicit api_key.
261
+ logging.info(
262
+ 'The project/location from the environment variables will take'
263
+ + ' precedence over the API key from the environment variables.'
264
+ )
265
+ self.api_key = None
266
+ if not self.project and not self.api_key:
219
267
  self.project = google.auth.default()[1]
220
- # Will change this to support EasyGCP in the future.
221
- if not self.project or not self.location:
268
+ if not (self.project or self.location) and not self.api_key:
222
269
  raise ValueError(
223
- 'Project and location must be set when using the Vertex AI API.'
270
+ 'Project/location or API key must be set when using the Vertex AI API.'
271
+ )
272
+ if self.api_key:
273
+ self._http_options['base_url'] = (
274
+ f'https://aiplatform.googleapis.com/'
275
+ )
276
+ else:
277
+ self._http_options['base_url'] = (
278
+ f'https://{self.location}-aiplatform.googleapis.com/'
224
279
  )
225
- self._http_options['base_url'] = (
226
- f'https://{self.location}-aiplatform.googleapis.com/'
227
- )
228
280
  self._http_options['api_version'] = 'v1beta1'
229
281
  else: # ML Dev API
230
- self.api_key = api_key or os.environ.get('GOOGLE_API_KEY', None)
231
282
  if not self.api_key:
232
283
  raise ValueError('API key must be set when using the Google AI API.')
233
284
  self._http_options['base_url'] = (
@@ -236,7 +287,7 @@ class ApiClient:
236
287
  self._http_options['api_version'] = 'v1beta'
237
288
  # Default options for both clients.
238
289
  self._http_options['headers'] = {'Content-Type': 'application/json'}
239
- if self.api_key:
290
+ if self.api_key and not self.vertexai:
240
291
  self._http_options['headers']['x-goog-api-key'] = self.api_key
241
292
  # Update the http options with the user provided http options.
242
293
  if http_options:
@@ -266,8 +317,18 @@ class ApiClient:
266
317
  )
267
318
  else:
268
319
  patched_http_options = self._http_options
269
- if self.vertexai and not path.startswith('projects/'):
320
+ skip_project_and_location_in_path_val = patched_http_options.get(
321
+ 'skip_project_and_location_in_path', False
322
+ )
323
+ if (
324
+ self.vertexai
325
+ and not path.startswith('projects/')
326
+ and not skip_project_and_location_in_path_val
327
+ and not self.api_key
328
+ ):
270
329
  path = f'projects/{self.project}/locations/{self.location}/' + path
330
+ elif self.vertexai and self.api_key:
331
+ path = f'{path}?key={self.api_key}'
271
332
  url = _join_url_path(
272
333
  patched_http_options['base_url'],
273
334
  patched_http_options['api_version'] + '/' + path,
@@ -285,7 +346,7 @@ class ApiClient:
285
346
  http_request: HttpRequest,
286
347
  stream: bool = False,
287
348
  ) -> HttpResponse:
288
- if self.vertexai:
349
+ if self.vertexai and not self.api_key:
289
350
  if not self._credentials:
290
351
  self._credentials, _ = google.auth.default(
291
352
  scopes=["https://www.googleapis.com/auth/cloud-platform"],
@@ -513,13 +574,12 @@ class ApiClient:
513
574
  pass
514
575
 
515
576
 
577
+ # TODO(b/389693448): Cleanup datetime hacks.
516
578
  class RequestJsonEncoder(json.JSONEncoder):
517
579
  """Encode bytes as strings without modify its content."""
518
580
 
519
581
  def default(self, o):
520
- if isinstance(o, bytes):
521
- return o.decode()
522
- elif isinstance(o, datetime.datetime):
582
+ if isinstance(o, datetime.datetime):
523
583
  # This Zulu time format is used by the Vertex AI API and the test recorder
524
584
  # Using strftime works well, but we want to align with the replay encoder.
525
585
  # o.astimezone(datetime.timezone.utc).strftime('%Y-%m-%dT%H:%M:%S.%fZ')
google/genai/_common.py CHANGED
@@ -17,7 +17,6 @@
17
17
 
18
18
  import base64
19
19
  import datetime
20
- import json
21
20
  import typing
22
21
  from typing import Union
23
22
  import uuid
@@ -116,7 +115,7 @@ def get_value_by_path(data: object, keys: list[str]):
116
115
  class BaseModule:
117
116
 
118
117
  def __init__(self, api_client_: _api_client.ApiClient):
119
- self.api_client = api_client_
118
+ self._api_client = api_client_
120
119
 
121
120
 
122
121
  def convert_to_dict(obj: dict[str, object]) -> dict[str, object]:
@@ -190,6 +189,8 @@ class BaseModel(pydantic.BaseModel):
190
189
  extra='forbid',
191
190
  # This allows us to use arbitrary types in the model. E.g. PIL.Image.
192
191
  arbitrary_types_allowed=True,
192
+ ser_json_bytes='base64',
193
+ val_json_bytes='base64',
193
194
  )
194
195
 
195
196
  @classmethod
@@ -201,7 +202,10 @@ class BaseModel(pydantic.BaseModel):
201
202
  # We will provide another mechanism to allow users to access these fields.
202
203
  _remove_extra_fields(cls, response)
203
204
  validated_response = cls.model_validate(response)
204
- return apply_base64_decoding_for_model(validated_response)
205
+ return validated_response
206
+
207
+ def to_json_dict(self) -> dict[str, object]:
208
+ return self.model_dump(exclude_none=True, mode='json')
205
209
 
206
210
 
207
211
  def timestamped_unique_name() -> str:
@@ -217,40 +221,21 @@ def timestamped_unique_name() -> str:
217
221
 
218
222
  def apply_base64_encoding(data: dict[str, object]) -> dict[str, object]:
219
223
  """Applies base64 encoding to bytes values in the given data."""
220
- return process_bytes_fields(data, encode=True)
221
-
222
-
223
- def apply_base64_decoding(data: dict[str, object]) -> dict[str, object]:
224
- """Applies base64 decoding to bytes values in the given data."""
225
- return process_bytes_fields(data, encode=False)
226
-
227
-
228
- def apply_base64_decoding_for_model(data: BaseModel) -> BaseModel:
229
- d = data.model_dump(exclude_none=True)
230
- d = apply_base64_decoding(d)
231
- return data.model_validate(d)
232
-
233
-
234
- def process_bytes_fields(data: dict[str, object], encode=True) -> dict[str, object]:
235
224
  processed_data = {}
236
225
  if not isinstance(data, dict):
237
226
  return data
238
227
  for key, value in data.items():
239
228
  if isinstance(value, bytes):
240
- if encode:
241
- processed_data[key] = base64.b64encode(value)
242
- else:
243
- processed_data[key] = base64.b64decode(value)
229
+ processed_data[key] = base64.urlsafe_b64encode(value).decode('ascii')
244
230
  elif isinstance(value, dict):
245
- processed_data[key] = process_bytes_fields(value, encode)
231
+ processed_data[key] = apply_base64_encoding(value)
246
232
  elif isinstance(value, list):
247
- if encode and all(isinstance(v, bytes) for v in value):
248
- processed_data[key] = [base64.b64encode(v) for v in value]
249
- elif all(isinstance(v, bytes) for v in value):
250
- processed_data[key] = [base64.b64decode(v) for v in value]
233
+ if all(isinstance(v, bytes) for v in value):
234
+ processed_data[key] = [
235
+ base64.urlsafe_b64encode(v).decode('ascii') for v in value
236
+ ]
251
237
  else:
252
- processed_data[key] = [process_bytes_fields(v, encode) for v in value]
238
+ processed_data[key] = [apply_base64_encoding(v) for v in value]
253
239
  else:
254
240
  processed_data[key] = value
255
241
  return processed_data
256
-
@@ -25,7 +25,6 @@ import datetime
25
25
  from typing import Any, Literal, Optional, Union
26
26
 
27
27
  import google.auth
28
- from pydantic import BaseModel
29
28
  from requests.exceptions import HTTPError
30
29
 
31
30
  from . import errors
@@ -34,6 +33,7 @@ from ._api_client import HttpOptions
34
33
  from ._api_client import HttpRequest
35
34
  from ._api_client import HttpResponse
36
35
  from ._api_client import RequestJsonEncoder
36
+ from ._common import BaseModel
37
37
 
38
38
  def _redact_version_numbers(version_string: str) -> str:
39
39
  """Redacts version numbers in the form x.y.z from a string."""
@@ -72,6 +72,11 @@ def _redact_request_url(url: str) -> str:
72
72
  '{VERTEX_URL_PREFIX}/',
73
73
  url,
74
74
  )
75
+ result = re.sub(
76
+ r'.*-aiplatform.googleapis.com/[^/]+/',
77
+ '{VERTEX_URL_PREFIX}/',
78
+ result,
79
+ )
75
80
  result = re.sub(
76
81
  r'https://generativelanguage.googleapis.com/[^/]+',
77
82
  '{MLDEV_URL_PREFIX}',
@@ -259,18 +264,9 @@ class ReplayApiClient(ApiClient):
259
264
  replay_file_path = self._get_replay_file_path()
260
265
  os.makedirs(os.path.dirname(replay_file_path), exist_ok=True)
261
266
  with open(replay_file_path, 'w') as f:
262
- replay_session_dict = self.replay_session.model_dump()
263
- # Use for non-utf-8 bytes in image/video... output.
264
- for interaction in replay_session_dict['interactions']:
265
- segments = []
266
- for response in interaction['response']['sdk_response_segments']:
267
- segments.append(json.loads(json.dumps(
268
- response, cls=ResponseJsonEncoder
269
- )))
270
- interaction['response']['sdk_response_segments'] = segments
271
267
  f.write(
272
268
  json.dumps(
273
- replay_session_dict, indent=2, cls=RequestJsonEncoder
269
+ self.replay_session.model_dump(mode='json'), indent=2, cls=ResponseJsonEncoder
274
270
  )
275
271
  )
276
272
  self.replay_session = None
@@ -371,15 +367,8 @@ class ReplayApiClient(ApiClient):
371
367
  if isinstance(response_model, list):
372
368
  response_model = response_model[0]
373
369
  print('response_model: ', response_model.model_dump(exclude_none=True))
374
- actual = json.dumps(
375
- response_model.model_dump(exclude_none=True),
376
- cls=ResponseJsonEncoder,
377
- sort_keys=True,
378
- )
379
- expected = json.dumps(
380
- interaction.response.sdk_response_segments[self._sdk_response_index],
381
- sort_keys=True,
382
- )
370
+ actual = response_model.model_dump(exclude_none=True, mode='json')
371
+ expected = interaction.response.sdk_response_segments[self._sdk_response_index]
383
372
  assert (
384
373
  actual == expected
385
374
  ), f'SDK response mismatch:\nActual: {actual}\nExpected: {expected}'
@@ -432,36 +421,12 @@ class ReplayApiClient(ApiClient):
432
421
  return self._build_response_from_replay(request).text
433
422
 
434
423
 
424
+ # TODO(b/389693448): Cleanup datetime hacks.
435
425
  class ResponseJsonEncoder(json.JSONEncoder):
436
426
  """The replay test json encoder for response.
437
-
438
- We need RequestJsonEncoder and ResponseJsonEncoder because:
439
- 1. In production, we only need RequestJsonEncoder to help json module
440
- to convert non-stringable and stringable types to json string. Especially
441
- for bytes type, the value of bytes field is encoded to base64 string so it
442
- is always stringable and the RequestJsonEncoder doesn't have to deal with
443
- utf-8 JSON broken issue.
444
- 2. In replay test, we also need ResponseJsonEncoder to help json module
445
- convert non-stringable and stringable types to json string. But response
446
- object returned from SDK method is different from the request api_client
447
- sent to server. For the bytes type, there is no base64 string in response
448
- anymore, because SDK handles it internally. So bytes type in Response is
449
- non-stringable. The ResponseJsonEncoder uses different encoding
450
- strategy than the RequestJsonEncoder to deal with utf-8 JSON broken issue.
451
427
  """
452
428
  def default(self, o):
453
- if isinstance(o, bytes):
454
- # Use base64.b64encode() to encode bytes to string so that the media bytes
455
- # fields are serializable.
456
- # o.decode(encoding='utf-8', errors='replace') doesn't work because it
457
- # uses a fixed error string `\ufffd` for all non-utf-8 characters,
458
- # which cannot be converted back to original bytes. And other languages
459
- # only have the original bytes to compare with.
460
- # Since we use base64.b64encoding() in replay test, a change that breaks
461
- # native bytes can be captured by
462
- # test_compute_tokens.py::test_token_bytes_deserialization.
463
- return base64.b64encode(o).decode(encoding='utf-8')
464
- elif isinstance(o, datetime.datetime):
429
+ if isinstance(o, datetime.datetime):
465
430
  # dt.isoformat() prints "2024-11-15T23:27:45.624657+00:00"
466
431
  # but replay files want "2024-11-15T23:27:45.624657Z"
467
432
  if o.isoformat().endswith('+00:00'):
@@ -142,6 +142,30 @@ def t_model(client: _api_client.ApiClient, model: str):
142
142
  else:
143
143
  return f'models/{model}'
144
144
 
145
+ def t_models_url(api_client: _api_client.ApiClient, base_models: bool) -> str:
146
+ if api_client.vertexai:
147
+ if base_models:
148
+ return 'publishers/google/models'
149
+ else:
150
+ return 'models'
151
+ else:
152
+ if base_models:
153
+ return 'models'
154
+ else:
155
+ return 'tunedModels'
156
+
157
+
158
+ def t_extract_models(api_client: _api_client.ApiClient, response: dict) -> list[types.Model]:
159
+ if response.get('models') is not None:
160
+ return response.get('models')
161
+ elif response.get('tunedModels') is not None:
162
+ return response.get('tunedModels')
163
+ elif response.get('publisherModels') is not None:
164
+ return response.get('publisherModels')
165
+ else:
166
+ raise ValueError('Cannot determine the models type.')
167
+
168
+
145
169
  def t_caches_model(api_client: _api_client.ApiClient, model: str):
146
170
  model = t_model(api_client, model)
147
171
  if not model:
@@ -452,3 +476,17 @@ def t_tuning_job_status(
452
476
  return 'JOB_STATE_FAILED'
453
477
  else:
454
478
  return status
479
+
480
+
481
+ # Some fields don't accept url safe base64 encoding.
482
+ # We shouldn't use this transformer if the backend adhere to Cloud Type
483
+ # format https://cloud.google.com/docs/discovery/type-format.
484
+ # TODO(b/389133914): Remove the hack after Vertex backend fix the issue.
485
+ def t_bytes(api_client: _api_client.ApiClient, data: bytes) -> str:
486
+ if not isinstance(data, bytes):
487
+ return data
488
+ if api_client.vertexai:
489
+ return base64.b64encode(data).decode('ascii')
490
+ else:
491
+ return base64.urlsafe_encode(data).decode('ascii')
492
+