google-genai 1.2.0__py3-none-any.whl → 1.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
google/genai/chats.py CHANGED
@@ -21,14 +21,10 @@ from .models import AsyncModels, Models
21
21
  from .types import Content, ContentDict, GenerateContentConfigOrDict, GenerateContentResponse, Part, PartUnionDict
22
22
 
23
23
 
24
- def _validate_response(response: GenerateContentResponse) -> bool:
25
- if not response.candidates:
26
- return False
27
- if not response.candidates[0].content:
28
- return False
29
- if not response.candidates[0].content.parts:
24
+ def _validate_content(content: Content) -> bool:
25
+ if not content.parts:
30
26
  return False
31
- for part in response.candidates[0].content.parts:
27
+ for part in content.parts:
32
28
  if part == Part():
33
29
  return False
34
30
  if part.text is not None and part.text == "":
@@ -36,6 +32,76 @@ def _validate_response(response: GenerateContentResponse) -> bool:
36
32
  return True
37
33
 
38
34
 
35
+ def _validate_contents(contents: list[Content]) -> bool:
36
+ if not contents:
37
+ return False
38
+ for content in contents:
39
+ if not _validate_content(content):
40
+ return False
41
+ return True
42
+
43
+
44
+ def _validate_response(response: GenerateContentResponse) -> bool:
45
+ if not response.candidates:
46
+ return False
47
+ if not response.candidates[0].content:
48
+ return False
49
+ return _validate_content(response.candidates[0].content)
50
+
51
+
52
+ def _extract_curated_history(
53
+ comprehensive_history: list[Content],
54
+ ) -> list[Content]:
55
+ """Extracts the curated (valid) history from a comprehensive history.
56
+
57
+ The comprehensive history contains all turns (user input and model responses),
58
+ including any invalid or rejected model outputs. This function filters
59
+ that history to return only the valid turns.
60
+
61
+ A "turn" starts with one user input (a single content) and then follows by
62
+ corresponding model response (which may consist of multiple contents).
63
+ Turns are assumed to alternate: user input, model output, user input, model
64
+ output, etc.
65
+
66
+ Args:
67
+ comprehensive_history: A list representing the complete chat history.
68
+ Including invalid turns.
69
+
70
+ Returns:
71
+ curated history, which is a list of valid turns.
72
+ """
73
+ if not comprehensive_history:
74
+ return []
75
+ curated_history = []
76
+ length = len(comprehensive_history)
77
+ i = 0
78
+ current_input = comprehensive_history[i]
79
+ if current_input.role != "user":
80
+ raise ValueError("History must start with a user turn.")
81
+ while i < length:
82
+ if comprehensive_history[i].role not in ["user", "model"]:
83
+ raise ValueError(
84
+ "Role must be user or model, but got"
85
+ f" {comprehensive_history[i].role}"
86
+ )
87
+
88
+ if comprehensive_history[i].role == "user":
89
+ current_input = comprehensive_history[i]
90
+ i += 1
91
+ else:
92
+ current_output = []
93
+ is_valid = True
94
+ while i < length and comprehensive_history[i].role == "model":
95
+ current_output.append(comprehensive_history[i])
96
+ if is_valid and not _validate_content(comprehensive_history[i]):
97
+ is_valid = False
98
+ i += 1
99
+ if is_valid:
100
+ curated_history.append(current_input)
101
+ curated_history.extend(current_output)
102
+ return curated_history
103
+
104
+
39
105
  class _BaseChat:
40
106
  """Base chat session."""
41
107
 
@@ -44,13 +110,70 @@ class _BaseChat:
44
110
  *,
45
111
  modules: Union[Models, AsyncModels],
46
112
  model: str,
47
- config: GenerateContentConfigOrDict = None,
113
+ config: Optional[GenerateContentConfigOrDict] = None,
48
114
  history: list[Content],
49
115
  ):
50
116
  self._modules = modules
51
117
  self._model = model
52
118
  self._config = config
