google-genai 0.5.0__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -20,6 +20,7 @@ import asyncio
20
20
  import copy
21
21
  from dataclasses import dataclass
22
22
  import datetime
23
+ import io
23
24
  import json
24
25
  import logging
25
26
  import os
@@ -35,55 +36,7 @@ import requests
35
36
 
36
37
  from . import errors
37
38
  from . import version
38
-
39
-
40
- class HttpOptions(BaseModel):
41
- """HTTP options for the api client."""
42
- model_config = ConfigDict(extra='forbid')
43
-
44
- base_url: Optional[str] = Field(
45
- default=None,
46
- description="""The base URL for the AI platform service endpoint.""",
47
- )
48
- api_version: Optional[str] = Field(
49
- default=None,
50
- description="""Specifies the version of the API to use.""",
51
- )
52
- headers: Optional[dict[str, str]] = Field(
53
- default=None,
54
- description="""Additional HTTP headers to be sent with the request.""",
55
- )
56
- response_payload: Optional[dict] = Field(
57
- default=None,
58
- description="""If set, the response payload will be returned int the supplied dict.""",
59
- )
60
- timeout: Optional[Union[float, Tuple[float, float]]] = Field(
61
- default=None,
62
- description="""Timeout for the request in seconds.""",
63
- )
64
- skip_project_and_location_in_path: bool = Field(
65
- default=False,
66
- description="""If set to True, the project and location will not be appended to the path.""",
67
- )
68
-
69
-
70
- class HttpOptionsDict(TypedDict):
71
- """HTTP options for the api client."""
72
-
73
- base_url: Optional[str] = None
74
- """The base URL for the AI platform service endpoint."""
75
- api_version: Optional[str] = None
76
- """Specifies the version of the API to use."""
77
- headers: Optional[dict[str, str]] = None
78
- """Additional HTTP headers to be sent with the request."""
79
- response_payload: Optional[dict] = None
80
- """If set, the response payload will be returned int the supplied dict."""
81
- timeout: Optional[Union[float, Tuple[float, float]]] = None
82
- """Timeout for the request in seconds."""
83
- skip_project_and_location_in_path: bool = False
84
- """If set to True, the project and location will not be appended to the path."""
85
-
86
- HttpOptionsOrDict = Union[HttpOptions, HttpOptionsDict]
39
+ from .types import HttpOptions, HttpOptionsDict, HttpOptionsOrDict
87
40
 
88
41
 
89
42
  def _append_library_version_headers(headers: dict[str, str]) -> None:
@@ -143,18 +96,35 @@ class HttpRequest:
143
96
  url: str
144
97
  method: str
145
98
  data: Union[dict[str, object], bytes]
146
- timeout: Optional[Union[float, Tuple[float, float]]] = None
99
+ timeout: Optional[float] = None
147
100
 
148
101
 
149
102
  class HttpResponse:
150
103
 
151
- def __init__(self, headers: dict[str, str], response_stream: Union[Any, str]):
104
+ def __init__(
105
+ self,
106
+ headers: dict[str, str],
107
+ response_stream: Union[Any, str] = None,
108
+ byte_stream: Union[Any, bytes] = None,
109
+ ):
152
110
  self.status_code = 200
153
111
  self.headers = headers
154
112
  self.response_stream = response_stream
113
+ self.byte_stream = byte_stream
114
+ self.segment_iterator = self.segments()
115
+
116
+ # Async iterator for async streaming.
117
+ def __aiter__(self):
118
+ return self
119
+
120
+ async def __anext__(self):
121
+ try:
122
+ return next(self.segment_iterator)
123
+ except StopIteration:
124
+ raise StopAsyncIteration
155
125
 
156
126
  @property
157
- def text(self) -> str:
127
+ def json(self) -> Any:
158
128
  if not self.response_stream[0]: # Empty response
159
129
  return ''
160
130
  return json.loads(self.response_stream[0])
@@ -164,6 +134,8 @@ class HttpResponse:
164
134
  # list of objects retrieved from replay or from non-streaming API.
165
135
  for chunk in self.response_stream:
166
136
  yield json.loads(chunk) if chunk else {}
137
+ elif self.response_stream is None:
138
+ yield from []
167
139
  else:
