google-genai 1.5.0__py3-none-any.whl → 1.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google/genai/_api_client.py +196 -141
- google/genai/_automatic_function_calling_util.py +4 -14
- google/genai/_common.py +7 -5
- google/genai/_replay_api_client.py +6 -3
- google/genai/_transformers.py +61 -37
- google/genai/batches.py +4 -0
- google/genai/caches.py +20 -26
- google/genai/chats.py +137 -46
- google/genai/client.py +3 -2
- google/genai/errors.py +11 -19
- google/genai/files.py +9 -9
- google/genai/live.py +276 -93
- google/genai/models.py +245 -68
- google/genai/operations.py +30 -2
- google/genai/pagers.py +3 -5
- google/genai/tunings.py +31 -21
- google/genai/types.py +88 -33
- google/genai/version.py +1 -1
- {google_genai-1.5.0.dist-info → google_genai-1.7.0.dist-info}/METADATA +201 -31
- google_genai-1.7.0.dist-info/RECORD +27 -0
- {google_genai-1.5.0.dist-info → google_genai-1.7.0.dist-info}/WHEEL +1 -1
- google_genai-1.5.0.dist-info/RECORD +0 -27
- {google_genai-1.5.0.dist-info → google_genai-1.7.0.dist-info}/LICENSE +0 -0
- {google_genai-1.5.0.dist-info → google_genai-1.7.0.dist-info}/top_level.txt +0 -0
google/genai/chats.py
CHANGED
@@ -13,12 +13,19 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
#
|
15
15
|
|
16
|
-
|
17
|
-
from typing import Union
|
16
|
+
import sys
|
17
|
+
from typing import AsyncIterator, Awaitable, Optional, Union, get_args
|
18
18
|
|
19
19
|
from . import _transformers as t
|
20
|
+
from . import types
|
20
21
|
from .models import AsyncModels, Models
|
21
|
-
from .types import Content,
|
22
|
+
from .types import Content, GenerateContentConfigOrDict, GenerateContentResponse, Part, PartUnionDict
|
23
|
+
|
24
|
+
|
25
|
+
if sys.version_info >= (3, 10):
|
26
|
+
from typing import TypeGuard
|
27
|
+
else:
|
28
|
+
from typing_extensions import TypeGuard
|
22
29
|
|
23
30
|
|
24
31
|
def _validate_content(content: Content) -> bool:
|
@@ -81,8 +88,7 @@ def _extract_curated_history(
|
|
81
88
|
while i < length:
|
82
89
|
if comprehensive_history[i].role not in ["user", "model"]:
|
83
90
|
raise ValueError(
|
84
|
-
"Role must be user or model, but got"
|
85
|
-
f" {comprehensive_history[i].role}"
|
91
|
+
f"Role must be user or model, but got {comprehensive_history[i].role}"
|
86
92
|
)
|
87
93
|
|
88
94
|
if comprehensive_history[i].role == "user":
|
@@ -108,12 +114,10 @@ class _BaseChat:
|
|
108
114
|
def __init__(
|
109
115
|
self,
|
110
116
|
*,
|
111
|
-
modules: Union[Models, AsyncModels],
|
112
117
|
model: str,
|
113
118
|
config: Optional[GenerateContentConfigOrDict] = None,
|
114
119
|
history: list[Content],
|
115
120
|
):
|
116
|
-
self._modules = modules
|
117
121
|
self._model = model
|
118
122
|
self._config = config
|
119
123
|
self._comprehensive_history = history
|
@@ -123,27 +127,32 @@ class _BaseChat:
|
|
123
127
|
"""Curated history is the set of valid turns that will be used in the subsequent send requests.
|
124
128
|
"""
|
125
129
|
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
130
|
+
def record_history(
|
131
|
+
self,
|
132
|
+
user_input: Content,
|
133
|
+
model_output: list[Content],
|
134
|
+
automatic_function_calling_history: list[Content],
|
135
|
+
is_valid: bool,
|
136
|
+
):
|
131
137
|
"""Records the chat history.
|
132
138
|
|
133
139
|
Maintaining both comprehensive and curated histories.
|
134
140
|
|
135
141
|
Args:
|
136
142
|
user_input: The user's input content.
|
137
|
-
model_output: A list of `Content` from the model's response.
|
138
|
-
|
139
|
-
automatic_function_calling_history: A list of `Content` representing
|
140
|
-
|
141
|
-
|
143
|
+
model_output: A list of `Content` from the model's response. This can be
|
144
|
+
an empty list if the model produced no output.
|
145
|
+
automatic_function_calling_history: A list of `Content` representing the
|
146
|
+
history of automatic function calls, including the user input as the
|
147
|
+
first entry.
|
142
148
|
is_valid: A boolean flag indicating whether the current model output is
|
143
149
|
considered valid.
|
144
150
|
"""
|
145
151
|
input_contents = (
|
146
|
-
|
152
|
+
# Because the AFC input contains the entire curated chat history in
|
153
|
+
# addition to the new user input, we need to truncate the AFC history
|
154
|
+
# to deduplicate the existing chat history.
|
155
|
+
automatic_function_calling_history[len(self._curated_history):]
|
147
156
|
if automatic_function_calling_history
|
148
157
|
else [user_input]
|
149
158
|
)
|
@@ -158,14 +167,13 @@ class _BaseChat:
|
|
158
167
|
self._curated_history.extend(input_contents)
|
159
168
|
self._curated_history.extend(output_contents)
|
160
169
|
|
161
|
-
|
162
170
|
def get_history(self, curated: bool = False) -> list[Content]:
|
163
171
|
"""Returns the chat history.
|
164
172
|
|
165
173
|
Args:
|
166
|
-
curated: A boolean flag indicating whether to return the curated
|
167
|
-
|
168
|
-
|
174
|
+
curated: A boolean flag indicating whether to return the curated (valid)
|
175
|
+
history or the comprehensive (all turns) history. Defaults to False
|
176
|
+
(returns the comprehensive history).
|
169
177
|
|
170
178
|
Returns:
|
171
179
|
A list of `Content` objects representing the chat history.
|
@@ -176,9 +184,41 @@ class _BaseChat:
|
|
176
184
|
return self._comprehensive_history
|
177
185
|
|
178
186
|
|
187
|
+
def _is_part_type(
|
188
|
+
contents: Union[list[PartUnionDict], PartUnionDict],
|
189
|
+
) -> TypeGuard[t.ContentType]:
|
190
|
+
if isinstance(contents, list):
|
191
|
+
return all(_is_part_type(part) for part in contents)
|
192
|
+
else:
|
193
|
+
allowed_part_types = get_args(types.PartUnion)
|
194
|
+
if type(contents) in allowed_part_types:
|
195
|
+
return True
|
196
|
+
else:
|
197
|
+
# Some images don't pass isinstance(item, PIL.Image.Image)
|
198
|
+
# For example <class 'PIL.JpegImagePlugin.JpegImageFile'>
|
199
|
+
if types.PIL_Image is not None and isinstance(contents, types.PIL_Image):
|
200
|
+
return True
|
201
|
+
return False
|
202
|
+
|
203
|
+
|
179
204
|
class Chat(_BaseChat):
|
180
205
|
"""Chat session."""
|
181
206
|
|
207
|
+
def __init__(
|
208
|
+
self,
|
209
|
+
*,
|
210
|
+
modules: Models,
|
211
|
+
model: str,
|
212
|
+
config: Optional[GenerateContentConfigOrDict] = None,
|
213
|
+
history: list[Content],
|
214
|
+
):
|
215
|
+
self._modules = modules
|
216
|
+
super().__init__(
|
217
|
+
model=model,
|
218
|
+
config=config,
|
219
|
+
history=history,
|
220
|
+
)
|
221
|
+
|
182
222
|
def send_message(
|
183
223
|
self,
|
184
224
|
message: Union[list[PartUnionDict], PartUnionDict],
|
@@ -202,10 +242,15 @@ class Chat(_BaseChat):
|
|
202
242
|
response = chat.send_message('tell me a story')
|
203
243
|
"""
|
204
244
|
|
245
|
+
if not _is_part_type(message):
|
246
|
+
raise ValueError(
|
247
|
+
f"Message must be a valid part type: {types.PartUnion} or"
|
248
|
+
f" {types.PartUnionDict}, got {type(message)}"
|
249
|
+
)
|
205
250
|
input_content = t.t_content(self._modules._api_client, message)
|
206
251
|
response = self._modules.generate_content(
|
207
252
|
model=self._model,
|
208
|
-
contents=self._curated_history + [input_content],
|
253
|
+
contents=self._curated_history + [input_content], # type: ignore[arg-type]
|
209
254
|
config=config if config else self._config,
|
210
255
|
)
|
211
256
|
model_output = (
|
@@ -213,10 +258,15 @@ class Chat(_BaseChat):
|
|
213
258
|
if response.candidates and response.candidates[0].content
|
214
259
|
else []
|
215
260
|
)
|
261
|
+
automatic_function_calling_history = (
|
262
|
+
response.automatic_function_calling_history
|
263
|
+
if response.automatic_function_calling_history
|
264
|
+
else []
|
265
|
+
)
|
216
266
|
self.record_history(
|
217
267
|
user_input=input_content,
|
218
268
|
model_output=model_output,
|
219
|
-
automatic_function_calling_history=
|
269
|
+
automatic_function_calling_history=automatic_function_calling_history,
|
220
270
|
is_valid=_validate_response(response),
|
221
271
|
)
|
222
272
|
return response
|
@@ -245,29 +295,42 @@ class Chat(_BaseChat):
|
|
245
295
|
print(chunk.text)
|
246
296
|
"""
|
247
297
|
|
298
|
+
if not _is_part_type(message):
|
299
|
+
raise ValueError(
|
300
|
+
f"Message must be a valid part type: {types.PartUnion} or"
|
301
|
+
f" {types.PartUnionDict}, got {type(message)}"
|
302
|
+
)
|
248
303
|
input_content = t.t_content(self._modules._api_client, message)
|
249
304
|
output_contents = []
|
250
305
|
finish_reason = None
|
251
306
|
is_valid = True
|
252
307
|
chunk = None
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
|
270
|
-
|
308
|
+
if isinstance(self._modules, Models):
|
309
|
+
for chunk in self._modules.generate_content_stream(
|
310
|
+
model=self._model,
|
311
|
+
contents=self._curated_history + [input_content], # type: ignore[arg-type]
|
312
|
+
config=config if config else self._config,
|
313
|
+
):
|
314
|
+
if not _validate_response(chunk):
|
315
|
+
is_valid = False
|
316
|
+
if chunk.candidates and chunk.candidates[0].content:
|
317
|
+
output_contents.append(chunk.candidates[0].content)
|
318
|
+
if chunk.candidates and chunk.candidates[0].finish_reason:
|
319
|
+
finish_reason = chunk.candidates[0].finish_reason
|
320
|
+
yield chunk
|
321
|
+
automatic_function_calling_history = (
|
322
|
+
chunk.automatic_function_calling_history
|
323
|
+
if chunk.automatic_function_calling_history
|
324
|
+
else []
|
325
|
+
)
|
326
|
+
self.record_history(
|
327
|
+
user_input=input_content,
|
328
|
+
model_output=output_contents,
|
329
|
+
automatic_function_calling_history=automatic_function_calling_history,
|
330
|
+
is_valid=is_valid
|
331
|
+
and output_contents is not None
|
332
|
+
and finish_reason is not None,
|
333
|
+
)
|
271
334
|
|
272
335
|
|
273
336
|
class Chats:
|
@@ -304,6 +367,21 @@ class Chats:
|
|
304
367
|
class AsyncChat(_BaseChat):
|
305
368
|
"""Async chat session."""
|
306
369
|
|
370
|
+
def __init__(
|
371
|
+
self,
|
372
|
+
*,
|
373
|
+
modules: AsyncModels,
|
374
|
+
model: str,
|
375
|
+
config: Optional[GenerateContentConfigOrDict] = None,
|
376
|
+
history: list[Content],
|
377
|
+
):
|
378
|
+
self._modules = modules
|
379
|
+
super().__init__(
|
380
|
+
model=model,
|
381
|
+
config=config,
|
382
|
+
history=history,
|
383
|
+
)
|
384
|
+
|
307
385
|
async def send_message(
|
308
386
|
self,
|
309
387
|
message: Union[list[PartUnionDict], PartUnionDict],
|
@@ -326,11 +404,15 @@ class AsyncChat(_BaseChat):
|
|
326
404
|
chat = client.aio.chats.create(model='gemini-1.5-flash')
|
327
405
|
response = await chat.send_message('tell me a story')
|
328
406
|
"""
|
329
|
-
|
407
|
+
if not _is_part_type(message):
|
408
|
+
raise ValueError(
|
409
|
+
f"Message must be a valid part type: {types.PartUnion} or"
|
410
|
+
f" {types.PartUnionDict}, got {type(message)}"
|
411
|
+
)
|
330
412
|
input_content = t.t_content(self._modules._api_client, message)
|
331
413
|
response = await self._modules.generate_content(
|
332
414
|
model=self._model,
|
333
|
-
contents=self._curated_history + [input_content],
|
415
|
+
contents=self._curated_history + [input_content], # type: ignore[arg-type]
|
334
416
|
config=config if config else self._config,
|
335
417
|
)
|
336
418
|
model_output = (
|
@@ -338,10 +420,15 @@ class AsyncChat(_BaseChat):
|
|
338
420
|
if response.candidates and response.candidates[0].content
|
339
421
|
else []
|
340
422
|
)
|
423
|
+
automatic_function_calling_history = (
|
424
|
+
response.automatic_function_calling_history
|
425
|
+
if response.automatic_function_calling_history
|
426
|
+
else []
|
427
|
+
)
|
341
428
|
self.record_history(
|
342
429
|
user_input=input_content,
|
343
430
|
model_output=model_output,
|
344
|
-
automatic_function_calling_history=
|
431
|
+
automatic_function_calling_history=automatic_function_calling_history,
|
345
432
|
is_valid=_validate_response(response),
|
346
433
|
)
|
347
434
|
return response
|
@@ -369,6 +456,11 @@ class AsyncChat(_BaseChat):
|
|
369
456
|
print(chunk.text)
|
370
457
|
"""
|
371
458
|
|
459
|
+
if not _is_part_type(message):
|
460
|
+
raise ValueError(
|
461
|
+
f"Message must be a valid part type: {types.PartUnion} or"
|
462
|
+
f" {types.PartUnionDict}, got {type(message)}"
|
463
|
+
)
|
372
464
|
input_content = t.t_content(self._modules._api_client, message)
|
373
465
|
|
374
466
|
async def async_generator():
|
@@ -394,7 +486,6 @@ class AsyncChat(_BaseChat):
|
|
394
486
|
model_output=output_contents,
|
395
487
|
automatic_function_calling_history=chunk.automatic_function_calling_history,
|
396
488
|
is_valid=is_valid and output_contents and finish_reason,
|
397
|
-
|
398
489
|
)
|
399
490
|
return async_generator()
|
400
491
|
|
google/genai/client.py
CHANGED
@@ -130,8 +130,9 @@ class Client:
|
|
130
130
|
from environment variables. Applies to the Vertex AI API only.
|
131
131
|
debug_config: Config settings that control network behavior of the client.
|
132
132
|
This is typically used when running test code.
|
133
|
-
http_options: Http options to use for the client.
|
134
|
-
|
133
|
+
http_options: Http options to use for the client. These options will be
|
134
|
+
applied to all requests made by the client. Example usage:
|
135
|
+
`client = genai.Client(http_options=types.HttpOptions(api_version='v1'))`.
|
135
136
|
|
136
137
|
Usage for the Gemini Developer API:
|
137
138
|
|
google/genai/errors.py
CHANGED
@@ -18,7 +18,6 @@
|
|
18
18
|
from typing import Any, Optional, TYPE_CHECKING, Union
|
19
19
|
import httpx
|
20
20
|
import json
|
21
|
-
import requests
|
22
21
|
|
23
22
|
|
24
23
|
if TYPE_CHECKING:
|
@@ -28,7 +27,7 @@ if TYPE_CHECKING:
|
|
28
27
|
class APIError(Exception):
|
29
28
|
"""General errors raised by the GenAI API."""
|
30
29
|
code: int
|
31
|
-
response: Union[
|
30
|
+
response: Union['ReplayResponse', httpx.Response]
|
32
31
|
|
33
32
|
status: Optional[str] = None
|
34
33
|
message: Optional[str] = None
|
@@ -36,28 +35,21 @@ class APIError(Exception):
|
|
36
35
|
def __init__(
|
37
36
|
self,
|
38
37
|
code: int,
|
39
|
-
response: Union[
|
38
|
+
response: Union['ReplayResponse', httpx.Response],
|
40
39
|
):
|
41
40
|
self.response = response
|
42
|
-
|
43
|
-
if isinstance(response,
|
41
|
+
message = None
|
42
|
+
if isinstance(response, httpx.Response):
|
44
43
|
try:
|
45
|
-
# do not do any extra muanipulation on the response.
|
46
|
-
# return the raw response json as is.
|
47
44
|
response_json = response.json()
|
48
|
-
except
|
45
|
+
except (json.decoder.JSONDecodeError):
|
46
|
+
message = response.text
|
49
47
|
response_json = {
|
50
|
-
'message':
|
51
|
-
'status': response.
|
48
|
+
'message': message,
|
49
|
+
'status': response.reason_phrase,
|
52
50
|
}
|
53
|
-
|
54
|
-
|
55
|
-
response_json = response.json()
|
56
|
-
except (json.decoder.JSONDecodeError, httpx.ResponseNotRead):
|
57
|
-
try:
|
58
|
-
message = response.text
|
59
|
-
except httpx.ResponseNotRead:
|
60
|
-
message = None
|
51
|
+
except httpx.ResponseNotRead:
|
52
|
+
message = 'Response not read'
|
61
53
|
response_json = {
|
62
54
|
'message': message,
|
63
55
|
'status': response.reason_phrase,
|
@@ -103,7 +95,7 @@ class APIError(Exception):
|
|
103
95
|
|
104
96
|
@classmethod
|
105
97
|
def raise_for_response(
|
106
|
-
cls, response: Union[
|
98
|
+
cls, response: Union['ReplayResponse', httpx.Response]
|
107
99
|
):
|
108
100
|
"""Raises an error with detailed error message if the response has an error status."""
|
109
101
|
if response.status_code == 200:
|
google/genai/files.py
CHANGED
@@ -826,7 +826,7 @@ class Files(_api_module.BaseModule):
|
|
826
826
|
'Vertex AI does not support creating files. You can upload files to'
|
827
827
|
' GCS files instead.'
|
828
828
|
)
|
829
|
-
config_model =
|
829
|
+
config_model = types.UploadFileConfig()
|
830
830
|
if config:
|
831
831
|
if isinstance(config, dict):
|
832
832
|
config_model = types.UploadFileConfig(**config)
|
@@ -888,13 +888,13 @@ class Files(_api_module.BaseModule):
|
|
888
888
|
|
889
889
|
if (
|
890
890
|
response.http_headers is None
|
891
|
-
or '
|
891
|
+
or 'x-goog-upload-url' not in response.http_headers
|
892
892
|
):
|
893
893
|
raise KeyError(
|
894
894
|
'Failed to create file. Upload URL did not returned from the create'
|
895
895
|
' file request.'
|
896
896
|
)
|
897
|
-
upload_url = response.http_headers['
|
897
|
+
upload_url = response.http_headers['x-goog-upload-url']
|
898
898
|
|
899
899
|
if isinstance(file, io.IOBase):
|
900
900
|
return_file = self._api_client.upload_file(
|
@@ -907,7 +907,7 @@ class Files(_api_module.BaseModule):
|
|
907
907
|
|
908
908
|
return types.File._from_response(
|
909
909
|
response=_File_from_mldev(self._api_client, return_file['file']),
|
910
|
-
kwargs=
|
910
|
+
kwargs=config_model.model_dump() if config else {},
|
911
911
|
)
|
912
912
|
|
913
913
|
def list(
|
@@ -979,7 +979,7 @@ class Files(_api_module.BaseModule):
|
|
979
979
|
'downloaded. You can tell which files are downloadable by checking '
|
980
980
|
'the `source` or `download_uri` property.'
|
981
981
|
)
|
982
|
-
name = t.t_file_name(self, file)
|
982
|
+
name = t.t_file_name(self._api_client, file)
|
983
983
|
|
984
984
|
path = f'files/{name}:download'
|
985
985
|
|
@@ -996,7 +996,7 @@ class Files(_api_module.BaseModule):
|
|
996
996
|
|
997
997
|
if isinstance(file, types.Video):
|
998
998
|
file.video_bytes = data
|
999
|
-
elif isinstance(file, types.GeneratedVideo):
|
999
|
+
elif isinstance(file, types.GeneratedVideo) and file.video is not None:
|
1000
1000
|
file.video.video_bytes = data
|
1001
1001
|
|
1002
1002
|
return data
|
@@ -1293,7 +1293,7 @@ class AsyncFiles(_api_module.BaseModule):
|
|
1293
1293
|
'Vertex AI does not support creating files. You can upload files to'
|
1294
1294
|
' GCS files instead.'
|
1295
1295
|
)
|
1296
|
-
config_model =
|
1296
|
+
config_model = types.UploadFileConfig()
|
1297
1297
|
if config:
|
1298
1298
|
if isinstance(config, dict):
|
1299
1299
|
config_model = types.UploadFileConfig(**config)
|
@@ -1373,7 +1373,7 @@ class AsyncFiles(_api_module.BaseModule):
|
|
1373
1373
|
|
1374
1374
|
return types.File._from_response(
|
1375
1375
|
response=_File_from_mldev(self._api_client, return_file['file']),
|
1376
|
-
kwargs=
|
1376
|
+
kwargs=config_model.model_dump() if config else {},
|
1377
1377
|
)
|
1378
1378
|
|
1379
1379
|
async def list(
|
@@ -1433,7 +1433,7 @@ class AsyncFiles(_api_module.BaseModule):
|
|
1433
1433
|
else:
|
1434
1434
|
config_model = config
|
1435
1435
|
|
1436
|
-
name = t.t_file_name(self, file)
|
1436
|
+
name = t.t_file_name(self._api_client, file)
|
1437
1437
|
|
1438
1438
|
path = f'files/{name}:download'
|
1439
1439
|
|