google-genai 1.14.0__py3-none-any.whl → 1.16.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
google/genai/client.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Google LLC
1
+ # Copyright 2025 Google LLC
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -29,6 +29,7 @@ from .files import AsyncFiles, Files
29
29
  from .live import AsyncLive
30
30
  from .models import AsyncModels, Models
31
31
  from .operations import AsyncOperations, Operations
32
+ from .tokens import AsyncTokens, Tokens
32
33
  from .tunings import AsyncTunings, Tunings
33
34
  from .types import HttpOptions, HttpOptionsDict
34
35
 
@@ -45,6 +46,7 @@ class AsyncClient:
45
46
  self._batches = AsyncBatches(self._api_client)
46
47
  self._files = AsyncFiles(self._api_client)
47
48
  self._live = AsyncLive(self._api_client)
49
+ self._tokens = AsyncTokens(self._api_client)
48
50
  self._operations = AsyncOperations(self._api_client)
49
51
 
50
52
  @property
@@ -75,6 +77,10 @@ class AsyncClient:
75
77
  def live(self) -> AsyncLive:
76
78
  return self._live
77
79
 
80
+ @property
81
+ def auth_tokens(self) -> AsyncTokens:
82
+ return self._tokens
83
+
78
84
  @property
79
85
  def operations(self) -> AsyncOperations:
80
86
  return self._operations
@@ -226,6 +232,7 @@ class Client:
226
232
  self._caches = Caches(self._api_client)
227
233
  self._batches = Batches(self._api_client)
228
234
  self._files = Files(self._api_client)
235
+ self._tokens = Tokens(self._api_client)
229
236
  self._operations = Operations(self._api_client)
230
237
 
231
238
  @staticmethod
@@ -292,6 +299,10 @@ class Client:
292
299
  def files(self) -> Files:
293
300
  return self._files
294
301
 
302
+ @property
303
+ def auth_tokens(self) -> Tokens:
304
+ return self._tokens
305
+
295
306
  @property
296
307
  def operations(self) -> Operations:
297
308
  return self._operations
google/genai/errors.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Google LLC
1
+ # Copyright 2025 Google LLC
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
google/genai/live.py CHANGED
@@ -20,7 +20,8 @@ import base64
20
20
  import contextlib
21
21
  import json
22
22
  import logging
23
- from typing import Any, AsyncIterator, Dict, Optional, Sequence, Union, cast, get_args
23
+ import typing
24
+ from typing import Any, AsyncIterator, Dict, Optional, Sequence, Union, get_args
24
25
  import warnings
25
26
 
26
27
  import google.auth
@@ -29,24 +30,46 @@ from websockets import ConnectionClosed
29
30
 
30
31
  from . import _api_module
31
32
  from . import _common
33
+ from . import _live_converters as live_converters
34
+ from . import _mcp_utils
32
35
  from . import _transformers as t
33
36
  from . import client
34
37
  from . import types
35
38
  from ._api_client import BaseApiClient
36
39
  from ._common import get_value_by_path as getv
37
40
  from ._common import set_value_by_path as setv
38
- from . import _live_converters as live_converters
41
+ from .live_music import AsyncLiveMusic
39
42
  from .models import _Content_to_mldev
40
43
  from .models import _Content_to_vertex
41
44
 
42
45
 
43
46
  try:
44
47
  from websockets.asyncio.client import ClientConnection
45
- from websockets.asyncio.client import connect
48
+ from websockets.asyncio.client import connect as ws_connect
46
49
  except ModuleNotFoundError:
47
50
  # This try/except is for TAP, mypy complains about it which is why we have the type: ignore
48
51
  from websockets.client import ClientConnection # type: ignore
49
- from websockets.client import connect # type: ignore
52
+ from websockets.client import connect as ws_connect # type: ignore
53
+
54
+ if typing.TYPE_CHECKING:
55
+ from mcp import ClientSession as McpClientSession
56
+ from mcp.types import Tool as McpTool
57
+ from ._adapters import McpToGenAiToolAdapter
58
+ from ._mcp_utils import mcp_to_gemini_tool
59
+ else:
60
+ McpClientSession: typing.Type = Any
61
+ McpTool: typing.Type = Any
62
+ McpToGenAiToolAdapter: typing.Type = Any
63
+ try:
64
+ from mcp import ClientSession as McpClientSession
65
+ from mcp.types import Tool as McpTool
66
+ from ._adapters import McpToGenAiToolAdapter
67
+ from ._mcp_utils import mcp_to_gemini_tool
68
+ except ImportError:
69
+ McpClientSession = None
70
+ McpTool = None
71
+ McpToGenAiToolAdapter = None
72
+ mcp_to_gemini_tool = None
50
73
 
51
74
  logger = logging.getLogger('google_genai.live')
52
75
 
@@ -56,12 +79,13 @@ _FUNCTION_RESPONSE_REQUIRES_ID = (
56
79
  )
57
80
 
58
81
 
82
+ _DUMMY_KEY = 'dummy_key'
83
+
84
+
59
85
  class AsyncSession:
60
86
  """[Preview] AsyncSession."""
61
87
 
62
- def __init__(
63
- self, api_client: BaseApiClient, websocket: ClientConnection
64
- ):
88
+ def __init__(self, api_client: BaseApiClient, websocket: ClientConnection):
65
89
  self._api_client = api_client
66
90
  self._ws = websocket
67
91
 
@@ -123,7 +147,7 @@ class AsyncSession:
123
147
  Union[
124
148
  types.Content,
125
149
  types.ContentDict,
126
- list[Union[types.Content, types.ContentDict]]
150
+ list[Union[types.Content, types.ContentDict]],
127
151
  ]
128
152
  ] = None,
129
153
  turn_complete: bool = True,
@@ -264,7 +288,7 @@ class AsyncSession:
264
288
  print(f'{msg.text}')
265
289
  ```
266
290
  """
267
- kwargs:dict[str, Any] = {}
291
+ kwargs: dict[str, Any] = {}
268
292
  if media is not None:
269
293
  kwargs['media'] = media
270
294
  if audio is not None:
@@ -506,9 +530,13 @@ class AsyncSession:
506
530
  response = {}
507
531
 
508
532
  if self._api_client.vertexai:
509
- response_dict = live_converters._LiveServerMessage_from_vertex(self._api_client, response)
533
+ response_dict = live_converters._LiveServerMessage_from_vertex(
534
+ self._api_client, response
535
+ )
510
536
  else:
511
- response_dict = live_converters._LiveServerMessage_from_mldev(self._api_client, response)
537
+ response_dict = live_converters._LiveServerMessage_from_mldev(
538
+ self._api_client, response
539
+ )
512
540
 
513
541
  return types.LiveServerMessage._from_response(
514
542
  response=response_dict, kwargs=parameter_model.model_dump()
@@ -522,7 +550,7 @@ class AsyncSession:
522
550
  ) -> None:
523
551
  async for data in data_stream:
524
552
  model_input = types.LiveClientRealtimeInput(
525
- media_chunks=[types.Blob(data=data, mime_type=mime_type)]
553
+ media_chunks=[types.Blob(data=data, mime_type=mime_type)]
526
554
  )
527
555
  await self.send(input=model_input)
528
556
  # Give a chance for the receive loop to process responses.
@@ -560,9 +588,8 @@ class AsyncSession:
560
588
  raise ValueError(
561
589
  f'Unsupported input type "{type(input)}" or input content "{input}"'
562
590
  )
563
- if (
564
- isinstance(blob_input, types.Blob)
565
- and isinstance(blob_input.data, bytes)
591
+ if isinstance(blob_input, types.Blob) and isinstance(
592
+ blob_input.data, bytes
566
593
  ):
567
594
  formatted_input = [
568
595
  blob_input.model_dump(mode='json', exclude_none=True)
@@ -841,6 +868,13 @@ class AsyncSession:
841
868
  class AsyncLive(_api_module.BaseModule):
842
869
  """[Preview] AsyncLive."""
843
870
 
871
+ def __init__(self, api_client: BaseApiClient):
872
+ super().__init__(api_client)
873
+ self._music = AsyncLiveMusic(api_client)
874
+
875
+ @property
876
+ def music(self) -> AsyncLiveMusic:
877
+ return self._music
844
878
 
845
879
  @contextlib.asynccontextmanager
846
880
  async def connect(
@@ -860,30 +894,70 @@ class AsyncLive(_api_module.BaseModule):
860
894
  client = genai.Client(api_key=API_KEY)
861
895
  config = {}
862
896
  async with client.aio.live.connect(model='...', config=config) as session:
863
- await session.send(input='Hello world!', end_of_turn=True)
897
+ await session.send_client_content(
898
+ turns=types.Content(
899
+ role='user',
900
+ parts=[types.Part(text='hello!')]
901
+ ),
902
+ turn_complete=True
903
+ )
864
904
  async for message in session.receive():
865
905
  print(message)
906
+
907
+ Args:
908
+ model: The model to use for the live session.
909
+ config: The configuration for the live session.
910
+ **kwargs: additional keyword arguments.
911
+
912
+ Yields:
913
+ An AsyncSession object.
866
914
  """
915
+ async with self._connect(
916
+ model=model,
917
+ config=config,
918
+ ) as session:
919
+ yield session
920
+
921
+ @contextlib.asynccontextmanager
922
+ async def _connect(
923
+ self,
924
+ *,
925
+ model: Optional[str] = None,
926
+ config: Optional[types.LiveConnectConfigOrDict] = None,
927
+ uri: Optional[str] = None,
928
+ ) -> AsyncIterator[AsyncSession]:
929
+
930
+ # TODO(b/404946570): Support per request http options.
931
+ if isinstance(config, dict):
932
+ config = types.LiveConnectConfig(**config)
933
+ if config and config.http_options and uri is None:
934
+ raise ValueError(
935
+ 'google.genai.client.aio.live.connect() does not support'
936
+ ' http_options at request-level in LiveConnectConfig yet. Please use'
937
+ ' the client-level http_options configuration instead.'
938
+ )
939
+
867
940
  base_url = self._api_client._websocket_base_url()
868
941
  if isinstance(base_url, bytes):
869
942
  base_url = base_url.decode('utf-8')
870
- transformed_model = t.t_model(self._api_client, model)
943
+ transformed_model = t.t_model(self._api_client, model) # type: ignore
871
944
 
872
- parameter_model = _t_live_connect_config(self._api_client, config)
945
+ parameter_model = await _t_live_connect_config(self._api_client, config)
873
946
 
874
947
  if self._api_client.api_key:
875
948
  api_key = self._api_client.api_key
876
949
  version = self._api_client._http_options.api_version
877
- uri = f'{base_url}/ws/google.ai.generativelanguage.{version}.GenerativeService.BidiGenerateContent?key={api_key}'
950
+ if uri is None:
951
+ uri = f'{base_url}/ws/google.ai.generativelanguage.{version}.GenerativeService.BidiGenerateContent?key={api_key}'
878
952
  headers = self._api_client._http_options.headers
879
953
 
880
954
  request_dict = _common.convert_to_dict(
881
955
  live_converters._LiveConnectParameters_to_mldev(
882
956
  api_client=self._api_client,
883
957
  from_object=types.LiveConnectParameters(
884
- model=transformed_model,
885
- config=parameter_model,
886
- ).model_dump(exclude_none=True)
958
+ model=transformed_model,
959
+ config=parameter_model,
960
+ ).model_dump(exclude_none=True),
887
961
  )
888
962
  )
889
963
  del request_dict['config']
@@ -894,15 +968,15 @@ class AsyncLive(_api_module.BaseModule):
894
968
  else:
895
969
  if not self._api_client._credentials:
896
970
  # Get bearer token through Application Default Credentials.
897
- creds, _ = google.auth.default( # type: ignore[no-untyped-call]
898
- scopes=['https://www.googleapis.com/auth/cloud-platform']
971
+ creds, _ = google.auth.default( # type: ignore
972
+ scopes=['https://www.googleapis.com/auth/cloud-platform']
899
973
  )
900
974
  else:
901
975
  creds = self._api_client._credentials
902
976
  # creds.valid is False, and creds.token is None
903
977
  # Need to refresh credentials to populate those
904
978
  if not (creds.token and creds.valid):
905
- auth_req = google.auth.transport.requests.Request() # type: ignore[no-untyped-call]
979
+ auth_req = google.auth.transport.requests.Request() # type: ignore
906
980
  creds.refresh(auth_req)
907
981
  bearer_token = creds.token
908
982
  headers = self._api_client._http_options.headers
@@ -922,34 +996,126 @@ class AsyncLive(_api_module.BaseModule):
922
996
  live_converters._LiveConnectParameters_to_vertex(
923
997
  api_client=self._api_client,
924
998
  from_object=types.LiveConnectParameters(
925
- model=transformed_model,
926
- config=parameter_model,
927
- ).model_dump(exclude_none=True)
999
+ model=transformed_model,
1000
+ config=parameter_model,
1001
+ ).model_dump(exclude_none=True),
928
1002
  )
929
1003
  )
930
1004
  del request_dict['config']
931
1005
 
932
- if getv(request_dict, ['setup', 'generationConfig', 'responseModalities']) is None:
933
- setv(request_dict, ['setup', 'generationConfig', 'responseModalities'], ['AUDIO'])
1006
+ if (
1007
+ getv(
1008
+ request_dict, ['setup', 'generationConfig', 'responseModalities']
1009
+ )
1010
+ is None
1011
+ ):
1012
+ setv(
1013
+ request_dict,
1014
+ ['setup', 'generationConfig', 'responseModalities'],
1015
+ ['AUDIO'],
1016
+ )
934
1017
 
935
1018
  request = json.dumps(request_dict)
936
1019
 
1020
+ if parameter_model.tools and _mcp_utils.has_mcp_tool_usage(
1021
+ parameter_model.tools
1022
+ ):
1023
+ if headers is None:
1024
+ headers = {}
1025
+ _mcp_utils.set_mcp_usage_header(headers)
937
1026
  try:
938
- async with connect(uri, additional_headers=headers) as ws:
1027
+ async with ws_connect(uri, additional_headers=headers) as ws:
939
1028
  await ws.send(request)
940
1029
  logger.info(await ws.recv(decode=False))
941
1030
 
942
1031
  yield AsyncSession(api_client=self._api_client, websocket=ws)
943
1032
  except TypeError:
944
1033
  # Try with the older websockets API
945
- async with connect(uri, extra_headers=headers) as ws:
1034
+ async with ws_connect(uri, extra_headers=headers) as ws:
946
1035
  await ws.send(request)
947
1036
  logger.info(await ws.recv())
948
1037
 
949
1038
  yield AsyncSession(api_client=self._api_client, websocket=ws)
950
1039
 
951
1040
 
952
- def _t_live_connect_config(
1041
+ @_common.experimental_warning(
1042
+ "The SDK's Live API connection with ephemeral token implementation is"
1043
+ ' experimental, and may change in future versions.',
1044
+ )
1045
+ @contextlib.asynccontextmanager
1046
+ async def live_ephemeral_connect(
1047
+ access_token: str,
1048
+ model: Optional[str] = None,
1049
+ config: Optional[types.LiveConnectConfigOrDict] = None,
1050
+ ) -> AsyncIterator[AsyncSession]:
1051
+ """[Experimental] Connect to the live server using ephermeral token (Gemini Developer API only).
1052
+
1053
+ Note: the live API is currently in experimental.
1054
+
1055
+ Usage:
1056
+
1057
+ .. code-block:: python
1058
+ from google import genai
1059
+
1060
+ config = {}
1061
+ async with genai.live_ephemeral_connect(
1062
+ access_token='auth_tokens/12345',
1063
+ model='...',
1064
+ config=config,
1065
+ http_options=types.HttpOptions(api_version='v1beta'),
1066
+ ) as session:
1067
+ await session.send_client_content(
1068
+ turns=types.Content(
1069
+ role='user',
1070
+ parts=[types.Part(text='hello!')]
1071
+ ),
1072
+ turn_complete=True
1073
+ )
1074
+
1075
+ async for message in session.receive():
1076
+ print(message)
1077
+
1078
+ Args:
1079
+ access_token: The access token to use for the Live session. It can be
1080
+ generated by the `client.tokens.create` method.
1081
+ model: The model to use for the Live session.
1082
+ config: The configuration for the Live session.
1083
+
1084
+ Yields:
1085
+ An AsyncSession object.
1086
+ """
1087
+ if isinstance(config, dict):
1088
+ config = types.LiveConnectConfig(**config)
1089
+
1090
+ http_options = config.http_options if config else None
1091
+
1092
+ base_url = (
1093
+ http_options.base_url
1094
+ if http_options and http_options.base_url
1095
+ else 'https://generativelanguage.googleapis.com/'
1096
+ )
1097
+ api_version = (
1098
+ http_options.api_version
1099
+ if http_options and http_options.api_version
1100
+ else 'v1beta'
1101
+ )
1102
+ internal_client = client.Client(
1103
+ api_key=_DUMMY_KEY, # Can't be None during initialization
1104
+ http_options=types.HttpOptions(
1105
+ base_url=base_url,
1106
+ api_version=api_version,
1107
+ ),
1108
+ )
1109
+ websocket_base_url = internal_client._api_client._websocket_base_url()
1110
+ uri = f'{websocket_base_url}/ws/google.ai.generativelanguage.{api_version}.GenerativeService.BidiGenerateContentConstrained?access_token={access_token}'
1111
+
1112
+ async with internal_client.aio.live._connect(
1113
+ model=model, config=config, uri=uri
1114
+ ) as session:
1115
+ yield session
1116
+
1117
+
1118
+ async def _t_live_connect_config(
953
1119
  api_client: BaseApiClient,
954
1120
  config: Optional[types.LiveConnectConfigOrDict],
955
1121
  ) -> types.LiveConnectConfig:
@@ -975,7 +1141,24 @@ def _t_live_connect_config(
975
1141
  parameter_model = config
976
1142
  parameter_model.system_instruction = system_instruction
977
1143
 
978
- if parameter_model.generation_config is not None:
1144
+ # Create a copy of the config model with the tools field cleared as they will
1145
+ # be replaced with the MCP tools converted to GenAI tools.
1146
+ parameter_model_copy = parameter_model.model_copy(update={'tools': None})
1147
+ if parameter_model.tools:
1148
+ parameter_model_copy.tools = []
1149
+ for tool in parameter_model.tools:
1150
+ if McpClientSession is not None and isinstance(tool, McpClientSession):
1151
+ mcp_to_genai_tool_adapter = McpToGenAiToolAdapter(
1152
+ tool, await tool.list_tools()
1153
+ )
1154
+ # Extend the config with the MCP session tools converted to GenAI tools.
1155
+ parameter_model_copy.tools.extend(mcp_to_genai_tool_adapter.tools)
1156
+ elif McpTool is not None and isinstance(tool, McpTool):
1157
+ parameter_model_copy.tools.append(mcp_to_gemini_tool(tool))
1158
+ else:
1159
+ parameter_model_copy.tools.append(tool)
1160
+
1161
+ if parameter_model_copy.generation_config is not None:
979
1162
  warnings.warn(
980
1163
  'Setting `LiveConnectConfig.generation_config` is deprecated, '
981
1164
  'please set the fields on `LiveConnectConfig` directly. This will '
@@ -984,4 +1167,4 @@ def _t_live_connect_config(
984
1167
  stacklevel=4,
985
1168
  )
986
1169
 
987
- return parameter_model
1170
+ return parameter_model_copy
@@ -0,0 +1,201 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+
16
+ """[Experimental] Live Music API client."""
17
+
18
+ import contextlib
19
+ import json
20
+ import logging
21
+ from typing import AsyncIterator
22
+
23
+ from . import _api_module
24
+ from . import _common
25
+ from . import _transformers as t
26
+ from . import types
27
+ from ._api_client import BaseApiClient
28
+ from ._common import set_value_by_path as setv
29
+ from . import _live_converters as live_converters
30
+ from .models import _Content_to_mldev
31
+ from .models import _Content_to_vertex
32
+
33
+
34
+ try:
35
+ from websockets.asyncio.client import ClientConnection
36
+ from websockets.asyncio.client import connect
37
+ except ModuleNotFoundError:
38
+ from websockets.client import ClientConnection # type: ignore
39
+ from websockets.client import connect # type: ignore
40
+
41
+ logger = logging.getLogger('google_genai.live_music')
42
+
43
+
44
+ class AsyncMusicSession:
45
+ """[Experimental] AsyncMusicSession."""
46
+
47
+ def __init__(
48
+ self, api_client: BaseApiClient, websocket: ClientConnection
49
+ ):
50
+ self._api_client = api_client
51
+ self._ws = websocket
52
+
53
+ async def set_weighted_prompts(
54
+ self,
55
+ prompts: list[types.WeightedPrompt]
56
+ ) -> None:
57
+ if self._api_client.vertexai:
58
+ raise NotImplementedError('Live music generation is not supported in Vertex AI.')
59
+ else:
60
+ client_content_dict = live_converters._LiveMusicClientContent_to_mldev(
61
+ api_client=self._api_client, from_object={'weighted_prompts': prompts}
62
+ )
63
+ await self._ws.send(json.dumps({'clientContent': client_content_dict}))
64
+
65
+ async def set_music_generation_config(
66
+ self,
67
+ config: types.LiveMusicGenerationConfig
68
+ ) -> None:
69
+ if self._api_client.vertexai:
70
+ raise NotImplementedError('Live music generation is not supported in Vertex AI.')
71
+ else:
72
+ config_dict = live_converters._LiveMusicGenerationConfig_to_mldev(
73
+ api_client=self._api_client, from_object=config
74
+ )
75
+ await self._ws.send(json.dumps({'musicGenerationConfig': config_dict}))
76
+
77
+ async def _send_control_signal(
78
+ self,
79
+ playback_control: types.LiveMusicPlaybackControl
80
+ ) -> None:
81
+ if self._api_client.vertexai:
82
+ raise NotImplementedError('Live music generation is not supported in Vertex AI.')
83
+ else:
84
+ playback_control_dict = live_converters._LiveMusicClientMessage_to_mldev(
85
+ api_client=self._api_client, from_object={'playback_control': playback_control}
86
+ )
87
+ await self._ws.send(json.dumps(playback_control_dict))
88
+
89
+ async def play(self) -> None:
90
+ """Sends playback signal to start the music stream."""
91
+ return await self._send_control_signal(types.LiveMusicPlaybackControl.PLAY)
92
+
93
+ async def pause(self) -> None:
94
+ """Sends a playback signal to pause the music stream."""
95
+ return await self._send_control_signal(types.LiveMusicPlaybackControl.PAUSE)
96
+
97
+ async def stop(self) -> None:
98
+ """Sends a playback signal to stop the music stream.
99
+
100
+ Resets the music generation context while retaining the current config.
101
+ """
102
+ return await self._send_control_signal(types.LiveMusicPlaybackControl.STOP)
103
+
104
+ async def reset_context(self) -> None:
105
+ """Reset the context (prompts retained) without stopping the music generation."""
106
+ return await self._send_control_signal(
107
+ types.LiveMusicPlaybackControl.RESET_CONTEXT
108
+ )
109
+
110
+ async def receive(self) -> AsyncIterator[types.LiveMusicServerMessage]:
111
+ """Receive model responses from the server.
112
+
113
+ Yields:
114
+ The audio chunks from the server.
115
+ """
116
+ # TODO(b/365983264) Handle intermittent issues for the user.
117
+ while result := await self._receive():
118
+ yield result
119
+
120
+ async def _receive(self) -> types.LiveMusicServerMessage:
121
+ parameter_model = types.LiveMusicServerMessage()
122
+ try:
123
+ raw_response = await self._ws.recv(decode=False)
124
+ except TypeError:
125
+ raw_response = await self._ws.recv() # type: ignore[assignment]
126
+ if raw_response:
127
+ try:
128
+ response = json.loads(raw_response)
129
+ except json.decoder.JSONDecodeError:
130
+ raise ValueError(f'Failed to parse response: {raw_response!r}')
131
+ else:
132
+ response = {}
133
+
134
+ if self._api_client.vertexai:
135
+ raise NotImplementedError('Live music generation is not supported in Vertex AI.')
136
+ else:
137
+ response_dict = live_converters._LiveMusicServerMessage_from_mldev(
138
+ self._api_client, response
139
+ )
140
+
141
+ return types.LiveMusicServerMessage._from_response(
142
+ response=response_dict, kwargs=parameter_model.model_dump()
143
+ )
144
+
145
+ async def close(self) -> None:
146
+ """Closes the bi-directional stream and terminates the session."""
147
+ await self._ws.close()
148
+
149
+
150
+ class AsyncLiveMusic(_api_module.BaseModule):
151
+ """[Experimental] Live music module.
152
+
153
+ Live music can be accessed via `client.aio.live.music`.
154
+ """
155
+
156
+ @_common.experimental_warning(
157
+ 'Realtime music generation is experimental and may change in future versions.'
158
+ )
159
+ @contextlib.asynccontextmanager
160
+ async def connect(self, *, model: str) -> AsyncIterator[AsyncMusicSession]:
161
+ """[Experimental] Connect to the live music server."""
162
+ base_url = self._api_client._websocket_base_url()
163
+ if isinstance(base_url, bytes):
164
+ base_url = base_url.decode('utf-8')
165
+ transformed_model = t.t_model(self._api_client, model)
166
+
167
+ if self._api_client.api_key:
168
+ api_key = self._api_client.api_key
169
+ version = self._api_client._http_options.api_version
170
+ uri = f'{base_url}/ws/google.ai.generativelanguage.{version}.GenerativeService.BidiGenerateMusic?key={api_key}'
171
+ headers = self._api_client._http_options.headers
172
+
173
+ # Only mldev supported
174
+ request_dict = _common.convert_to_dict(
175
+ live_converters._LiveMusicConnectParameters_to_mldev(
176
+ api_client=self._api_client,
177
+ from_object=types.LiveMusicConnectParameters(
178
+ model=transformed_model,
179
+ ).model_dump(exclude_none=True)
180
+ )
181
+ )
182
+
183
+ setv(request_dict, ['setup', 'model'], transformed_model)
184
+
185
+ request = json.dumps(request_dict)
186
+ else:
187
+ raise NotImplementedError('Live music generation is not supported in Vertex AI.')
188
+
189
+ try:
190
+ async with connect(uri, additional_headers=headers) as ws:
191
+ await ws.send(request)
192
+ logger.info(await ws.recv(decode=False))
193
+
194
+ yield AsyncMusicSession(api_client=self._api_client, websocket=ws)
195
+ except TypeError:
196
+ # Try with the older websockets API
197
+ async with connect(uri, extra_headers=headers) as ws:
198
+ await ws.send(request)
199
+ logger.info(await ws.recv())
200
+
201
+ yield AsyncMusicSession(api_client=self._api_client, websocket=ws)