google-genai 1.15.0__py3-none-any.whl → 1.16.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google/genai/__init__.py +5 -3
- google/genai/_adapters.py +55 -0
- google/genai/_api_client.py +3 -3
- google/genai/_api_module.py +1 -1
- google/genai/_automatic_function_calling_util.py +1 -1
- google/genai/_common.py +1 -1
- google/genai/_extra_utils.py +114 -9
- google/genai/_live_converters.py +1295 -20
- google/genai/_mcp_utils.py +117 -0
- google/genai/_replay_api_client.py +1 -1
- google/genai/_test_api_client.py +1 -1
- google/genai/_tokens_converters.py +1701 -0
- google/genai/_transformers.py +66 -33
- google/genai/caches.py +223 -20
- google/genai/chats.py +1 -1
- google/genai/client.py +12 -1
- google/genai/errors.py +1 -1
- google/genai/live.py +218 -35
- google/genai/live_music.py +201 -0
- google/genai/models.py +505 -44
- google/genai/pagers.py +1 -1
- google/genai/tokens.py +357 -0
- google/genai/types.py +7887 -6765
- google/genai/version.py +2 -2
- {google_genai-1.15.0.dist-info → google_genai-1.16.1.dist-info}/METADATA +8 -4
- google_genai-1.16.1.dist-info/RECORD +35 -0
- {google_genai-1.15.0.dist-info → google_genai-1.16.1.dist-info}/WHEEL +1 -1
- google_genai-1.15.0.dist-info/RECORD +0 -30
- {google_genai-1.15.0.dist-info → google_genai-1.16.1.dist-info}/licenses/LICENSE +0 -0
- {google_genai-1.15.0.dist-info → google_genai-1.16.1.dist-info}/top_level.txt +0 -0
google/genai/live.py
CHANGED
@@ -20,7 +20,8 @@ import base64
|
|
20
20
|
import contextlib
|
21
21
|
import json
|
22
22
|
import logging
|
23
|
-
|
23
|
+
import typing
|
24
|
+
from typing import Any, AsyncIterator, Dict, Optional, Sequence, Union, get_args
|
24
25
|
import warnings
|
25
26
|
|
26
27
|
import google.auth
|
@@ -29,24 +30,46 @@ from websockets import ConnectionClosed
|
|
29
30
|
|
30
31
|
from . import _api_module
|
31
32
|
from . import _common
|
33
|
+
from . import _live_converters as live_converters
|
34
|
+
from . import _mcp_utils
|
32
35
|
from . import _transformers as t
|
33
36
|
from . import client
|
34
37
|
from . import types
|
35
38
|
from ._api_client import BaseApiClient
|
36
39
|
from ._common import get_value_by_path as getv
|
37
40
|
from ._common import set_value_by_path as setv
|
38
|
-
from . import
|
41
|
+
from .live_music import AsyncLiveMusic
|
39
42
|
from .models import _Content_to_mldev
|
40
43
|
from .models import _Content_to_vertex
|
41
44
|
|
42
45
|
|
43
46
|
try:
|
44
47
|
from websockets.asyncio.client import ClientConnection
|
45
|
-
from websockets.asyncio.client import connect
|
48
|
+
from websockets.asyncio.client import connect as ws_connect
|
46
49
|
except ModuleNotFoundError:
|
47
50
|
# This try/except is for TAP, mypy complains about it which is why we have the type: ignore
|
48
51
|
from websockets.client import ClientConnection # type: ignore
|
49
|
-
from websockets.client import connect # type: ignore
|
52
|
+
from websockets.client import connect as ws_connect # type: ignore
|
53
|
+
|
54
|
+
if typing.TYPE_CHECKING:
|
55
|
+
from mcp import ClientSession as McpClientSession
|
56
|
+
from mcp.types import Tool as McpTool
|
57
|
+
from ._adapters import McpToGenAiToolAdapter
|
58
|
+
from ._mcp_utils import mcp_to_gemini_tool
|
59
|
+
else:
|
60
|
+
McpClientSession: typing.Type = Any
|
61
|
+
McpTool: typing.Type = Any
|
62
|
+
McpToGenAiToolAdapter: typing.Type = Any
|
63
|
+
try:
|
64
|
+
from mcp import ClientSession as McpClientSession
|
65
|
+
from mcp.types import Tool as McpTool
|
66
|
+
from ._adapters import McpToGenAiToolAdapter
|
67
|
+
from ._mcp_utils import mcp_to_gemini_tool
|
68
|
+
except ImportError:
|
69
|
+
McpClientSession = None
|
70
|
+
McpTool = None
|
71
|
+
McpToGenAiToolAdapter = None
|
72
|
+
mcp_to_gemini_tool = None
|
50
73
|
|
51
74
|
logger = logging.getLogger('google_genai.live')
|
52
75
|
|
@@ -56,12 +79,13 @@ _FUNCTION_RESPONSE_REQUIRES_ID = (
|
|
56
79
|
)
|
57
80
|
|
58
81
|
|
82
|
+
_DUMMY_KEY = 'dummy_key'
|
83
|
+
|
84
|
+
|
59
85
|
class AsyncSession:
|
60
86
|
"""[Preview] AsyncSession."""
|
61
87
|
|
62
|
-
def __init__(
|
63
|
-
self, api_client: BaseApiClient, websocket: ClientConnection
|
64
|
-
):
|
88
|
+
def __init__(self, api_client: BaseApiClient, websocket: ClientConnection):
|
65
89
|
self._api_client = api_client
|
66
90
|
self._ws = websocket
|
67
91
|
|
@@ -123,7 +147,7 @@ class AsyncSession:
|
|
123
147
|
Union[
|
124
148
|
types.Content,
|
125
149
|
types.ContentDict,
|
126
|
-
list[Union[types.Content, types.ContentDict]]
|
150
|
+
list[Union[types.Content, types.ContentDict]],
|
127
151
|
]
|
128
152
|
] = None,
|
129
153
|
turn_complete: bool = True,
|
@@ -264,7 +288,7 @@ class AsyncSession:
|
|
264
288
|
print(f'{msg.text}')
|
265
289
|
```
|
266
290
|
"""
|
267
|
-
kwargs:dict[str, Any] = {}
|
291
|
+
kwargs: dict[str, Any] = {}
|
268
292
|
if media is not None:
|
269
293
|
kwargs['media'] = media
|
270
294
|
if audio is not None:
|
@@ -506,9 +530,13 @@ class AsyncSession:
|
|
506
530
|
response = {}
|
507
531
|
|
508
532
|
if self._api_client.vertexai:
|
509
|
-
response_dict = live_converters._LiveServerMessage_from_vertex(
|
533
|
+
response_dict = live_converters._LiveServerMessage_from_vertex(
|
534
|
+
self._api_client, response
|
535
|
+
)
|
510
536
|
else:
|
511
|
-
response_dict = live_converters._LiveServerMessage_from_mldev(
|
537
|
+
response_dict = live_converters._LiveServerMessage_from_mldev(
|
538
|
+
self._api_client, response
|
539
|
+
)
|
512
540
|
|
513
541
|
return types.LiveServerMessage._from_response(
|
514
542
|
response=response_dict, kwargs=parameter_model.model_dump()
|
@@ -522,7 +550,7 @@ class AsyncSession:
|
|
522
550
|
) -> None:
|
523
551
|
async for data in data_stream:
|
524
552
|
model_input = types.LiveClientRealtimeInput(
|
525
|
-
|
553
|
+
media_chunks=[types.Blob(data=data, mime_type=mime_type)]
|
526
554
|
)
|
527
555
|
await self.send(input=model_input)
|
528
556
|
# Give a chance for the receive loop to process responses.
|
@@ -560,9 +588,8 @@ class AsyncSession:
|
|
560
588
|
raise ValueError(
|
561
589
|
f'Unsupported input type "{type(input)}" or input content "{input}"'
|
562
590
|
)
|
563
|
-
if (
|
564
|
-
|
565
|
-
and isinstance(blob_input.data, bytes)
|
591
|
+
if isinstance(blob_input, types.Blob) and isinstance(
|
592
|
+
blob_input.data, bytes
|
566
593
|
):
|
567
594
|
formatted_input = [
|
568
595
|
blob_input.model_dump(mode='json', exclude_none=True)
|
@@ -841,6 +868,13 @@ class AsyncSession:
|
|
841
868
|
class AsyncLive(_api_module.BaseModule):
|
842
869
|
"""[Preview] AsyncLive."""
|
843
870
|
|
871
|
+
def __init__(self, api_client: BaseApiClient):
|
872
|
+
super().__init__(api_client)
|
873
|
+
self._music = AsyncLiveMusic(api_client)
|
874
|
+
|
875
|
+
@property
|
876
|
+
def music(self) -> AsyncLiveMusic:
|
877
|
+
return self._music
|
844
878
|
|
845
879
|
@contextlib.asynccontextmanager
|
846
880
|
async def connect(
|
@@ -860,30 +894,70 @@ class AsyncLive(_api_module.BaseModule):
|
|
860
894
|
client = genai.Client(api_key=API_KEY)
|
861
895
|
config = {}
|
862
896
|
async with client.aio.live.connect(model='...', config=config) as session:
|
863
|
-
await session.
|
897
|
+
await session.send_client_content(
|
898
|
+
turns=types.Content(
|
899
|
+
role='user',
|
900
|
+
parts=[types.Part(text='hello!')]
|
901
|
+
),
|
902
|
+
turn_complete=True
|
903
|
+
)
|
864
904
|
async for message in session.receive():
|
865
905
|
print(message)
|
906
|
+
|
907
|
+
Args:
|
908
|
+
model: The model to use for the live session.
|
909
|
+
config: The configuration for the live session.
|
910
|
+
**kwargs: additional keyword arguments.
|
911
|
+
|
912
|
+
Yields:
|
913
|
+
An AsyncSession object.
|
866
914
|
"""
|
915
|
+
async with self._connect(
|
916
|
+
model=model,
|
917
|
+
config=config,
|
918
|
+
) as session:
|
919
|
+
yield session
|
920
|
+
|
921
|
+
@contextlib.asynccontextmanager
|
922
|
+
async def _connect(
|
923
|
+
self,
|
924
|
+
*,
|
925
|
+
model: Optional[str] = None,
|
926
|
+
config: Optional[types.LiveConnectConfigOrDict] = None,
|
927
|
+
uri: Optional[str] = None,
|
928
|
+
) -> AsyncIterator[AsyncSession]:
|
929
|
+
|
930
|
+
# TODO(b/404946570): Support per request http options.
|
931
|
+
if isinstance(config, dict):
|
932
|
+
config = types.LiveConnectConfig(**config)
|
933
|
+
if config and config.http_options and uri is None:
|
934
|
+
raise ValueError(
|
935
|
+
'google.genai.client.aio.live.connect() does not support'
|
936
|
+
' http_options at request-level in LiveConnectConfig yet. Please use'
|
937
|
+
' the client-level http_options configuration instead.'
|
938
|
+
)
|
939
|
+
|
867
940
|
base_url = self._api_client._websocket_base_url()
|
868
941
|
if isinstance(base_url, bytes):
|
869
942
|
base_url = base_url.decode('utf-8')
|
870
|
-
transformed_model = t.t_model(self._api_client, model)
|
943
|
+
transformed_model = t.t_model(self._api_client, model) # type: ignore
|
871
944
|
|
872
|
-
parameter_model = _t_live_connect_config(self._api_client, config)
|
945
|
+
parameter_model = await _t_live_connect_config(self._api_client, config)
|
873
946
|
|
874
947
|
if self._api_client.api_key:
|
875
948
|
api_key = self._api_client.api_key
|
876
949
|
version = self._api_client._http_options.api_version
|
877
|
-
uri
|
950
|
+
if uri is None:
|
951
|
+
uri = f'{base_url}/ws/google.ai.generativelanguage.{version}.GenerativeService.BidiGenerateContent?key={api_key}'
|
878
952
|
headers = self._api_client._http_options.headers
|
879
953
|
|
880
954
|
request_dict = _common.convert_to_dict(
|
881
955
|
live_converters._LiveConnectParameters_to_mldev(
|
882
956
|
api_client=self._api_client,
|
883
957
|
from_object=types.LiveConnectParameters(
|
884
|
-
|
885
|
-
|
886
|
-
).model_dump(exclude_none=True)
|
958
|
+
model=transformed_model,
|
959
|
+
config=parameter_model,
|
960
|
+
).model_dump(exclude_none=True),
|
887
961
|
)
|
888
962
|
)
|
889
963
|
del request_dict['config']
|
@@ -894,15 +968,15 @@ class AsyncLive(_api_module.BaseModule):
|
|
894
968
|
else:
|
895
969
|
if not self._api_client._credentials:
|
896
970
|
# Get bearer token through Application Default Credentials.
|
897
|
-
creds, _ = google.auth.default( # type: ignore
|
898
|
-
|
971
|
+
creds, _ = google.auth.default( # type: ignore
|
972
|
+
scopes=['https://www.googleapis.com/auth/cloud-platform']
|
899
973
|
)
|
900
974
|
else:
|
901
975
|
creds = self._api_client._credentials
|
902
976
|
# creds.valid is False, and creds.token is None
|
903
977
|
# Need to refresh credentials to populate those
|
904
978
|
if not (creds.token and creds.valid):
|
905
|
-
auth_req = google.auth.transport.requests.Request() # type: ignore
|
979
|
+
auth_req = google.auth.transport.requests.Request() # type: ignore
|
906
980
|
creds.refresh(auth_req)
|
907
981
|
bearer_token = creds.token
|
908
982
|
headers = self._api_client._http_options.headers
|
@@ -922,34 +996,126 @@ class AsyncLive(_api_module.BaseModule):
|
|
922
996
|
live_converters._LiveConnectParameters_to_vertex(
|
923
997
|
api_client=self._api_client,
|
924
998
|
from_object=types.LiveConnectParameters(
|
925
|
-
|
926
|
-
|
927
|
-
).model_dump(exclude_none=True)
|
999
|
+
model=transformed_model,
|
1000
|
+
config=parameter_model,
|
1001
|
+
).model_dump(exclude_none=True),
|
928
1002
|
)
|
929
1003
|
)
|
930
1004
|
del request_dict['config']
|
931
1005
|
|
932
|
-
if
|
933
|
-
|
1006
|
+
if (
|
1007
|
+
getv(
|
1008
|
+
request_dict, ['setup', 'generationConfig', 'responseModalities']
|
1009
|
+
)
|
1010
|
+
is None
|
1011
|
+
):
|
1012
|
+
setv(
|
1013
|
+
request_dict,
|
1014
|
+
['setup', 'generationConfig', 'responseModalities'],
|
1015
|
+
['AUDIO'],
|
1016
|
+
)
|
934
1017
|
|
935
1018
|
request = json.dumps(request_dict)
|
936
1019
|
|
1020
|
+
if parameter_model.tools and _mcp_utils.has_mcp_tool_usage(
|
1021
|
+
parameter_model.tools
|
1022
|
+
):
|
1023
|
+
if headers is None:
|
1024
|
+
headers = {}
|
1025
|
+
_mcp_utils.set_mcp_usage_header(headers)
|
937
1026
|
try:
|
938
|
-
async with
|
1027
|
+
async with ws_connect(uri, additional_headers=headers) as ws:
|
939
1028
|
await ws.send(request)
|
940
1029
|
logger.info(await ws.recv(decode=False))
|
941
1030
|
|
942
1031
|
yield AsyncSession(api_client=self._api_client, websocket=ws)
|
943
1032
|
except TypeError:
|
944
1033
|
# Try with the older websockets API
|
945
|
-
async with
|
1034
|
+
async with ws_connect(uri, extra_headers=headers) as ws:
|
946
1035
|
await ws.send(request)
|
947
1036
|
logger.info(await ws.recv())
|
948
1037
|
|
949
1038
|
yield AsyncSession(api_client=self._api_client, websocket=ws)
|
950
1039
|
|
951
1040
|
|
952
|
-
|
1041
|
+
@_common.experimental_warning(
|
1042
|
+
"The SDK's Live API connection with ephemeral token implementation is"
|
1043
|
+
' experimental, and may change in future versions.',
|
1044
|
+
)
|
1045
|
+
@contextlib.asynccontextmanager
|
1046
|
+
async def live_ephemeral_connect(
|
1047
|
+
access_token: str,
|
1048
|
+
model: Optional[str] = None,
|
1049
|
+
config: Optional[types.LiveConnectConfigOrDict] = None,
|
1050
|
+
) -> AsyncIterator[AsyncSession]:
|
1051
|
+
"""[Experimental] Connect to the live server using ephermeral token (Gemini Developer API only).
|
1052
|
+
|
1053
|
+
Note: the live API is currently in experimental.
|
1054
|
+
|
1055
|
+
Usage:
|
1056
|
+
|
1057
|
+
.. code-block:: python
|
1058
|
+
from google import genai
|
1059
|
+
|
1060
|
+
config = {}
|
1061
|
+
async with genai.live_ephemeral_connect(
|
1062
|
+
access_token='auth_tokens/12345',
|
1063
|
+
model='...',
|
1064
|
+
config=config,
|
1065
|
+
http_options=types.HttpOptions(api_version='v1beta'),
|
1066
|
+
) as session:
|
1067
|
+
await session.send_client_content(
|
1068
|
+
turns=types.Content(
|
1069
|
+
role='user',
|
1070
|
+
parts=[types.Part(text='hello!')]
|
1071
|
+
),
|
1072
|
+
turn_complete=True
|
1073
|
+
)
|
1074
|
+
|
1075
|
+
async for message in session.receive():
|
1076
|
+
print(message)
|
1077
|
+
|
1078
|
+
Args:
|
1079
|
+
access_token: The access token to use for the Live session. It can be
|
1080
|
+
generated by the `client.tokens.create` method.
|
1081
|
+
model: The model to use for the Live session.
|
1082
|
+
config: The configuration for the Live session.
|
1083
|
+
|
1084
|
+
Yields:
|
1085
|
+
An AsyncSession object.
|
1086
|
+
"""
|
1087
|
+
if isinstance(config, dict):
|
1088
|
+
config = types.LiveConnectConfig(**config)
|
1089
|
+
|
1090
|
+
http_options = config.http_options if config else None
|
1091
|
+
|
1092
|
+
base_url = (
|
1093
|
+
http_options.base_url
|
1094
|
+
if http_options and http_options.base_url
|
1095
|
+
else 'https://generativelanguage.googleapis.com/'
|
1096
|
+
)
|
1097
|
+
api_version = (
|
1098
|
+
http_options.api_version
|
1099
|
+
if http_options and http_options.api_version
|
1100
|
+
else 'v1beta'
|
1101
|
+
)
|
1102
|
+
internal_client = client.Client(
|
1103
|
+
api_key=_DUMMY_KEY, # Can't be None during initialization
|
1104
|
+
http_options=types.HttpOptions(
|
1105
|
+
base_url=base_url,
|
1106
|
+
api_version=api_version,
|
1107
|
+
),
|
1108
|
+
)
|
1109
|
+
websocket_base_url = internal_client._api_client._websocket_base_url()
|
1110
|
+
uri = f'{websocket_base_url}/ws/google.ai.generativelanguage.{api_version}.GenerativeService.BidiGenerateContentConstrained?access_token={access_token}'
|
1111
|
+
|
1112
|
+
async with internal_client.aio.live._connect(
|
1113
|
+
model=model, config=config, uri=uri
|
1114
|
+
) as session:
|
1115
|
+
yield session
|
1116
|
+
|
1117
|
+
|
1118
|
+
async def _t_live_connect_config(
|
953
1119
|
api_client: BaseApiClient,
|
954
1120
|
config: Optional[types.LiveConnectConfigOrDict],
|
955
1121
|
) -> types.LiveConnectConfig:
|
@@ -975,7 +1141,24 @@ def _t_live_connect_config(
|
|
975
1141
|
parameter_model = config
|
976
1142
|
parameter_model.system_instruction = system_instruction
|
977
1143
|
|
978
|
-
|
1144
|
+
# Create a copy of the config model with the tools field cleared as they will
|
1145
|
+
# be replaced with the MCP tools converted to GenAI tools.
|
1146
|
+
parameter_model_copy = parameter_model.model_copy(update={'tools': None})
|
1147
|
+
if parameter_model.tools:
|
1148
|
+
parameter_model_copy.tools = []
|
1149
|
+
for tool in parameter_model.tools:
|
1150
|
+
if McpClientSession is not None and isinstance(tool, McpClientSession):
|
1151
|
+
mcp_to_genai_tool_adapter = McpToGenAiToolAdapter(
|
1152
|
+
tool, await tool.list_tools()
|
1153
|
+
)
|
1154
|
+
# Extend the config with the MCP session tools converted to GenAI tools.
|
1155
|
+
parameter_model_copy.tools.extend(mcp_to_genai_tool_adapter.tools)
|
1156
|
+
elif McpTool is not None and isinstance(tool, McpTool):
|
1157
|
+
parameter_model_copy.tools.append(mcp_to_gemini_tool(tool))
|
1158
|
+
else:
|
1159
|
+
parameter_model_copy.tools.append(tool)
|
1160
|
+
|
1161
|
+
if parameter_model_copy.generation_config is not None:
|
979
1162
|
warnings.warn(
|
980
1163
|
'Setting `LiveConnectConfig.generation_config` is deprecated, '
|
981
1164
|
'please set the fields on `LiveConnectConfig` directly. This will '
|
@@ -984,4 +1167,4 @@ def _t_live_connect_config(
|
|
984
1167
|
stacklevel=4,
|
985
1168
|
)
|
986
1169
|
|
987
|
-
return
|
1170
|
+
return parameter_model_copy
|
@@ -0,0 +1,201 @@
|
|
1
|
+
# Copyright 2025 Google LLC
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
#
|
15
|
+
|
16
|
+
"""[Experimental] Live Music API client."""
|
17
|
+
|
18
|
+
import contextlib
|
19
|
+
import json
|
20
|
+
import logging
|
21
|
+
from typing import AsyncIterator
|
22
|
+
|
23
|
+
from . import _api_module
|
24
|
+
from . import _common
|
25
|
+
from . import _transformers as t
|
26
|
+
from . import types
|
27
|
+
from ._api_client import BaseApiClient
|
28
|
+
from ._common import set_value_by_path as setv
|
29
|
+
from . import _live_converters as live_converters
|
30
|
+
from .models import _Content_to_mldev
|
31
|
+
from .models import _Content_to_vertex
|
32
|
+
|
33
|
+
|
34
|
+
try:
|
35
|
+
from websockets.asyncio.client import ClientConnection
|
36
|
+
from websockets.asyncio.client import connect
|
37
|
+
except ModuleNotFoundError:
|
38
|
+
from websockets.client import ClientConnection # type: ignore
|
39
|
+
from websockets.client import connect # type: ignore
|
40
|
+
|
41
|
+
logger = logging.getLogger('google_genai.live_music')
|
42
|
+
|
43
|
+
|
44
|
+
class AsyncMusicSession:
|
45
|
+
"""[Experimental] AsyncMusicSession."""
|
46
|
+
|
47
|
+
def __init__(
|
48
|
+
self, api_client: BaseApiClient, websocket: ClientConnection
|
49
|
+
):
|
50
|
+
self._api_client = api_client
|
51
|
+
self._ws = websocket
|
52
|
+
|
53
|
+
async def set_weighted_prompts(
|
54
|
+
self,
|
55
|
+
prompts: list[types.WeightedPrompt]
|
56
|
+
) -> None:
|
57
|
+
if self._api_client.vertexai:
|
58
|
+
raise NotImplementedError('Live music generation is not supported in Vertex AI.')
|
59
|
+
else:
|
60
|
+
client_content_dict = live_converters._LiveMusicClientContent_to_mldev(
|
61
|
+
api_client=self._api_client, from_object={'weighted_prompts': prompts}
|
62
|
+
)
|
63
|
+
await self._ws.send(json.dumps({'clientContent': client_content_dict}))
|
64
|
+
|
65
|
+
async def set_music_generation_config(
|
66
|
+
self,
|
67
|
+
config: types.LiveMusicGenerationConfig
|
68
|
+
) -> None:
|
69
|
+
if self._api_client.vertexai:
|
70
|
+
raise NotImplementedError('Live music generation is not supported in Vertex AI.')
|
71
|
+
else:
|
72
|
+
config_dict = live_converters._LiveMusicGenerationConfig_to_mldev(
|
73
|
+
api_client=self._api_client, from_object=config
|
74
|
+
)
|
75
|
+
await self._ws.send(json.dumps({'musicGenerationConfig': config_dict}))
|
76
|
+
|
77
|
+
async def _send_control_signal(
|
78
|
+
self,
|
79
|
+
playback_control: types.LiveMusicPlaybackControl
|
80
|
+
) -> None:
|
81
|
+
if self._api_client.vertexai:
|
82
|
+
raise NotImplementedError('Live music generation is not supported in Vertex AI.')
|
83
|
+
else:
|
84
|
+
playback_control_dict = live_converters._LiveMusicClientMessage_to_mldev(
|
85
|
+
api_client=self._api_client, from_object={'playback_control': playback_control}
|
86
|
+
)
|
87
|
+
await self._ws.send(json.dumps(playback_control_dict))
|
88
|
+
|
89
|
+
async def play(self) -> None:
|
90
|
+
"""Sends playback signal to start the music stream."""
|
91
|
+
return await self._send_control_signal(types.LiveMusicPlaybackControl.PLAY)
|
92
|
+
|
93
|
+
async def pause(self) -> None:
|
94
|
+
"""Sends a playback signal to pause the music stream."""
|
95
|
+
return await self._send_control_signal(types.LiveMusicPlaybackControl.PAUSE)
|
96
|
+
|
97
|
+
async def stop(self) -> None:
|
98
|
+
"""Sends a playback signal to stop the music stream.
|
99
|
+
|
100
|
+
Resets the music generation context while retaining the current config.
|
101
|
+
"""
|
102
|
+
return await self._send_control_signal(types.LiveMusicPlaybackControl.STOP)
|
103
|
+
|
104
|
+
async def reset_context(self) -> None:
|
105
|
+
"""Reset the context (prompts retained) without stopping the music generation."""
|
106
|
+
return await self._send_control_signal(
|
107
|
+
types.LiveMusicPlaybackControl.RESET_CONTEXT
|
108
|
+
)
|
109
|
+
|
110
|
+
async def receive(self) -> AsyncIterator[types.LiveMusicServerMessage]:
|
111
|
+
"""Receive model responses from the server.
|
112
|
+
|
113
|
+
Yields:
|
114
|
+
The audio chunks from the server.
|
115
|
+
"""
|
116
|
+
# TODO(b/365983264) Handle intermittent issues for the user.
|
117
|
+
while result := await self._receive():
|
118
|
+
yield result
|
119
|
+
|
120
|
+
async def _receive(self) -> types.LiveMusicServerMessage:
|
121
|
+
parameter_model = types.LiveMusicServerMessage()
|
122
|
+
try:
|
123
|
+
raw_response = await self._ws.recv(decode=False)
|
124
|
+
except TypeError:
|
125
|
+
raw_response = await self._ws.recv() # type: ignore[assignment]
|
126
|
+
if raw_response:
|
127
|
+
try:
|
128
|
+
response = json.loads(raw_response)
|
129
|
+
except json.decoder.JSONDecodeError:
|
130
|
+
raise ValueError(f'Failed to parse response: {raw_response!r}')
|
131
|
+
else:
|
132
|
+
response = {}
|
133
|
+
|
134
|
+
if self._api_client.vertexai:
|
135
|
+
raise NotImplementedError('Live music generation is not supported in Vertex AI.')
|
136
|
+
else:
|
137
|
+
response_dict = live_converters._LiveMusicServerMessage_from_mldev(
|
138
|
+
self._api_client, response
|
139
|
+
)
|
140
|
+
|
141
|
+
return types.LiveMusicServerMessage._from_response(
|
142
|
+
response=response_dict, kwargs=parameter_model.model_dump()
|
143
|
+
)
|
144
|
+
|
145
|
+
async def close(self) -> None:
|
146
|
+
"""Closes the bi-directional stream and terminates the session."""
|
147
|
+
await self._ws.close()
|
148
|
+
|
149
|
+
|
150
|
+
class AsyncLiveMusic(_api_module.BaseModule):
|
151
|
+
"""[Experimental] Live music module.
|
152
|
+
|
153
|
+
Live music can be accessed via `client.aio.live.music`.
|
154
|
+
"""
|
155
|
+
|
156
|
+
@_common.experimental_warning(
|
157
|
+
'Realtime music generation is experimental and may change in future versions.'
|
158
|
+
)
|
159
|
+
@contextlib.asynccontextmanager
|
160
|
+
async def connect(self, *, model: str) -> AsyncIterator[AsyncMusicSession]:
|
161
|
+
"""[Experimental] Connect to the live music server."""
|
162
|
+
base_url = self._api_client._websocket_base_url()
|
163
|
+
if isinstance(base_url, bytes):
|
164
|
+
base_url = base_url.decode('utf-8')
|
165
|
+
transformed_model = t.t_model(self._api_client, model)
|
166
|
+
|
167
|
+
if self._api_client.api_key:
|
168
|
+
api_key = self._api_client.api_key
|
169
|
+
version = self._api_client._http_options.api_version
|
170
|
+
uri = f'{base_url}/ws/google.ai.generativelanguage.{version}.GenerativeService.BidiGenerateMusic?key={api_key}'
|
171
|
+
headers = self._api_client._http_options.headers
|
172
|
+
|
173
|
+
# Only mldev supported
|
174
|
+
request_dict = _common.convert_to_dict(
|
175
|
+
live_converters._LiveMusicConnectParameters_to_mldev(
|
176
|
+
api_client=self._api_client,
|
177
|
+
from_object=types.LiveMusicConnectParameters(
|
178
|
+
model=transformed_model,
|
179
|
+
).model_dump(exclude_none=True)
|
180
|
+
)
|
181
|
+
)
|
182
|
+
|
183
|
+
setv(request_dict, ['setup', 'model'], transformed_model)
|
184
|
+
|
185
|
+
request = json.dumps(request_dict)
|
186
|
+
else:
|
187
|
+
raise NotImplementedError('Live music generation is not supported in Vertex AI.')
|
188
|
+
|
189
|
+
try:
|
190
|
+
async with connect(uri, additional_headers=headers) as ws:
|
191
|
+
await ws.send(request)
|
192
|
+
logger.info(await ws.recv(decode=False))
|
193
|
+
|
194
|
+
yield AsyncMusicSession(api_client=self._api_client, websocket=ws)
|
195
|
+
except TypeError:
|
196
|
+
# Try with the older websockets API
|
197
|
+
async with connect(uri, extra_headers=headers) as ws:
|
198
|
+
await ws.send(request)
|
199
|
+
logger.info(await ws.recv())
|
200
|
+
|
201
|
+
yield AsyncMusicSession(api_client=self._api_client, websocket=ws)
|