53
- self._curated_history = history
119
+ self._comprehensive_history = history
120
+ """Comprehensive history is the full history of the chat, including turns of the invalid contents from the model and their associated inputs.
121
+ """
122
+ self._curated_history = _extract_curated_history(history)
123
+ """Curated history is the set of valid turns that will be used in the subsequent send requests.
124
+ """
125
+
126
+
127
+ def record_history(self, user_input: Content,
128
+ model_output: list[Content],
129
+ automatic_function_calling_history: list[Content],
130
+ is_valid: bool):
131
+ """Records the chat history.
132
+
133
+ Maintaining both comprehensive and curated histories.
134
+
135
+ Args:
136
+ user_input: The user's input content.
137
+ model_output: A list of `Content` from the model's response.
138
+ This can be an empty list if the model produced no output.
139
+ automatic_function_calling_history: A list of `Content` representing
140
+ the history of automatic function calls, including the user input as
141
+ the first entry.
142
+ is_valid: A boolean flag indicating whether the current model output is
143
+ considered valid.
144
+ """
145
+ input_contents = (
146
+ automatic_function_calling_history
147
+ if automatic_function_calling_history
148
+ else [user_input]
149
+ )
150
+ # Appends an empty content when model returns empty response, so that the
151
+ # history is always alternating between user and model.
152
+ output_contents = (
153
+ model_output if model_output else [Content(role="model", parts=[])]
154
+ )
155
+ self._comprehensive_history.extend(input_contents)
156
+ self._comprehensive_history.extend(output_contents)
157
+ if is_valid:
158
+ self._curated_history.extend(input_contents)
159
+ self._curated_history.extend(output_contents)
160
+
161
+
162
+ def get_history(self, curated: bool = False) -> list[Content]:
163
+ """Returns the chat history.
164
+
165
+ Args:
166
+ curated: A boolean flag indicating whether to return the curated
167
+ (valid) history or the comprehensive (all turns) history.
168
+ Defaults to False (returns the comprehensive history).
169
+
170
+ Returns:
171
+ A list of `Content` objects representing the chat history.
172
+ """
173
+ if curated:
174
+ return self._curated_history
175
+ else:
176
+ return self._comprehensive_history
54
177
 
55
178
 
56
179
  class Chat(_BaseChat):
@@ -85,14 +208,17 @@ class Chat(_BaseChat):
85
208
  contents=self._curated_history + [input_content],
86
209
  config=config if config else self._config,
87
210
  )
88
- if _validate_response(response):
89
- if response.automatic_function_calling_history:
90
- self._curated_history.extend(
91
- response.automatic_function_calling_history
92
- )
93
- else:
94
- self._curated_history.append(input_content)
95
- self._curated_history.append(response.candidates[0].content)
211
+ model_output = (
212
+ [response.candidates[0].content]
213
+ if response.candidates and response.candidates[0].content
214
+ else []
215
+ )
216
+ self.record_history(
217
+ user_input=input_content,
218
+ model_output=model_output,
219
+ automatic_function_calling_history=response.automatic_function_calling_history,
220
+ is_valid=_validate_response(response),
221
+ )
96
222
  return response
97
223
 
98
224
  def send_message_stream(
@@ -122,19 +248,26 @@ class Chat(_BaseChat):
122
248
  input_content = t.t_content(self._modules._api_client, message)
123
249
  output_contents = []
124
250
  finish_reason = None
251
+ is_valid = True
252
+ chunk = None
125
253
  for chunk in self._modules.generate_content_stream(
126
254
  model=self._model,
127
255
  contents=self._curated_history + [input_content],
128
256
  config=config if config else self._config,
129
257
  ):
130
- if _validate_response(chunk):
258
+ if not _validate_response(chunk):
259
+ is_valid = False
260
+ if chunk.candidates and chunk.candidates[0].content:
131
261
  output_contents.append(chunk.candidates[0].content)
132
262
  if chunk.candidates and chunk.candidates[0].finish_reason:
133
263
  finish_reason = chunk.candidates[0].finish_reason
134
264
  yield chunk
135
- if output_contents and finish_reason:
136
- self._curated_history.append(input_content)
137
- self._curated_history.extend(output_contents)
265
+ self.record_history(
266
+ user_input=input_content,
267
+ model_output=output_contents,
268
+ automatic_function_calling_history=chunk.automatic_function_calling_history,
269
+ is_valid=is_valid and output_contents and finish_reason,
270
+ )
138
271
 
139
272
 
140
273
  class Chats:
@@ -147,7 +280,7 @@ class Chats:
147
280
  self,
148
281
  *,
149
282
  model: str,
150
- config: GenerateContentConfigOrDict = None,
283
+ config: Optional[GenerateContentConfigOrDict] = None,
151
284
  history: Optional[list[Content]] = None,
152
285
  ) -> Chat:
153
286
  """Creates a new chat session.
@@ -200,14 +333,17 @@ class AsyncChat(_BaseChat):
200
333
  contents=self._curated_history + [input_content],
201
334
  config=config if config else self._config,
202
335
  )
203
- if _validate_response(response):
204
- if response.automatic_function_calling_history:
205
- self._curated_history.extend(
206
- response.automatic_function_calling_history
207
- )
208
- else:
209
- self._curated_history.append(input_content)
210
- self._curated_history.append(response.candidates[0].content)
336
+ model_output = (
337
+ [response.candidates[0].content]
338
+ if response.candidates and response.candidates[0].content
339
+ else []
340
+ )
341
+ self.record_history(
342
+ user_input=input_content,
343
+ model_output=model_output,
344
+ automatic_function_calling_history=response.automatic_function_calling_history,
345
+ is_valid=_validate_response(response),
346
+ )
211
347
  return response
212
348
 
213
349
  async def send_message_stream(
@@ -238,20 +374,28 @@ class AsyncChat(_BaseChat):
238
374
  async def async_generator():
239
375
  output_contents = []
240
376
  finish_reason = None
377
+ is_valid = True
378
+ chunk = None
241
379
  async for chunk in await self._modules.generate_content_stream(
242
380
  model=self._model,
243
381
  contents=self._curated_history + [input_content],
244
382
  config=config if config else self._config,
245
383
  ):
246
- if _validate_response(chunk):
384
+ if not _validate_response(chunk):
385
+ is_valid = False
386
+ if chunk.candidates and chunk.candidates[0].content:
247
387
  output_contents.append(chunk.candidates[0].content)
248
388
  if chunk.candidates and chunk.candidates[0].finish_reason:
249
389
  finish_reason = chunk.candidates[0].finish_reason
250
390
  yield chunk
251
391
 
252
- if output_contents and finish_reason:
253
- self._curated_history.append(input_content)
254
- self._curated_history.extend(output_contents)
392
+ self.record_history(
393
+ user_input=input_content,
394
+ model_output=output_contents,
395
+ automatic_function_calling_history=chunk.automatic_function_calling_history,
396
+ is_valid=is_valid and output_contents and finish_reason,
397
+
398
+ )
255
399
  return async_generator()
256
400
 
257
401
 
@@ -265,7 +409,7 @@ class AsyncChats:
265
409
  self,
266
410
  *,
267
411
  model: str,
268
- config: GenerateContentConfigOrDict = None,
412
+ config: Optional[GenerateContentConfigOrDict] = None,
269
413
  history: Optional[list[Content]] = None,
270
414
  ) -> AsyncChat:
271
415
  """Creates a new chat session.
google/genai/client.py CHANGED
@@ -19,7 +19,7 @@ from typing import Optional, Union
19
19
  import google.auth
20
20
  import pydantic
21
21
 
22
- from ._api_client import ApiClient, HttpOptions, HttpOptionsDict
22
+ from ._api_client import BaseApiClient, HttpOptions, HttpOptionsDict
23
23
  from ._replay_api_client import ReplayApiClient
24
24
  from .batches import AsyncBatches, Batches
25
25
  from .caches import AsyncCaches, Caches
@@ -27,13 +27,14 @@ from .chats import AsyncChats, Chats
27
27
  from .files import AsyncFiles, Files
28
28
  from .live import AsyncLive
29
29
  from .models import AsyncModels, Models
30
+ from .operations import AsyncOperations, Operations
30
31
  from .tunings import AsyncTunings, Tunings
31
32
 
32
33
 
33
34
  class AsyncClient:
34
35
  """Client for making asynchronous (non-blocking) requests."""
35
36
 
36
- def __init__(self, api_client: ApiClient):
37
+ def __init__(self, api_client: BaseApiClient):
37
38
 
38
39
  self._api_client = api_client
39
40
  self._models = AsyncModels(self._api_client)
@@ -42,6 +43,7 @@ class AsyncClient:
42
43
  self._batches = AsyncBatches(self._api_client)
43
44
  self._files = AsyncFiles(self._api_client)
44
45
  self._live = AsyncLive(self._api_client)
46
+ self._operations = AsyncOperations(self._api_client)
45
47
 
46
48
  @property
47
49
  def models(self) -> AsyncModels:
@@ -71,6 +73,9 @@ class AsyncClient:
71
73
  def live(self) -> AsyncLive:
