google-genai 0.5.0__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -20,6 +20,7 @@ import asyncio
20
20
  import copy
21
21
  from dataclasses import dataclass
22
22
  import datetime
23
+ import io
23
24
  import json
24
25
  import logging
25
26
  import os
@@ -148,10 +149,16 @@ class HttpRequest:
148
149
 
149
150
  class HttpResponse:
150
151
 
151
- def __init__(self, headers: dict[str, str], response_stream: Union[Any, str]):
152
+ def __init__(
153
+ self,
154
+ headers: dict[str, str],
155
+ response_stream: Union[Any, str] = None,
156
+ byte_stream: Union[Any, bytes] = None,
157
+ ):
152
158
  self.status_code = 200
153
159
  self.headers = headers
154
160
  self.response_stream = response_stream
161
+ self.byte_stream = byte_stream
155
162
 
156
163
  @property
157
164
  def text(self) -> str:
@@ -164,6 +171,8 @@ class HttpResponse:
164
171
  # list of objects retrieved from replay or from non-streaming API.
165
172
  for chunk in self.response_stream:
166
173
  yield json.loads(chunk) if chunk else {}
174
+ elif self.response_stream is None:
175
+ yield from []
167
176
  else:
168
177
  # Iterator of objects retrieved from the API.
169
178
  for chunk in self.response_stream.iter_lines():
@@ -174,6 +183,17 @@ class HttpResponse:
174
183
  chunk = chunk[len(b'data: ') :]
175
184
  yield json.loads(str(chunk, 'utf-8'))
176
185
 
186
+ def byte_segments(self):
187
+ if isinstance(self.byte_stream, list):
188
+ # list of objects retrieved from replay or from non-streaming API.
189
+ yield from self.byte_stream
190
+ elif self.byte_stream is None:
191
+ yield from []
192
+ else:
193
+ raise ValueError(
194
+ 'Byte segments are not supported for streaming responses.'
195
+ )
196
+
177
197
  def copy_to_dict(self, response_payload: dict[str, object]):
178
198
  for attribute in dir(self):
179
199
  response_payload[attribute] = copy.deepcopy(getattr(self, attribute))
@@ -199,7 +219,7 @@ class ApiClient:
199
219
  ]:
200
220
  self.vertexai = True
201
221
 
202
- # Validate explicitly set intializer values.
222
+ # Validate explicitly set initializer values.
203
223
  if (project or location) and api_key:
204
224
  # API cannot consume both project/location and api_key.
205
225
  raise ValueError(
@@ -265,9 +285,10 @@ class ApiClient:
265
285
  self.api_key = None
266
286
  if not self.project and not self.api_key:
267
287
  self.project = google.auth.default()[1]
268
- if not (self.project or self.location) and not self.api_key:
288
+ if not ((self.project and self.location) or self.api_key):
269
289
  raise ValueError(
270
- 'Project/location or API key must be set when using the Vertex AI API.'
290
+ 'Project and location or API key must be set when using the Vertex '
291
+ 'AI API.'
271
292
  )
272
293
  if self.api_key:
273
294
  self._http_options['base_url'] = (
@@ -357,7 +378,7 @@ class ApiClient:
357
378
  http_request.method.upper(),
358
379
  http_request.url,
359
380
  headers=http_request.headers,
360
- data=json.dumps(http_request.data, cls=RequestJsonEncoder)
381
+ data=json.dumps(http_request.data)
361
382
  if http_request.data
362
383
  else None,
363
384
  timeout=http_request.timeout,
@@ -377,7 +398,7 @@ class ApiClient:
377
398
  data = None
378
399
  if http_request.data:
379
400
  if not isinstance(http_request.data, bytes):
380
- data = json.dumps(http_request.data, cls=RequestJsonEncoder)
401
+ data = json.dumps(http_request.data)
381
402
  else:
382
403
  data = http_request.data
383
404
 
@@ -488,12 +509,35 @@ class ApiClient:
488
509
  if http_options and 'response_payload' in http_options:
489
510
  response.copy_to_dict(http_options['response_payload'])
490
511
 
491
- def upload_file(self, file_path: str, upload_url: str, upload_size: int):
512
+ def upload_file(
513
+ self, file_path: Union[str, io.IOBase], upload_url: str, upload_size: int
514
+ ) -> str:
492
515
  """Transfers a file to the given URL.
493
516
 
494
517
  Args:
495
- file_path: The full path to the file. If the local file path is not found,
496
- an error will be raised.
518
+ file_path: The full path to the file or a file like object inherited from
519
+ io.BytesIO. If the local file path is not found, an error will be
520
+ raised.
521
+ upload_url: The URL to upload the file to.
522
+ upload_size: The size of file content to be uploaded, this will have to
523
+ match the size requested in the resumable upload request.
524
+
525
+ returns:
526
+ The response json object from the finalize request.
527
+ """
528
+ if isinstance(file_path, io.IOBase):
529
+ return self._upload_fd(file_path, upload_url, upload_size)
530
+ else:
531
+ with open(file_path, 'rb') as file:
532
+ return self._upload_fd(file, upload_url, upload_size)
533
+
534
+ def _upload_fd(
535
+ self, file: io.IOBase, upload_url: str, upload_size: int
536
+ ) -> str:
537
+ """Transfers a file to the given URL.
538
+
539
+ Args:
540
+ file: A file like object inherited from io.BytesIO.
497
541
  upload_url: The URL to upload the file to.
498
542
  upload_size: The size of file content to be uploaded, this will have to
499
543
  match the size requested in the resumable upload request.
@@ -503,37 +547,36 @@ class ApiClient:
503
547
  """
504
548
  offset = 0
505
549
  # Upload the file in chunks
506
- with open(file_path, 'rb') as file:
507
- while True:
508
- file_chunk = file.read(1024 * 1024 * 8) # 8 MB chunk size
509
- chunk_size = 0
510
- if file_chunk:
511
- chunk_size = len(file_chunk)
512
- upload_command = 'upload'
513
- # If last chunk, finalize the upload.
514
- if chunk_size + offset >= upload_size:
515
- upload_command += ', finalize'
516
-
517
- request = HttpRequest(
518
- method='POST',
519
- url=upload_url,
520
- headers={
521
- 'X-Goog-Upload-Command': upload_command,
522
- 'X-Goog-Upload-Offset': str(offset),
523
- 'Content-Length': str(chunk_size),
524
- },
525
- data=file_chunk,
526
- )
527
- response = self._request_unauthorized(request, stream=False)
528
- offset += chunk_size
529
- if response.headers['X-Goog-Upload-Status'] != 'active':
530
- break # upload is complete or it has been interrupted.
550
+ while True:
551
+ file_chunk = file.read(1024 * 1024 * 8) # 8 MB chunk size
552
+ chunk_size = 0
553
+ if file_chunk:
554
+ chunk_size = len(file_chunk)
555
+ upload_command = 'upload'
556
+ # If last chunk, finalize the upload.
557
+ if chunk_size + offset >= upload_size:
558
+ upload_command += ', finalize'
559
+ request = HttpRequest(
560
+ method='POST',
561
+ url=upload_url,
562
+ headers={
563
+ 'X-Goog-Upload-Command': upload_command,
564
+ 'X-Goog-Upload-Offset': str(offset),
565
+ 'Content-Length': str(chunk_size),
566
+ },
567
+ data=file_chunk,
568
+ )
569
+
570
+ response = self._request_unauthorized(request, stream=False)
571
+ offset += chunk_size
572
+ if response.headers['X-Goog-Upload-Status'] != 'active':
573
+ break # upload is complete or it has been interrupted.
531
574
 
532
- if upload_size <= offset: # Status is not finalized.
533
- raise ValueError(
534
- 'All content has been uploaded, but the upload status is not'
535
- f' finalized. {response.headers}, body: {response.text}'
536
- )
575
+ if upload_size <= offset: # Status is not finalized.
576
+ raise ValueError(
577
+ 'All content has been uploaded, but the upload status is not'
578
+ f' finalized. {response.headers}, body: {response.text}'
579
+ )
537
580
 
538
581
  if response.headers['X-Goog-Upload-Status'] != 'final':
539
582
  raise ValueError(
@@ -542,12 +585,52 @@ class ApiClient:
542
585
  )
543
586
  return response.text
544
587
 
588
+ def download_file(self, path: str, http_options):
589
+ """Downloads the file data.
590
+
591
+ Args:
592
+ path: The request path with query params.
593
+ http_options: The http options to use for the request.
594
+
595
+ returns:
596
+ The file bytes
597
+ """
598
+ http_request = self._build_request(
599
+ 'get', path=path, request_dict={}, http_options=http_options
600
+ )
601
+ return self._download_file_request(http_request).byte_stream[0]
602
+
603
+ def _download_file_request(
604
+ self,
605
+ http_request: HttpRequest,
606
+ ) -> HttpResponse:
607
+ data = None
608
+ if http_request.data:
609
+ if not isinstance(http_request.data, bytes):
610
+ data = json.dumps(http_request.data, cls=RequestJsonEncoder)
611
+ else:
612
+ data = http_request.data
613
+
614
+ http_session = requests.Session()
615
+ response = http_session.request(
616
+ method=http_request.method,
617
+ url=http_request.url,
618
+ headers=http_request.headers,
619
+ data=data,
620
+ timeout=http_request.timeout,
621
+ stream=False,
622
+ )
623
+
624
+ errors.APIError.raise_for_response(response)
625
+ return HttpResponse(response.headers, byte_stream=[response.content])
626
+
627
+
545
628
  async def async_upload_file(
546
629
  self,
547
- file_path: str,
630
+ file_path: Union[str, io.IOBase],
548
631
  upload_url: str,
549
632
  upload_size: int,
550
- ):
633
+ ) -> str:
551
634
  """Transfers a file asynchronously to the given URL.
552
635
 
553
636
  Args:
@@ -567,22 +650,48 @@ class ApiClient:
567
650
  upload_size,
568
651
  )
569
652
 
653
+ async def _async_upload_fd(
654
+ self,
655
+ file: io.IOBase,
656
+ upload_url: str,
657
+ upload_size: int,
658
+ ) -> str:
659
+ """Transfers a file asynchronously to the given URL.
660
+
661
+ Args:
662
+ file: A file like object inherited from io.BytesIO.
663
+ upload_url: The URL to upload the file to.
664
+ upload_size: The size of file content to be uploaded, this will have to
665
+ match the size requested in the resumable upload request.
666
+
667
+ returns:
668
+ The response json object from the finalize request.
669
+ """
670
+ return await asyncio.to_thread(
671
+ self._upload_fd,
672
+ file,
673
+ upload_url,
674
+ upload_size,
675
+ )
676
+
677
+ async def async_download_file(self, path: str, http_options):
678
+ """Downloads the file data.
679
+
680
+ Args:
681
+ path: The request path with query params.
682
+ http_options: The http options to use for the request.
683
+
684
+ returns:
685
+ The file bytes
686
+ """
687
+ return await asyncio.to_thread(
688
+ self.download_file,
689
+ path,
690
+ http_options,
691
+ )
692
+
570
693
  # This method does nothing in the real api client. It is used in the
571
694
  # replay_api_client to verify the response from the SDK method matches the
572
695
  # recorded response.
573
696
  def _verify_response(self, response_model: BaseModel):
574
697
  pass
575
-
576
-
577
- # TODO(b/389693448): Cleanup datetime hacks.
578
- class RequestJsonEncoder(json.JSONEncoder):
579
- """Encode bytes as strings without modify its content."""
580
-
581
- def default(self, o):
582
- if isinstance(o, datetime.datetime):
583
- # This Zulu time format is used by the Vertex AI API and the test recorder
584
- # Using strftime works well, but we want to align with the replay encoder.
585
- # o.astimezone(datetime.timezone.utc).strftime('%Y-%m-%dT%H:%M:%S.%fZ')
586
- return o.isoformat().replace('+00:00', 'Z')
587
- else:
588
- return super().default(o)
google/genai/_common.py CHANGED
@@ -17,6 +17,7 @@
17
17
 
18
18
  import base64
19
19
  import datetime
20
+ import enum
20
21
  import typing
21
22
  from typing import Union
22
23
  import uuid
@@ -144,7 +145,7 @@ def _remove_extra_fields(
144
145
  ) -> None:
145
146
  """Removes extra fields from the response that are not in the model.
146
147
 
147
- Muates the response in place.
148
+ Mutates the response in place.
148
149
  """
149
150
 
150
151
  key_values = list(response.items())
@@ -185,7 +186,7 @@ class BaseModel(pydantic.BaseModel):
185
186
  alias_generator=alias_generators.to_camel,
186
187
  populate_by_name=True,
187
188
  from_attributes=True,
188
- protected_namespaces={},
189
+ protected_namespaces=(),
189
190
  extra='forbid',
190
191
  # This allows us to use arbitrary types in the model. E.g. PIL.Image.
191
192
  arbitrary_types_allowed=True,
@@ -208,6 +209,20 @@ class BaseModel(pydantic.BaseModel):
208
209
  return self.model_dump(exclude_none=True, mode='json')
209
210
 
210
211
 
212
+ class CaseInSensitiveEnum(str, enum.Enum):
213
+ """Case insensitive enum."""
214
+
215
+ @classmethod
216
+ def _missing_(cls, value):
217
+ try:
218
+ return cls[value.upper()] # Try to access directly with uppercase
219
+ except KeyError:
220
+ try:
221
+ return cls[value.lower()] # Try to access directly with lowercase
222
+ except KeyError as e:
223
+ raise ValueError(f"{value} is not a valid {cls.__name__}") from e
224
+
225
+
211
226
  def timestamped_unique_name() -> str:
212
227
  """Composes a timestamped unique name.
213
228
 
@@ -219,23 +234,39 @@ def timestamped_unique_name() -> str:
219
234
  return f'{timestamp}_{unique_id}'
220
235
 
221
236
 
222
- def apply_base64_encoding(data: dict[str, object]) -> dict[str, object]:
223
- """Applies base64 encoding to bytes values in the given data."""
237
+ def encode_unserializable_types(data: dict[str, object]) -> dict[str, object]:
238
+ """Converts unserializable types in dict to json.dumps() compatible types.
239
+
240
+ This function is called in models.py after calling convert_to_dict(). The
241
+ convert_to_dict() can convert pydantic object to dict. However, the input to
242
+ convert_to_dict() is dict mixed of pydantic object and nested dict(the output
243
+ of converters). So they may be bytes in the dict and they are out of
244
+ `ser_json_bytes` control in model_dump(mode='json') called in
245
+ `convert_to_dict`, as well as datetime deserialization in Pydantic json mode.
246
+
247
+ Returns:
248
+ A dictionary with json.dumps() incompatible type (e.g. bytes datetime)
249
+ to compatible type (e.g. base64 encoded string, isoformat date string).
250
+ """
224
251
  processed_data = {}
225
252
  if not isinstance(data, dict):
226
253
  return data
227
254
  for key, value in data.items():
228
255
  if isinstance(value, bytes):
229
256
  processed_data[key] = base64.urlsafe_b64encode(value).decode('ascii')
257
+ elif isinstance(value, datetime.datetime):
258
+ processed_data[key] = value.isoformat()
230
259
  elif isinstance(value, dict):
231
- processed_data[key] = apply_base64_encoding(value)
260
+ processed_data[key] = encode_unserializable_types(value)
232
261
  elif isinstance(value, list):
233
262
  if all(isinstance(v, bytes) for v in value):
234
263
  processed_data[key] = [
235
264
  base64.urlsafe_b64encode(v).decode('ascii') for v in value
236
265
  ]
266
+ if all(isinstance(v, datetime.datetime) for v in value):
267
+ processed_data[key] = [v.isoformat() for v in value]
237
268
  else:
238
- processed_data[key] = [apply_base64_encoding(v) for v in value]
269
+ processed_data[key] = [encode_unserializable_types(v) for v in value]
239
270
  else:
240
271
  processed_data[key] = value
241
272
  return processed_data
@@ -116,7 +116,7 @@ def convert_if_exist_pydantic_model(
116
116
  try:
117
117
  return annotation(**value)
118
118
  except pydantic.ValidationError as e:
119
- raise errors.UnkownFunctionCallArgumentError(
119
+ raise errors.UnknownFunctionCallArgumentError(
120
120
  f'Failed to parse parameter {param_name} for function'
121
121
  f' {func_name} from function call part because function call argument'
122
122
  f' value {value} is not compatible with parameter annotation'
@@ -150,7 +150,7 @@ def convert_if_exist_pydantic_model(
150
150
  except pydantic.ValidationError:
151
151
  continue
152
152
  # if none of the union type is matched, raise error
153
- raise errors.UnkownFunctionCallArgumentError(
153
+ raise errors.UnknownFunctionCallArgumentError(
154
154
  f'Failed to parse parameter {param_name} for function'
155
155
  f' {func_name} from function call part because function call argument'
156
156
  f' value {value} cannot be converted to parameter annotation'
@@ -161,7 +161,7 @@ def convert_if_exist_pydantic_model(
161
161
  if isinstance(value, int) and annotation is float:
162
162
  return value
163
163
  if not isinstance(value, annotation):
164
- raise errors.UnkownFunctionCallArgumentError(
164
+ raise errors.UnknownFunctionCallArgumentError(
165
165
  f'Failed to parse parameter {param_name} for function {func_name} from'
166
166
  f' function call part because function call argument value {value} is'
167
167
  f' not compatible with parameter annotation {annotation}.'
@@ -17,11 +17,12 @@
17
17
 
18
18
  import base64
19
19
  import copy
20
+ import datetime
20
21
  import inspect
22
+ import io
21
23
  import json
22
24
  import os
23
25
  import re
24
- import datetime
25
26
  from typing import Any, Literal, Optional, Union
26
27
 
27
28
  import google.auth
@@ -32,9 +33,9 @@ from ._api_client import ApiClient
32
33
  from ._api_client import HttpOptions
33
34
  from ._api_client import HttpRequest
34
35
  from ._api_client import HttpResponse
35
- from ._api_client import RequestJsonEncoder
36
36
  from ._common import BaseModel
37
37
 
38
+
38
39
  def _redact_version_numbers(version_string: str) -> str:
39
40
  """Redacts version numbers in the form x.y.z from a string."""
40
41
  return re.sub(r'\d+\.\d+\.\d+', '{VERSION_NUMBER}', version_string)
@@ -145,6 +146,7 @@ class ReplayResponse(BaseModel):
145
146
  status_code: int = 200
146
147
  headers: dict[str, str]
147
148
  body_segments: list[dict[str, object]]
149
+ byte_segments: Optional[list[bytes]] = None
148
150
  sdk_response_segments: list[dict[str, object]]
149
151
 
150
152
  def model_post_init(self, __context: Any) -> None:
@@ -264,17 +266,13 @@ class ReplayApiClient(ApiClient):
264
266
  replay_file_path = self._get_replay_file_path()
265
267
  os.makedirs(os.path.dirname(replay_file_path), exist_ok=True)
266
268
  with open(replay_file_path, 'w') as f:
267
- f.write(
268
- json.dumps(
269
- self.replay_session.model_dump(mode='json'), indent=2, cls=ResponseJsonEncoder
270
- )
271
- )
269
+ f.write(self.replay_session.model_dump_json(exclude_unset=True, indent=2))
272
270
  self.replay_session = None
273
271
 
274
272
  def _record_interaction(
275
273
  self,
276
274
  http_request: HttpRequest,
277
- http_response: Union[HttpResponse, errors.APIError],
275
+ http_response: Union[HttpResponse, errors.APIError, bytes],
278
276
  ):
279
277
  if not self._should_update_replay():
280
278
  return
@@ -289,6 +287,9 @@ class ReplayApiClient(ApiClient):
289
287
  response = ReplayResponse(
290
288
  headers=dict(http_response.headers),
291
289
  body_segments=list(http_response.segments()),
290
+ byte_segments=[
291
+ seg[:100] + b'...' for seg in http_response.byte_segments()
292
+ ],
292
293
  status_code=http_response.status_code,
293
294
  sdk_response_segments=[],
294
295
  )
@@ -322,11 +323,7 @@ class ReplayApiClient(ApiClient):
322
323
  # so that the comparison is fair.
323
324
  _redact_request_body(request_data_copy)
324
325
 
325
- # Need to call dumps() and loads() to convert dict bytes values to strings.
326
- # Because the expected_request_body dict never contains bytes values.
327
- actual_request_body = [
328
- json.loads(json.dumps(request_data_copy, cls=RequestJsonEncoder))
329
- ]
326
+ actual_request_body = [request_data_copy]
330
327
  expected_request_body = interaction.request.body_segments
331
328
  assert actual_request_body == expected_request_body, (
332
329
  'Request body mismatch:\n'
@@ -349,6 +346,7 @@ class ReplayApiClient(ApiClient):
349
346
  json.dumps(segment)
350
347
  for segment in interaction.response.body_segments
351
348
  ],
349
+ byte_stream=interaction.response.byte_segments,
352
350
  )
353
351
 
354
352
  def _verify_response(self, response_model: BaseModel):
@@ -368,7 +366,9 @@ class ReplayApiClient(ApiClient):
368
366
  response_model = response_model[0]
369
367
  print('response_model: ', response_model.model_dump(exclude_none=True))
370
368
  actual = response_model.model_dump(exclude_none=True, mode='json')
371
- expected = interaction.response.sdk_response_segments[self._sdk_response_index]
369
+ expected = interaction.response.sdk_response_segments[
370
+ self._sdk_response_index
371
+ ]
372
372
  assert (
373
373
  actual == expected
374
374
  ), f'SDK response mismatch:\nActual: {actual}\nExpected: {expected}'
@@ -402,10 +402,21 @@ class ReplayApiClient(ApiClient):
402
402
  else:
403
403
  return self._build_response_from_replay(http_request)
404
404
 
405
- def upload_file(self, file_path: str, upload_url: str, upload_size: int):
406
- request = HttpRequest(
407
- method='POST', url='', data={'file_path': file_path}, headers={}
408
- )
405
+ def upload_file(self, file_path: Union[str, io.IOBase], upload_url: str, upload_size: int):
406
+ if isinstance(file_path, io.IOBase):
407
+ offset = file_path.tell()
408
+ content = file_path.read()
409
+ file_path.seek(offset, os.SEEK_SET)
410
+ request = HttpRequest(
411
+ method='POST',
412
+ url='',
413
+ data={'bytes': base64.b64encode(content).decode('utf-8')},
414
+ headers={}
415
+ )
416
+ else:
417
+ request = HttpRequest(
418
+ method='POST', url='', data={'file_path': file_path}, headers={}
419
+ )
409
420
  if self._should_call_api():
410
421
  try:
411
422
  result = super().upload_file(file_path, upload_url, upload_size)
@@ -420,18 +431,19 @@ class ReplayApiClient(ApiClient):
420
431
  else:
421
432
  return self._build_response_from_replay(request).text
422
433
 
423
-
424
- # TODO(b/389693448): Cleanup datetime hacks.
425
- class ResponseJsonEncoder(json.JSONEncoder):
426
- """The replay test json encoder for response.
427
- """
428
- def default(self, o):
429
- if isinstance(o, datetime.datetime):
430
- # dt.isoformat() prints "2024-11-15T23:27:45.624657+00:00"
431
- # but replay files want "2024-11-15T23:27:45.624657Z"
432
- if o.isoformat().endswith('+00:00'):
433
- return o.isoformat().replace('+00:00', 'Z')
434
- else:
435
- return o.isoformat()
434
+ def _download_file_request(self, request):
435
+ self._initialize_replay_session_if_not_loaded()
436
+ if self._should_call_api():
437
+ try:
438
+ result = super()._download_file_request(request)
439
+ except HTTPError as e:
440
+ result = HttpResponse(
441
+ e.response.headers, [json.dumps({'reason': e.response.reason})]
442
+ )
443
+ result.status_code = e.response.status_code
444
+ raise e
445
+ self._record_interaction(request, result)
446
+ return result
436
447
  else:
437
- return super().default(o)
448
+ return self._build_response_from_replay(request)
449
+