google-genai 1.33.0__py3-none-any.whl → 1.53.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
google/genai/live.py CHANGED
@@ -21,7 +21,7 @@ import contextlib
21
21
  import json
22
22
  import logging
23
23
  import typing
24
- from typing import Any, AsyncIterator, Dict, Optional, Sequence, Union, get_args
24
+ from typing import Any, AsyncIterator, Optional, Sequence, Union, get_args
25
25
  import warnings
26
26
 
27
27
  import google.auth
@@ -40,7 +40,6 @@ from ._common import get_value_by_path as getv
40
40
  from ._common import set_value_by_path as setv
41
41
  from .live_music import AsyncLiveMusic
42
42
  from .models import _Content_to_mldev
43
- from .models import _Content_to_vertex
44
43
 
45
44
 
46
45
  try:
@@ -51,6 +50,11 @@ except ModuleNotFoundError:
51
50
  from websockets.client import ClientConnection # type: ignore
52
51
  from websockets.client import connect as ws_connect # type: ignore
53
52
 
53
+ try:
54
+ from google.auth.transport import requests
55
+ except ImportError:
56
+ requests = None # type: ignore[assignment]
57
+
54
58
  if typing.TYPE_CHECKING:
55
59
  from mcp import ClientSession as McpClientSession
56
60
  from mcp.types import Tool as McpTool
