google-genai 1.3.0__py3-none-any.whl → 1.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
google/genai/chats.py CHANGED
@@ -21,14 +21,10 @@ from .models import AsyncModels, Models
21
21
  from .types import Content, ContentDict, GenerateContentConfigOrDict, GenerateContentResponse, Part, PartUnionDict
22
22
 
23
23
 
24
- def _validate_response(response: GenerateContentResponse) -> bool:
25
- if not response.candidates:
26
- return False
27
- if not response.candidates[0].content:
28
- return False
29
- if not response.candidates[0].content.parts:
24
+ def _validate_content(content: Content) -> bool:
25
+ if not content.parts:
30
26
  return False
31
- for part in response.candidates[0].content.parts:
27
+ for part in content.parts:
32
28
  if part == Part():
33
29
  return False
34
30
  if part.text is not None and part.text == "":
@@ -36,6 +32,76 @@ def _validate_response(response: GenerateContentResponse) -> bool:
36
32
  return True
37
33
 
38
34
 
35
+ def _validate_contents(contents: list[Content]) -> bool:
36
+ if not contents:
37
+ return False
38
+ for content in contents:
39
+ if not _validate_content(content):
40
+ return False
41
+ return True
42
+
43
+
44
+ def _validate_response(response: GenerateContentResponse) -> bool:
45
+ if not response.candidates:
46
+ return False
47
+ if not response.candidates[0].content:
48
+ return False
49
+ return _validate_content(response.candidates[0].content)
50
+
51
+
52
+ def _extract_curated_history(
53
+ comprehensive_history: list[Content],
54
+ ) -> list[Content]:
55
+ """Extracts the curated (valid) history from a comprehensive history.
56
+
57
+ The comprehensive history contains all turns (user input and model responses),
58
+ including any invalid or rejected model outputs. This function filters
59
+ that history to return only the valid turns.
60
+
61
+ A "turn" starts with one user input (a single content) and then follows by
62
+ corresponding model response (which may consist of multiple contents).
63
+ Turns are assumed to alternate: user input, model output, user input, model
64
+ output, etc.
65
+
66
+ Args:
67
+ comprehensive_history: A list representing the complete chat history.
68
+ Including invalid turns.
69
+
70
+ Returns:
71
+ curated history, which is a list of valid turns.
72
+ """
73
+ if not comprehensive_history:
74
+ return []
75
+ curated_history = []
76
+ length = len(comprehensive_history)
77
+ i = 0
78
+ current_input = comprehensive_history[i]
79
+ if current_input.role != "user":
80
+ raise ValueError("History must start with a user turn.")
81
+ while i < length:
82
+ if comprehensive_history[i].role not in ["user", "model"]:
83
+ raise ValueError(
84
+ "Role must be user or model, but got"
85
+ f" {comprehensive_history[i].role}"
86
+ )
87
+
88
+ if comprehensive_history[i].role == "user":
89
+ current_input = comprehensive_history[i]
90
+ i += 1
91
+ else:
92
+ current_output = []
93
+ is_valid = True
94
+ while i < length and comprehensive_history[i].role == "model":
95
+ current_output.append(comprehensive_history[i])
96
+ if is_valid and not _validate_content(comprehensive_history[i]):
97
+ is_valid = False
98
+ i += 1
99
+ if is_valid:
100
+ curated_history.append(current_input)
101
+ curated_history.extend(current_output)
102
+ return curated_history
103
+
104
+
39
105
  class _BaseChat:
40
106
  """Base chat session."""
41
107
 
@@ -44,13 +110,70 @@ class _BaseChat:
44
110
  *,
45
111
  modules: Union[Models, AsyncModels],
46
112
  model: str,
47
- config: GenerateContentConfigOrDict = None,
113
+ config: Optional[GenerateContentConfigOrDict] = None,
48
114
  history: list[Content],
49
115
  ):
50
116
  self._modules = modules
51
117
  self._model = model
52
118
  self._config = config
53
- self._curated_history = history
119
+ self._comprehensive_history = history
120
+ """Comprehensive history is the full history of the chat, including turns of the invalid contents from the model and their associated inputs.
121
+ """
122
+ self._curated_history = _extract_curated_history(history)
123
+ """Curated history is the set of valid turns that will be used in the subsequent send requests.
124
+ """
125
+
126
+
127
+ def record_history(self, user_input: Content,
128
+ model_output: list[Content],
129
+ automatic_function_calling_history: list[Content],
130
+ is_valid: bool):
131
+ """Records the chat history.
132
+
133
+ Maintaining both comprehensive and curated histories.
134
+
135
+ Args:
136
+ user_input: The user's input content.
137
+ model_output: A list of `Content` from the model's response.
138
+ This can be an empty list if the model produced no output.
139
+ automatic_function_calling_history: A list of `Content` representing
140
+ the history of automatic function calls, including the user input as
141
+ the first entry.
142
+ is_valid: A boolean flag indicating whether the current model output is
143
+ considered valid.
144
+ """
145
+ input_contents = (
146
+ automatic_function_calling_history
147
+ if automatic_function_calling_history
148
+ else [user_input]
149
+ )
150
+ # Appends an empty content when model returns empty response, so that the
151
+ # history is always alternating between user and model.
152
+ output_contents = (
153
+ model_output if model_output else [Content(role="model", parts=[])]
154
+ )
155
+ self._comprehensive_history.extend(input_contents)
156
+ self._comprehensive_history.extend(output_contents)
157
+ if is_valid:
158
+ self._curated_history.extend(input_contents)
159
+ self._curated_history.extend(output_contents)
160
+
161
+
162
+ def get_history(self, curated: bool = False) -> list[Content]:
163
+ """Returns the chat history.
164
+
165
+ Args:
166
+ curated: A boolean flag indicating whether to return the curated
167
+ (valid) history or the comprehensive (all turns) history.
168
+ Defaults to False (returns the comprehensive history).
169
+
170
+ Returns:
171
+ A list of `Content` objects representing the chat history.
172
+ """
173
+ if curated:
174
+ return self._curated_history
175
+ else:
176
+ return self._comprehensive_history
54
177
 
55
178
 
56
179
  class Chat(_BaseChat):
@@ -85,14 +208,17 @@ class Chat(_BaseChat):
85
208
  contents=self._curated_history + [input_content],
86
209
  config=config if config else self._config,
87
210
  )
88
- if _validate_response(response):
89
- if response.automatic_function_calling_history:
90
- self._curated_history.extend(
91
- response.automatic_function_calling_history
92
- )
93
- else:
94
- self._curated_history.append(input_content)
95
- self._curated_history.append(response.candidates[0].content)
211
+ model_output = (
212
+ [response.candidates[0].content]
213
+ if response.candidates and response.candidates[0].content
214
+ else []
215
+ )
216
+ self.record_history(
217
+ user_input=input_content,
218
+ model_output=model_output,
219
+ automatic_function_calling_history=response.automatic_function_calling_history,
220
+ is_valid=_validate_response(response),
221
+ )
96
222
  return response
97
223
 
98
224
  def send_message_stream(
@@ -122,19 +248,26 @@ class Chat(_BaseChat):
122
248
  input_content = t.t_content(self._modules._api_client, message)
123
249
  output_contents = []
124
250
  finish_reason = None
251
+ is_valid = True
252
+ chunk = None
125
253
  for chunk in self._modules.generate_content_stream(
126
254
  model=self._model,
127
255
  contents=self._curated_history + [input_content],
128
256
  config=config if config else self._config,
129
257
  ):
130
- if _validate_response(chunk):
258
+ if not _validate_response(chunk):
259
+ is_valid = False
260
+ if chunk.candidates and chunk.candidates[0].content:
131
261
  output_contents.append(chunk.candidates[0].content)
132
262
  if chunk.candidates and chunk.candidates[0].finish_reason:
133
263
  finish_reason = chunk.candidates[0].finish_reason
134
264
  yield chunk
135
- if output_contents and finish_reason:
136
- self._curated_history.append(input_content)
137
- self._curated_history.extend(output_contents)
265
+ self.record_history(
266
+ user_input=input_content,
267
+ model_output=output_contents,
268
+ automatic_function_calling_history=chunk.automatic_function_calling_history,
269
+ is_valid=is_valid and output_contents and finish_reason,
270
+ )
138
271
 
139
272
 
140
273
  class Chats:
@@ -147,7 +280,7 @@ class Chats:
147
280
  self,
148
281
  *,
149
282
  model: str,
150
- config: GenerateContentConfigOrDict = None,
283
+ config: Optional[GenerateContentConfigOrDict] = None,
151
284
  history: Optional[list[Content]] = None,
152
285
  ) -> Chat:
153
286
  """Creates a new chat session.
@@ -200,14 +333,17 @@ class AsyncChat(_BaseChat):
200
333
  contents=self._curated_history + [input_content],
201
334
  config=config if config else self._config,
202
335
  )
203
- if _validate_response(response):
204
- if response.automatic_function_calling_history:
205
- self._curated_history.extend(
206
- response.automatic_function_calling_history
207
- )
208
- else:
209
- self._curated_history.append(input_content)
210
- self._curated_history.append(response.candidates[0].content)
336
+ model_output = (
337
+ [response.candidates[0].content]
338
+ if response.candidates and response.candidates[0].content
339
+ else []
340
+ )
341
+ self.record_history(
342
+ user_input=input_content,
343
+ model_output=model_output,
344
+ automatic_function_calling_history=response.automatic_function_calling_history,
345
+ is_valid=_validate_response(response),
346
+ )
211
347
  return response
212
348
 
213
349
  async def send_message_stream(
@@ -238,20 +374,28 @@ class AsyncChat(_BaseChat):
238
374
  async def async_generator():
239
375
  output_contents = []
240
376
  finish_reason = None
377
+ is_valid = True
378
+ chunk = None
241
379
  async for chunk in await self._modules.generate_content_stream(
242
380
  model=self._model,
243
381
  contents=self._curated_history + [input_content],
244
382
  config=config if config else self._config,
245
383
  ):
246
- if _validate_response(chunk):
384
+ if not _validate_response(chunk):
385
+ is_valid = False
386
+ if chunk.candidates and chunk.candidates[0].content:
247
387
  output_contents.append(chunk.candidates[0].content)
248
388
  if chunk.candidates and chunk.candidates[0].finish_reason:
249
389
  finish_reason = chunk.candidates[0].finish_reason
250
390
  yield chunk
251
391
 
252
- if output_contents and finish_reason:
253
- self._curated_history.append(input_content)
254
- self._curated_history.extend(output_contents)
392
+ self.record_history(
393
+ user_input=input_content,
394
+ model_output=output_contents,
395
+ automatic_function_calling_history=chunk.automatic_function_calling_history,
396
+ is_valid=is_valid and output_contents and finish_reason,
397
+
398
+ )
255
399
  return async_generator()
256
400
 
257
401
 
@@ -265,7 +409,7 @@ class AsyncChats:
265
409
  self,
266
410
  *,
267
411
  model: str,
268
- config: GenerateContentConfigOrDict = None,
412
+ config: Optional[GenerateContentConfigOrDict] = None,
269
413
  history: Optional[list[Content]] = None,
270
414
  ) -> AsyncChat:
271
415
  """Creates a new chat session.
google/genai/client.py CHANGED
@@ -19,7 +19,7 @@ from typing import Optional, Union
19
19
  import google.auth
20
20
  import pydantic
21
21
 
22
- from ._api_client import ApiClient, HttpOptions, HttpOptionsDict
22
+ from ._api_client import BaseApiClient, HttpOptions, HttpOptionsDict
23
23
  from ._replay_api_client import ReplayApiClient
24
24
  from .batches import AsyncBatches, Batches
25
25
  from .caches import AsyncCaches, Caches
@@ -34,7 +34,7 @@ from .tunings import AsyncTunings, Tunings
34
34
  class AsyncClient:
35
35
  """Client for making asynchronous (non-blocking) requests."""
36
36
 
37
- def __init__(self, api_client: ApiClient):
37
+ def __init__(self, api_client: BaseApiClient):
38
38
 
39
39
  self._api_client = api_client
40
40
  self._models = AsyncModels(self._api_client)
@@ -239,7 +239,7 @@ class Client:
239
239
  http_options=http_options,
240
240
  )
241
241
 
242
- return ApiClient(
242
+ return BaseApiClient(
243
243
  vertexai=vertexai,
244
244
  api_key=api_key,
245
245
  credentials=credentials,
google/genai/errors.py CHANGED
@@ -28,11 +28,10 @@ if TYPE_CHECKING:
28
28
  class APIError(Exception):
29
29
  """General errors raised by the GenAI API."""
30
30
  code: int
31
- response: requests.Response
31
+ response: Union[requests.Response, 'ReplayResponse', httpx.Response]
32
32
 
33
33
  status: Optional[str] = None
34
34
  message: Optional[str] = None
35
- response: Optional[Any] = None
36
35
 
37
36
  def __init__(
38
37
  self,