google-genai 1.33.0__py3-none-any.whl → 1.53.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google/genai/_api_client.py +361 -208
- google/genai/_common.py +260 -69
- google/genai/_extra_utils.py +142 -12
- google/genai/_live_converters.py +691 -2746
- google/genai/_local_tokenizer_loader.py +0 -9
- google/genai/_operations_converters.py +186 -99
- google/genai/_replay_api_client.py +48 -51
- google/genai/_tokens_converters.py +169 -489
- google/genai/_transformers.py +193 -90
- google/genai/batches.py +1014 -1307
- google/genai/caches.py +458 -1107
- google/genai/client.py +101 -0
- google/genai/documents.py +532 -0
- google/genai/errors.py +58 -4
- google/genai/file_search_stores.py +1296 -0
- google/genai/files.py +108 -358
- google/genai/live.py +90 -32
- google/genai/live_music.py +24 -27
- google/genai/local_tokenizer.py +36 -3
- google/genai/models.py +2308 -3375
- google/genai/operations.py +129 -21
- google/genai/pagers.py +7 -1
- google/genai/tokens.py +2 -12
- google/genai/tunings.py +770 -436
- google/genai/types.py +4341 -1218
- google/genai/version.py +1 -1
- {google_genai-1.33.0.dist-info → google_genai-1.53.0.dist-info}/METADATA +359 -201
- google_genai-1.53.0.dist-info/RECORD +41 -0
- google_genai-1.33.0.dist-info/RECORD +0 -39
- {google_genai-1.33.0.dist-info → google_genai-1.53.0.dist-info}/WHEEL +0 -0
- {google_genai-1.33.0.dist-info → google_genai-1.53.0.dist-info}/licenses/LICENSE +0 -0
- {google_genai-1.33.0.dist-info → google_genai-1.53.0.dist-info}/top_level.txt +0 -0
google/genai/live.py
CHANGED
|
@@ -21,7 +21,7 @@ import contextlib
|
|
|
21
21
|
import json
|
|
22
22
|
import logging
|
|
23
23
|
import typing
|
|
24
|
-
from typing import Any, AsyncIterator,
|
|
24
|
+
from typing import Any, AsyncIterator, Optional, Sequence, Union, get_args
|
|
25
25
|
import warnings
|
|
26
26
|
|
|
27
27
|
import google.auth
|
|
@@ -40,7 +40,6 @@ from ._common import get_value_by_path as getv
|
|
|
40
40
|
from ._common import set_value_by_path as setv
|
|
41
41
|
from .live_music import AsyncLiveMusic
|
|
42
42
|
from .models import _Content_to_mldev
|
|
43
|
-
from .models import _Content_to_vertex
|
|
44
43
|
|
|
45
44
|
|
|
46
45
|
try:
|
|
@@ -51,6 +50,11 @@ except ModuleNotFoundError:
|
|
|
51
50
|
from websockets.client import ClientConnection # type: ignore
|
|
52
51
|
from websockets.client import connect as ws_connect # type: ignore
|
|
53
52
|
|
|
53
|
+
try:
|
|
54
|
+
from google.auth.transport import requests
|
|
55
|
+
except ImportError:
|
|
56
|
+
requests = None # type: ignore[assignment]
|
|
57
|
+
|
|
54
58
|
if typing.TYPE_CHECKING:
|
|
55
59
|
from mcp import ClientSession as McpClientSession
|
|
56
60
|
from mcp.types import Tool as McpTool
|
|
@@ -82,9 +86,15 @@ _FUNCTION_RESPONSE_REQUIRES_ID = (
|
|
|
82
86
|
class AsyncSession:
|
|
83
87
|
"""[Preview] AsyncSession."""
|
|
84
88
|
|
|
85
|
-
def __init__(
|
|
89
|
+
def __init__(
|
|
90
|
+
self,
|
|
91
|
+
api_client: BaseApiClient,
|
|
92
|
+
websocket: ClientConnection,
|
|
93
|
+
session_id: Optional[str] = None,
|
|
94
|
+
):
|
|
86
95
|
self._api_client = api_client
|
|
87
96
|
self._ws = websocket
|
|
97
|
+
self.session_id = session_id
|
|
88
98
|
|
|
89
99
|
async def send(
|
|
90
100
|
self,
|
|
@@ -217,8 +227,8 @@ class AsyncSession:
|
|
|
217
227
|
)
|
|
218
228
|
|
|
219
229
|
if self._api_client.vertexai:
|
|
220
|
-
client_content_dict =
|
|
221
|
-
|
|
230
|
+
client_content_dict = _common.convert_to_dict(
|
|
231
|
+
client_content, convert_keys=True
|
|
222
232
|
)
|
|
223
233
|
else:
|
|
224
234
|
client_content_dict = live_converters._LiveClientContent_to_mldev(
|
|
@@ -404,12 +414,12 @@ class AsyncSession:
|
|
|
404
414
|
"""
|
|
405
415
|
tool_response = t.t_tool_response(function_responses)
|
|
406
416
|
if self._api_client.vertexai:
|
|
407
|
-
tool_response_dict =
|
|
408
|
-
|
|
417
|
+
tool_response_dict = _common.convert_to_dict(
|
|
418
|
+
tool_response, convert_keys=True
|
|
409
419
|
)
|
|
410
420
|
else:
|
|
411
|
-
tool_response_dict =
|
|
412
|
-
|
|
421
|
+
tool_response_dict = _common.convert_to_dict(
|
|
422
|
+
tool_response, convert_keys=True
|
|
413
423
|
)
|
|
414
424
|
for response in tool_response_dict.get('functionResponses', []):
|
|
415
425
|
if response.get('id') is None:
|
|
@@ -535,7 +545,7 @@ class AsyncSession:
|
|
|
535
545
|
if self._api_client.vertexai:
|
|
536
546
|
response_dict = live_converters._LiveServerMessage_from_vertex(response)
|
|
537
547
|
else:
|
|
538
|
-
response_dict =
|
|
548
|
+
response_dict = response
|
|
539
549
|
|
|
540
550
|
return types.LiveServerMessage._from_response(
|
|
541
551
|
response=response_dict, kwargs=parameter_model.model_dump()
|
|
@@ -649,7 +659,7 @@ class AsyncSession:
|
|
|
649
659
|
content_input_parts.append(item)
|
|
650
660
|
if self._api_client.vertexai:
|
|
651
661
|
contents = [
|
|
652
|
-
|
|
662
|
+
_common.convert_to_dict(item, convert_keys=True)
|
|
653
663
|
for item in t.t_contents(content_input_parts)
|
|
654
664
|
]
|
|
655
665
|
else:
|
|
@@ -975,7 +985,8 @@ class AsyncLive(_api_module.BaseModule):
|
|
|
975
985
|
api_key = self._api_client.api_key
|
|
976
986
|
version = self._api_client._http_options.api_version
|
|
977
987
|
uri = f'{base_url}/ws/google.cloud.aiplatform.{version}.LlmBidiService/BidiGenerateContent'
|
|
978
|
-
|
|
988
|
+
original_headers = self._api_client._http_options.headers
|
|
989
|
+
headers = original_headers.copy() if original_headers is not None else {}
|
|
979
990
|
|
|
980
991
|
request_dict = _common.convert_to_dict(
|
|
981
992
|
live_converters._LiveConnectParameters_to_vertex(
|
|
@@ -992,27 +1003,51 @@ class AsyncLive(_api_module.BaseModule):
|
|
|
992
1003
|
|
|
993
1004
|
request = json.dumps(request_dict)
|
|
994
1005
|
else:
|
|
995
|
-
|
|
996
|
-
|
|
997
|
-
|
|
998
|
-
|
|
1006
|
+
version = self._api_client._http_options.api_version
|
|
1007
|
+
has_sufficient_auth = (
|
|
1008
|
+
self._api_client.project and self._api_client.location
|
|
1009
|
+
)
|
|
1010
|
+
if self._api_client.custom_base_url and not has_sufficient_auth:
|
|
1011
|
+
# API gateway proxy can use the auth in custom headers, not url.
|
|
1012
|
+
# Enable custom url if auth is not sufficient.
|
|
1013
|
+
uri = self._api_client.custom_base_url
|
|
1014
|
+
# Keep the model as is.
|
|
1015
|
+
transformed_model = model
|
|
1016
|
+
# Do not get credentials for custom url.
|
|
1017
|
+
original_headers = self._api_client._http_options.headers
|
|
1018
|
+
headers = (
|
|
1019
|
+
original_headers.copy() if original_headers is not None else {}
|
|
999
1020
|
)
|
|
1021
|
+
|
|
1000
1022
|
else:
|
|
1001
|
-
|
|
1002
|
-
|
|
1003
|
-
|
|
1004
|
-
|
|
1005
|
-
|
|
1006
|
-
|
|
1007
|
-
|
|
1008
|
-
|
|
1009
|
-
|
|
1010
|
-
|
|
1011
|
-
|
|
1012
|
-
|
|
1023
|
+
uri = f'{base_url}/ws/google.cloud.aiplatform.{version}.LlmBidiService/BidiGenerateContent'
|
|
1024
|
+
|
|
1025
|
+
if not self._api_client._credentials:
|
|
1026
|
+
# Get bearer token through Application Default Credentials.
|
|
1027
|
+
creds, _ = google.auth.default( # type: ignore
|
|
1028
|
+
scopes=['https://www.googleapis.com/auth/cloud-platform']
|
|
1029
|
+
)
|
|
1030
|
+
else:
|
|
1031
|
+
creds = self._api_client._credentials
|
|
1032
|
+
# creds.valid is False, and creds.token is None
|
|
1033
|
+
# Need to refresh credentials to populate those
|
|
1034
|
+
if not (creds.token and creds.valid):
|
|
1035
|
+
if requests is None:
|
|
1036
|
+
raise ValueError('The requests module is required to refresh google-auth credentials. Please install with `pip install google-auth[requests]`')
|
|
1037
|
+
auth_req = requests.Request() # type: ignore
|
|
1038
|
+
creds.refresh(auth_req)
|
|
1039
|
+
bearer_token = creds.token
|
|
1040
|
+
|
|
1041
|
+
original_headers = self._api_client._http_options.headers
|
|
1042
|
+
headers = (
|
|
1043
|
+
original_headers.copy() if original_headers is not None else {}
|
|
1044
|
+
)
|
|
1045
|
+
if not headers.get('Authorization'):
|
|
1046
|
+
headers['Authorization'] = f'Bearer {bearer_token}'
|
|
1047
|
+
|
|
1013
1048
|
location = self._api_client.location
|
|
1014
1049
|
project = self._api_client.project
|
|
1015
|
-
if transformed_model.startswith('publishers/'):
|
|
1050
|
+
if transformed_model.startswith('publishers/') and project and location:
|
|
1016
1051
|
transformed_model = (
|
|
1017
1052
|
f'projects/{project}/locations/{location}/' + transformed_model
|
|
1018
1053
|
)
|
|
@@ -1054,11 +1089,34 @@ class AsyncLive(_api_module.BaseModule):
|
|
|
1054
1089
|
await ws.send(request)
|
|
1055
1090
|
try:
|
|
1056
1091
|
# websockets 14.0+
|
|
1057
|
-
|
|
1092
|
+
raw_response = await ws.recv(decode=False)
|
|
1058
1093
|
except TypeError:
|
|
1059
|
-
|
|
1094
|
+
raw_response = await ws.recv() # type: ignore[assignment]
|
|
1095
|
+
if raw_response:
|
|
1096
|
+
try:
|
|
1097
|
+
response = json.loads(raw_response)
|
|
1098
|
+
except json.decoder.JSONDecodeError:
|
|
1099
|
+
raise ValueError(f'Failed to parse response: {raw_response!r}')
|
|
1100
|
+
else:
|
|
1101
|
+
response = {}
|
|
1060
1102
|
|
|
1061
|
-
|
|
1103
|
+
if self._api_client.vertexai:
|
|
1104
|
+
response_dict = live_converters._LiveServerMessage_from_vertex(response)
|
|
1105
|
+
else:
|
|
1106
|
+
response_dict = response
|
|
1107
|
+
|
|
1108
|
+
setup_response = types.LiveServerMessage._from_response(
|
|
1109
|
+
response=response_dict, kwargs=parameter_model.model_dump()
|
|
1110
|
+
)
|
|
1111
|
+
if setup_response.setup_complete:
|
|
1112
|
+
session_id = setup_response.setup_complete.session_id
|
|
1113
|
+
else:
|
|
1114
|
+
session_id = None
|
|
1115
|
+
yield AsyncSession(
|
|
1116
|
+
api_client=self._api_client,
|
|
1117
|
+
websocket=ws,
|
|
1118
|
+
session_id=session_id,
|
|
1119
|
+
)
|
|
1062
1120
|
|
|
1063
1121
|
|
|
1064
1122
|
async def _t_live_connect_config(
|
google/genai/live_music.py
CHANGED
|
@@ -22,13 +22,11 @@ from typing import AsyncIterator
|
|
|
22
22
|
|
|
23
23
|
from . import _api_module
|
|
24
24
|
from . import _common
|
|
25
|
+
from . import _live_converters as live_converters
|
|
25
26
|
from . import _transformers as t
|
|
26
27
|
from . import types
|
|
27
28
|
from ._api_client import BaseApiClient
|
|
28
29
|
from ._common import set_value_by_path as setv
|
|
29
|
-
from . import _live_converters as live_converters
|
|
30
|
-
from .models import _Content_to_mldev
|
|
31
|
-
from .models import _Content_to_vertex
|
|
32
30
|
|
|
33
31
|
|
|
34
32
|
try:
|
|
@@ -44,46 +42,47 @@ logger = logging.getLogger('google_genai.live_music')
|
|
|
44
42
|
class AsyncMusicSession:
|
|
45
43
|
"""[Experimental] AsyncMusicSession."""
|
|
46
44
|
|
|
47
|
-
def __init__(
|
|
48
|
-
self, api_client: BaseApiClient, websocket: ClientConnection
|
|
49
|
-
):
|
|
45
|
+
def __init__(self, api_client: BaseApiClient, websocket: ClientConnection):
|
|
50
46
|
self._api_client = api_client
|
|
51
47
|
self._ws = websocket
|
|
52
48
|
|
|
53
49
|
async def set_weighted_prompts(
|
|
54
|
-
self,
|
|
55
|
-
prompts: list[types.WeightedPrompt]
|
|
50
|
+
self, prompts: list[types.WeightedPrompt]
|
|
56
51
|
) -> None:
|
|
57
52
|
if self._api_client.vertexai:
|
|
58
|
-
raise NotImplementedError(
|
|
59
|
-
|
|
60
|
-
client_content_dict = live_converters._LiveMusicClientContent_to_mldev(
|
|
61
|
-
from_object={'weighted_prompts': prompts}
|
|
53
|
+
raise NotImplementedError(
|
|
54
|
+
'Live music generation is not supported in Vertex AI.'
|
|
62
55
|
)
|
|
56
|
+
else:
|
|
57
|
+
client_content_dict = {
|
|
58
|
+
'weightedPrompts': [
|
|
59
|
+
_common.convert_to_dict(prompt, convert_keys=True)
|
|
60
|
+
for prompt in prompts
|
|
61
|
+
]
|
|
62
|
+
}
|
|
63
|
+
|
|
63
64
|
await self._ws.send(json.dumps({'clientContent': client_content_dict}))
|
|
64
65
|
|
|
65
66
|
async def set_music_generation_config(
|
|
66
|
-
self,
|
|
67
|
-
config: types.LiveMusicGenerationConfig
|
|
67
|
+
self, config: types.LiveMusicGenerationConfig
|
|
68
68
|
) -> None:
|
|
69
69
|
if self._api_client.vertexai:
|
|
70
|
-
raise NotImplementedError(
|
|
71
|
-
|
|
72
|
-
config_dict = live_converters._LiveMusicGenerationConfig_to_mldev(
|
|
73
|
-
from_object=config
|
|
70
|
+
raise NotImplementedError(
|
|
71
|
+
'Live music generation is not supported in Vertex AI.'
|
|
74
72
|
)
|
|
73
|
+
else:
|
|
74
|
+
config_dict = _common.convert_to_dict(config, convert_keys=True)
|
|
75
75
|
await self._ws.send(json.dumps({'musicGenerationConfig': config_dict}))
|
|
76
76
|
|
|
77
77
|
async def _send_control_signal(
|
|
78
|
-
self,
|
|
79
|
-
playback_control: types.LiveMusicPlaybackControl
|
|
78
|
+
self, playback_control: types.LiveMusicPlaybackControl
|
|
80
79
|
) -> None:
|
|
81
80
|
if self._api_client.vertexai:
|
|
82
|
-
raise NotImplementedError(
|
|
83
|
-
|
|
84
|
-
playback_control_dict = live_converters._LiveMusicClientMessage_to_mldev(
|
|
85
|
-
from_object={'playback_control': playback_control}
|
|
81
|
+
raise NotImplementedError(
|
|
82
|
+
'Live music generation is not supported in Vertex AI.'
|
|
86
83
|
)
|
|
84
|
+
else:
|
|
85
|
+
playback_control_dict = {'playbackControl': playback_control.value}
|
|
87
86
|
await self._ws.send(json.dumps(playback_control_dict))
|
|
88
87
|
|
|
89
88
|
async def play(self) -> None:
|
|
@@ -134,9 +133,7 @@ class AsyncMusicSession:
|
|
|
134
133
|
if self._api_client.vertexai:
|
|
135
134
|
raise NotImplementedError('Live music generation is not supported in Vertex AI.')
|
|
136
135
|
else:
|
|
137
|
-
response_dict =
|
|
138
|
-
response
|
|
139
|
-
)
|
|
136
|
+
response_dict = response
|
|
140
137
|
|
|
141
138
|
return types.LiveMusicServerMessage._from_response(
|
|
142
139
|
response=response_dict, kwargs=parameter_model.model_dump()
|
google/genai/local_tokenizer.py
CHANGED
|
@@ -25,11 +25,16 @@ from . import _common
|
|
|
25
25
|
from . import _local_tokenizer_loader as loader
|
|
26
26
|
from . import _transformers as t
|
|
27
27
|
from . import types
|
|
28
|
-
from . import types
|
|
29
|
-
from ._transformers import t_contents
|
|
30
28
|
|
|
31
29
|
logger = logging.getLogger("google_genai.local_tokenizer")
|
|
32
30
|
|
|
31
|
+
__all__ = [
|
|
32
|
+
"_parse_hex_byte",
|
|
33
|
+
"_token_str_to_bytes",
|
|
34
|
+
"LocalTokenizer",
|
|
35
|
+
"_TextsAccumulator",
|
|
36
|
+
]
|
|
37
|
+
|
|
33
38
|
|
|
34
39
|
class _TextsAccumulator:
|
|
35
40
|
"""Accumulates countable texts from `Content` and `Tool` objects.
|
|
@@ -303,9 +308,20 @@ class LocalTokenizer:
|
|
|
303
308
|
|
|
304
309
|
Args:
|
|
305
310
|
contents: The contents to tokenize.
|
|
311
|
+
config: The configuration for counting tokens.
|
|
306
312
|
|
|
307
313
|
Returns:
|
|
308
314
|
A `CountTokensResult` containing the total number of tokens.
|
|
315
|
+
|
|
316
|
+
Usage:
|
|
317
|
+
|
|
318
|
+
.. code-block:: python
|
|
319
|
+
|
|
320
|
+
from google import genai
|
|
321
|
+
tokenizer = genai.LocalTokenizer(model_name='gemini-2.0-flash-001')
|
|
322
|
+
result = tokenizer.count_tokens("What is your name?")
|
|
323
|
+
print(result)
|
|
324
|
+
# total_tokens=5
|
|
309
325
|
"""
|
|
310
326
|
processed_contents = t.t_contents(contents)
|
|
311
327
|
text_accumulator = _TextsAccumulator()
|
|
@@ -330,7 +346,24 @@ class LocalTokenizer:
|
|
|
330
346
|
self,
|
|
331
347
|
contents: Union[types.ContentListUnion, types.ContentListUnionDict],
|
|
332
348
|
) -> types.ComputeTokensResult:
|
|
333
|
-
"""Computes the tokens ids and string pieces in the input.
|
|
349
|
+
"""Computes the tokens ids and string pieces in the input.
|
|
350
|
+
|
|
351
|
+
Args:
|
|
352
|
+
contents: The contents to tokenize.
|
|
353
|
+
|
|
354
|
+
Returns:
|
|
355
|
+
A `ComputeTokensResult` containing the token information.
|
|
356
|
+
|
|
357
|
+
Usage:
|
|
358
|
+
|
|
359
|
+
.. code-block:: python
|
|
360
|
+
|
|
361
|
+
from google import genai
|
|
362
|
+
tokenizer = genai.LocalTokenizer(model_name='gemini-2.0-flash-001')
|
|
363
|
+
result = tokenizer.compute_tokens("What is your name?")
|
|
364
|
+
print(result)
|
|
365
|
+
# tokens_info=[TokensInfo(token_ids=[279, 329, 1313, 2508, 13], tokens=[b' What', b' is', b' your', b' name', b'?'], role='user')]
|
|
366
|
+
"""
|
|
334
367
|
processed_contents = t.t_contents(contents)
|
|
335
368
|
text_accumulator = _TextsAccumulator()
|
|
336
369
|
for content in processed_contents:
|