168
140
  # Iterator of objects retrieved from the API.
169
141
  for chunk in self.response_stream.iter_lines():
@@ -174,7 +146,20 @@ class HttpResponse:
174
146
  chunk = chunk[len(b'data: ') :]
175
147
  yield json.loads(str(chunk, 'utf-8'))
176
148
 
177
- def copy_to_dict(self, response_payload: dict[str, object]):
149
+ def byte_segments(self):
150
+ if isinstance(self.byte_stream, list):
151
+ # list of objects retrieved from replay or from non-streaming API.
152
+ yield from self.byte_stream
153
+ elif self.byte_stream is None:
154
+ yield from []
155
+ else:
156
+ raise ValueError(
157
+ 'Byte segments are not supported for streaming responses.'
158
+ )
159
+
160
+ def _copy_to_dict(self, response_payload: dict[str, object]):
161
+ # Cannot pickle 'generator' object.
162
+ delattr(self, 'segment_iterator')
178
163
  for attribute in dir(self):
179
164
  response_payload[attribute] = copy.deepcopy(getattr(self, attribute))
180
165
 
@@ -199,7 +184,7 @@ class ApiClient:
199
184
  ]:
200
185
  self.vertexai = True
201
186
 
202
- # Validate explicitly set intializer values.
187
+ # Validate explicitly set initializer values.
203
188
  if (project or location) and api_key:
204
189
  # API cannot consume both project/location and api_key.
205
190
  raise ValueError(
@@ -265,9 +250,10 @@ class ApiClient:
265
250
  self.api_key = None
266
251
  if not self.project and not self.api_key:
267
252
  self.project = google.auth.default()[1]
268
- if not (self.project or self.location) and not self.api_key:
253
+ if not ((self.project and self.location) or self.api_key):
269
254
  raise ValueError(
270
- 'Project/location or API key must be set when using the Vertex AI API.'
255
+ 'Project and location or API key must be set when using the Vertex '
256
+ 'AI API.'
271
257
  )
272
258
  if self.api_key:
273
259
  self._http_options['base_url'] = (
@@ -304,7 +290,7 @@ class ApiClient:
304
290
  http_method: str,
305
291
  path: str,
306
292
  request_dict: dict[str, object],
307
- http_options: HttpOptionsDict = None,
293
+ http_options: HttpOptionsOrDict = None,
308
294
  ) -> HttpRequest:
309
295
  # Remove all special dict keys such as _url and _query.
310
296
  keys_to_delete = [key for key in request_dict.keys() if key.startswith('_')]
@@ -312,18 +298,28 @@ class ApiClient:
312
298
  del request_dict[key]
313
299
  # patch the http options with the user provided settings.
314
300
  if http_options:
315
- patched_http_options = _patch_http_options(
316
- self._http_options, http_options
317
- )
301
+ if isinstance(http_options, HttpOptions):
302
+ patched_http_options = _patch_http_options(
303
+ self._http_options, http_options.model_dump()
304
+ )
305
+ else:
306
+ patched_http_options = _patch_http_options(
307
+ self._http_options, http_options
308
+ )
318
309
  else:
319
310
  patched_http_options = self._http_options
320
- skip_project_and_location_in_path_val = patched_http_options.get(
321
- 'skip_project_and_location_in_path', False
322
- )
311
+ # Skip adding project and locations when getting Vertex AI base models.
312
+ query_vertex_base_models = False
313
+ if (
314
+ self.vertexai
315
+ and http_method == 'get'
316
+ and path.startswith('publishers/google/models')
317
+ ):
318
+ query_vertex_base_models = True
323
319
  if (
324
320
  self.vertexai
325
321
  and not path.startswith('projects/')
326
- and not skip_project_and_location_in_path_val
322
+ and not query_vertex_base_models
327
323
  and not self.api_key
328
324
  ):
329
325
  path = f'projects/{self.project}/locations/{self.location}/' + path
@@ -333,12 +329,19 @@ class ApiClient:
333
329
  patched_http_options['base_url'],
334
330
  patched_http_options['api_version'] + '/' + path,
335
331
  )
332
+
333
+ timeout_in_seconds = patched_http_options.get('timeout', None)
334
+ if timeout_in_seconds:
335
+ timeout_in_seconds = timeout_in_seconds / 1000.0
336
+ else:
337
+ timeout_in_seconds = None
338
+
336
339
  return HttpRequest(
337
340
  method=http_method,
338
341
  url=url,
339
342
  headers=patched_http_options['headers'],
340
343
  data=request_dict,
341
- timeout=patched_http_options.get('timeout', None),
344
+ timeout=timeout_in_seconds,
342
345
  )
343
346
 
344
347
  def _request(
@@ -357,7 +360,7 @@ class ApiClient:
357
360
  http_request.method.upper(),
358
361
  http_request.url,
359
362
  headers=http_request.headers,
360
- data=json.dumps(http_request.data, cls=RequestJsonEncoder)
363
+ data=json.dumps(http_request.data)
361
364
  if http_request.data
362
365
  else None,
363
366
  timeout=http_request.timeout,
@@ -377,7 +380,7 @@ class ApiClient:
377
380
  data = None
378
381
  if http_request.data:
379
382
  if not isinstance(http_request.data, bytes):
380
- data = json.dumps(http_request.data, cls=RequestJsonEncoder)
383
+ data = json.dumps(http_request.data)
381
384
  else:
382
385
  data = http_request.data
383
386
 
@@ -427,15 +430,24 @@ class ApiClient:
427
430
  http_method: str,
428
431
  path: str,
429
432
  request_dict: dict[str, object],
430
- http_options: HttpOptionsDict = None,
433
+ http_options: HttpOptionsOrDict = None,
431
434
  ):
432
435
  http_request = self._build_request(
433
436
  http_method, path, request_dict, http_options
434
437
  )
435
438
  response = self._request(http_request, stream=False)
436
- if http_options and 'response_payload' in http_options:
437
- response.copy_to_dict(http_options['response_payload'])
438
- return response.text
439
+ if http_options:
440
+ if (
441
+ isinstance(http_options, HttpOptions)
442
+ and http_options.deprecated_response_payload is not None
443
+ ):
444
+ response._copy_to_dict(http_options.deprecated_response_payload)
445
+ elif (
446
+ isinstance(http_options, dict)
447
+ and 'deprecated_response_payload' in http_options
448
+ ):
449
+ response._copy_to_dict(http_options['deprecated_response_payload'])
450
+ return response.json
439
451
 
440
452
  def request_streamed(
441
453
  self,
@@ -449,8 +461,10 @@ class ApiClient:
449
461
  )
450
462
 
451
463
  session_response = self._request(http_request, stream=True)
452
- if http_options and 'response_payload' in http_options:
453
- session_response.copy_to_dict(http_options['response_payload'])
464
+ if http_options and 'deprecated_response_payload' in http_options:
465
+ session_response._copy_to_dict(
466
+ http_options['deprecated_response_payload']
467
+ )
454
468
  for chunk in session_response.segments():
455
469
  yield chunk
456
470
 
@@ -466,9 +480,9 @@ class ApiClient:
466
480
  )
467
481
 
468
482
  result = await self._async_request(http_request=http_request, stream=False)
469
- if http_options and 'response_payload' in http_options:
470
- result.copy_to_dict(http_options['response_payload'])
471
- return result.text
483
+ if http_options and 'deprecated_response_payload' in http_options:
484
+ result._copy_to_dict(http_options['deprecated_response_payload'])
485
+ return result.json
472
486
 
473
487
  async def async_request_streamed(
474
488
  self,
@@ -483,17 +497,42 @@ class ApiClient:
483
497
 
484
498
  response = await self._async_request(http_request=http_request, stream=True)
485
499
 
486
- for chunk in response.segments():
487
- yield chunk
488
- if http_options and 'response_payload' in http_options:
489
- response.copy_to_dict(http_options['response_payload'])
500
+ if http_options and 'deprecated_response_payload' in http_options:
501
+ response._copy_to_dict(http_options['deprecated_response_payload'])
502
+ async def async_generator():
503
+ async for chunk in response:
504
+ yield chunk
505
+ return async_generator()
490
506
 
491
- def upload_file(self, file_path: str, upload_url: str, upload_size: int):
507
+ def upload_file(
508
+ self, file_path: Union[str, io.IOBase], upload_url: str, upload_size: int
509
+ ) -> str:
492
510
  """Transfers a file to the given URL.
493
511
 
494
512
  Args:
495
- file_path: The full path to the file. If the local file path is not found,
496
- an error will be raised.
513
+ file_path: The full path to the file or a file like object inherited from
514
+ io.BytesIO. If the local file path is not found, an error will be
515
+ raised.
516
+ upload_url: The URL to upload the file to.
517
+ upload_size: The size of file content to be uploaded, this will have to
518
+ match the size requested in the resumable upload request.
519
+
520
+ returns:
521
+ The response json object from the finalize request.
522
+ """
523
+ if isinstance(file_path, io.IOBase):
524
+ return self._upload_fd(file_path, upload_url, upload_size)
525
+ else:
526
+ with open(file_path, 'rb') as file:
527
+ return self._upload_fd(file, upload_url, upload_size)
528
+
529
+ def _upload_fd(
530
+ self, file: io.IOBase, upload_url: str, upload_size: int
531
+ ) -> str:
532
+ """Transfers a file to the given URL.
533
+
534
+ Args:
535
+ file: A file like object inherited from io.BytesIO.
497
536
  upload_url: The URL to upload the file to.
498
537
  upload_size: The size of file content to be uploaded, this will have to
499
538
  match the size requested in the resumable upload request.
@@ -503,51 +542,89 @@ class ApiClient:
503
542
  """
504
543
  offset = 0
505
544
  # Upload the file in chunks
506
- with open(file_path, 'rb') as file:
507
- while True:
508
- file_chunk = file.read(1024 * 1024 * 8) # 8 MB chunk size
509
- chunk_size = 0
510
- if file_chunk:
511
- chunk_size = len(file_chunk)
512
- upload_command = 'upload'
513
- # If last chunk, finalize the upload.
514
- if chunk_size + offset >= upload_size:
515
- upload_command += ', finalize'
516
-
517
- request = HttpRequest(
518
- method='POST',
519
- url=upload_url,
520
- headers={
521
- 'X-Goog-Upload-Command': upload_command,
522
- 'X-Goog-Upload-Offset': str(offset),
523
- 'Content-Length': str(chunk_size),
524
- },
525
- data=file_chunk,
526
- )
527
- response = self._request_unauthorized(request, stream=False)
528
- offset += chunk_size
529
- if response.headers['X-Goog-Upload-Status'] != 'active':
530
- break # upload is complete or it has been interrupted.
545
+ while True:
546
+ file_chunk = file.read(1024 * 1024 * 8) # 8 MB chunk size
547
+ chunk_size = 0
548
+ if file_chunk:
549
+ chunk_size = len(file_chunk)
550
+ upload_command = 'upload'
551
+ # If last chunk, finalize the upload.
552
+ if chunk_size + offset >= upload_size:
553
+ upload_command += ', finalize'
554
+ request = HttpRequest(
555
+ method='POST',
556
+ url=upload_url,
557
+ headers={
558
+ 'X-Goog-Upload-Command': upload_command,
559
+ 'X-Goog-Upload-Offset': str(offset),
560
+ 'Content-Length': str(chunk_size),
561
+ },
562
+ data=file_chunk,
563
+ )
531
564
 
532
- if upload_size <= offset: # Status is not finalized.
533
- raise ValueError(
534
- 'All content has been uploaded, but the upload status is not'
535
- f' finalized. {response.headers}, body: {response.text}'
536
- )
565
+ response = self._request_unauthorized(request, stream=False)
566
+ offset += chunk_size
567
+ if response.headers['X-Goog-Upload-Status'] != 'active':
568
+ break # upload is complete or it has been interrupted.
569
+
570
+ if upload_size <= offset: # Status is not finalized.
571
+ raise ValueError(
572
+ 'All content has been uploaded, but the upload status is not'
573
+ f' finalized. {response.headers}, body: {response.json}'
574
+ )
537
575
 
538
576
  if response.headers['X-Goog-Upload-Status'] != 'final':
539
577
  raise ValueError(
540
578
  'Failed to upload file: Upload status is not finalized. headers:'
541
- f' {response.headers}, body: {response.text}'
579
+ f' {response.headers}, body: {response.json}'
542
580
  )
543
- return response.text
581
+ return response.json
582
+
583
+ def download_file(self, path: str, http_options):
584
+ """Downloads the file data.
585
+
586
+ Args:
587
+ path: The request path with query params.
588
+ http_options: The http options to use for the request.
589
+
590
+ returns:
591
+ The file bytes
592
+ """
593
+ http_request = self._build_request(
594
+ 'get', path=path, request_dict={}, http_options=http_options
595
+ )
596
+ return self._download_file_request(http_request).byte_stream[0]
597
+
598
+ def _download_file_request(
599
+ self,
600
+ http_request: HttpRequest,
601
+ ) -> HttpResponse:
602
+ data = None
603
+ if http_request.data:
604
+ if not isinstance(http_request.data, bytes):
605
+ data = json.dumps(http_request.data, cls=RequestJsonEncoder)
606
+ else:
607
+ data = http_request.data
608
+
609
+ http_session = requests.Session()
610
+ response = http_session.request(
611
+ method=http_request.method,
612
+ url=http_request.url,
613
+ headers=http_request.headers,
614
+ data=data,
615
+ timeout=http_request.timeout,
616
+ stream=False,
617
+ )
618
+
619
+ errors.APIError.raise_for_response(response)
620
+ return HttpResponse(response.headers, byte_stream=[response.content])
544
621
 
545
622
  async def async_upload_file(
546
623
  self,
547
- file_path: str,
624
+ file_path: Union[str, io.IOBase],
548
625
  upload_url: str,
549
626
  upload_size: int,
550
- ):
627
+ ) -> str:
551
628
  """Transfers a file asynchronously to the given URL.
552
629
 
553
630
  Args:
@@ -567,22 +644,48 @@ class ApiClient:
567
644
  upload_size,
568
645
  )
569
646
 
647
+ async def _async_upload_fd(
648
+ self,
649
+ file: io.IOBase,
650
+ upload_url: str,
651
+ upload_size: int,
652
+ ) -> str:
653
+ """Transfers a file asynchronously to the given URL.
654
+
655
+ Args:
656
+ file: A file like object inherited from io.BytesIO.
657
+ upload_url: The URL to upload the file to.
658
+ upload_size: The size of file content to be uploaded, this will have to
659
+ match the size requested in the resumable upload request.
660
+
661
+ returns:
662
+ The response json object from the finalize request.
663
+ """
664
+ return await asyncio.to_thread(
665
+ self._upload_fd,
666
+ file,
667
+ upload_url,
668
+ upload_size,
669
+ )
670
+
671
+ async def async_download_file(self, path: str, http_options):
672
+ """Downloads the file data.
673
+
674
+ Args:
675
+ path: The request path with query params.
676
+ http_options: The http options to use for the request.
677
+
678
+ returns:
679
+ The file bytes
680
+ """
681
+ return await asyncio.to_thread(
682
+ self.download_file,
683
+ path,
684
+ http_options,
685
+ )
686
+
570
687
  # This method does nothing in the real api client. It is used in the
571
688
  # replay_api_client to verify the response from the SDK method matches the
572
689
  # recorded response.
573
690
  def _verify_response(self, response_model: BaseModel):
574
691
  pass
575
-
576
-
577
- # TODO(b/389693448): Cleanup datetime hacks.
578
- class RequestJsonEncoder(json.JSONEncoder):
579
- """Encode bytes as strings without modify its content."""
580
-
581
- def default(self, o):
582
- if isinstance(o, datetime.datetime):
583
- # This Zulu time format is used by the Vertex AI API and the test recorder
584
- # Using strftime works well, but we want to align with the replay encoder.
585
- # o.astimezone(datetime.timezone.utc).strftime('%Y-%m-%dT%H:%M:%S.%fZ')
586
- return o.isoformat().replace('+00:00', 'Z')
587
- else:
588
- return super().default(o)
@@ -0,0 +1,24 @@
1
+ # Copyright 2024 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+
16
+ """Utilities for the API Modules of the Google Gen AI SDK."""
17
+
18
+ from . import _api_client
19
+
20
+
21
+ class BaseModule:
22
+
23
+ def __init__(self, api_client_: _api_client.ApiClient):
24
+ self._api_client = api_client_