google-genai 1.6.0__py3-none-any.whl → 1.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -188,9 +188,11 @@ class HttpResponse:
188
188
  if chunk:
189
189
  # In streaming mode, the chunk of JSON is prefixed with "data:" which
190
190
  # we must strip before parsing.
191
- if chunk.startswith(b'data: '):
192
- chunk = chunk[len(b'data: ') :]
193
- yield json.loads(str(chunk, 'utf-8'))
191
+ if not isinstance(chunk, str):
192
+ chunk = chunk.decode('utf-8')
193
+ if chunk.startswith('data: '):
194
+ chunk = chunk[len('data: ') :]
195
+ yield json.loads(chunk)
194
196
 
195
197
  async def async_segments(self) -> AsyncIterator[Any]:
196
198
  if isinstance(self.response_stream, list):
@@ -206,8 +208,10 @@ class HttpResponse:
206
208
  async for chunk in self.response_stream.aiter_lines():
207
209
  # This is httpx.Response.
208
210
  if chunk:
209
- # In async streaming mode, the chunk of JSON is prefixed with "data:"
210
- # which we must strip before parsing.
211
+ # In async streaming mode, the chunk of JSON is prefixed with
212
+ # "data:" which we must strip before parsing.
213
+ if not isinstance(chunk, str):
214
+ chunk = chunk.decode('utf-8')
211
215
  if chunk.startswith('data: '):
212
216
  chunk = chunk[len('data: ') :]
213
217
  yield json.loads(chunk)
@@ -234,6 +238,41 @@ class HttpResponse:
234
238
  response_payload[attribute] = copy.deepcopy(getattr(self, attribute))
235
239
 
236
240
 
241
+ class SyncHttpxClient(httpx.Client):
242
+ """Sync httpx client."""
243
+
244
+ def __init__(self, **kwargs: Any) -> None:
245
+ """Initializes the httpx client."""
246
+ kwargs.setdefault('follow_redirects', True)
247
+ super().__init__(**kwargs)
248
+
249
+ def __del__(self) -> None:
250
+ """Closes the httpx client."""
251
+ if self.is_closed:
252
+ return
253
+ try:
254
+ self.close()
255
+ except Exception:
256
+ pass
257
+
258
+
259
+ class AsyncHttpxClient(httpx.AsyncClient):
260
+ """Async httpx client."""
261
+
262
+ def __init__(self, **kwargs: Any) -> None:
263
+ """Initializes the httpx client."""
264
+ kwargs.setdefault('follow_redirects', True)
265
+ super().__init__(**kwargs)
266
+
267
+ def __del__(self) -> None:
268
+ if self.is_closed:
269
+ return
270
+ try:
271
+ asyncio.get_running_loop().create_task(self.aclose())
272
+ except Exception:
273
+ pass
274
+
275
+
237
276
  class BaseApiClient:
238
277
  """Client for calling HTTP APIs sending and receiving JSON."""
239
278
 
@@ -365,6 +404,9 @@ class BaseApiClient:
365
404
  )
366
405
  else:
367
406
  _append_library_version_headers(self._http_options['headers'])
407
+ # Initialize the httpx client.
408
+ self._httpx_client = SyncHttpxClient()
409
+ self._async_httpx_client = AsyncHttpxClient()
368
410
 
369
411
  def _websocket_base_url(self):
370
412
  url_parts = urlparse(self._http_options['base_url'])
@@ -495,41 +537,39 @@ class BaseApiClient:
495
537
  http_request.headers['x-goog-user-project'] = (
496
538
  self._credentials.quota_project_id
497
539
  )
498
- data = json.dumps(http_request.data)
540
+ data = json.dumps(http_request.data) if http_request.data else None
499
541
  else:
500
542
  if http_request.data:
501
543
  if not isinstance(http_request.data, bytes):
502
- data = json.dumps(http_request.data)
544
+ data = json.dumps(http_request.data) if http_request.data else None
503
545
  else:
504
546
  data = http_request.data
505
547
 
506
548
  if stream:
507
- client = httpx.Client()
508
- httpx_request = client.build_request(
549
+ httpx_request = self._httpx_client.build_request(
509
550
  method=http_request.method,
510
551
  url=http_request.url,
511
552
  content=data,
512
553
  headers=http_request.headers,
513
554
  timeout=http_request.timeout,
514
555
  )
515
- response = client.send(httpx_request, stream=stream)
556
+ response = self._httpx_client.send(httpx_request, stream=stream)
516
557
  errors.APIError.raise_for_response(response)
517
558
  return HttpResponse(
518
559
  response.headers, response if stream else [response.text]
519
560
  )
520
561
  else:
521
- with httpx.Client() as client:
522
- response = client.request(
523
- method=http_request.method,
524
- url=http_request.url,
525
- headers=http_request.headers,
526
- content=data,
527
- timeout=http_request.timeout,
528
- )
529
- errors.APIError.raise_for_response(response)
530
- return HttpResponse(
531
- response.headers, response if stream else [response.text]
532
- )
562
+ response = self._httpx_client.request(
563
+ method=http_request.method,
564
+ url=http_request.url,
565
+ headers=http_request.headers,
566
+ content=data,
567
+ timeout=http_request.timeout,
568
+ )
569
+ errors.APIError.raise_for_response(response)
570
+ return HttpResponse(
571
+ response.headers, response if stream else [response.text]
572
+ )
533
573
 
534
574
  async def _async_request(
535
575
  self, http_request: HttpRequest, stream: bool = False
@@ -543,24 +583,23 @@ class BaseApiClient:
543
583
  http_request.headers['x-goog-user-project'] = (
544
584
  self._credentials.quota_project_id
545
585
  )
546
- data = json.dumps(http_request.data)
586
+ data = json.dumps(http_request.data) if http_request.data else None
547
587
  else:
548
588
  if http_request.data:
549
589
  if not isinstance(http_request.data, bytes):
550
- data = json.dumps(http_request.data)
590
+ data = json.dumps(http_request.data) if http_request.data else None
551
591
  else:
552
592
  data = http_request.data
553
593
 
554
594
  if stream:
555
- aclient = httpx.AsyncClient()
556
- httpx_request = aclient.build_request(
595
+ httpx_request = self._async_httpx_client.build_request(
557
596
  method=http_request.method,
558
597
  url=http_request.url,
559
598
  content=data,
560
599
  headers=http_request.headers,
561
600
  timeout=http_request.timeout,
562
601
  )
563
- response = await aclient.send(
602
+ response = await self._async_httpx_client.send(
564
603
  httpx_request,
565
604
  stream=stream,
566
605
  )
@@ -569,18 +608,17 @@ class BaseApiClient:
569
608
  response.headers, response if stream else [response.text]
570
609
  )
571
610
  else:
572
- async with httpx.AsyncClient() as aclient:
573
- response = await aclient.request(
574
- method=http_request.method,
575
- url=http_request.url,
576
- headers=http_request.headers,
577
- content=data,
578
- timeout=http_request.timeout,
579
- )
580
- errors.APIError.raise_for_response(response)
581
- return HttpResponse(
582
- response.headers, response if stream else [response.text]
583
- )
611
+ response = await self._async_httpx_client.request(
612
+ method=http_request.method,
613
+ url=http_request.url,
614
+ headers=http_request.headers,
615
+ content=data,
616
+ timeout=http_request.timeout,
617
+ )
618
+ errors.APIError.raise_for_response(response)
619
+ return HttpResponse(
620
+ response.headers, response if stream else [response.text]
621
+ )
584
622
 
585
623
  def get_read_only_http_options(self) -> HttpOptionsDict:
586
624
  copied = HttpOptionsDict()
@@ -705,7 +743,7 @@ class BaseApiClient:
705
743
  # If last chunk, finalize the upload.
706
744
  if chunk_size + offset >= upload_size:
707
745
  upload_command += ', finalize'
708
- request = HttpRequest(
746
+ response = self._httpx_client.request(
709
747
  method='POST',
710
748
  url=upload_url,
711
749
  headers={
@@ -713,25 +751,22 @@ class BaseApiClient:
713
751
  'X-Goog-Upload-Offset': str(offset),
714
752
  'Content-Length': str(chunk_size),
715
753
  },
716
- data=file_chunk,
754
+ content=file_chunk,
717
755
  )
718
-
719
- response = self._request(request, stream=False)
720
756
  offset += chunk_size
721
- if response.headers['X-Goog-Upload-Status'] != 'active':
757
+ if response.headers['x-goog-upload-status'] != 'active':
722
758
  break # upload is complete or it has been interrupted.
723
-
724
759
  if upload_size <= offset: # Status is not finalized.
725
760
  raise ValueError(
726
761
  'All content has been uploaded, but the upload status is not'
727
762
  f' finalized.'
728
763
  )
729
764
 
730
- if response.headers['X-Goog-Upload-Status'] != 'final':
765
+ if response.headers['x-goog-upload-status'] != 'final':
731
766
  raise ValueError(
732
767
  'Failed to upload file: Upload status is not finalized.'
733
768
  )
734
- return response.json
769
+ return response.json()
735
770
 
736
771
  def download_file(self, path: str, http_options):
737
772
  """Downloads the file data.
@@ -746,12 +781,7 @@ class BaseApiClient:
746
781
  http_request = self._build_request(
747
782
  'get', path=path, request_dict={}, http_options=http_options
748
783
  )
749
- return self._download_file_request(http_request).byte_stream[0]
750
784
 
751
- def _download_file_request(
752
- self,
753
- http_request: HttpRequest,
754
- ) -> HttpResponse:
755
785
  data: Optional[Union[str, bytes]] = None
756
786
  if http_request.data:
757
787
  if not isinstance(http_request.data, bytes):
@@ -759,17 +789,18 @@ class BaseApiClient:
759
789
  else:
760
790
  data = http_request.data
761
791
 
762
- with httpx.Client(follow_redirects=True) as client:
763
- response = client.request(
764
- method=http_request.method,
765
- url=http_request.url,
766
- headers=http_request.headers,
767
- content=data,
768
- timeout=http_request.timeout,
769
- )
792
+ response = self._httpx_client.request(
793
+ method=http_request.method,
794
+ url=http_request.url,
795
+ headers=http_request.headers,
796
+ content=data,
797
+ timeout=http_request.timeout,
798
+ )
770
799
 
771
- errors.APIError.raise_for_response(response)
772
- return HttpResponse(response.headers, byte_stream=[response.read()])
800
+ errors.APIError.raise_for_response(response)
801
+ return HttpResponse(
802
+ response.headers, byte_stream=[response.read()]
803
+ ).byte_stream[0]
773
804
 
774
805
  async def async_upload_file(
775
806
  self,
@@ -814,45 +845,44 @@ class BaseApiClient:
814
845
  returns:
815
846
  The response json object from the finalize request.
816
847
  """
817
- async with httpx.AsyncClient() as aclient:
818
- offset = 0
819
- # Upload the file in chunks
820
- while True:
821
- if isinstance(file, io.IOBase):
822
- file_chunk = file.read(CHUNK_SIZE)
823
- else:
824
- file_chunk = await file.read(CHUNK_SIZE)
825
- chunk_size = 0
826
- if file_chunk:
827
- chunk_size = len(file_chunk)
828
- upload_command = 'upload'
829
- # If last chunk, finalize the upload.
830
- if chunk_size + offset >= upload_size:
831
- upload_command += ', finalize'
832
- response = await aclient.request(
833
- method='POST',
834
- url=upload_url,
835
- content=file_chunk,
836
- headers={
837
- 'X-Goog-Upload-Command': upload_command,
838
- 'X-Goog-Upload-Offset': str(offset),
839
- 'Content-Length': str(chunk_size),
840
- },
841
- )
842
- offset += chunk_size
843
- if response.headers.get('x-goog-upload-status') != 'active':
844
- break # upload is complete or it has been interrupted.
845
-
846
- if upload_size <= offset: # Status is not finalized.
847
- raise ValueError(
848
- 'All content has been uploaded, but the upload status is not'
849
- f' finalized.'
850
- )
851
- if response.headers.get('x-goog-upload-status') != 'final':
848
+ offset = 0
849
+ # Upload the file in chunks
850
+ while True:
851
+ if isinstance(file, io.IOBase):
852
+ file_chunk = file.read(CHUNK_SIZE)
853
+ else:
854
+ file_chunk = await file.read(CHUNK_SIZE)
855
+ chunk_size = 0
856
+ if file_chunk:
857
+ chunk_size = len(file_chunk)
858
+ upload_command = 'upload'
859
+ # If last chunk, finalize the upload.
860
+ if chunk_size + offset >= upload_size:
861
+ upload_command += ', finalize'
862
+ response = await self._async_httpx_client.request(
863
+ method='POST',
864
+ url=upload_url,
865
+ content=file_chunk,
866
+ headers={
867
+ 'X-Goog-Upload-Command': upload_command,
868
+ 'X-Goog-Upload-Offset': str(offset),
869
+ 'Content-Length': str(chunk_size),
870
+ },
871
+ )
872
+ offset += chunk_size
873
+ if response.headers.get('x-goog-upload-status') != 'active':
874
+ break # upload is complete or it has been interrupted.
875
+
876
+ if upload_size <= offset: # Status is not finalized.
852
877
  raise ValueError(
853
- 'Failed to upload file: Upload status is not finalized.'
878
+ 'All content has been uploaded, but the upload status is not'
879
+ f' finalized.'
854
880
  )
855
- return response.json()
881
+ if response.headers.get('x-goog-upload-status') != 'final':
882
+ raise ValueError(
883
+ 'Failed to upload file: Upload status is not finalized.'
884
+ )
885
+ return response.json()
856
886
 
857
887
  async def async_download_file(self, path: str, http_options):
858
888
  """Downloads the file data.
@@ -875,19 +905,18 @@ class BaseApiClient:
875
905
  else:
876
906
  data = http_request.data
877
907
 
878
- async with httpx.AsyncClient(follow_redirects=True) as aclient:
879
- response = await aclient.request(
880
- method=http_request.method,
881
- url=http_request.url,
882
- headers=http_request.headers,
883
- content=data,
884
- timeout=http_request.timeout,
885
- )
886
- errors.APIError.raise_for_response(response)
908
+ response = await self._async_httpx_client.request(
909
+ method=http_request.method,
910
+ url=http_request.url,
911
+ headers=http_request.headers,
912
+ content=data,
913
+ timeout=http_request.timeout,
914
+ )
915
+ errors.APIError.raise_for_response(response)
887
916
 
888
- return HttpResponse(
889
- response.headers, byte_stream=[response.read()]
890
- ).byte_stream[0]
917
+ return HttpResponse(
918
+ response.headers, byte_stream=[response.read()]
919
+ ).byte_stream[0]
891
920
 
892
921
  # This method does nothing in the real api client. It is used in the
893
922
  # replay_api_client to verify the response from the SDK method matches the
google/genai/_common.py CHANGED
@@ -20,7 +20,7 @@ import datetime
20
20
  import enum
21
21
  import functools
22
22
  import typing
23
- from typing import Union
23
+ from typing import Any, Union
24
24
  import uuid
25
25
  import warnings
26
26
 
@@ -93,7 +93,7 @@ def set_value_by_path(data, keys, value):
93
93
  data[keys[-1]] = value
94
94
 
95
95
 
96
- def get_value_by_path(data: object, keys: list[str]):
96
+ def get_value_by_path(data: Any, keys: list[str]):
97
97
  """Examples:
98
98
 
99
99
  get_value_by_path({'a': {'b': v}}, ['a', 'b'])
@@ -128,7 +128,7 @@ def get_value_by_path(data: object, keys: list[str]):
128
128
  return data
129
129
 
130
130
 
131
- def convert_to_dict(obj: dict[str, object]) -> dict[str, object]:
131
+ def convert_to_dict(obj: object) -> Any:
132
132
  """Recursively converts a given object to a dictionary.
133
133
 
134
134
  If the object is a Pydantic model, it uses the model's `model_dump()` method.
@@ -137,7 +137,9 @@ def convert_to_dict(obj: dict[str, object]) -> dict[str, object]:
137
137
  obj: The object to convert.
138
138
 
139
139
  Returns:
140
- A dictionary representation of the object.
140
+ A dictionary representation of the object, a list of objects if a list is
141
+ passed, or the object itself if it is not a dictionary, list, or Pydantic
142
+ model.
141
143
  """
142
144
  if isinstance(obj, pydantic.BaseModel):
143
145
  return obj.model_dump(exclude_none=True)
@@ -150,7 +152,7 @@ def convert_to_dict(obj: dict[str, object]) -> dict[str, object]:
150
152
 
151
153
 
152
154
  def _remove_extra_fields(
153
- model: pydantic.BaseModel, response: dict[str, object]
155
+ model: Any, response: dict[str, object]
154
156
  ) -> None:
155
157
  """Removes extra fields from the response that are not in the model.
156
158
 
@@ -520,11 +520,14 @@ class ReplayApiClient(BaseApiClient):
520
520
  else:
521
521
  return self._build_response_from_replay(request).json
522
522
 
523
- def _download_file_request(self, request):
523
+ def download_file(self, path: str, http_options: HttpOptions):
524
524
  self._initialize_replay_session_if_not_loaded()
525
+ request = self._build_request(
526
+ 'get', path=path, request_dict={}, http_options=http_options
527
+ )
525
528
  if self._should_call_api():
526
529
  try:
527
- result = super()._download_file_request(request)
530
+ result = super().download_file(path, http_options)
528
531
  except HTTPError as e:
529
532
  result = HttpResponse(
530
533
  e.response.headers, [json.dumps({'reason': e.response.reason})]
@@ -534,7 +537,7 @@ class ReplayApiClient(BaseApiClient):
534
537
  self._record_interaction(request, result)
535
538
  return result
536
539
  else:
537
- return self._build_response_from_replay(request)
540
+ return self._build_response_from_replay(request).byte_stream[0]
538
541
 
539
542
  async def async_download_file(self, path: str, http_options):
540
543
  self._initialize_replay_session_if_not_loaded()
google/genai/chats.py CHANGED
@@ -13,12 +13,19 @@
13
13
  # limitations under the License.
14
14
  #
15
15
 
16
- from typing import AsyncIterator, Awaitable, Optional
17
- from typing import Union
16
+ import sys
17
+ from typing import AsyncIterator, Awaitable, Optional, Union, get_args
18
18
 
19
19
  from . import _transformers as t
20
+ from . import types
20
21
  from .models import AsyncModels, Models
21
- from .types import Content, ContentDict, GenerateContentConfigOrDict, GenerateContentResponse, Part, PartUnionDict
22
+ from .types import Content, GenerateContentConfigOrDict, GenerateContentResponse, Part, PartUnionDict
23
+
24
+
25
+ if sys.version_info >= (3, 10):
26
+ from typing import TypeGuard
27
+ else:
28
+ from typing_extensions import TypeGuard
22
29
 
23
30
 
24
31
  def _validate_content(content: Content) -> bool:
@@ -81,8 +88,7 @@ def _extract_curated_history(
81
88
  while i < length:
82
89
  if comprehensive_history[i].role not in ["user", "model"]:
83
90
  raise ValueError(
84
- "Role must be user or model, but got"
85
- f" {comprehensive_history[i].role}"
91
+ f"Role must be user or model, but got {comprehensive_history[i].role}"
86
92
  )
87
93
 
88
94
  if comprehensive_history[i].role == "user":
@@ -108,12 +114,10 @@ class _BaseChat:
108
114
  def __init__(
109
115
  self,
110
116
  *,
111
- modules: Union[Models, AsyncModels],
112
117
  model: str,
113
118
  config: Optional[GenerateContentConfigOrDict] = None,
114
119
  history: list[Content],
115
120
  ):
116
- self._modules = modules
117
121
  self._model = model
118
122
  self._config = config
119
123
  self._comprehensive_history = history
@@ -123,27 +127,32 @@ class _BaseChat:
123
127
  """Curated history is the set of valid turns that will be used in the subsequent send requests.
124
128
  """
125
129
 
126
-
127
- def record_history(self, user_input: Content,
128
- model_output: list[Content],
129
- automatic_function_calling_history: list[Content],
130
- is_valid: bool):
130
+ def record_history(
131
+ self,
132
+ user_input: Content,
133
+ model_output: list[Content],
134
+ automatic_function_calling_history: list[Content],
135
+ is_valid: bool,
136
+ ):
131
137
  """Records the chat history.
132
138
 
133
139
  Maintaining both comprehensive and curated histories.
134
140
 
135
141
  Args:
136
142
  user_input: The user's input content.
137
- model_output: A list of `Content` from the model's response.
138
- This can be an empty list if the model produced no output.
139
- automatic_function_calling_history: A list of `Content` representing
140
- the history of automatic function calls, including the user input as
141
- the first entry.
143
+ model_output: A list of `Content` from the model's response. This can be
144
+ an empty list if the model produced no output.
145
+ automatic_function_calling_history: A list of `Content` representing the
146
+ history of automatic function calls, including the user input as the
147
+ first entry.
142
148
  is_valid: A boolean flag indicating whether the current model output is
143
149
  considered valid.
144
150
  """
145
151
  input_contents = (
146
- automatic_function_calling_history
152
+ # Because the AFC input contains the entire curated chat history in
153
+ # addition to the new user input, we need to truncate the AFC history
154
+ # to deduplicate the existing chat history.
155
+ automatic_function_calling_history[len(self._curated_history):]
147
156
  if automatic_function_calling_history
148
157
  else [user_input]
149
158
  )
@@ -158,14 +167,13 @@ class _BaseChat:
158
167
  self._curated_history.extend(input_contents)
159
168
  self._curated_history.extend(output_contents)
160
169
 
161
-
162
170
  def get_history(self, curated: bool = False) -> list[Content]:
163
171
  """Returns the chat history.
164
172
 
165
173
  Args:
166
- curated: A boolean flag indicating whether to return the curated
167
- (valid) history or the comprehensive (all turns) history.
168
- Defaults to False (returns the comprehensive history).
174
+ curated: A boolean flag indicating whether to return the curated (valid)
175
+ history or the comprehensive (all turns) history. Defaults to False
176
+ (returns the comprehensive history).
169
177
 
170
178
  Returns:
171
179
  A list of `Content` objects representing the chat history.
@@ -176,9 +184,41 @@ class _BaseChat:
176
184
  return self._comprehensive_history
177
185
 
178
186
 
187
+ def _is_part_type(
188
+ contents: Union[list[PartUnionDict], PartUnionDict],
189
+ ) -> TypeGuard[t.ContentType]:
190
+ if isinstance(contents, list):
191
+ return all(_is_part_type(part) for part in contents)
192
+ else:
193
+ allowed_part_types = get_args(types.PartUnion)
194
+ if type(contents) in allowed_part_types:
195
+ return True
196
+ else:
197
+ # Some images don't pass isinstance(item, PIL.Image.Image)
198
+ # For example <class 'PIL.JpegImagePlugin.JpegImageFile'>
199
+ if types.PIL_Image is not None and isinstance(contents, types.PIL_Image):
200
+ return True
201
+ return False
202
+
203
+
179
204
  class Chat(_BaseChat):
180
205
  """Chat session."""
181
206
 
207
+ def __init__(
208
+ self,
209
+ *,
210
+ modules: Models,
211
+ model: str,
212
+ config: Optional[GenerateContentConfigOrDict] = None,
213
+ history: list[Content],
214
+ ):
215
+ self._modules = modules
216
+ super().__init__(
217
+ model=model,
218
+ config=config,
219
+ history=history,
220
+ )
221
+
182
222
  def send_message(
183
223
  self,
184
224
  message: Union[list[PartUnionDict], PartUnionDict],
@@ -202,10 +242,15 @@ class Chat(_BaseChat):
202
242
  response = chat.send_message('tell me a story')
203
243
  """
204
244
 
245
+ if not _is_part_type(message):
246
+ raise ValueError(
247
+ f"Message must be a valid part type: {types.PartUnion} or"
248
+ f" {types.PartUnionDict}, got {type(message)}"
249
+ )
205
250
  input_content = t.t_content(self._modules._api_client, message)
206
251
  response = self._modules.generate_content(
207
252
  model=self._model,
208
- contents=self._curated_history + [input_content],
253
+ contents=self._curated_history + [input_content], # type: ignore[arg-type]
209
254
  config=config if config else self._config,
210
255
  )
211
256
  model_output = (
@@ -213,10 +258,15 @@ class Chat(_BaseChat):
213
258
  if response.candidates and response.candidates[0].content
214
259
  else []
215
260
  )
261
+ automatic_function_calling_history = (
262
+ response.automatic_function_calling_history
263
+ if response.automatic_function_calling_history
264
+ else []
265
+ )
216
266
  self.record_history(
217
267
  user_input=input_content,
218
268
  model_output=model_output,
219
- automatic_function_calling_history=response.automatic_function_calling_history,
269
+ automatic_function_calling_history=automatic_function_calling_history,
220
270
  is_valid=_validate_response(response),
221
271
  )
222
272
  return response
@@ -245,29 +295,42 @@ class Chat(_BaseChat):
245
295
  print(chunk.text)
246
296
  """
247
297
 
298
+ if not _is_part_type(message):
299
+ raise ValueError(
300
+ f"Message must be a valid part type: {types.PartUnion} or"
301
+ f" {types.PartUnionDict}, got {type(message)}"
302
+ )
248
303
  input_content = t.t_content(self._modules._api_client, message)
249
304
  output_contents = []
250
305
  finish_reason = None
251
306
  is_valid = True
252
307
  chunk = None
253
- for chunk in self._modules.generate_content_stream(
254
- model=self._model,
255
- contents=self._curated_history + [input_content],
256
- config=config if config else self._config,
257
- ):
258
- if not _validate_response(chunk):
259
- is_valid = False
260
- if chunk.candidates and chunk.candidates[0].content:
261
- output_contents.append(chunk.candidates[0].content)
262
- if chunk.candidates and chunk.candidates[0].finish_reason:
263
- finish_reason = chunk.candidates[0].finish_reason
264
- yield chunk
265
- self.record_history(
266
- user_input=input_content,
267
- model_output=output_contents,
268
- automatic_function_calling_history=chunk.automatic_function_calling_history,
269
- is_valid=is_valid and output_contents and finish_reason,
270
- )
308
+ if isinstance(self._modules, Models):
309
+ for chunk in self._modules.generate_content_stream(
310
+ model=self._model,
311
+ contents=self._curated_history + [input_content], # type: ignore[arg-type]
312
+ config=config if config else self._config,
313
+ ):
314
+ if not _validate_response(chunk):
315
+ is_valid = False
316
+ if chunk.candidates and chunk.candidates[0].content:
317
+ output_contents.append(chunk.candidates[0].content)
318
+ if chunk.candidates and chunk.candidates[0].finish_reason:
319
+ finish_reason = chunk.candidates[0].finish_reason
320
+ yield chunk
321
+ automatic_function_calling_history = (
322
+ chunk.automatic_function_calling_history
323
+ if chunk.automatic_function_calling_history
324
+ else []
325
+ )
326
+ self.record_history(
327
+ user_input=input_content,
328
+ model_output=output_contents,
329
+ automatic_function_calling_history=automatic_function_calling_history,
330
+ is_valid=is_valid
331
+ and output_contents is not None
332
+ and finish_reason is not None,
333
+ )
271
334
 
272
335
 
273
336
  class Chats:
@@ -304,6 +367,21 @@ class Chats:
304
367
  class AsyncChat(_BaseChat):
305
368
  """Async chat session."""
306
369
 
370
+ def __init__(
371
+ self,
372
+ *,
373
+ modules: AsyncModels,
374
+ model: str,
375
+ config: Optional[GenerateContentConfigOrDict] = None,
376
+ history: list[Content],
377
+ ):
378
+ self._modules = modules
379
+ super().__init__(
380
+ model=model,
381
+ config=config,
382
+ history=history,
383
+ )
384
+
307
385
  async def send_message(
308
386
  self,
309
387
  message: Union[list[PartUnionDict], PartUnionDict],
@@ -326,11 +404,15 @@ class AsyncChat(_BaseChat):
326
404
  chat = client.aio.chats.create(model='gemini-1.5-flash')
327
405
  response = await chat.send_message('tell me a story')
328
406
  """
329
-
407
+ if not _is_part_type(message):
408
+ raise ValueError(
409
+ f"Message must be a valid part type: {types.PartUnion} or"
410
+ f" {types.PartUnionDict}, got {type(message)}"
411
+ )
330
412
  input_content = t.t_content(self._modules._api_client, message)
331
413
  response = await self._modules.generate_content(
332
414
  model=self._model,
333
- contents=self._curated_history + [input_content],
415
+ contents=self._curated_history + [input_content], # type: ignore[arg-type]
334
416
  config=config if config else self._config,
335
417
  )
336
418
  model_output = (
@@ -338,10 +420,15 @@ class AsyncChat(_BaseChat):
338
420
  if response.candidates and response.candidates[0].content
339
421
  else []
340
422
  )
423
+ automatic_function_calling_history = (
424
+ response.automatic_function_calling_history
425
+ if response.automatic_function_calling_history
426
+ else []
427
+ )
341
428
  self.record_history(
342
429
  user_input=input_content,
343
430
  model_output=model_output,
344
- automatic_function_calling_history=response.automatic_function_calling_history,
431
+ automatic_function_calling_history=automatic_function_calling_history,
345
432
  is_valid=_validate_response(response),
346
433
  )
347
434
  return response
@@ -369,6 +456,11 @@ class AsyncChat(_BaseChat):
369
456
  print(chunk.text)
370
457
  """
371
458
 
459
+ if not _is_part_type(message):
460
+ raise ValueError(
461
+ f"Message must be a valid part type: {types.PartUnion} or"
462
+ f" {types.PartUnionDict}, got {type(message)}"
463
+ )
372
464
  input_content = t.t_content(self._modules._api_client, message)
373
465
 
374
466
  async def async_generator():
@@ -394,7 +486,6 @@ class AsyncChat(_BaseChat):
394
486
  model_output=output_contents,
395
487
  automatic_function_calling_history=chunk.automatic_function_calling_history,
396
488
  is_valid=is_valid and output_contents and finish_reason,
397
-
398
489
  )
399
490
  return async_generator()
400
491
 
google/genai/files.py CHANGED
@@ -888,13 +888,13 @@ class Files(_api_module.BaseModule):
888
888
 
889
889
  if (
890
890
  response.http_headers is None
891
- or 'X-Goog-Upload-URL' not in response.http_headers
891
+ or 'x-goog-upload-url' not in response.http_headers
892
892
  ):
893
893
  raise KeyError(
894
894
  'Failed to create file. Upload URL did not returned from the create'
895
895
  ' file request.'
896
896
  )
897
- upload_url = response.http_headers['X-Goog-Upload-URL']
897
+ upload_url = response.http_headers['x-goog-upload-url']
898
898
 
899
899
  if isinstance(file, io.IOBase):
900
900
  return_file = self._api_client.upload_file(
google/genai/models.py CHANGED
@@ -4605,7 +4605,7 @@ class Models(_api_module.BaseModule):
4605
4605
  self._api_client._verify_response(return_value)
4606
4606
  return return_value
4607
4607
 
4608
- def edit_image(
4608
+ def _edit_image(
4609
4609
  self,
4610
4610
  *,
4611
4611
  model: str,
@@ -5558,6 +5558,62 @@ class Models(_api_module.BaseModule):
5558
5558
  automatic_function_calling_history.append(func_call_content)
5559
5559
  automatic_function_calling_history.append(func_response_content)
5560
5560
 
5561
+ def edit_image(
5562
+ self,
5563
+ *,
5564
+ model: str,
5565
+ prompt: str,
5566
+ reference_images: list[types._ReferenceImageAPIOrDict],
5567
+ config: Optional[types.EditImageConfigOrDict] = None,
5568
+ ) -> types.EditImageResponse:
5569
+ """Edits an image based on a text description and configuration.
5570
+
5571
+ Args:
5572
+ model (str): The model to use.
5573
+ prompt (str): A text description of the edit to apply to the image.
5574
+ reference_images (list[Union[RawReferenceImage, MaskReferenceImage,
5575
+ ControlReferenceImage, StyleReferenceImage, SubjectReferenceImage]): The
5576
+ reference images for editing.
5577
+ config (EditImageConfig): Configuration for editing.
5578
+
5579
+ Usage:
5580
+
5581
+ .. code-block:: python
5582
+
5583
+ from google.genai.types import RawReferenceImage, MaskReferenceImage
5584
+
5585
+ raw_ref_image = RawReferenceImage(
5586
+ reference_id=1,
5587
+ reference_image=types.Image.from_file(IMAGE_FILE_PATH),
5588
+ )
5589
+
5590
+ mask_ref_image = MaskReferenceImage(
5591
+ reference_id=2,
5592
+ config=types.MaskReferenceConfig(
5593
+ mask_mode='MASK_MODE_FOREGROUND',
5594
+ mask_dilation=0.06,
5595
+ ),
5596
+ )
5597
+ response = client.models.edit_image(
5598
+ model='imagen-3.0-capability-001',
5599
+ prompt='man with dog',
5600
+ reference_images=[raw_ref_image, mask_ref_image],
5601
+ config=types.EditImageConfig(
5602
+ edit_mode= "EDIT_MODE_INPAINT_INSERTION",
5603
+ number_of_images= 1,
5604
+ include_rai_reason= True,
5605
+ )
5606
+ )
5607
+ response.generated_images[0].image.show()
5608
+ # Shows a man with a dog instead of a cat.
5609
+ """
5610
+ return self._edit_image(
5611
+ model=model,
5612
+ prompt=prompt,
5613
+ reference_images=reference_images,
5614
+ config=config,
5615
+ )
5616
+
5561
5617
  def upscale_image(
5562
5618
  self,
5563
5619
  *,
@@ -5990,7 +6046,7 @@ class AsyncModels(_api_module.BaseModule):
5990
6046
  self._api_client._verify_response(return_value)
5991
6047
  return return_value
5992
6048
 
5993
- async def edit_image(
6049
+ async def _edit_image(
5994
6050
  self,
5995
6051
  *,
5996
6052
  model: str,
@@ -6923,6 +6979,62 @@ class AsyncModels(_api_module.BaseModule):
6923
6979
 
6924
6980
  return async_generator(model, contents, config)
6925
6981
 
6982
+ async def edit_image(
6983
+ self,
6984
+ *,
6985
+ model: str,
6986
+ prompt: str,
6987
+ reference_images: list[types._ReferenceImageAPIOrDict],
6988
+ config: Optional[types.EditImageConfigOrDict] = None,
6989
+ ) -> types.EditImageResponse:
6990
+ """Edits an image based on a text description and configuration.
6991
+
6992
+ Args:
6993
+ model (str): The model to use.
6994
+ prompt (str): A text description of the edit to apply to the image.
6995
+ reference_images (list[Union[RawReferenceImage, MaskReferenceImage,
6996
+ ControlReferenceImage, StyleReferenceImage, SubjectReferenceImage]): The
6997
+ reference images for editing.
6998
+ config (EditImageConfig): Configuration for editing.
6999
+
7000
+ Usage:
7001
+
7002
+ .. code-block:: python
7003
+
7004
+ from google.genai.types import RawReferenceImage, MaskReferenceImage
7005
+
7006
+ raw_ref_image = RawReferenceImage(
7007
+ reference_id=1,
7008
+ reference_image=types.Image.from_file(IMAGE_FILE_PATH),
7009
+ )
7010
+
7011
+ mask_ref_image = MaskReferenceImage(
7012
+ reference_id=2,
7013
+ config=types.MaskReferenceConfig(
7014
+ mask_mode='MASK_MODE_FOREGROUND',
7015
+ mask_dilation=0.06,
7016
+ ),
7017
+ )
7018
+ response = await client.aio.models.edit_image(
7019
+ model='imagen-3.0-capability-001',
7020
+ prompt='man with dog',
7021
+ reference_images=[raw_ref_image, mask_ref_image],
7022
+ config=types.EditImageConfig(
7023
+ edit_mode= "EDIT_MODE_INPAINT_INSERTION",
7024
+ number_of_images= 1,
7025
+ include_rai_reason= True,
7026
+ )
7027
+ )
7028
+ response.generated_images[0].image.show()
7029
+ # Shows a man with a dog instead of a cat.
7030
+ """
7031
+ return await self._edit_image(
7032
+ model=model,
7033
+ prompt=prompt,
7034
+ reference_images=reference_images,
7035
+ config=config,
7036
+ )
7037
+
6926
7038
  async def list(
6927
7039
  self,
6928
7040
  *,
google/genai/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  #
15
15
 
16
- __version__ = '1.6.0' # x-release-please-version
16
+ __version__ = '1.7.0' # x-release-please-version
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: google-genai
3
- Version: 1.6.0
3
+ Version: 1.7.0
4
4
  Summary: GenAI Python SDK
5
5
  Author-email: Google LLC <googleapis-packages@google.com>
6
6
  License: Apache-2.0
@@ -25,7 +25,7 @@ Requires-Dist: google-auth<3.0.0,>=2.14.1
25
25
  Requires-Dist: httpx<1.0.0,>=0.28.1
26
26
  Requires-Dist: pydantic<3.0.0,>=2.0.0
27
27
  Requires-Dist: requests<3.0.0,>=2.28.1
28
- Requires-Dist: websockets<15.0.0,>=13.0.0
28
+ Requires-Dist: websockets<15.1.0,>=13.0.0
29
29
  Requires-Dist: typing-extensions<5.0.0,>=4.11.0
30
30
 
31
31
  # Google Gen AI SDK
@@ -41,7 +41,7 @@ Google Gen AI Python SDK provides an interface for developers to integrate Googl
41
41
 
42
42
  ## Installation
43
43
 
44
- ```cmd
44
+ ```sh
45
45
  pip install google-genai
46
46
  ```
47
47
 
@@ -72,16 +72,16 @@ client = genai.Client(
72
72
  **(Optional) Using environment variables:**
73
73
 
74
74
  You can create a client by configuring the necessary environment variables.
75
- Configuration setup instructions depends on whether you're using the Gemini API
76
- on Vertex AI or the ML Dev Gemini API.
75
+ Configuration setup instructions depends on whether you're using the Gemini
76
+ Developer API or the Gemini API in Vertex AI.
77
77
 
78
- **ML Dev Gemini API:** Set `GOOGLE_API_KEY` as shown below:
78
+ **Gemini Developer API:** Set `GOOGLE_API_KEY` as shown below:
79
79
 
80
80
  ```bash
81
81
  export GOOGLE_API_KEY='your-api-key'
82
82
  ```
83
83
 
84
- **Vertex AI API:** Set `GOOGLE_GENAI_USE_VERTEXAI`, `GOOGLE_CLOUD_PROJECT`
84
+ **Gemini API on Vertex AI:** Set `GOOGLE_GENAI_USE_VERTEXAI`, `GOOGLE_CLOUD_PROJECT`
85
85
  and `GOOGLE_CLOUD_LOCATION`, as shown below:
86
86
 
87
87
  ```bash
@@ -142,7 +142,7 @@ response = client.models.generate_content(
142
142
  print(response.text)
143
143
  ```
144
144
 
145
- #### with uploaded file (Gemini API only)
145
+ #### with uploaded file (Gemini Developer API only)
146
146
  download the file in console.
147
147
 
148
148
  ```sh
@@ -347,6 +347,7 @@ The SDK will convert the list of parts into a content with a `user` role
347
347
  ```
348
348
 
349
349
  ##### Mix types in contents
350
+
350
351
  You can also provide a list of `types.ContentUnion`. The SDK leaves items of
351
352
  `types.Content` as is, it groups consecutive non function call parts into a
352
353
  single `types.UserContent`, and it groups consecutive function call parts into
@@ -1,27 +1,27 @@
1
1
  google/genai/__init__.py,sha256=IYw-PcsdgjSpS1mU_ZcYkTfPocsJ4aVmrDxP7vX7c6Y,709
2
- google/genai/_api_client.py,sha256=xEme7KhIrp5lCDlde_HECGzH4TppepR8YraSDiGvhPc,30593
2
+ google/genai/_api_client.py,sha256=X-ULRU6ZfI6WPsWFj3SG_6xncnjj6kK9DtUguReMzbE,31280
3
3
  google/genai/_api_module.py,sha256=66FsFq9N8PdTegDyx3am3NHpI0Bw7HBmifUMCrZsx_Q,902
4
4
  google/genai/_automatic_function_calling_util.py,sha256=xAH-96LIEmC-yefEIae8TrBPZZAI1UJrn0bAIZsISDE,10899
5
- google/genai/_common.py,sha256=u0qX3Uli_7qaYmoTZm9JnVzoMihD4ASPksmqmsjGSBs,10071
5
+ google/genai/_common.py,sha256=PNwxVUKCD93ICHJlwCTAItGH3Wjva5xHC7_7mc-p8oA,10153
6
6
  google/genai/_extra_utils.py,sha256=l9U0uaq4TWdfY7fOpGR3LcsA6-TMEblfQlEXdC0IGPY,12462
7
- google/genai/_replay_api_client.py,sha256=OxZIAyyyI4D9uj0avNO0QGf4DPWJ4Pqf_MCbUs5pvYc,18459
7
+ google/genai/_replay_api_client.py,sha256=_TH5E_hBu35hOJajjpUagsdhiBAbOTWWZDyP43bDjJE,18606
8
8
  google/genai/_test_api_client.py,sha256=XNOWq8AkYbqInv1aljNGlFXsv8slQIWTYy_hdcCetD0,4797
9
9
  google/genai/_transformers.py,sha256=9cVaRp1zLOG27D7iNJeW2rw2jbjlvuEUeVl5SUN6ljY,29863
10
10
  google/genai/batches.py,sha256=K6RgkNWkYBjknADWe4hrv6BtXxWTl1N8b91Pg9MUnAc,41545
11
11
  google/genai/caches.py,sha256=JymnKSaSZYGyTl201tR8PgbZF6fRKdXuCKe4BILKkTc,57551
12
- google/genai/chats.py,sha256=ds5iF4hqvyHbHE4OlP1b5s93SwD0hlMNpWxT7db2E48,13493
12
+ google/genai/chats.py,sha256=uv75f2uWa2cjgkZINCS9SYEH120hM61aZU8Ylcawuec,16379
13
13
  google/genai/client.py,sha256=jN4oNT5qQtX0UILuGcZqxmIL66DOJ-T5P086ygnSnSg,9877
14
14
  google/genai/errors.py,sha256=p_JbOU_eDKIIvWT6NBYGpZcxww622ChAi6eX1FuKKY0,3874
15
- google/genai/files.py,sha256=N9TQYyuHRQm4Gb2agjyOlWFljwK-hdPUr-mP0vF0Bc8,45532
15
+ google/genai/files.py,sha256=po136tfi5tRmMXG01S-NgxDwX-IpzPaXS9cJBWUwxUU,45532
16
16
  google/genai/live.py,sha256=Ftj_LxQ2zClK-2hbdRZNXkmnQQguduoNyVntIdPtTdM,32033
17
- google/genai/models.py,sha256=1ZlRfRqEl2OvaMvU1nIosGhX7g35A-hupUkrAYqwRpg,210141
17
+ google/genai/models.py,sha256=3Y0r-0jRP0AN6iU_k7in05u9N9R2ieUtvzc4qfXbWBY,213651
18
18
  google/genai/operations.py,sha256=Tvlbk_AuQyXGL9b0wJbzgC8QGHGLSVjb1evbe-ZuZN0,20781
19
19
  google/genai/pagers.py,sha256=1jxDjre7M_Udt0ntgOr_79iR0-axjdr_Q6tZZzVRli8,6784
20
20
  google/genai/tunings.py,sha256=2tWvIkofDDnt6VPsfHAbcrfDIo7jAbIlpdwcx6-cUvM,48239
21
21
  google/genai/types.py,sha256=sVOdjxmviV_oA5JQEKKOlg4nwrlTQjbWGqtgqQACP6Y,299883
22
- google/genai/version.py,sha256=5_cjSe7IVSPjk4bpIyj-Lx9SCIXuPbulXr4yMMQgdT8,626
23
- google_genai-1.6.0.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
24
- google_genai-1.6.0.dist-info/METADATA,sha256=_zLVCqlXM9ZDEOPgINpGTkytCEN8pFJ_7Gg4DyPKpJQ,32842
25
- google_genai-1.6.0.dist-info/WHEEL,sha256=52BFRY2Up02UkjOa29eZOS2VxUrpPORXg1pkohGGUS8,91
26
- google_genai-1.6.0.dist-info/top_level.txt,sha256=_1QvSJIhFAGfxb79D6DhB7SUw2X6T4rwnz_LLrbcD3c,7
27
- google_genai-1.6.0.dist-info/RECORD,,
22
+ google/genai/version.py,sha256=gMEV9Q_hiFtL9MTW8j-xSs_8ZfrVk5l9IpAHot0qkCA,626
23
+ google_genai-1.7.0.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
24
+ google_genai-1.7.0.dist-info/METADATA,sha256=DrE2OdIur-Cpp82AkcRds2mFxYbDEYjtWXXseRSM6wY,32868
25
+ google_genai-1.7.0.dist-info/WHEEL,sha256=beeZ86-EfXScwlR_HKu4SllMC9wUEj_8Z_4FJ3egI2w,91
26
+ google_genai-1.7.0.dist-info/top_level.txt,sha256=_1QvSJIhFAGfxb79D6DhB7SUw2X6T4rwnz_LLrbcD3c,7
27
+ google_genai-1.7.0.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (76.0.0)
2
+ Generator: setuptools (76.1.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5