72
74
  return self._live
73
75
 
76
+ @property
77
+ def operations(self) -> AsyncOperations:
78
+ return self._operations
74
79
 
75
80
  class DebugConfig(pydantic.BaseModel):
76
81
  """Configuration options that change client network behavior when testing."""
@@ -100,9 +105,9 @@ class Client:
100
105
  `api_key="your-api-key"` or by defining `GOOGLE_API_KEY="your-api-key"` as an
101
106
  environment variable
102
107
 
103
- Vertex AI API users can provide inputs argument as `vertexai=false,
108
+ Vertex AI API users can provide inputs argument as `vertexai=True,
104
109
  project="your-project-id", location="us-central1"` or by defining
105
- `GOOGLE_GENAI_USE_VERTEXAI=false`, `GOOGLE_CLOUD_PROJECT` and
110
+ `GOOGLE_GENAI_USE_VERTEXAI=true`, `GOOGLE_CLOUD_PROJECT` and
106
111
  `GOOGLE_CLOUD_LOCATION` environment variables.
107
112
 
108
113
  Attributes:
@@ -205,6 +210,7 @@ class Client:
205
210
  self._caches = Caches(self._api_client)
206
211
  self._batches = Batches(self._api_client)
207
212
  self._files = Files(self._api_client)
213
+ self._operations = Operations(self._api_client)
208
214
 
209
215
  @staticmethod
210
216
  def _get_api_client(
@@ -233,7 +239,7 @@ class Client:
233
239
  http_options=http_options,
234
240
  )
235
241
 
236
- return ApiClient(
242
+ return BaseApiClient(
237
243
  vertexai=vertexai,
238
244
  api_key=api_key,
239
245
  credentials=credentials,
@@ -270,7 +276,11 @@ class Client:
270
276
  def files(self) -> Files:
271
277
  return self._files
272
278
 
279
+ @property
280
+ def operations(self) -> Operations:
281
+ return self._operations
282
+
273
283
  @property
274
284
  def vertexai(self) -> bool:
275
285
  """Returns whether the client is using the Vertex AI API."""
276
- return self._api_client.vertexai or False
286
+ return self._api_client.vertexai or False
google/genai/errors.py CHANGED
@@ -16,7 +16,8 @@
16
16
  """Error classes for the GenAI SDK."""
17
17
 
18
18
  from typing import Any, Optional, TYPE_CHECKING, Union
19
-
19
+ import httpx
20
+ import json
20
21
  import requests
21
22
 
22
23
 
@@ -27,14 +28,15 @@ if TYPE_CHECKING:
27
28
  class APIError(Exception):
28
29
  """General errors raised by the GenAI API."""
29
30
  code: int
30
- response: requests.Response
31
+ response: Union[requests.Response, 'ReplayResponse', httpx.Response]
31
32
 
32
33
  status: Optional[str] = None
33
34
  message: Optional[str] = None
34
- response: Optional[Any] = None
35
35
 
36
36
  def __init__(
37
- self, code: int, response: Union[requests.Response, 'ReplayResponse']
37
+ self,
38
+ code: int,
39
+ response: Union[requests.Response, 'ReplayResponse', httpx.Response],
38
40
  ):
39
41
  self.response = response
40
42
 
@@ -48,6 +50,18 @@ class APIError(Exception):
48
50
  'message': response.text,
49
51
  'status': response.reason,
50
52
  }
53
+ elif isinstance(response, httpx.Response):
54
+ try:
55
+ response_json = response.json()
56
+ except (json.decoder.JSONDecodeError, httpx.ResponseNotRead):
57
+ try:
58
+ message = response.text
59
+ except httpx.ResponseNotRead:
60
+ message = None
61
+ response_json = {
62
+ 'message': message,
63
+ 'status': response.reason_phrase,
64
+ }
51
65
  else:
52
66
  response_json = response.body_segments[0].get('error', {})
53
67
 
@@ -89,7 +103,7 @@ class APIError(Exception):
89
103
 
90
104
  @classmethod
91
105
  def raise_for_response(
92
- cls, response: Union[requests.Response, 'ReplayResponse']
106
+ cls, response: Union[requests.Response, 'ReplayResponse', httpx.Response]
93
107
  ):
94
108
  """Raises an error with detailed error message if the response has an error status."""
95
109
  if response.status_code == 200: