google-genai 1.7.0__py3-none-any.whl → 1.53.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. google/genai/__init__.py +4 -2
  2. google/genai/_adapters.py +55 -0
  3. google/genai/_api_client.py +1301 -299
  4. google/genai/_api_module.py +1 -1
  5. google/genai/_automatic_function_calling_util.py +54 -33
  6. google/genai/_base_transformers.py +26 -0
  7. google/genai/_base_url.py +50 -0
  8. google/genai/_common.py +560 -59
  9. google/genai/_extra_utils.py +371 -38
  10. google/genai/_live_converters.py +1467 -0
  11. google/genai/_local_tokenizer_loader.py +214 -0
  12. google/genai/_mcp_utils.py +117 -0
  13. google/genai/_operations_converters.py +394 -0
  14. google/genai/_replay_api_client.py +204 -92
  15. google/genai/_test_api_client.py +1 -1
  16. google/genai/_tokens_converters.py +520 -0
  17. google/genai/_transformers.py +633 -233
  18. google/genai/batches.py +1733 -538
  19. google/genai/caches.py +678 -1012
  20. google/genai/chats.py +48 -38
  21. google/genai/client.py +142 -15
  22. google/genai/documents.py +532 -0
  23. google/genai/errors.py +141 -35
  24. google/genai/file_search_stores.py +1296 -0
  25. google/genai/files.py +312 -744
  26. google/genai/live.py +617 -367
  27. google/genai/live_music.py +197 -0
  28. google/genai/local_tokenizer.py +395 -0
  29. google/genai/models.py +3598 -3116
  30. google/genai/operations.py +201 -362
  31. google/genai/pagers.py +23 -7
  32. google/genai/py.typed +1 -0
  33. google/genai/tokens.py +362 -0
  34. google/genai/tunings.py +1274 -496
  35. google/genai/types.py +14535 -5454
  36. google/genai/version.py +2 -2
  37. {google_genai-1.7.0.dist-info → google_genai-1.53.0.dist-info}/METADATA +736 -234
  38. google_genai-1.53.0.dist-info/RECORD +41 -0
  39. {google_genai-1.7.0.dist-info → google_genai-1.53.0.dist-info}/WHEEL +1 -1
  40. google_genai-1.7.0.dist-info/RECORD +0 -27
  41. {google_genai-1.7.0.dist-info → google_genai-1.53.0.dist-info/licenses}/LICENSE +0 -0
  42. {google_genai-1.7.0.dist-info → google_genai-1.53.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,197 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+
16
+ """[Experimental] Live Music API client."""
17
+
18
+ import contextlib
19
+ import json
20
+ import logging
21
+ from typing import AsyncIterator
22
+
23
+ from . import _api_module
24
+ from . import _common
25
+ from . import _live_converters as live_converters
26
+ from . import _transformers as t
27
+ from . import types
28
+ from ._api_client import BaseApiClient
29
+ from ._common import set_value_by_path as setv
30
+
31
+
32
+ try:
33
+ from websockets.asyncio.client import ClientConnection
34
+ from websockets.asyncio.client import connect
35
+ except ModuleNotFoundError:
36
+ from websockets.client import ClientConnection # type: ignore
37
+ from websockets.client import connect # type: ignore
38
+
39
+ logger = logging.getLogger('google_genai.live_music')
40
+
41
+
42
+ class AsyncMusicSession:
43
+ """[Experimental] AsyncMusicSession."""
44
+
45
+ def __init__(self, api_client: BaseApiClient, websocket: ClientConnection):
46
+ self._api_client = api_client
47
+ self._ws = websocket
48
+
49
+ async def set_weighted_prompts(
50
+ self, prompts: list[types.WeightedPrompt]
51
+ ) -> None:
52
+ if self._api_client.vertexai:
53
+ raise NotImplementedError(
54
+ 'Live music generation is not supported in Vertex AI.'
55
+ )
56
+ else:
57
+ client_content_dict = {
58
+ 'weightedPrompts': [
59
+ _common.convert_to_dict(prompt, convert_keys=True)
60
+ for prompt in prompts
61
+ ]
62
+ }
63
+
64
+ await self._ws.send(json.dumps({'clientContent': client_content_dict}))
65
+
66
+ async def set_music_generation_config(
67
+ self, config: types.LiveMusicGenerationConfig
68
+ ) -> None:
69
+ if self._api_client.vertexai:
70
+ raise NotImplementedError(
71
+ 'Live music generation is not supported in Vertex AI.'
72
+ )
73
+ else:
74
+ config_dict = _common.convert_to_dict(config, convert_keys=True)
75
+ await self._ws.send(json.dumps({'musicGenerationConfig': config_dict}))
76
+
77
+ async def _send_control_signal(
78
+ self, playback_control: types.LiveMusicPlaybackControl
79
+ ) -> None:
80
+ if self._api_client.vertexai:
81
+ raise NotImplementedError(
82
+ 'Live music generation is not supported in Vertex AI.'
83
+ )
84
+ else:
85
+ playback_control_dict = {'playbackControl': playback_control.value}
86
+ await self._ws.send(json.dumps(playback_control_dict))
87
+
88
+ async def play(self) -> None:
89
+ """Sends playback signal to start the music stream."""
90
+ return await self._send_control_signal(types.LiveMusicPlaybackControl.PLAY)
91
+
92
+ async def pause(self) -> None:
93
+ """Sends a playback signal to pause the music stream."""
94
+ return await self._send_control_signal(types.LiveMusicPlaybackControl.PAUSE)
95
+
96
+ async def stop(self) -> None:
97
+ """Sends a playback signal to stop the music stream.
98
+
99
+ Resets the music generation context while retaining the current config.
100
+ """
101
+ return await self._send_control_signal(types.LiveMusicPlaybackControl.STOP)
102
+
103
+ async def reset_context(self) -> None:
104
+ """Reset the context (prompts retained) without stopping the music generation."""
105
+ return await self._send_control_signal(
106
+ types.LiveMusicPlaybackControl.RESET_CONTEXT
107
+ )
108
+
109
+ async def receive(self) -> AsyncIterator[types.LiveMusicServerMessage]:
110
+ """Receive model responses from the server.
111
+
112
+ Yields:
113
+ The audio chunks from the server.
114
+ """
115
+ # TODO(b/365983264) Handle intermittent issues for the user.
116
+ while result := await self._receive():
117
+ yield result
118
+
119
+ async def _receive(self) -> types.LiveMusicServerMessage:
120
+ parameter_model = types.LiveMusicServerMessage()
121
+ try:
122
+ raw_response = await self._ws.recv(decode=False)
123
+ except TypeError:
124
+ raw_response = await self._ws.recv() # type: ignore[assignment]
125
+ if raw_response:
126
+ try:
127
+ response = json.loads(raw_response)
128
+ except json.decoder.JSONDecodeError:
129
+ raise ValueError(f'Failed to parse response: {raw_response!r}')
130
+ else:
131
+ response = {}
132
+
133
+ if self._api_client.vertexai:
134
+ raise NotImplementedError('Live music generation is not supported in Vertex AI.')
135
+ else:
136
+ response_dict = response
137
+
138
+ return types.LiveMusicServerMessage._from_response(
139
+ response=response_dict, kwargs=parameter_model.model_dump()
140
+ )
141
+
142
+ async def close(self) -> None:
143
+ """Closes the bi-directional stream and terminates the session."""
144
+ await self._ws.close()
145
+
146
+
147
+ class AsyncLiveMusic(_api_module.BaseModule):
148
+ """[Experimental] Live music module.
149
+
150
+ Live music can be accessed via `client.aio.live.music`.
151
+ """
152
+
153
+ @_common.experimental_warning(
154
+ 'Realtime music generation is experimental and may change in future versions.'
155
+ )
156
+ @contextlib.asynccontextmanager
157
+ async def connect(self, *, model: str) -> AsyncIterator[AsyncMusicSession]:
158
+ """[Experimental] Connect to the live music server."""
159
+ base_url = self._api_client._websocket_base_url()
160
+ if isinstance(base_url, bytes):
161
+ base_url = base_url.decode('utf-8')
162
+ transformed_model = t.t_model(self._api_client, model)
163
+
164
+ if self._api_client.api_key:
165
+ api_key = self._api_client.api_key
166
+ version = self._api_client._http_options.api_version
167
+ uri = f'{base_url}/ws/google.ai.generativelanguage.{version}.GenerativeService.BidiGenerateMusic?key={api_key}'
168
+ headers = self._api_client._http_options.headers
169
+
170
+ # Only mldev supported
171
+ request_dict = _common.convert_to_dict(
172
+ live_converters._LiveMusicConnectParameters_to_mldev(
173
+ from_object=types.LiveMusicConnectParameters(
174
+ model=transformed_model,
175
+ ).model_dump(exclude_none=True)
176
+ )
177
+ )
178
+
179
+ setv(request_dict, ['setup', 'model'], transformed_model)
180
+
181
+ request = json.dumps(request_dict)
182
+ else:
183
+ raise NotImplementedError('Live music generation is not supported in Vertex AI.')
184
+
185
+ try:
186
+ async with connect(uri, additional_headers=headers) as ws:
187
+ await ws.send(request)
188
+ logger.info(await ws.recv(decode=False))
189
+
190
+ yield AsyncMusicSession(api_client=self._api_client, websocket=ws)
191
+ except TypeError:
192
+ # Try with the older websockets API
193
+ async with connect(uri, extra_headers=headers) as ws:
194
+ await ws.send(request)
195
+ logger.info(await ws.recv())
196
+
197
+ yield AsyncMusicSession(api_client=self._api_client, websocket=ws)
@@ -0,0 +1,395 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+
16
+ """[Experimental] Text Only Local Tokenizer."""
17
+
18
+ import logging
19
+ from typing import Any, Iterable
20
+ from typing import Optional, Union
21
+
22
+ from sentencepiece import sentencepiece_model_pb2
23
+
24
+ from . import _common
25
+ from . import _local_tokenizer_loader as loader
26
+ from . import _transformers as t
27
+ from . import types
28
+
29
+ logger = logging.getLogger("google_genai.local_tokenizer")
30
+
31
+ __all__ = [
32
+ "_parse_hex_byte",
33
+ "_token_str_to_bytes",
34
+ "LocalTokenizer",
35
+ "_TextsAccumulator",
36
+ ]
37
+
38
+
39
+ class _TextsAccumulator:
40
+ """Accumulates countable texts from `Content` and `Tool` objects.
41
+
42
+ This class is responsible for traversing complex `Content` and `Tool`
43
+ objects and extracting all the text content that should be included when
44
+ calculating token counts.
45
+
46
+ A key feature of this class is its ability to detect unsupported fields in
47
+ `Content` objects. If a user provides a `Content` object with fields that
48
+ this local tokenizer doesn't recognize (e.g., new fields added in a future
49
+ API update), this class will log a warning.
50
+
51
+ The detection mechanism for `Content` objects works by recursively building
52
+ a "counted" version of the input object. This "counted" object only
53
+ contains the data that was successfully processed and added to the text
54
+ list for tokenization. After traversing the input, the original `Content`
55
+ object is compared to the "counted" object. If they don't match, it
56
+ signifies the presence of unsupported fields, and a warning is logged.
57
+ """
58
+
59
+ def __init__(self) -> None:
60
+ self._texts: list[str] = []
61
+
62
+ def get_texts(self) -> Iterable[str]:
63
+ return self._texts
64
+
65
+ def add_contents(self, contents: Iterable[types.Content]) -> None:
66
+ for content in contents:
67
+ self.add_content(content)
68
+
69
+ def add_content(self, content: types.Content) -> None:
70
+ counted_content = types.Content(parts=[], role=content.role)
71
+ if content.parts:
72
+ for part in content.parts:
73
+ assert counted_content.parts is not None
74
+ counted_part = types.Part()
75
+ if part.file_data is not None or part.inline_data is not None:
76
+ raise ValueError(
77
+ "LocalTokenizers do not support non-text content types."
78
+ )
79
+ if part.video_metadata is not None:
80
+ counted_part.video_metadata = part.video_metadata
81
+ if part.function_call is not None:
82
+ self.add_function_call(part.function_call)
83
+ counted_part.function_call = part.function_call
84
+ if part.function_response is not None:
85
+ self.add_function_response(part.function_response)
86
+ counted_part.function_response = part.function_response
87
+ if part.text is not None:
88
+ counted_part.text = part.text
89
+ self._texts.append(part.text)
90
+ counted_content.parts.append(counted_part)
91
+
92
+ if content.model_dump(exclude_none=True) != counted_content.model_dump(
93
+ exclude_none=True
94
+ ):
95
+ logger.warning(
96
+ "Content contains unsupported types for token counting. Supported"
97
+ f" fields {counted_content}. Got {content}."
98
+ )
99
+
100
+ def add_function_call(self, function_call: types.FunctionCall) -> None:
101
+ """Processes a function call and adds relevant text to the accumulator.
102
+
103
+ Args:
104
+ function_call: The function call to process.
105
+ """
106
+ if function_call.name:
107
+ self._texts.append(function_call.name)
108
+ counted_function_call = types.FunctionCall(name=function_call.name)
109
+ if function_call.args:
110
+ counted_args = self._dict_traverse(function_call.args)
111
+ counted_function_call.args = counted_args
112
+
113
+ def add_tool(self, tool: types.Tool) -> types.Tool:
114
+ counted_tool = types.Tool(function_declarations=[])
115
+ if tool.function_declarations:
116
+ for function_declaration in tool.function_declarations:
117
+ counted_function_declaration = self._function_declaration_traverse(
118
+ function_declaration
119
+ )
120
+ if counted_tool.function_declarations is None:
121
+ counted_tool.function_declarations = []
122
+ counted_tool.function_declarations.append(counted_function_declaration)
123
+
124
+ return counted_tool
125
+
126
+ def add_tools(self, tools: Iterable[types.Tool]) -> None:
127
+ for tool in tools:
128
+ self.add_tool(tool)
129
+
130
+ def add_function_responses(
131
+ self, function_responses: Iterable[types.FunctionResponse]
132
+ ) -> None:
133
+ for function_response in function_responses:
134
+ self.add_function_response(function_response)
135
+
136
+ def add_function_response(
137
+ self, function_response: types.FunctionResponse
138
+ ) -> None:
139
+ counted_function_response = types.FunctionResponse()
140
+ if function_response.name:
141
+ self._texts.append(function_response.name)
142
+ counted_function_response.name = function_response.name
143
+ if function_response.response:
144
+ counted_response = self._dict_traverse(function_response.response)
145
+ counted_function_response.response = counted_response
146
+
147
+ def _function_declaration_traverse(
148
+ self, function_declaration: types.FunctionDeclaration
149
+ ) -> types.FunctionDeclaration:
150
+ counted_function_declaration = types.FunctionDeclaration()
151
+ if function_declaration.name:
152
+ self._texts.append(function_declaration.name)
153
+ counted_function_declaration.name = function_declaration.name
154
+ if function_declaration.description:
155
+ self._texts.append(function_declaration.description)
156
+ counted_function_declaration.description = (
157
+ function_declaration.description
158
+ )
159
+ if function_declaration.parameters:
160
+ counted_parameters = self.add_schema(function_declaration.parameters)
161
+ counted_function_declaration.parameters = counted_parameters
162
+ if function_declaration.response:
163
+ counted_response = self.add_schema(function_declaration.response)
164
+ counted_function_declaration.response = counted_response
165
+ return counted_function_declaration
166
+
167
+ def add_schema(self, schema: types.Schema) -> types.Schema:
168
+ """Processes a schema and adds relevant text to the accumulator.
169
+
170
+ Args:
171
+ schema: The schema to process.
172
+
173
+ Returns:
174
+ The new schema object with only countable fields.
175
+ """
176
+ counted_schema = types.Schema()
177
+ if schema.type:
178
+ counted_schema.type = schema.type
179
+ if schema.title:
180
+ counted_schema.title = schema.title
181
+ if schema.default is not None:
182
+ counted_schema.default = schema.default
183
+ if schema.format:
184
+ self._texts.append(schema.format)
185
+ counted_schema.format = schema.format
186
+ if schema.description:
187
+ self._texts.append(schema.description)
188
+ counted_schema.description = schema.description
189
+ if schema.enum:
190
+ self._texts.extend(schema.enum)
191
+ counted_schema.enum = schema.enum
192
+ if schema.required:
193
+ self._texts.extend(schema.required)
194
+ counted_schema.required = schema.required
195
+ if schema.property_ordering:
196
+ counted_schema.property_ordering = schema.property_ordering
197
+ if schema.items:
198
+ counted_schema_items = self.add_schema(schema.items)
199
+ counted_schema.items = counted_schema_items
200
+ if schema.properties:
201
+ d = {}
202
+ for key, value in schema.properties.items():
203
+ self._texts.append(key)
204
+ counted_value = self.add_schema(value)
205
+ d[key] = counted_value
206
+ counted_schema.properties = d
207
+ if schema.example:
208
+ counted_schema_example = self._any_traverse(schema.example)
209
+ counted_schema.example = counted_schema_example
210
+ return counted_schema
211
+
212
+ def _dict_traverse(self, d: dict[str, Any]) -> dict[str, Any]:
213
+ """Processes a dict and adds relevant text to the accumulator.
214
+
215
+ Args:
216
+ d: The dict to process.
217
+
218
+ Returns:
219
+ The new dict object with only countable fields.
220
+ """
221
+ counted_dict = {}
222
+ self._texts.extend(list(d.keys()))
223
+ for key, val in d.items():
224
+ counted_dict[key] = self._any_traverse(val)
225
+ return counted_dict
226
+
227
+ def _any_traverse(self, value: Any) -> Any:
228
+ """Processes a value and adds relevant text to the accumulator.
229
+
230
+ Args:
231
+ value: The value to process.
232
+
233
+ Returns:
234
+ The new value with only countable fields.
235
+ """
236
+ if isinstance(value, str):
237
+ self._texts.append(value)
238
+ return value
239
+ elif isinstance(value, dict):
240
+ return self._dict_traverse(value)
241
+ elif isinstance(value, list):
242
+ return [self._any_traverse(item) for item in value]
243
+ else:
244
+ return value
245
+
246
+
247
+ def _token_str_to_bytes(
248
+ token: str, type: sentencepiece_model_pb2.ModelProto.SentencePiece.Type
249
+ ) -> bytes:
250
+ if type == sentencepiece_model_pb2.ModelProto.SentencePiece.Type.BYTE:
251
+ return _parse_hex_byte(token).to_bytes(length=1, byteorder="big")
252
+ else:
253
+ return token.replace("▁", " ").encode("utf-8")
254
+
255
+
256
+ def _parse_hex_byte(token: str) -> int:
257
+ """Parses a hex byte string of the form '<0xXX>' and returns the integer value.
258
+
259
+ Raises ValueError if the input is malformed or the byte value is invalid.
260
+ """
261
+
262
+ if len(token) != 6:
263
+ raise ValueError(f"Invalid byte length: {token}")
264
+ if not token.startswith("<0x") or not token.endswith(">"):
265
+ raise ValueError(f"Invalid byte format: {token}")
266
+
267
+ try:
268
+ val = int(token[3:5], 16) # Parse the hex part directly
269
+ except ValueError:
270
+ raise ValueError(f"Invalid hex value: {token}")
271
+
272
+ if val >= 256:
273
+ raise ValueError(f"Byte value out of range: {token}")
274
+
275
+ return val
276
+
277
+
278
+ class LocalTokenizer:
279
+ """[Experimental] Text Only Local Tokenizer.
280
+
281
+ This class provides a local tokenizer for text only token counting.
282
+
283
+ LIMITATIONS:
284
+ - Only supports text based tokenization and no multimodal tokenization.
285
+ - Forward compatibility depends on the open-source tokenizer models for future
286
+ Gemini versions.
287
+ - For token counting of tools and response schemas, the `LocalTokenizer` only
288
+ supports `types.Tool` and `types.Schema` objects. Python functions or Pydantic
289
+ models cannot be passed directly.
290
+ """
291
+
292
+ def __init__(self, model_name: str):
293
+ self._tokenizer_name = loader.get_tokenizer_name(model_name)
294
+ self._model_proto = loader.load_model_proto(self._tokenizer_name)
295
+ self._tokenizer = loader.get_sentencepiece(self._tokenizer_name)
296
+
297
+ @_common.experimental_warning(
298
+ "The SDK's local tokenizer implementation is experimental and may change"
299
+ " in the future. It only supports text based tokenization."
300
+ )
301
+ def count_tokens(
302
+ self,
303
+ contents: Union[types.ContentListUnion, types.ContentListUnionDict],
304
+ *,
305
+ config: Optional[types.CountTokensConfigOrDict] = None,
306
+ ) -> types.CountTokensResult:
307
+ """Counts the number of tokens in a given text.
308
+
309
+ Args:
310
+ contents: The contents to tokenize.
311
+ config: The configuration for counting tokens.
312
+
313
+ Returns:
314
+ A `CountTokensResult` containing the total number of tokens.
315
+
316
+ Usage:
317
+
318
+ .. code-block:: python
319
+
320
+ from google import genai
321
+ tokenizer = genai.LocalTokenizer(model_name='gemini-2.0-flash-001')
322
+ result = tokenizer.count_tokens("What is your name?")
323
+ print(result)
324
+ # total_tokens=5
325
+ """
326
+ processed_contents = t.t_contents(contents)
327
+ text_accumulator = _TextsAccumulator()
328
+ config = types.CountTokensConfig.model_validate(config or {})
329
+ text_accumulator.add_contents(processed_contents)
330
+ if config.tools:
331
+ text_accumulator.add_tools(config.tools)
332
+ if config.generation_config and config.generation_config.response_schema:
333
+ text_accumulator.add_schema(config.generation_config.response_schema)
334
+ if config.system_instruction:
335
+ text_accumulator.add_contents(t.t_contents([config.system_instruction]))
336
+ tokens_list = self._tokenizer.encode(list(text_accumulator.get_texts()))
337
+ return types.CountTokensResult(
338
+ total_tokens=sum(len(tokens) for tokens in tokens_list)
339
+ )
340
+
341
+ @_common.experimental_warning(
342
+ "The SDK's local tokenizer implementation is experimental and may change"
343
+ " in the future. It only supports text based tokenization."
344
+ )
345
+ def compute_tokens(
346
+ self,
347
+ contents: Union[types.ContentListUnion, types.ContentListUnionDict],
348
+ ) -> types.ComputeTokensResult:
349
+ """Computes the tokens ids and string pieces in the input.
350
+
351
+ Args:
352
+ contents: The contents to tokenize.
353
+
354
+ Returns:
355
+ A `ComputeTokensResult` containing the token information.
356
+
357
+ Usage:
358
+
359
+ .. code-block:: python
360
+
361
+ from google import genai
362
+ tokenizer = genai.LocalTokenizer(model_name='gemini-2.0-flash-001')
363
+ result = tokenizer.compute_tokens("What is your name?")
364
+ print(result)
365
+ # tokens_info=[TokensInfo(token_ids=[279, 329, 1313, 2508, 13], tokens=[b' What', b' is', b' your', b' name', b'?'], role='user')]
366
+ """
367
+ processed_contents = t.t_contents(contents)
368
+ text_accumulator = _TextsAccumulator()
369
+ for content in processed_contents:
370
+ text_accumulator.add_content(content)
371
+ tokens_protos = self._tokenizer.EncodeAsImmutableProto(
372
+ text_accumulator.get_texts()
373
+ )
374
+
375
+ roles = []
376
+ for content in processed_contents:
377
+ if content.parts:
378
+ for _ in content.parts:
379
+ roles.append(content.role)
380
+
381
+ token_infos = []
382
+ for tokens_proto, role in zip(tokens_protos, roles):
383
+ token_infos.append(
384
+ types.TokensInfo(
385
+ token_ids=[piece.id for piece in tokens_proto.pieces],
386
+ tokens=[
387
+ _token_str_to_bytes(
388
+ piece.piece, self._model_proto.pieces[piece.id].type
389
+ )
390
+ for piece in tokens_proto.pieces
391
+ ],
392
+ role=role,
393
+ )
394
+ )
395
+ return types.ComputeTokensResult(tokens_info=token_infos)