@@ -82,9 +86,15 @@ _FUNCTION_RESPONSE_REQUIRES_ID = (
82
86
  class AsyncSession:
83
87
  """[Preview] AsyncSession."""
84
88
 
85
- def __init__(self, api_client: BaseApiClient, websocket: ClientConnection):
89
+ def __init__(
90
+ self,
91
+ api_client: BaseApiClient,
92
+ websocket: ClientConnection,
93
+ session_id: Optional[str] = None,
94
+ ):
86
95
  self._api_client = api_client
87
96
  self._ws = websocket
97
+ self.session_id = session_id
88
98
 
89
99
  async def send(
90
100
  self,
@@ -217,8 +227,8 @@ class AsyncSession:
217
227
  )
218
228
 
219
229
  if self._api_client.vertexai:
220
- client_content_dict = live_converters._LiveClientContent_to_vertex(
221
- from_object=client_content
230
+ client_content_dict = _common.convert_to_dict(
231
+ client_content, convert_keys=True
222
232
  )
223
233
  else:
224
234
  client_content_dict = live_converters._LiveClientContent_to_mldev(
@@ -404,12 +414,12 @@ class AsyncSession:
404
414
  """
405
415
  tool_response = t.t_tool_response(function_responses)
406
416
  if self._api_client.vertexai:
407
- tool_response_dict = live_converters._LiveClientToolResponse_to_vertex(
408
- from_object=tool_response
417
+ tool_response_dict = _common.convert_to_dict(
418
+ tool_response, convert_keys=True
409
419
  )
410
420
  else:
411
- tool_response_dict = live_converters._LiveClientToolResponse_to_mldev(
412
- from_object=tool_response
421
+ tool_response_dict = _common.convert_to_dict(
422
+ tool_response, convert_keys=True
413
423
  )
414
424
  for response in tool_response_dict.get('functionResponses', []):
415
425
  if response.get('id') is None:
@@ -535,7 +545,7 @@ class AsyncSession:
535
545
  if self._api_client.vertexai:
536
546
  response_dict = live_converters._LiveServerMessage_from_vertex(response)
537
547
  else:
538
- response_dict = live_converters._LiveServerMessage_from_mldev(response)
548
+ response_dict = response
539
549
 
540
550
  return types.LiveServerMessage._from_response(
541
551
  response=response_dict, kwargs=parameter_model.model_dump()
@@ -649,7 +659,7 @@ class AsyncSession:
649
659
  content_input_parts.append(item)
650
660
  if self._api_client.vertexai:
651
661
  contents = [
652
- _Content_to_vertex(item, to_object)
662
+ _common.convert_to_dict(item, convert_keys=True)
653
663
  for item in t.t_contents(content_input_parts)
654
664
  ]
655
665
  else:
@@ -975,7 +985,8 @@ class AsyncLive(_api_module.BaseModule):
975
985
  api_key = self._api_client.api_key
976
986
  version = self._api_client._http_options.api_version
977
987
  uri = f'{base_url}/ws/google.cloud.aiplatform.{version}.LlmBidiService/BidiGenerateContent'
978
- headers = self._api_client._http_options.headers or {}
988
+ original_headers = self._api_client._http_options.headers
989
+ headers = original_headers.copy() if original_headers is not None else {}
979
990
 
980
991
  request_dict = _common.convert_to_dict(
981
992
  live_converters._LiveConnectParameters_to_vertex(
@@ -992,27 +1003,51 @@ class AsyncLive(_api_module.BaseModule):
992
1003
 
993
1004
  request = json.dumps(request_dict)
994
1005
  else:
995
- if not self._api_client._credentials:
996
- # Get bearer token through Application Default Credentials.
997
- creds, _ = google.auth.default( # type: ignore
998
- scopes=['https://www.googleapis.com/auth/cloud-platform']
1006
+ version = self._api_client._http_options.api_version
1007
+ has_sufficient_auth = (
1008
+ self._api_client.project and self._api_client.location
1009
+ )
1010
+ if self._api_client.custom_base_url and not has_sufficient_auth:
1011
+ # API gateway proxy can use the auth in custom headers, not url.
1012
+ # Enable custom url if auth is not sufficient.
1013
+ uri = self._api_client.custom_base_url
1014
+ # Keep the model as is.
1015
+ transformed_model = model
1016
+ # Do not get credentials for custom url.
1017
+ original_headers = self._api_client._http_options.headers
1018
+ headers = (
1019
+ original_headers.copy() if original_headers is not None else {}
999
1020
  )
1021
+
1000
1022
  else:
1001
- creds = self._api_client._credentials
1002
- # creds.valid is False, and creds.token is None
1003
- # Need to refresh credentials to populate those
1004
- if not (creds.token and creds.valid):
1005
- auth_req = google.auth.transport.requests.Request() # type: ignore
1006
- creds.refresh(auth_req)
1007
- bearer_token = creds.token
1008
- original_headers = self._api_client._http_options.headers
1009
- headers = original_headers.copy() if original_headers is not None else {}
1010
- headers['Authorization'] = f'Bearer {bearer_token}'
1011
- version = self._api_client._http_options.api_version
1012
- uri = f'{base_url}/ws/google.cloud.aiplatform.{version}.LlmBidiService/BidiGenerateContent'
1023
+ uri = f'{base_url}/ws/google.cloud.aiplatform.{version}.LlmBidiService/BidiGenerateContent'
1024
+
1025
+ if not self._api_client._credentials:
1026
+ # Get bearer token through Application Default Credentials.
1027
+ creds, _ = google.auth.default( # type: ignore
1028
+ scopes=['https://www.googleapis.com/auth/cloud-platform']
1029
+ )
1030
+ else:
1031
+ creds = self._api_client._credentials
1032
+ # creds.valid is False, and creds.token is None
1033
+ # Need to refresh credentials to populate those
1034
+ if not (creds.token and creds.valid):
1035
+ if requests is None:
1036
+ raise ValueError('The requests module is required to refresh google-auth credentials. Please install with `pip install google-auth[requests]`')
1037
+ auth_req = requests.Request() # type: ignore
1038
+ creds.refresh(auth_req)
1039
+ bearer_token = creds.token
1040
+
1041
+ original_headers = self._api_client._http_options.headers
1042
+ headers = (
1043
+ original_headers.copy() if original_headers is not None else {}
1044
+ )
1045
+ if not headers.get('Authorization'):
1046
+ headers['Authorization'] = f'Bearer {bearer_token}'
1047
+
1013
1048
  location = self._api_client.location
1014
1049
  project = self._api_client.project
1015
- if transformed_model.startswith('publishers/'):
1050
+ if transformed_model.startswith('publishers/') and project and location:
1016
1051
  transformed_model = (
1017
1052
  f'projects/{project}/locations/{location}/' + transformed_model
1018
1053
  )
@@ -1054,11 +1089,34 @@ class AsyncLive(_api_module.BaseModule):
1054
1089
  await ws.send(request)
1055
1090
  try:
1056
1091
  # websockets 14.0+
1057
- logger.info(await ws.recv(decode=False))
1092
+ raw_response = await ws.recv(decode=False)
1058
1093
  except TypeError:
1059
- logger.info(await ws.recv())
1094
+ raw_response = await ws.recv() # type: ignore[assignment]
1095
+ if raw_response:
1096
+ try:
1097
+ response = json.loads(raw_response)
1098
+ except json.decoder.JSONDecodeError:
1099
+ raise ValueError(f'Failed to parse response: {raw_response!r}')
1100
+ else:
1101
+ response = {}
1060
1102
 
1061
- yield AsyncSession(api_client=self._api_client, websocket=ws)
1103
+ if self._api_client.vertexai:
1104
+ response_dict = live_converters._LiveServerMessage_from_vertex(response)
1105
+ else:
1106
+ response_dict = response
1107
+
1108
+ setup_response = types.LiveServerMessage._from_response(
1109
+ response=response_dict, kwargs=parameter_model.model_dump()
1110
+ )
1111
+ if setup_response.setup_complete:
1112
+ session_id = setup_response.setup_complete.session_id
1113
+ else:
1114
+ session_id = None
1115
+ yield AsyncSession(
1116
+ api_client=self._api_client,
1117
+ websocket=ws,
1118
+ session_id=session_id,
1119
+ )
1062
1120
 
1063
1121
 
1064
1122
  async def _t_live_connect_config(
@@ -22,13 +22,11 @@ from typing import AsyncIterator
22
22
 
23
23
  from . import _api_module
24
24
  from . import _common
25
+ from . import _live_converters as live_converters
25
26
  from . import _transformers as t
26
27
  from . import types
27
28
  from ._api_client import BaseApiClient
28
29
  from ._common import set_value_by_path as setv
29
- from . import _live_converters as live_converters
30
- from .models import _Content_to_mldev
31
- from .models import _Content_to_vertex
32
30
 
33
31
 
34
32
  try:
@@ -44,46 +42,47 @@ logger = logging.getLogger('google_genai.live_music')
44
42
  class AsyncMusicSession:
45
43
  """[Experimental] AsyncMusicSession."""
46
44
 
47
- def __init__(
48
- self, api_client: BaseApiClient, websocket: ClientConnection
49
- ):
45
+ def __init__(self, api_client: BaseApiClient, websocket: ClientConnection):
50
46
  self._api_client = api_client
51
47
  self._ws = websocket
52
48
 
53
49
  async def set_weighted_prompts(
54
- self,
55
- prompts: list[types.WeightedPrompt]
50
+ self, prompts: list[types.WeightedPrompt]
56
51
  ) -> None:
57
52
  if self._api_client.vertexai:
58
- raise NotImplementedError('Live music generation is not supported in Vertex AI.')
59
- else:
60
- client_content_dict = live_converters._LiveMusicClientContent_to_mldev(
61
- from_object={'weighted_prompts': prompts}
53
+ raise NotImplementedError(
54
+ 'Live music generation is not supported in Vertex AI.'
62
55
  )
56
+ else:
57
+ client_content_dict = {
58
+ 'weightedPrompts': [
59
+ _common.convert_to_dict(prompt, convert_keys=True)
60
+ for prompt in prompts
61
+ ]
62
+ }
63
+
63
64
  await self._ws.send(json.dumps({'clientContent': client_content_dict}))
64
65
 
65
66
  async def set_music_generation_config(
66
- self,
67
- config: types.LiveMusicGenerationConfig
67
+ self, config: types.LiveMusicGenerationConfig
68
68
  ) -> None:
69
69
  if self._api_client.vertexai:
70
- raise NotImplementedError('Live music generation is not supported in Vertex AI.')
71
- else:
72
- config_dict = live_converters._LiveMusicGenerationConfig_to_mldev(
73
- from_object=config
70
+ raise NotImplementedError(
71
+ 'Live music generation is not supported in Vertex AI.'
74
72
  )
73
+ else:
74
+ config_dict = _common.convert_to_dict(config, convert_keys=True)
75
75
  await self._ws.send(json.dumps({'musicGenerationConfig': config_dict}))
76
76
 
77
77
  async def _send_control_signal(
78
- self,
79
- playback_control: types.LiveMusicPlaybackControl
78
+ self, playback_control: types.LiveMusicPlaybackControl
80
79
  ) -> None:
81
80
  if self._api_client.vertexai:
82
- raise NotImplementedError('Live music generation is not supported in Vertex AI.')
83
- else:
84
- playback_control_dict = live_converters._LiveMusicClientMessage_to_mldev(
85
- from_object={'playback_control': playback_control}
81
+ raise NotImplementedError(
82
+ 'Live music generation is not supported in Vertex AI.'
86
83
  )
84
+ else:
85
+ playback_control_dict = {'playbackControl': playback_control.value}
87
86
  await self._ws.send(json.dumps(playback_control_dict))
88
87
 
89
88
  async def play(self) -> None:
@@ -134,9 +133,7 @@ class AsyncMusicSession:
134
133
  if self._api_client.vertexai:
135
134
  raise NotImplementedError('Live music generation is not supported in Vertex AI.')
136
135
  else:
137
- response_dict = live_converters._LiveMusicServerMessage_from_mldev(
138
- response
139
- )
136
+ response_dict = response
140
137
 
141
138
  return types.LiveMusicServerMessage._from_response(
142
139
  response=response_dict, kwargs=parameter_model.model_dump()
@@ -25,11 +25,16 @@ from . import _common
25
25
  from . import _local_tokenizer_loader as loader
26
26
  from . import _transformers as t
27
27
  from . import types
28
- from . import types
29
- from ._transformers import t_contents
30
28
 
31
29
  logger = logging.getLogger("google_genai.local_tokenizer")
32
30
 
31
+ __all__ = [
32
+ "_parse_hex_byte",
33
+ "_token_str_to_bytes",
34
+ "LocalTokenizer",
35
+ "_TextsAccumulator",
36
+ ]
37
+
33
38
 
34
39
  class _TextsAccumulator:
35
40
  """Accumulates countable texts from `Content` and `Tool` objects.
@@ -303,9 +308,20 @@ class LocalTokenizer:
303
308
 
304
309
  Args:
305
310
  contents: The contents to tokenize.
311
+ config: The configuration for counting tokens.
306
312
 
307
313
  Returns:
308
314
  A `CountTokensResult` containing the total number of tokens.
315
+
316
+ Usage:
317
+
318
+ .. code-block:: python
319
+
320
+ from google import genai
321
+ tokenizer = genai.LocalTokenizer(model_name='gemini-2.0-flash-001')
322
+ result = tokenizer.count_tokens("What is your name?")
323
+ print(result)
324
+ # total_tokens=5
309
325
  """
310
326
  processed_contents = t.t_contents(contents)
311
327
  text_accumulator = _TextsAccumulator()
@@ -330,7 +346,24 @@ class LocalTokenizer:
330
346
  self,
331
347
  contents: Union[types.ContentListUnion, types.ContentListUnionDict],
332
348
  ) -> types.ComputeTokensResult:
333
- """Computes the tokens ids and string pieces in the input."""
349
+ """Computes the tokens ids and string pieces in the input.
350
+
351
+ Args:
352
+ contents: The contents to tokenize.
353
+
354
+ Returns:
355
+ A `ComputeTokensResult` containing the token information.
356
+
357
+ Usage:
358
+
359
+ .. code-block:: python
360
+
361
+ from google import genai
362
+ tokenizer = genai.LocalTokenizer(model_name='gemini-2.0-flash-001')
363
+ result = tokenizer.compute_tokens("What is your name?")
364
+ print(result)
365
+ # tokens_info=[TokensInfo(token_ids=[279, 329, 1313, 2508, 13], tokens=[b' What', b' is', b' your', b' name', b'?'], role='user')]
366
+ """
334
367
  processed_contents = t.t_contents(contents)
335
368
  text_accumulator = _TextsAccumulator()
336
369
  for content in processed_contents: