google-genai 1.7.0__py3-none-any.whl → 1.53.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google/genai/__init__.py +4 -2
- google/genai/_adapters.py +55 -0
- google/genai/_api_client.py +1301 -299
- google/genai/_api_module.py +1 -1
- google/genai/_automatic_function_calling_util.py +54 -33
- google/genai/_base_transformers.py +26 -0
- google/genai/_base_url.py +50 -0
- google/genai/_common.py +560 -59
- google/genai/_extra_utils.py +371 -38
- google/genai/_live_converters.py +1467 -0
- google/genai/_local_tokenizer_loader.py +214 -0
- google/genai/_mcp_utils.py +117 -0
- google/genai/_operations_converters.py +394 -0
- google/genai/_replay_api_client.py +204 -92
- google/genai/_test_api_client.py +1 -1
- google/genai/_tokens_converters.py +520 -0
- google/genai/_transformers.py +633 -233
- google/genai/batches.py +1733 -538
- google/genai/caches.py +678 -1012
- google/genai/chats.py +48 -38
- google/genai/client.py +142 -15
- google/genai/documents.py +532 -0
- google/genai/errors.py +141 -35
- google/genai/file_search_stores.py +1296 -0
- google/genai/files.py +312 -744
- google/genai/live.py +617 -367
- google/genai/live_music.py +197 -0
- google/genai/local_tokenizer.py +395 -0
- google/genai/models.py +3598 -3116
- google/genai/operations.py +201 -362
- google/genai/pagers.py +23 -7
- google/genai/py.typed +1 -0
- google/genai/tokens.py +362 -0
- google/genai/tunings.py +1274 -496
- google/genai/types.py +14535 -5454
- google/genai/version.py +2 -2
- {google_genai-1.7.0.dist-info → google_genai-1.53.0.dist-info}/METADATA +736 -234
- google_genai-1.53.0.dist-info/RECORD +41 -0
- {google_genai-1.7.0.dist-info → google_genai-1.53.0.dist-info}/WHEEL +1 -1
- google_genai-1.7.0.dist-info/RECORD +0 -27
- {google_genai-1.7.0.dist-info → google_genai-1.53.0.dist-info/licenses}/LICENSE +0 -0
- {google_genai-1.7.0.dist-info → google_genai-1.53.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,197 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
#
|
|
15
|
+
|
|
16
|
+
"""[Experimental] Live Music API client."""
|
|
17
|
+
|
|
18
|
+
import contextlib
|
|
19
|
+
import json
|
|
20
|
+
import logging
|
|
21
|
+
from typing import AsyncIterator
|
|
22
|
+
|
|
23
|
+
from . import _api_module
|
|
24
|
+
from . import _common
|
|
25
|
+
from . import _live_converters as live_converters
|
|
26
|
+
from . import _transformers as t
|
|
27
|
+
from . import types
|
|
28
|
+
from ._api_client import BaseApiClient
|
|
29
|
+
from ._common import set_value_by_path as setv
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
try:
|
|
33
|
+
from websockets.asyncio.client import ClientConnection
|
|
34
|
+
from websockets.asyncio.client import connect
|
|
35
|
+
except ModuleNotFoundError:
|
|
36
|
+
from websockets.client import ClientConnection # type: ignore
|
|
37
|
+
from websockets.client import connect # type: ignore
|
|
38
|
+
|
|
39
|
+
logger = logging.getLogger('google_genai.live_music')
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class AsyncMusicSession:
|
|
43
|
+
"""[Experimental] AsyncMusicSession."""
|
|
44
|
+
|
|
45
|
+
def __init__(self, api_client: BaseApiClient, websocket: ClientConnection):
|
|
46
|
+
self._api_client = api_client
|
|
47
|
+
self._ws = websocket
|
|
48
|
+
|
|
49
|
+
async def set_weighted_prompts(
|
|
50
|
+
self, prompts: list[types.WeightedPrompt]
|
|
51
|
+
) -> None:
|
|
52
|
+
if self._api_client.vertexai:
|
|
53
|
+
raise NotImplementedError(
|
|
54
|
+
'Live music generation is not supported in Vertex AI.'
|
|
55
|
+
)
|
|
56
|
+
else:
|
|
57
|
+
client_content_dict = {
|
|
58
|
+
'weightedPrompts': [
|
|
59
|
+
_common.convert_to_dict(prompt, convert_keys=True)
|
|
60
|
+
for prompt in prompts
|
|
61
|
+
]
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
await self._ws.send(json.dumps({'clientContent': client_content_dict}))
|
|
65
|
+
|
|
66
|
+
async def set_music_generation_config(
|
|
67
|
+
self, config: types.LiveMusicGenerationConfig
|
|
68
|
+
) -> None:
|
|
69
|
+
if self._api_client.vertexai:
|
|
70
|
+
raise NotImplementedError(
|
|
71
|
+
'Live music generation is not supported in Vertex AI.'
|
|
72
|
+
)
|
|
73
|
+
else:
|
|
74
|
+
config_dict = _common.convert_to_dict(config, convert_keys=True)
|
|
75
|
+
await self._ws.send(json.dumps({'musicGenerationConfig': config_dict}))
|
|
76
|
+
|
|
77
|
+
async def _send_control_signal(
|
|
78
|
+
self, playback_control: types.LiveMusicPlaybackControl
|
|
79
|
+
) -> None:
|
|
80
|
+
if self._api_client.vertexai:
|
|
81
|
+
raise NotImplementedError(
|
|
82
|
+
'Live music generation is not supported in Vertex AI.'
|
|
83
|
+
)
|
|
84
|
+
else:
|
|
85
|
+
playback_control_dict = {'playbackControl': playback_control.value}
|
|
86
|
+
await self._ws.send(json.dumps(playback_control_dict))
|
|
87
|
+
|
|
88
|
+
async def play(self) -> None:
|
|
89
|
+
"""Sends playback signal to start the music stream."""
|
|
90
|
+
return await self._send_control_signal(types.LiveMusicPlaybackControl.PLAY)
|
|
91
|
+
|
|
92
|
+
async def pause(self) -> None:
|
|
93
|
+
"""Sends a playback signal to pause the music stream."""
|
|
94
|
+
return await self._send_control_signal(types.LiveMusicPlaybackControl.PAUSE)
|
|
95
|
+
|
|
96
|
+
async def stop(self) -> None:
|
|
97
|
+
"""Sends a playback signal to stop the music stream.
|
|
98
|
+
|
|
99
|
+
Resets the music generation context while retaining the current config.
|
|
100
|
+
"""
|
|
101
|
+
return await self._send_control_signal(types.LiveMusicPlaybackControl.STOP)
|
|
102
|
+
|
|
103
|
+
async def reset_context(self) -> None:
|
|
104
|
+
"""Reset the context (prompts retained) without stopping the music generation."""
|
|
105
|
+
return await self._send_control_signal(
|
|
106
|
+
types.LiveMusicPlaybackControl.RESET_CONTEXT
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
async def receive(self) -> AsyncIterator[types.LiveMusicServerMessage]:
|
|
110
|
+
"""Receive model responses from the server.
|
|
111
|
+
|
|
112
|
+
Yields:
|
|
113
|
+
The audio chunks from the server.
|
|
114
|
+
"""
|
|
115
|
+
# TODO(b/365983264) Handle intermittent issues for the user.
|
|
116
|
+
while result := await self._receive():
|
|
117
|
+
yield result
|
|
118
|
+
|
|
119
|
+
async def _receive(self) -> types.LiveMusicServerMessage:
|
|
120
|
+
parameter_model = types.LiveMusicServerMessage()
|
|
121
|
+
try:
|
|
122
|
+
raw_response = await self._ws.recv(decode=False)
|
|
123
|
+
except TypeError:
|
|
124
|
+
raw_response = await self._ws.recv() # type: ignore[assignment]
|
|
125
|
+
if raw_response:
|
|
126
|
+
try:
|
|
127
|
+
response = json.loads(raw_response)
|
|
128
|
+
except json.decoder.JSONDecodeError:
|
|
129
|
+
raise ValueError(f'Failed to parse response: {raw_response!r}')
|
|
130
|
+
else:
|
|
131
|
+
response = {}
|
|
132
|
+
|
|
133
|
+
if self._api_client.vertexai:
|
|
134
|
+
raise NotImplementedError('Live music generation is not supported in Vertex AI.')
|
|
135
|
+
else:
|
|
136
|
+
response_dict = response
|
|
137
|
+
|
|
138
|
+
return types.LiveMusicServerMessage._from_response(
|
|
139
|
+
response=response_dict, kwargs=parameter_model.model_dump()
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
async def close(self) -> None:
|
|
143
|
+
"""Closes the bi-directional stream and terminates the session."""
|
|
144
|
+
await self._ws.close()
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
class AsyncLiveMusic(_api_module.BaseModule):
|
|
148
|
+
"""[Experimental] Live music module.
|
|
149
|
+
|
|
150
|
+
Live music can be accessed via `client.aio.live.music`.
|
|
151
|
+
"""
|
|
152
|
+
|
|
153
|
+
@_common.experimental_warning(
|
|
154
|
+
'Realtime music generation is experimental and may change in future versions.'
|
|
155
|
+
)
|
|
156
|
+
@contextlib.asynccontextmanager
|
|
157
|
+
async def connect(self, *, model: str) -> AsyncIterator[AsyncMusicSession]:
|
|
158
|
+
"""[Experimental] Connect to the live music server."""
|
|
159
|
+
base_url = self._api_client._websocket_base_url()
|
|
160
|
+
if isinstance(base_url, bytes):
|
|
161
|
+
base_url = base_url.decode('utf-8')
|
|
162
|
+
transformed_model = t.t_model(self._api_client, model)
|
|
163
|
+
|
|
164
|
+
if self._api_client.api_key:
|
|
165
|
+
api_key = self._api_client.api_key
|
|
166
|
+
version = self._api_client._http_options.api_version
|
|
167
|
+
uri = f'{base_url}/ws/google.ai.generativelanguage.{version}.GenerativeService.BidiGenerateMusic?key={api_key}'
|
|
168
|
+
headers = self._api_client._http_options.headers
|
|
169
|
+
|
|
170
|
+
# Only mldev supported
|
|
171
|
+
request_dict = _common.convert_to_dict(
|
|
172
|
+
live_converters._LiveMusicConnectParameters_to_mldev(
|
|
173
|
+
from_object=types.LiveMusicConnectParameters(
|
|
174
|
+
model=transformed_model,
|
|
175
|
+
).model_dump(exclude_none=True)
|
|
176
|
+
)
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
setv(request_dict, ['setup', 'model'], transformed_model)
|
|
180
|
+
|
|
181
|
+
request = json.dumps(request_dict)
|
|
182
|
+
else:
|
|
183
|
+
raise NotImplementedError('Live music generation is not supported in Vertex AI.')
|
|
184
|
+
|
|
185
|
+
try:
|
|
186
|
+
async with connect(uri, additional_headers=headers) as ws:
|
|
187
|
+
await ws.send(request)
|
|
188
|
+
logger.info(await ws.recv(decode=False))
|
|
189
|
+
|
|
190
|
+
yield AsyncMusicSession(api_client=self._api_client, websocket=ws)
|
|
191
|
+
except TypeError:
|
|
192
|
+
# Try with the older websockets API
|
|
193
|
+
async with connect(uri, extra_headers=headers) as ws:
|
|
194
|
+
await ws.send(request)
|
|
195
|
+
logger.info(await ws.recv())
|
|
196
|
+
|
|
197
|
+
yield AsyncMusicSession(api_client=self._api_client, websocket=ws)
|
|
@@ -0,0 +1,395 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
#
|
|
15
|
+
|
|
16
|
+
"""[Experimental] Text Only Local Tokenizer."""
|
|
17
|
+
|
|
18
|
+
import logging
|
|
19
|
+
from typing import Any, Iterable
|
|
20
|
+
from typing import Optional, Union
|
|
21
|
+
|
|
22
|
+
from sentencepiece import sentencepiece_model_pb2
|
|
23
|
+
|
|
24
|
+
from . import _common
|
|
25
|
+
from . import _local_tokenizer_loader as loader
|
|
26
|
+
from . import _transformers as t
|
|
27
|
+
from . import types
|
|
28
|
+
|
|
29
|
+
logger = logging.getLogger("google_genai.local_tokenizer")
|
|
30
|
+
|
|
31
|
+
__all__ = [
|
|
32
|
+
"_parse_hex_byte",
|
|
33
|
+
"_token_str_to_bytes",
|
|
34
|
+
"LocalTokenizer",
|
|
35
|
+
"_TextsAccumulator",
|
|
36
|
+
]
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class _TextsAccumulator:
|
|
40
|
+
"""Accumulates countable texts from `Content` and `Tool` objects.
|
|
41
|
+
|
|
42
|
+
This class is responsible for traversing complex `Content` and `Tool`
|
|
43
|
+
objects and extracting all the text content that should be included when
|
|
44
|
+
calculating token counts.
|
|
45
|
+
|
|
46
|
+
A key feature of this class is its ability to detect unsupported fields in
|
|
47
|
+
`Content` objects. If a user provides a `Content` object with fields that
|
|
48
|
+
this local tokenizer doesn't recognize (e.g., new fields added in a future
|
|
49
|
+
API update), this class will log a warning.
|
|
50
|
+
|
|
51
|
+
The detection mechanism for `Content` objects works by recursively building
|
|
52
|
+
a "counted" version of the input object. This "counted" object only
|
|
53
|
+
contains the data that was successfully processed and added to the text
|
|
54
|
+
list for tokenization. After traversing the input, the original `Content`
|
|
55
|
+
object is compared to the "counted" object. If they don't match, it
|
|
56
|
+
signifies the presence of unsupported fields, and a warning is logged.
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
def __init__(self) -> None:
|
|
60
|
+
self._texts: list[str] = []
|
|
61
|
+
|
|
62
|
+
def get_texts(self) -> Iterable[str]:
|
|
63
|
+
return self._texts
|
|
64
|
+
|
|
65
|
+
def add_contents(self, contents: Iterable[types.Content]) -> None:
|
|
66
|
+
for content in contents:
|
|
67
|
+
self.add_content(content)
|
|
68
|
+
|
|
69
|
+
def add_content(self, content: types.Content) -> None:
|
|
70
|
+
counted_content = types.Content(parts=[], role=content.role)
|
|
71
|
+
if content.parts:
|
|
72
|
+
for part in content.parts:
|
|
73
|
+
assert counted_content.parts is not None
|
|
74
|
+
counted_part = types.Part()
|
|
75
|
+
if part.file_data is not None or part.inline_data is not None:
|
|
76
|
+
raise ValueError(
|
|
77
|
+
"LocalTokenizers do not support non-text content types."
|
|
78
|
+
)
|
|
79
|
+
if part.video_metadata is not None:
|
|
80
|
+
counted_part.video_metadata = part.video_metadata
|
|
81
|
+
if part.function_call is not None:
|
|
82
|
+
self.add_function_call(part.function_call)
|
|
83
|
+
counted_part.function_call = part.function_call
|
|
84
|
+
if part.function_response is not None:
|
|
85
|
+
self.add_function_response(part.function_response)
|
|
86
|
+
counted_part.function_response = part.function_response
|
|
87
|
+
if part.text is not None:
|
|
88
|
+
counted_part.text = part.text
|
|
89
|
+
self._texts.append(part.text)
|
|
90
|
+
counted_content.parts.append(counted_part)
|
|
91
|
+
|
|
92
|
+
if content.model_dump(exclude_none=True) != counted_content.model_dump(
|
|
93
|
+
exclude_none=True
|
|
94
|
+
):
|
|
95
|
+
logger.warning(
|
|
96
|
+
"Content contains unsupported types for token counting. Supported"
|
|
97
|
+
f" fields {counted_content}. Got {content}."
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
def add_function_call(self, function_call: types.FunctionCall) -> None:
|
|
101
|
+
"""Processes a function call and adds relevant text to the accumulator.
|
|
102
|
+
|
|
103
|
+
Args:
|
|
104
|
+
function_call: The function call to process.
|
|
105
|
+
"""
|
|
106
|
+
if function_call.name:
|
|
107
|
+
self._texts.append(function_call.name)
|
|
108
|
+
counted_function_call = types.FunctionCall(name=function_call.name)
|
|
109
|
+
if function_call.args:
|
|
110
|
+
counted_args = self._dict_traverse(function_call.args)
|
|
111
|
+
counted_function_call.args = counted_args
|
|
112
|
+
|
|
113
|
+
def add_tool(self, tool: types.Tool) -> types.Tool:
|
|
114
|
+
counted_tool = types.Tool(function_declarations=[])
|
|
115
|
+
if tool.function_declarations:
|
|
116
|
+
for function_declaration in tool.function_declarations:
|
|
117
|
+
counted_function_declaration = self._function_declaration_traverse(
|
|
118
|
+
function_declaration
|
|
119
|
+
)
|
|
120
|
+
if counted_tool.function_declarations is None:
|
|
121
|
+
counted_tool.function_declarations = []
|
|
122
|
+
counted_tool.function_declarations.append(counted_function_declaration)
|
|
123
|
+
|
|
124
|
+
return counted_tool
|
|
125
|
+
|
|
126
|
+
def add_tools(self, tools: Iterable[types.Tool]) -> None:
|
|
127
|
+
for tool in tools:
|
|
128
|
+
self.add_tool(tool)
|
|
129
|
+
|
|
130
|
+
def add_function_responses(
|
|
131
|
+
self, function_responses: Iterable[types.FunctionResponse]
|
|
132
|
+
) -> None:
|
|
133
|
+
for function_response in function_responses:
|
|
134
|
+
self.add_function_response(function_response)
|
|
135
|
+
|
|
136
|
+
def add_function_response(
|
|
137
|
+
self, function_response: types.FunctionResponse
|
|
138
|
+
) -> None:
|
|
139
|
+
counted_function_response = types.FunctionResponse()
|
|
140
|
+
if function_response.name:
|
|
141
|
+
self._texts.append(function_response.name)
|
|
142
|
+
counted_function_response.name = function_response.name
|
|
143
|
+
if function_response.response:
|
|
144
|
+
counted_response = self._dict_traverse(function_response.response)
|
|
145
|
+
counted_function_response.response = counted_response
|
|
146
|
+
|
|
147
|
+
def _function_declaration_traverse(
|
|
148
|
+
self, function_declaration: types.FunctionDeclaration
|
|
149
|
+
) -> types.FunctionDeclaration:
|
|
150
|
+
counted_function_declaration = types.FunctionDeclaration()
|
|
151
|
+
if function_declaration.name:
|
|
152
|
+
self._texts.append(function_declaration.name)
|
|
153
|
+
counted_function_declaration.name = function_declaration.name
|
|
154
|
+
if function_declaration.description:
|
|
155
|
+
self._texts.append(function_declaration.description)
|
|
156
|
+
counted_function_declaration.description = (
|
|
157
|
+
function_declaration.description
|
|
158
|
+
)
|
|
159
|
+
if function_declaration.parameters:
|
|
160
|
+
counted_parameters = self.add_schema(function_declaration.parameters)
|
|
161
|
+
counted_function_declaration.parameters = counted_parameters
|
|
162
|
+
if function_declaration.response:
|
|
163
|
+
counted_response = self.add_schema(function_declaration.response)
|
|
164
|
+
counted_function_declaration.response = counted_response
|
|
165
|
+
return counted_function_declaration
|
|
166
|
+
|
|
167
|
+
def add_schema(self, schema: types.Schema) -> types.Schema:
|
|
168
|
+
"""Processes a schema and adds relevant text to the accumulator.
|
|
169
|
+
|
|
170
|
+
Args:
|
|
171
|
+
schema: The schema to process.
|
|
172
|
+
|
|
173
|
+
Returns:
|
|
174
|
+
The new schema object with only countable fields.
|
|
175
|
+
"""
|
|
176
|
+
counted_schema = types.Schema()
|
|
177
|
+
if schema.type:
|
|
178
|
+
counted_schema.type = schema.type
|
|
179
|
+
if schema.title:
|
|
180
|
+
counted_schema.title = schema.title
|
|
181
|
+
if schema.default is not None:
|
|
182
|
+
counted_schema.default = schema.default
|
|
183
|
+
if schema.format:
|
|
184
|
+
self._texts.append(schema.format)
|
|
185
|
+
counted_schema.format = schema.format
|
|
186
|
+
if schema.description:
|
|
187
|
+
self._texts.append(schema.description)
|
|
188
|
+
counted_schema.description = schema.description
|
|
189
|
+
if schema.enum:
|
|
190
|
+
self._texts.extend(schema.enum)
|
|
191
|
+
counted_schema.enum = schema.enum
|
|
192
|
+
if schema.required:
|
|
193
|
+
self._texts.extend(schema.required)
|
|
194
|
+
counted_schema.required = schema.required
|
|
195
|
+
if schema.property_ordering:
|
|
196
|
+
counted_schema.property_ordering = schema.property_ordering
|
|
197
|
+
if schema.items:
|
|
198
|
+
counted_schema_items = self.add_schema(schema.items)
|
|
199
|
+
counted_schema.items = counted_schema_items
|
|
200
|
+
if schema.properties:
|
|
201
|
+
d = {}
|
|
202
|
+
for key, value in schema.properties.items():
|
|
203
|
+
self._texts.append(key)
|
|
204
|
+
counted_value = self.add_schema(value)
|
|
205
|
+
d[key] = counted_value
|
|
206
|
+
counted_schema.properties = d
|
|
207
|
+
if schema.example:
|
|
208
|
+
counted_schema_example = self._any_traverse(schema.example)
|
|
209
|
+
counted_schema.example = counted_schema_example
|
|
210
|
+
return counted_schema
|
|
211
|
+
|
|
212
|
+
def _dict_traverse(self, d: dict[str, Any]) -> dict[str, Any]:
|
|
213
|
+
"""Processes a dict and adds relevant text to the accumulator.
|
|
214
|
+
|
|
215
|
+
Args:
|
|
216
|
+
d: The dict to process.
|
|
217
|
+
|
|
218
|
+
Returns:
|
|
219
|
+
The new dict object with only countable fields.
|
|
220
|
+
"""
|
|
221
|
+
counted_dict = {}
|
|
222
|
+
self._texts.extend(list(d.keys()))
|
|
223
|
+
for key, val in d.items():
|
|
224
|
+
counted_dict[key] = self._any_traverse(val)
|
|
225
|
+
return counted_dict
|
|
226
|
+
|
|
227
|
+
def _any_traverse(self, value: Any) -> Any:
|
|
228
|
+
"""Processes a value and adds relevant text to the accumulator.
|
|
229
|
+
|
|
230
|
+
Args:
|
|
231
|
+
value: The value to process.
|
|
232
|
+
|
|
233
|
+
Returns:
|
|
234
|
+
The new value with only countable fields.
|
|
235
|
+
"""
|
|
236
|
+
if isinstance(value, str):
|
|
237
|
+
self._texts.append(value)
|
|
238
|
+
return value
|
|
239
|
+
elif isinstance(value, dict):
|
|
240
|
+
return self._dict_traverse(value)
|
|
241
|
+
elif isinstance(value, list):
|
|
242
|
+
return [self._any_traverse(item) for item in value]
|
|
243
|
+
else:
|
|
244
|
+
return value
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
def _token_str_to_bytes(
|
|
248
|
+
token: str, type: sentencepiece_model_pb2.ModelProto.SentencePiece.Type
|
|
249
|
+
) -> bytes:
|
|
250
|
+
if type == sentencepiece_model_pb2.ModelProto.SentencePiece.Type.BYTE:
|
|
251
|
+
return _parse_hex_byte(token).to_bytes(length=1, byteorder="big")
|
|
252
|
+
else:
|
|
253
|
+
return token.replace("▁", " ").encode("utf-8")
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
def _parse_hex_byte(token: str) -> int:
|
|
257
|
+
"""Parses a hex byte string of the form '<0xXX>' and returns the integer value.
|
|
258
|
+
|
|
259
|
+
Raises ValueError if the input is malformed or the byte value is invalid.
|
|
260
|
+
"""
|
|
261
|
+
|
|
262
|
+
if len(token) != 6:
|
|
263
|
+
raise ValueError(f"Invalid byte length: {token}")
|
|
264
|
+
if not token.startswith("<0x") or not token.endswith(">"):
|
|
265
|
+
raise ValueError(f"Invalid byte format: {token}")
|
|
266
|
+
|
|
267
|
+
try:
|
|
268
|
+
val = int(token[3:5], 16) # Parse the hex part directly
|
|
269
|
+
except ValueError:
|
|
270
|
+
raise ValueError(f"Invalid hex value: {token}")
|
|
271
|
+
|
|
272
|
+
if val >= 256:
|
|
273
|
+
raise ValueError(f"Byte value out of range: {token}")
|
|
274
|
+
|
|
275
|
+
return val
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
class LocalTokenizer:
|
|
279
|
+
"""[Experimental] Text Only Local Tokenizer.
|
|
280
|
+
|
|
281
|
+
This class provides a local tokenizer for text only token counting.
|
|
282
|
+
|
|
283
|
+
LIMITATIONS:
|
|
284
|
+
- Only supports text based tokenization and no multimodal tokenization.
|
|
285
|
+
- Forward compatibility depends on the open-source tokenizer models for future
|
|
286
|
+
Gemini versions.
|
|
287
|
+
- For token counting of tools and response schemas, the `LocalTokenizer` only
|
|
288
|
+
supports `types.Tool` and `types.Schema` objects. Python functions or Pydantic
|
|
289
|
+
models cannot be passed directly.
|
|
290
|
+
"""
|
|
291
|
+
|
|
292
|
+
def __init__(self, model_name: str):
|
|
293
|
+
self._tokenizer_name = loader.get_tokenizer_name(model_name)
|
|
294
|
+
self._model_proto = loader.load_model_proto(self._tokenizer_name)
|
|
295
|
+
self._tokenizer = loader.get_sentencepiece(self._tokenizer_name)
|
|
296
|
+
|
|
297
|
+
@_common.experimental_warning(
|
|
298
|
+
"The SDK's local tokenizer implementation is experimental and may change"
|
|
299
|
+
" in the future. It only supports text based tokenization."
|
|
300
|
+
)
|
|
301
|
+
def count_tokens(
|
|
302
|
+
self,
|
|
303
|
+
contents: Union[types.ContentListUnion, types.ContentListUnionDict],
|
|
304
|
+
*,
|
|
305
|
+
config: Optional[types.CountTokensConfigOrDict] = None,
|
|
306
|
+
) -> types.CountTokensResult:
|
|
307
|
+
"""Counts the number of tokens in a given text.
|
|
308
|
+
|
|
309
|
+
Args:
|
|
310
|
+
contents: The contents to tokenize.
|
|
311
|
+
config: The configuration for counting tokens.
|
|
312
|
+
|
|
313
|
+
Returns:
|
|
314
|
+
A `CountTokensResult` containing the total number of tokens.
|
|
315
|
+
|
|
316
|
+
Usage:
|
|
317
|
+
|
|
318
|
+
.. code-block:: python
|
|
319
|
+
|
|
320
|
+
from google import genai
|
|
321
|
+
tokenizer = genai.LocalTokenizer(model_name='gemini-2.0-flash-001')
|
|
322
|
+
result = tokenizer.count_tokens("What is your name?")
|
|
323
|
+
print(result)
|
|
324
|
+
# total_tokens=5
|
|
325
|
+
"""
|
|
326
|
+
processed_contents = t.t_contents(contents)
|
|
327
|
+
text_accumulator = _TextsAccumulator()
|
|
328
|
+
config = types.CountTokensConfig.model_validate(config or {})
|
|
329
|
+
text_accumulator.add_contents(processed_contents)
|
|
330
|
+
if config.tools:
|
|
331
|
+
text_accumulator.add_tools(config.tools)
|
|
332
|
+
if config.generation_config and config.generation_config.response_schema:
|
|
333
|
+
text_accumulator.add_schema(config.generation_config.response_schema)
|
|
334
|
+
if config.system_instruction:
|
|
335
|
+
text_accumulator.add_contents(t.t_contents([config.system_instruction]))
|
|
336
|
+
tokens_list = self._tokenizer.encode(list(text_accumulator.get_texts()))
|
|
337
|
+
return types.CountTokensResult(
|
|
338
|
+
total_tokens=sum(len(tokens) for tokens in tokens_list)
|
|
339
|
+
)
|
|
340
|
+
|
|
341
|
+
@_common.experimental_warning(
|
|
342
|
+
"The SDK's local tokenizer implementation is experimental and may change"
|
|
343
|
+
" in the future. It only supports text based tokenization."
|
|
344
|
+
)
|
|
345
|
+
def compute_tokens(
|
|
346
|
+
self,
|
|
347
|
+
contents: Union[types.ContentListUnion, types.ContentListUnionDict],
|
|
348
|
+
) -> types.ComputeTokensResult:
|
|
349
|
+
"""Computes the tokens ids and string pieces in the input.
|
|
350
|
+
|
|
351
|
+
Args:
|
|
352
|
+
contents: The contents to tokenize.
|
|
353
|
+
|
|
354
|
+
Returns:
|
|
355
|
+
A `ComputeTokensResult` containing the token information.
|
|
356
|
+
|
|
357
|
+
Usage:
|
|
358
|
+
|
|
359
|
+
.. code-block:: python
|
|
360
|
+
|
|
361
|
+
from google import genai
|
|
362
|
+
tokenizer = genai.LocalTokenizer(model_name='gemini-2.0-flash-001')
|
|
363
|
+
result = tokenizer.compute_tokens("What is your name?")
|
|
364
|
+
print(result)
|
|
365
|
+
# tokens_info=[TokensInfo(token_ids=[279, 329, 1313, 2508, 13], tokens=[b' What', b' is', b' your', b' name', b'?'], role='user')]
|
|
366
|
+
"""
|
|
367
|
+
processed_contents = t.t_contents(contents)
|
|
368
|
+
text_accumulator = _TextsAccumulator()
|
|
369
|
+
for content in processed_contents:
|
|
370
|
+
text_accumulator.add_content(content)
|
|
371
|
+
tokens_protos = self._tokenizer.EncodeAsImmutableProto(
|
|
372
|
+
text_accumulator.get_texts()
|
|
373
|
+
)
|
|
374
|
+
|
|
375
|
+
roles = []
|
|
376
|
+
for content in processed_contents:
|
|
377
|
+
if content.parts:
|
|
378
|
+
for _ in content.parts:
|
|
379
|
+
roles.append(content.role)
|
|
380
|
+
|
|
381
|
+
token_infos = []
|
|
382
|
+
for tokens_proto, role in zip(tokens_protos, roles):
|
|
383
|
+
token_infos.append(
|
|
384
|
+
types.TokensInfo(
|
|
385
|
+
token_ids=[piece.id for piece in tokens_proto.pieces],
|
|
386
|
+
tokens=[
|
|
387
|
+
_token_str_to_bytes(
|
|
388
|
+
piece.piece, self._model_proto.pieces[piece.id].type
|
|
389
|
+
)
|
|
390
|
+
for piece in tokens_proto.pieces
|
|
391
|
+
],
|
|
392
|
+
role=role,
|
|
393
|
+
)
|
|
394
|
+
)
|
|
395
|
+
return types.ComputeTokensResult(tokens_info=token_infos)
|