google-genai 1.5.0__py3-none-any.whl → 1.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
google/genai/chats.py CHANGED
@@ -13,12 +13,19 @@
13
13
  # limitations under the License.
14
14
  #
15
15
 
16
- from typing import AsyncIterator, Awaitable, Optional
17
- from typing import Union
16
+ import sys
17
+ from typing import AsyncIterator, Awaitable, Optional, Union, get_args
18
18
 
19
19
  from . import _transformers as t
20
+ from . import types
20
21
  from .models import AsyncModels, Models
21
- from .types import Content, ContentDict, GenerateContentConfigOrDict, GenerateContentResponse, Part, PartUnionDict
22
+ from .types import Content, GenerateContentConfigOrDict, GenerateContentResponse, Part, PartUnionDict
23
+
24
+
25
+ if sys.version_info >= (3, 10):
26
+ from typing import TypeGuard
27
+ else:
28
+ from typing_extensions import TypeGuard
22
29
 
23
30
 
24
31
  def _validate_content(content: Content) -> bool:
@@ -81,8 +88,7 @@ def _extract_curated_history(
81
88
  while i < length:
82
89
  if comprehensive_history[i].role not in ["user", "model"]:
83
90
  raise ValueError(
84
- "Role must be user or model, but got"
85
- f" {comprehensive_history[i].role}"
91
+ f"Role must be user or model, but got {comprehensive_history[i].role}"
86
92
  )
87
93
 
88
94
  if comprehensive_history[i].role == "user":
@@ -108,12 +114,10 @@ class _BaseChat:
108
114
  def __init__(
109
115
  self,
110
116
  *,
111
- modules: Union[Models, AsyncModels],
112
117
  model: str,
113
118
  config: Optional[GenerateContentConfigOrDict] = None,
114
119
  history: list[Content],
115
120
  ):
116
- self._modules = modules
117
121
  self._model = model
118
122
  self._config = config
119
123
  self._comprehensive_history = history
@@ -123,27 +127,32 @@ class _BaseChat:
123
127
  """Curated history is the set of valid turns that will be used in the subsequent send requests.
124
128
  """
125
129
 
126
-
127
- def record_history(self, user_input: Content,
128
- model_output: list[Content],
129
- automatic_function_calling_history: list[Content],
130
- is_valid: bool):
130
+ def record_history(
131
+ self,
132
+ user_input: Content,
133
+ model_output: list[Content],
134
+ automatic_function_calling_history: list[Content],
135
+ is_valid: bool,
136
+ ):
131
137
  """Records the chat history.
132
138
 
133
139
  Maintaining both comprehensive and curated histories.
134
140
 
135
141
  Args:
136
142
  user_input: The user's input content.
137
- model_output: A list of `Content` from the model's response.
138
- This can be an empty list if the model produced no output.
139
- automatic_function_calling_history: A list of `Content` representing
140
- the history of automatic function calls, including the user input as
141
- the first entry.
143
+ model_output: A list of `Content` from the model's response. This can be
144
+ an empty list if the model produced no output.
145
+ automatic_function_calling_history: A list of `Content` representing the
146
+ history of automatic function calls, including the user input as the
147
+ first entry.
142
148
  is_valid: A boolean flag indicating whether the current model output is
143
149
  considered valid.
144
150
  """
145
151
  input_contents = (
146
- automatic_function_calling_history
152
+ # Because the AFC input contains the entire curated chat history in
153
+ # addition to the new user input, we need to truncate the AFC history
154
+ # to deduplicate the existing chat history.
155
+ automatic_function_calling_history[len(self._curated_history):]
147
156
  if automatic_function_calling_history
148
157
  else [user_input]
149
158
  )
@@ -158,14 +167,13 @@ class _BaseChat:
158
167
  self._curated_history.extend(input_contents)
159
168
  self._curated_history.extend(output_contents)
160
169
 
161
-
162
170
  def get_history(self, curated: bool = False) -> list[Content]:
163
171
  """Returns the chat history.
164
172
 
165
173
  Args:
166
- curated: A boolean flag indicating whether to return the curated
167
- (valid) history or the comprehensive (all turns) history.
168
- Defaults to False (returns the comprehensive history).
174
+ curated: A boolean flag indicating whether to return the curated (valid)
175
+ history or the comprehensive (all turns) history. Defaults to False
176
+ (returns the comprehensive history).
169
177
 
170
178
  Returns:
171
179
  A list of `Content` objects representing the chat history.
@@ -176,9 +184,41 @@ class _BaseChat:
176
184
  return self._comprehensive_history
177
185
 
178
186
 
187
+ def _is_part_type(
188
+ contents: Union[list[PartUnionDict], PartUnionDict],
189
+ ) -> TypeGuard[t.ContentType]:
190
+ if isinstance(contents, list):
191
+ return all(_is_part_type(part) for part in contents)
192
+ else:
193
+ allowed_part_types = get_args(types.PartUnion)
194
+ if type(contents) in allowed_part_types:
195
+ return True
196
+ else:
197
+ # Some images don't pass isinstance(item, PIL.Image.Image)
198
+ # For example <class 'PIL.JpegImagePlugin.JpegImageFile'>
199
+ if types.PIL_Image is not None and isinstance(contents, types.PIL_Image):
200
+ return True
201
+ return False
202
+
203
+
179
204
  class Chat(_BaseChat):
180
205
  """Chat session."""
181
206
 
207
+ def __init__(
208
+ self,
209
+ *,
210
+ modules: Models,
211
+ model: str,
212
+ config: Optional[GenerateContentConfigOrDict] = None,
213
+ history: list[Content],
214
+ ):
215
+ self._modules = modules
216
+ super().__init__(
217
+ model=model,
218
+ config=config,
219
+ history=history,
220
+ )
221
+
182
222
  def send_message(
183
223
  self,
184
224
  message: Union[list[PartUnionDict], PartUnionDict],
@@ -202,10 +242,15 @@ class Chat(_BaseChat):
202
242
  response = chat.send_message('tell me a story')
203
243
  """
204
244
 
245
+ if not _is_part_type(message):
246
+ raise ValueError(
247
+ f"Message must be a valid part type: {types.PartUnion} or"
248
+ f" {types.PartUnionDict}, got {type(message)}"
249
+ )
205
250
  input_content = t.t_content(self._modules._api_client, message)
206
251
  response = self._modules.generate_content(
207
252
  model=self._model,
208
- contents=self._curated_history + [input_content],
253
+ contents=self._curated_history + [input_content], # type: ignore[arg-type]
209
254
  config=config if config else self._config,
210
255
  )
211
256
  model_output = (
@@ -213,10 +258,15 @@ class Chat(_BaseChat):
213
258
  if response.candidates and response.candidates[0].content
214
259
  else []
215
260
  )
261
+ automatic_function_calling_history = (
262
+ response.automatic_function_calling_history
263
+ if response.automatic_function_calling_history
264
+ else []
265
+ )
216
266
  self.record_history(
217
267
  user_input=input_content,
218
268
  model_output=model_output,
219
- automatic_function_calling_history=response.automatic_function_calling_history,
269
+ automatic_function_calling_history=automatic_function_calling_history,
220
270
  is_valid=_validate_response(response),
221
271
  )
222
272
  return response
@@ -245,29 +295,42 @@ class Chat(_BaseChat):
245
295
  print(chunk.text)
246
296
  """
247
297
 
298
+ if not _is_part_type(message):
299
+ raise ValueError(
300
+ f"Message must be a valid part type: {types.PartUnion} or"
301
+ f" {types.PartUnionDict}, got {type(message)}"
302
+ )
248
303
  input_content = t.t_content(self._modules._api_client, message)
249
304
  output_contents = []
250
305
  finish_reason = None
251
306
  is_valid = True
252
307
  chunk = None
253
- for chunk in self._modules.generate_content_stream(
254
- model=self._model,
255
- contents=self._curated_history + [input_content],
256
- config=config if config else self._config,
257
- ):
258
- if not _validate_response(chunk):
259
- is_valid = False
260
- if chunk.candidates and chunk.candidates[0].content:
261
- output_contents.append(chunk.candidates[0].content)
262
- if chunk.candidates and chunk.candidates[0].finish_reason:
263
- finish_reason = chunk.candidates[0].finish_reason
264
- yield chunk
265
- self.record_history(
266
- user_input=input_content,
267
- model_output=output_contents,
268
- automatic_function_calling_history=chunk.automatic_function_calling_history,
269
- is_valid=is_valid and output_contents and finish_reason,
270
- )
308
+ if isinstance(self._modules, Models):
309
+ for chunk in self._modules.generate_content_stream(
310
+ model=self._model,
311
+ contents=self._curated_history + [input_content], # type: ignore[arg-type]
312
+ config=config if config else self._config,
313
+ ):
314
+ if not _validate_response(chunk):
315
+ is_valid = False
316
+ if chunk.candidates and chunk.candidates[0].content:
317
+ output_contents.append(chunk.candidates[0].content)
318
+ if chunk.candidates and chunk.candidates[0].finish_reason:
319
+ finish_reason = chunk.candidates[0].finish_reason
320
+ yield chunk
321
+ automatic_function_calling_history = (
322
+ chunk.automatic_function_calling_history
323
+ if chunk.automatic_function_calling_history
324
+ else []
325
+ )
326
+ self.record_history(
327
+ user_input=input_content,
328
+ model_output=output_contents,
329
+ automatic_function_calling_history=automatic_function_calling_history,
330
+ is_valid=is_valid
331
+ and output_contents is not None
332
+ and finish_reason is not None,
333
+ )
271
334
 
272
335
 
273
336
  class Chats:
@@ -304,6 +367,21 @@ class Chats:
304
367
  class AsyncChat(_BaseChat):
305
368
  """Async chat session."""
306
369
 
370
+ def __init__(
371
+ self,
372
+ *,
373
+ modules: AsyncModels,
374
+ model: str,
375
+ config: Optional[GenerateContentConfigOrDict] = None,
376
+ history: list[Content],
377
+ ):
378
+ self._modules = modules
379
+ super().__init__(
380
+ model=model,
381
+ config=config,
382
+ history=history,
383
+ )
384
+
307
385
  async def send_message(
308
386
  self,
309
387
  message: Union[list[PartUnionDict], PartUnionDict],
@@ -326,11 +404,15 @@ class AsyncChat(_BaseChat):
326
404
  chat = client.aio.chats.create(model='gemini-1.5-flash')
327
405
  response = await chat.send_message('tell me a story')
328
406
  """
329
-
407
+ if not _is_part_type(message):
408
+ raise ValueError(
409
+ f"Message must be a valid part type: {types.PartUnion} or"
410
+ f" {types.PartUnionDict}, got {type(message)}"
411
+ )
330
412
  input_content = t.t_content(self._modules._api_client, message)
331
413
  response = await self._modules.generate_content(
332
414
  model=self._model,
333
- contents=self._curated_history + [input_content],
415
+ contents=self._curated_history + [input_content], # type: ignore[arg-type]
334
416
  config=config if config else self._config,
335
417
  )
336
418
  model_output = (
@@ -338,10 +420,15 @@ class AsyncChat(_BaseChat):
338
420
  if response.candidates and response.candidates[0].content
339
421
  else []
340
422
  )
423
+ automatic_function_calling_history = (
424
+ response.automatic_function_calling_history
425
+ if response.automatic_function_calling_history
426
+ else []
427
+ )
341
428
  self.record_history(
342
429
  user_input=input_content,
343
430
  model_output=model_output,
344
- automatic_function_calling_history=response.automatic_function_calling_history,
431
+ automatic_function_calling_history=automatic_function_calling_history,
345
432
  is_valid=_validate_response(response),
346
433
  )
347
434
  return response
@@ -369,6 +456,11 @@ class AsyncChat(_BaseChat):
369
456
  print(chunk.text)
370
457
  """
371
458
 
459
+ if not _is_part_type(message):
460
+ raise ValueError(
461
+ f"Message must be a valid part type: {types.PartUnion} or"
462
+ f" {types.PartUnionDict}, got {type(message)}"
463
+ )
372
464
  input_content = t.t_content(self._modules._api_client, message)
373
465
 
374
466
  async def async_generator():
@@ -394,7 +486,6 @@ class AsyncChat(_BaseChat):
394
486
  model_output=output_contents,
395
487
  automatic_function_calling_history=chunk.automatic_function_calling_history,
396
488
  is_valid=is_valid and output_contents and finish_reason,
397
-
398
489
  )
399
490
  return async_generator()
400
491
 
google/genai/client.py CHANGED
@@ -130,8 +130,9 @@ class Client:
130
130
  from environment variables. Applies to the Vertex AI API only.
131
131
  debug_config: Config settings that control network behavior of the client.
132
132
  This is typically used when running test code.
133
- http_options: Http options to use for the client. Response_payload can't be
134
- set when passing to the client constructor.
133
+ http_options: Http options to use for the client. These options will be
134
+ applied to all requests made by the client. Example usage:
135
+ `client = genai.Client(http_options=types.HttpOptions(api_version='v1'))`.
135
136
 
136
137
  Usage for the Gemini Developer API:
137
138
 
google/genai/errors.py CHANGED
@@ -18,7 +18,6 @@
18
18
  from typing import Any, Optional, TYPE_CHECKING, Union
19
19
  import httpx
20
20
  import json
21
- import requests
22
21
 
23
22
 
24
23
  if TYPE_CHECKING:
@@ -28,7 +27,7 @@ if TYPE_CHECKING:
28
27
  class APIError(Exception):
29
28
  """General errors raised by the GenAI API."""
30
29
  code: int
31
- response: Union[requests.Response, 'ReplayResponse', httpx.Response]
30
+ response: Union['ReplayResponse', httpx.Response]
32
31
 
33
32
  status: Optional[str] = None
34
33
  message: Optional[str] = None
@@ -36,28 +35,21 @@ class APIError(Exception):
36
35
  def __init__(
37
36
  self,
38
37
  code: int,
39
- response: Union[requests.Response, 'ReplayResponse', httpx.Response],
38
+ response: Union['ReplayResponse', httpx.Response],
40
39
  ):
41
40
  self.response = response
42
-
43
- if isinstance(response, requests.Response):
41
+ message = None
42
+ if isinstance(response, httpx.Response):
44
43
  try:
45
- # do not do any extra muanipulation on the response.
46
- # return the raw response json as is.
47
44
  response_json = response.json()
48
- except requests.exceptions.JSONDecodeError:
45
+ except (json.decoder.JSONDecodeError):
46
+ message = response.text
49
47
  response_json = {
50
- 'message': response.text,
51
- 'status': response.reason,
48
+ 'message': message,
49
+ 'status': response.reason_phrase,
52
50
  }
53
- elif isinstance(response, httpx.Response):
54
- try:
55
- response_json = response.json()
56
- except (json.decoder.JSONDecodeError, httpx.ResponseNotRead):
57
- try:
58
- message = response.text
59
- except httpx.ResponseNotRead:
60
- message = None
51
+ except httpx.ResponseNotRead:
52
+ message = 'Response not read'
61
53
  response_json = {
62
54
  'message': message,
63
55
  'status': response.reason_phrase,
@@ -103,7 +95,7 @@ class APIError(Exception):
103
95
 
104
96
  @classmethod
105
97
  def raise_for_response(
106
- cls, response: Union[requests.Response, 'ReplayResponse', httpx.Response]
98
+ cls, response: Union['ReplayResponse', httpx.Response]
107
99
  ):
108
100
  """Raises an error with detailed error message if the response has an error status."""
109
101
  if response.status_code == 200:
google/genai/files.py CHANGED
@@ -826,7 +826,7 @@ class Files(_api_module.BaseModule):
826
826
  'Vertex AI does not support creating files. You can upload files to'
827
827
  ' GCS files instead.'
828
828
  )
829
- config_model = None
829
+ config_model = types.UploadFileConfig()
830
830
  if config:
831
831
  if isinstance(config, dict):
832
832
  config_model = types.UploadFileConfig(**config)
@@ -888,13 +888,13 @@ class Files(_api_module.BaseModule):
888
888
 
889
889
  if (
890
890
  response.http_headers is None
891
- or 'X-Goog-Upload-URL' not in response.http_headers
891
+ or 'x-goog-upload-url' not in response.http_headers
892
892
  ):
893
893
  raise KeyError(
894
894
  'Failed to create file. Upload URL did not returned from the create'
895
895
  ' file request.'
896
896
  )
897
- upload_url = response.http_headers['X-Goog-Upload-URL']
897
+ upload_url = response.http_headers['x-goog-upload-url']
898
898
 
899
899
  if isinstance(file, io.IOBase):
900
900
  return_file = self._api_client.upload_file(
@@ -907,7 +907,7 @@ class Files(_api_module.BaseModule):
907
907
 
908
908
  return types.File._from_response(
909
909
  response=_File_from_mldev(self._api_client, return_file['file']),
910
- kwargs=None,
910
+ kwargs=config_model.model_dump() if config else {},
911
911
  )
912
912
 
913
913
  def list(
@@ -979,7 +979,7 @@ class Files(_api_module.BaseModule):
979
979
  'downloaded. You can tell which files are downloadable by checking '
980
980
  'the `source` or `download_uri` property.'
981
981
  )
982
- name = t.t_file_name(self, file)
982
+ name = t.t_file_name(self._api_client, file)
983
983
 
984
984
  path = f'files/{name}:download'
985
985
 
@@ -996,7 +996,7 @@ class Files(_api_module.BaseModule):
996
996
 
997
997
  if isinstance(file, types.Video):
998
998
  file.video_bytes = data
999
- elif isinstance(file, types.GeneratedVideo):
999
+ elif isinstance(file, types.GeneratedVideo) and file.video is not None:
1000
1000
  file.video.video_bytes = data
1001
1001
 
1002
1002
  return data
@@ -1293,7 +1293,7 @@ class AsyncFiles(_api_module.BaseModule):
1293
1293
  'Vertex AI does not support creating files. You can upload files to'
1294
1294
  ' GCS files instead.'
1295
1295
  )
1296
- config_model = None
1296
+ config_model = types.UploadFileConfig()
1297
1297
  if config:
1298
1298
  if isinstance(config, dict):
1299
1299
  config_model = types.UploadFileConfig(**config)
@@ -1373,7 +1373,7 @@ class AsyncFiles(_api_module.BaseModule):
1373
1373
 
1374
1374
  return types.File._from_response(
1375
1375
  response=_File_from_mldev(self._api_client, return_file['file']),
1376
- kwargs=None,
1376
+ kwargs=config_model.model_dump() if config else {},
1377
1377
  )
1378
1378
 
1379
1379
  async def list(
@@ -1433,7 +1433,7 @@ class AsyncFiles(_api_module.BaseModule):
1433
1433
  else:
1434
1434
  config_model = config
1435
1435
 
1436
- name = t.t_file_name(self, file)
1436
+ name = t.t_file_name(self._api_client, file)
1437
1437
 
1438
1438
  path = f'files/{name}:download'
1439
1439