google-genai 1.6.0__py3-none-any.whl → 1.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
google/genai/chats.py CHANGED
@@ -13,12 +13,19 @@
13
13
  # limitations under the License.
14
14
  #
15
15
 
16
- from typing import AsyncIterator, Awaitable, Optional
17
- from typing import Union
16
+ import sys
17
+ from typing import AsyncIterator, Awaitable, Optional, Union, get_args
18
18
 
19
19
  from . import _transformers as t
20
+ from . import types
20
21
  from .models import AsyncModels, Models
21
- from .types import Content, ContentDict, GenerateContentConfigOrDict, GenerateContentResponse, Part, PartUnionDict
22
+ from .types import Content, ContentOrDict, GenerateContentConfigOrDict, GenerateContentResponse, Part, PartUnionDict
23
+
24
+
25
+ if sys.version_info >= (3, 10):
26
+ from typing import TypeGuard
27
+ else:
28
+ from typing_extensions import TypeGuard
22
29
 
23
30
 
24
31
  def _validate_content(content: Content) -> bool:
@@ -81,8 +88,7 @@ def _extract_curated_history(
81
88
  while i < length:
82
89
  if comprehensive_history[i].role not in ["user", "model"]:
83
90
  raise ValueError(
84
- "Role must be user or model, but got"
85
- f" {comprehensive_history[i].role}"
91
+ f"Role must be user or model, but got {comprehensive_history[i].role}"
86
92
  )
87
93
 
88
94
  if comprehensive_history[i].role == "user":
@@ -108,42 +114,52 @@ class _BaseChat:
108
114
  def __init__(
109
115
  self,
110
116
  *,
111
- modules: Union[Models, AsyncModels],
112
117
  model: str,
113
118
  config: Optional[GenerateContentConfigOrDict] = None,
114
- history: list[Content],
119
+ history: list[ContentOrDict],
115
120
  ):
116
- self._modules = modules
117
121
  self._model = model
118
122
  self._config = config
119
- self._comprehensive_history = history
123
+ content_models = []
124
+ for content in history:
125
+ if not isinstance(content, Content):
126
+ content_model = Content.model_validate(content)
127
+ else:
128
+ content_model = content
129
+ content_models.append(content_model)
130
+ self._comprehensive_history = content_models
120
131
  """Comprehensive history is the full history of the chat, including turns of the invalid contents from the model and their associated inputs.
121
132
  """
122
- self._curated_history = _extract_curated_history(history)
133
+ self._curated_history = _extract_curated_history(content_models)
123
134
  """Curated history is the set of valid turns that will be used in the subsequent send requests.
124
135
  """
125
136
 
126
-
127
- def record_history(self, user_input: Content,
128
- model_output: list[Content],
129
- automatic_function_calling_history: list[Content],
130
- is_valid: bool):
137
+ def record_history(
138
+ self,
139
+ user_input: Content,
140
+ model_output: list[Content],
141
+ automatic_function_calling_history: list[Content],
142
+ is_valid: bool,
143
+ ):
131
144
  """Records the chat history.
132
145
 
133
146
  Maintaining both comprehensive and curated histories.
134
147
 
135
148
  Args:
136
149
  user_input: The user's input content.
137
- model_output: A list of `Content` from the model's response.
138
- This can be an empty list if the model produced no output.
139
- automatic_function_calling_history: A list of `Content` representing
140
- the history of automatic function calls, including the user input as
141
- the first entry.
150
+ model_output: A list of `Content` from the model's response. This can be
151
+ an empty list if the model produced no output.
152
+ automatic_function_calling_history: A list of `Content` representing the
153
+ history of automatic function calls, including the user input as the
154
+ first entry.
142
155
  is_valid: A boolean flag indicating whether the current model output is
143
156
  considered valid.
144
157
  """
145
158
  input_contents = (
146
- automatic_function_calling_history
159
+ # Because the AFC input contains the entire curated chat history in
160
+ # addition to the new user input, we need to truncate the AFC history
161
+ # to deduplicate the existing chat history.
162
+ automatic_function_calling_history[len(self._curated_history):]
147
163
  if automatic_function_calling_history
148
164
  else [user_input]
149
165
  )
@@ -158,14 +174,13 @@ class _BaseChat:
158
174
  self._curated_history.extend(input_contents)
159
175
  self._curated_history.extend(output_contents)
160
176
 
161
-
162
177
  def get_history(self, curated: bool = False) -> list[Content]:
163
178
  """Returns the chat history.
164
179
 
165
180
  Args:
166
- curated: A boolean flag indicating whether to return the curated
167
- (valid) history or the comprehensive (all turns) history.
168
- Defaults to False (returns the comprehensive history).
181
+ curated: A boolean flag indicating whether to return the curated (valid)
182
+ history or the comprehensive (all turns) history. Defaults to False
183
+ (returns the comprehensive history).
169
184
 
170
185
  Returns:
171
186
  A list of `Content` objects representing the chat history.
@@ -176,9 +191,41 @@ class _BaseChat:
176
191
  return self._comprehensive_history
177
192
 
178
193
 
194
+ def _is_part_type(
195
+ contents: Union[list[PartUnionDict], PartUnionDict],
196
+ ) -> TypeGuard[t.ContentType]:
197
+ if isinstance(contents, list):
198
+ return all(_is_part_type(part) for part in contents)
199
+ else:
200
+ allowed_part_types = get_args(types.PartUnion)
201
+ if type(contents) in allowed_part_types:
202
+ return True
203
+ else:
204
+ # Some images don't pass isinstance(item, PIL.Image.Image)
205
+ # For example <class 'PIL.JpegImagePlugin.JpegImageFile'>
206
+ if types.PIL_Image is not None and isinstance(contents, types.PIL_Image):
207
+ return True
208
+ return False
209
+
210
+
179
211
  class Chat(_BaseChat):
180
212
  """Chat session."""
181
213
 
214
+ def __init__(
215
+ self,
216
+ *,
217
+ modules: Models,
218
+ model: str,
219
+ config: Optional[GenerateContentConfigOrDict] = None,
220
+ history: list[ContentOrDict],
221
+ ):
222
+ self._modules = modules
223
+ super().__init__(
224
+ model=model,
225
+ config=config,
226
+ history=history,
227
+ )
228
+
182
229
  def send_message(
183
230
  self,
184
231
  message: Union[list[PartUnionDict], PartUnionDict],
@@ -202,10 +249,15 @@ class Chat(_BaseChat):
202
249
  response = chat.send_message('tell me a story')
203
250
  """
204
251
 
252
+ if not _is_part_type(message):
253
+ raise ValueError(
254
+ f"Message must be a valid part type: {types.PartUnion} or"
255
+ f" {types.PartUnionDict}, got {type(message)}"
256
+ )
205
257
  input_content = t.t_content(self._modules._api_client, message)
206
258
  response = self._modules.generate_content(
207
259
  model=self._model,
208
- contents=self._curated_history + [input_content],
260
+ contents=self._curated_history + [input_content], # type: ignore[arg-type]
209
261
  config=config if config else self._config,
210
262
  )
211
263
  model_output = (
@@ -213,10 +265,15 @@ class Chat(_BaseChat):
213
265
  if response.candidates and response.candidates[0].content
214
266
  else []
215
267
  )
268
+ automatic_function_calling_history = (
269
+ response.automatic_function_calling_history
270
+ if response.automatic_function_calling_history
271
+ else []
272
+ )
216
273
  self.record_history(
217
274
  user_input=input_content,
218
275
  model_output=model_output,
219
- automatic_function_calling_history=response.automatic_function_calling_history,
276
+ automatic_function_calling_history=automatic_function_calling_history,
220
277
  is_valid=_validate_response(response),
221
278
  )
222
279
  return response
@@ -245,29 +302,42 @@ class Chat(_BaseChat):
245
302
  print(chunk.text)
246
303
  """
247
304
 
305
+ if not _is_part_type(message):
306
+ raise ValueError(
307
+ f"Message must be a valid part type: {types.PartUnion} or"
308
+ f" {types.PartUnionDict}, got {type(message)}"
309
+ )
248
310
  input_content = t.t_content(self._modules._api_client, message)
249
311
  output_contents = []
250
312
  finish_reason = None
251
313
  is_valid = True
252
314
  chunk = None
253
- for chunk in self._modules.generate_content_stream(
254
- model=self._model,
255
- contents=self._curated_history + [input_content],
256
- config=config if config else self._config,
257
- ):
258
- if not _validate_response(chunk):
259
- is_valid = False
260
- if chunk.candidates and chunk.candidates[0].content:
261
- output_contents.append(chunk.candidates[0].content)
262
- if chunk.candidates and chunk.candidates[0].finish_reason:
263
- finish_reason = chunk.candidates[0].finish_reason
264
- yield chunk
265
- self.record_history(
266
- user_input=input_content,
267
- model_output=output_contents,
268
- automatic_function_calling_history=chunk.automatic_function_calling_history,
269
- is_valid=is_valid and output_contents and finish_reason,
270
- )
315
+ if isinstance(self._modules, Models):
316
+ for chunk in self._modules.generate_content_stream(
317
+ model=self._model,
318
+ contents=self._curated_history + [input_content], # type: ignore[arg-type]
319
+ config=config if config else self._config,
320
+ ):
321
+ if not _validate_response(chunk):
322
+ is_valid = False
323
+ if chunk.candidates and chunk.candidates[0].content:
324
+ output_contents.append(chunk.candidates[0].content)
325
+ if chunk.candidates and chunk.candidates[0].finish_reason:
326
+ finish_reason = chunk.candidates[0].finish_reason
327
+ yield chunk
328
+ automatic_function_calling_history = (
329
+ chunk.automatic_function_calling_history
330
+ if chunk.automatic_function_calling_history
331
+ else []
332
+ )
333
+ self.record_history(
334
+ user_input=input_content,
335
+ model_output=output_contents,
336
+ automatic_function_calling_history=automatic_function_calling_history,
337
+ is_valid=is_valid
338
+ and output_contents is not None
339
+ and finish_reason is not None,
340
+ )
271
341
 
272
342
 
273
343
  class Chats:
@@ -281,7 +351,7 @@ class Chats:
281
351
  *,
282
352
  model: str,
283
353
  config: Optional[GenerateContentConfigOrDict] = None,
284
- history: Optional[list[Content]] = None,
354
+ history: Optional[list[ContentOrDict]] = None,
285
355
  ) -> Chat:
286
356
  """Creates a new chat session.
287
357
 
@@ -304,6 +374,21 @@ class Chats:
304
374
  class AsyncChat(_BaseChat):
305
375
  """Async chat session."""
306
376
 
377
+ def __init__(
378
+ self,
379
+ *,
380
+ modules: AsyncModels,
381
+ model: str,
382
+ config: Optional[GenerateContentConfigOrDict] = None,
383
+ history: list[ContentOrDict],
384
+ ):
385
+ self._modules = modules
386
+ super().__init__(
387
+ model=model,
388
+ config=config,
389
+ history=history,
390
+ )
391
+
307
392
  async def send_message(
308
393
  self,
309
394
  message: Union[list[PartUnionDict], PartUnionDict],
@@ -326,11 +411,15 @@ class AsyncChat(_BaseChat):
326
411
  chat = client.aio.chats.create(model='gemini-1.5-flash')
327
412
  response = await chat.send_message('tell me a story')
328
413
  """
329
-
414
+ if not _is_part_type(message):
415
+ raise ValueError(
416
+ f"Message must be a valid part type: {types.PartUnion} or"
417
+ f" {types.PartUnionDict}, got {type(message)}"
418
+ )
330
419
  input_content = t.t_content(self._modules._api_client, message)
331
420
  response = await self._modules.generate_content(
332
421
  model=self._model,
333
- contents=self._curated_history + [input_content],
422
+ contents=self._curated_history + [input_content], # type: ignore[arg-type]
334
423
  config=config if config else self._config,
335
424
  )
336
425
  model_output = (
@@ -338,10 +427,15 @@ class AsyncChat(_BaseChat):
338
427
  if response.candidates and response.candidates[0].content
339
428
  else []
340
429
  )
430
+ automatic_function_calling_history = (
431
+ response.automatic_function_calling_history
432
+ if response.automatic_function_calling_history
433
+ else []
434
+ )
341
435
  self.record_history(
342
436
  user_input=input_content,
343
437
  model_output=model_output,
344
- automatic_function_calling_history=response.automatic_function_calling_history,
438
+ automatic_function_calling_history=automatic_function_calling_history,
345
439
  is_valid=_validate_response(response),
346
440
  )
347
441
  return response
@@ -369,6 +463,11 @@ class AsyncChat(_BaseChat):
369
463
  print(chunk.text)
370
464
  """
371
465
 
466
+ if not _is_part_type(message):
467
+ raise ValueError(
468
+ f"Message must be a valid part type: {types.PartUnion} or"
469
+ f" {types.PartUnionDict}, got {type(message)}"
470
+ )
372
471
  input_content = t.t_content(self._modules._api_client, message)
373
472
 
374
473
  async def async_generator():
@@ -394,7 +493,6 @@ class AsyncChat(_BaseChat):
394
493
  model_output=output_contents,
395
494
  automatic_function_calling_history=chunk.automatic_function_calling_history,
396
495
  is_valid=is_valid and output_contents and finish_reason,
397
-
398
496
  )
399
497
  return async_generator()
400
498
 
@@ -410,7 +508,7 @@ class AsyncChats:
410
508
  *,
411
509
  model: str,
412
510
  config: Optional[GenerateContentConfigOrDict] = None,
413
- history: Optional[list[Content]] = None,
511
+ history: Optional[list[ContentOrDict]] = None,
414
512
  ) -> AsyncChat:
415
513
  """Creates a new chat session.
416
514
 
google/genai/client.py CHANGED
@@ -194,6 +194,8 @@ class Client:
194
194
  """
195
195
 
196
196
  self._debug_config = debug_config or DebugConfig()
197
+ if isinstance(http_options, dict):
198
+ http_options = HttpOptions(**http_options)
197
199
 
198
200
  self._api_client = self._get_api_client(
199
201
  vertexai=vertexai,
@@ -229,10 +231,10 @@ class Client:
229
231
  'auto',
230
232
  ]:
231
233
  return ReplayApiClient(
232
- mode=debug_config.client_mode,
233
- replay_id=debug_config.replay_id,
234
+ mode=debug_config.client_mode, # type: ignore[arg-type]
235
+ replay_id=debug_config.replay_id, # type: ignore[arg-type]
234
236
  replays_directory=debug_config.replays_directory,
235
- vertexai=vertexai,
237
+ vertexai=vertexai, # type: ignore[arg-type]
236
238
  api_key=api_key,
237
239
  credentials=credentials,
238
240
  project=project,
google/genai/errors.py CHANGED
@@ -35,28 +35,10 @@ class APIError(Exception):
35
35
  def __init__(
36
36
  self,
37
37
  code: int,
38
+ response_json: Any,
38
39
  response: Union['ReplayResponse', httpx.Response],
39
40
  ):
40
41
  self.response = response
41
- message = None
42
- if isinstance(response, httpx.Response):
43
- try:
44
- response_json = response.json()
45
- except (json.decoder.JSONDecodeError):
46
- message = response.text
47
- response_json = {
48
- 'message': message,
49
- 'status': response.reason_phrase,
50
- }
51
- except httpx.ResponseNotRead:
52
- message = 'Response not read'
53
- response_json = {
54
- 'message': message,
55
- 'status': response.reason_phrase,
56
- }
57
- else:
58
- response_json = response.body_segments[0].get('error', {})
59
-
60
42
  self.details = response_json
61
43
  self.message = self._get_message(response_json)
62
44
  self.status = self._get_status(response_json)
@@ -101,13 +83,54 @@ class APIError(Exception):
101
83
  if response.status_code == 200:
102
84
  return
103
85
 
86
+ if isinstance(response, httpx.Response):
87
+ try:
88
+ response.read()
89
+ response_json = response.json()
90
+ except json.decoder.JSONDecodeError:
91
+ message = response.text
92
+ response_json = {
93
+ 'message': message,
94
+ 'status': response.reason_phrase,
95
+ }
96
+ else:
97
+ response_json = response.body_segments[0].get('error', {})
98
+
99
+ status_code = response.status_code
100
+ if 400 <= status_code < 500:
101
+ raise ClientError(status_code, response_json, response)
102
+ elif 500 <= status_code < 600:
103
+ raise ServerError(status_code, response_json, response)
104
+ else:
105
+ raise cls(status_code, response_json, response)
106
+
107
+ @classmethod
108
+ async def raise_for_async_response(
109
+ cls, response: Union['ReplayResponse', httpx.Response]
110
+ ):
111
+ """Raises an error with detailed error message if the response has an error status."""
112
+ if response.status_code == 200:
113
+ return
114
+ if isinstance(response, httpx.Response):
115
+ try:
116
+ await response.aread()
117
+ response_json = response.json()
118
+ except json.decoder.JSONDecodeError:
119
+ message = response.text
120
+ response_json = {
121
+ 'message': message,
122
+ 'status': response.reason_phrase,
123
+ }
124
+ else:
125
+ response_json = response.body_segments[0].get('error', {})
126
+
104
127
  status_code = response.status_code
105
128
  if 400 <= status_code < 500:
106
- raise ClientError(status_code, response)
129
+ raise ClientError(status_code, response_json, response)
107
130
  elif 500 <= status_code < 600:
108
- raise ServerError(status_code, response)
131
+ raise ServerError(status_code, response_json, response)
109
132
  else:
110
- raise cls(status_code, response)
133
+ raise cls(status_code, response_json, response)
111
134
 
112
135
 
113
136
  class ClientError(APIError):
@@ -137,4 +160,4 @@ class FunctionInvocationError(ValueError):
137
160
 
138
161
 
139
162
  class ExperimentalWarning(Warning):
140
- """Warning for experimental features."""
163
+ """Warning for experimental features."""