google-genai 1.40.0__py3-none-any.whl → 1.42.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
google/genai/errors.py CHANGED
@@ -134,6 +134,13 @@ class APIError(Exception):
134
134
  'status': response.reason_phrase,
135
135
  }
136
136
  status_code = response.status_code
137
+ elif hasattr(response, 'body_segments') and hasattr(
138
+ response, 'status_code'
139
+ ):
140
+ if response.status_code == 200:
141
+ return
142
+ response_json = response.body_segments[0].get('error', {})
143
+ status_code = response.status_code
137
144
  else:
138
145
  try:
139
146
  import aiohttp # pylint: disable=g-import-not-at-top
@@ -151,9 +158,9 @@ class APIError(Exception):
151
158
  }
152
159
  status_code = response.status
153
160
  else:
154
- response_json = response.body_segments[0].get('error', {})
161
+ raise ValueError(f'Unsupported response type: {type(response)}')
155
162
  except ImportError:
156
- response_json = response.body_segments[0].get('error', {})
163
+ raise ValueError(f'Unsupported response type: {type(response)}')
157
164
 
158
165
  if 400 <= status_code < 500:
159
166
  raise ClientError(status_code, response_json, response)
google/genai/files.py CHANGED
@@ -18,13 +18,13 @@
18
18
  import io
19
19
  import json
20
20
  import logging
21
- import mimetypes
22
21
  import os
23
22
  from typing import Any, Optional, Union
24
23
  from urllib.parse import urlencode
25
24
 
26
25
  from . import _api_module
27
26
  from . import _common
27
+ from . import _extra_utils
28
28
  from . import _transformers as t
29
29
  from . import types
30
30
  from ._common import get_value_by_path as getv
@@ -41,11 +41,7 @@ def _CreateFileParameters_to_mldev(
41
41
  ) -> dict[str, Any]:
42
42
  to_object: dict[str, Any] = {}
43
43
  if getv(from_object, ['file']) is not None:
44
- setv(
45
- to_object,
46
- ['file'],
47
- _File_to_mldev(getv(from_object, ['file']), to_object),
48
- )
44
+ setv(to_object, ['file'], getv(from_object, ['file']))
49
45
 
50
46
  return to_object
51
47
 
@@ -89,148 +85,6 @@ def _DeleteFileResponse_from_mldev(
89
85
  return to_object
90
86
 
91
87
 
92
- def _FileStatus_from_mldev(
93
- from_object: Union[dict[str, Any], object],
94
- parent_object: Optional[dict[str, Any]] = None,
95
- ) -> dict[str, Any]:
96
- to_object: dict[str, Any] = {}
97
- if getv(from_object, ['details']) is not None:
98
- setv(to_object, ['details'], getv(from_object, ['details']))
99
-
100
- if getv(from_object, ['message']) is not None:
101
- setv(to_object, ['message'], getv(from_object, ['message']))
102
-
103
- if getv(from_object, ['code']) is not None:
104
- setv(to_object, ['code'], getv(from_object, ['code']))
105
-
106
- return to_object
107
-
108
-
109
- def _FileStatus_to_mldev(
110
- from_object: Union[dict[str, Any], object],
111
- parent_object: Optional[dict[str, Any]] = None,
112
- ) -> dict[str, Any]:
113
- to_object: dict[str, Any] = {}
114
- if getv(from_object, ['details']) is not None:
115
- setv(to_object, ['details'], getv(from_object, ['details']))
116
-
117
- if getv(from_object, ['message']) is not None:
118
- setv(to_object, ['message'], getv(from_object, ['message']))
119
-
120
- if getv(from_object, ['code']) is not None:
121
- setv(to_object, ['code'], getv(from_object, ['code']))
122
-
123
- return to_object
124
-
125
-
126
- def _File_from_mldev(
127
- from_object: Union[dict[str, Any], object],
128
- parent_object: Optional[dict[str, Any]] = None,
129
- ) -> dict[str, Any]:
130
- to_object: dict[str, Any] = {}
131
- if getv(from_object, ['name']) is not None:
132
- setv(to_object, ['name'], getv(from_object, ['name']))
133
-
134
- if getv(from_object, ['displayName']) is not None:
135
- setv(to_object, ['display_name'], getv(from_object, ['displayName']))
136
-
137
- if getv(from_object, ['mimeType']) is not None:
138
- setv(to_object, ['mime_type'], getv(from_object, ['mimeType']))
139
-
140
- if getv(from_object, ['sizeBytes']) is not None:
141
- setv(to_object, ['size_bytes'], getv(from_object, ['sizeBytes']))
142
-
143
- if getv(from_object, ['createTime']) is not None:
144
- setv(to_object, ['create_time'], getv(from_object, ['createTime']))
145
-
146
- if getv(from_object, ['expirationTime']) is not None:
147
- setv(to_object, ['expiration_time'], getv(from_object, ['expirationTime']))
148
-
149
- if getv(from_object, ['updateTime']) is not None:
150
- setv(to_object, ['update_time'], getv(from_object, ['updateTime']))
151
-
152
- if getv(from_object, ['sha256Hash']) is not None:
153
- setv(to_object, ['sha256_hash'], getv(from_object, ['sha256Hash']))
154
-
155
- if getv(from_object, ['uri']) is not None:
156
- setv(to_object, ['uri'], getv(from_object, ['uri']))
157
-
158
- if getv(from_object, ['downloadUri']) is not None:
159
- setv(to_object, ['download_uri'], getv(from_object, ['downloadUri']))
160
-
161
- if getv(from_object, ['state']) is not None:
162
- setv(to_object, ['state'], getv(from_object, ['state']))
163
-
164
- if getv(from_object, ['source']) is not None:
165
- setv(to_object, ['source'], getv(from_object, ['source']))
166
-
167
- if getv(from_object, ['videoMetadata']) is not None:
168
- setv(to_object, ['video_metadata'], getv(from_object, ['videoMetadata']))
169
-
170
- if getv(from_object, ['error']) is not None:
171
- setv(
172
- to_object,
173
- ['error'],
174
- _FileStatus_from_mldev(getv(from_object, ['error']), to_object),
175
- )
176
-
177
- return to_object
178
-
179
-
180
- def _File_to_mldev(
181
- from_object: Union[dict[str, Any], object],
182
- parent_object: Optional[dict[str, Any]] = None,
183
- ) -> dict[str, Any]:
184
- to_object: dict[str, Any] = {}
185
- if getv(from_object, ['name']) is not None:
186
- setv(to_object, ['name'], getv(from_object, ['name']))
187
-
188
- if getv(from_object, ['display_name']) is not None:
189
- setv(to_object, ['displayName'], getv(from_object, ['display_name']))
190
-
191
- if getv(from_object, ['mime_type']) is not None:
192
- setv(to_object, ['mimeType'], getv(from_object, ['mime_type']))
193
-
194
- if getv(from_object, ['size_bytes']) is not None:
195
- setv(to_object, ['sizeBytes'], getv(from_object, ['size_bytes']))
196
-
197
- if getv(from_object, ['create_time']) is not None:
198
- setv(to_object, ['createTime'], getv(from_object, ['create_time']))
199
-
200
- if getv(from_object, ['expiration_time']) is not None:
201
- setv(to_object, ['expirationTime'], getv(from_object, ['expiration_time']))
202
-
203
- if getv(from_object, ['update_time']) is not None:
204
- setv(to_object, ['updateTime'], getv(from_object, ['update_time']))
205
-
206
- if getv(from_object, ['sha256_hash']) is not None:
207
- setv(to_object, ['sha256Hash'], getv(from_object, ['sha256_hash']))
208
-
209
- if getv(from_object, ['uri']) is not None:
210
- setv(to_object, ['uri'], getv(from_object, ['uri']))
211
-
212
- if getv(from_object, ['download_uri']) is not None:
213
- setv(to_object, ['downloadUri'], getv(from_object, ['download_uri']))
214
-
215
- if getv(from_object, ['state']) is not None:
216
- setv(to_object, ['state'], getv(from_object, ['state']))
217
-
218
- if getv(from_object, ['source']) is not None:
219
- setv(to_object, ['source'], getv(from_object, ['source']))
220
-
221
- if getv(from_object, ['video_metadata']) is not None:
222
- setv(to_object, ['videoMetadata'], getv(from_object, ['video_metadata']))
223
-
224
- if getv(from_object, ['error']) is not None:
225
- setv(
226
- to_object,
227
- ['error'],
228
- _FileStatus_to_mldev(getv(from_object, ['error']), to_object),
229
- )
230
-
231
- return to_object
232
-
233
-
234
88
  def _GetFileParameters_to_mldev(
235
89
  from_object: Union[dict[str, Any], object],
236
90
  parent_object: Optional[dict[str, Any]] = None,
@@ -290,14 +144,7 @@ def _ListFilesResponse_from_mldev(
290
144
  setv(to_object, ['next_page_token'], getv(from_object, ['nextPageToken']))
291
145
 
292
146
  if getv(from_object, ['files']) is not None:
293
- setv(
294
- to_object,
295
- ['files'],
296
- [
297
- _File_from_mldev(item, to_object)
298
- for item in getv(from_object, ['files'])
299
- ],
300
- )
147
+ setv(to_object, ['files'], [item for item in getv(from_object, ['files'])])
301
148
 
302
149
  return to_object
303
150
 
@@ -359,7 +206,7 @@ class Files(_api_module.BaseModule):
359
206
 
360
207
  response = self._api_client.request('get', path, request_dict, http_options)
361
208
 
362
- response_dict = '' if not response.body else json.loads(response.body)
209
+ response_dict = {} if not response.body else json.loads(response.body)
363
210
 
364
211
  if not self._api_client.vertexai:
365
212
  response_dict = _ListFilesResponse_from_mldev(response_dict)
@@ -424,7 +271,7 @@ class Files(_api_module.BaseModule):
424
271
  self._api_client._verify_response(return_value)
425
272
  return return_value
426
273
 
427
- response_dict = '' if not response.body else json.loads(response.body)
274
+ response_dict = {} if not response.body else json.loads(response.body)
428
275
 
429
276
  if not self._api_client.vertexai:
430
277
  response_dict = _CreateFileResponse_from_mldev(response_dict)
@@ -492,10 +339,7 @@ class Files(_api_module.BaseModule):
492
339
 
493
340
  response = self._api_client.request('get', path, request_dict, http_options)
494
341
 
495
- response_dict = '' if not response.body else json.loads(response.body)
496
-
497
- if not self._api_client.vertexai:
498
- response_dict = _File_from_mldev(response_dict)
342
+ response_dict = {} if not response.body else json.loads(response.body)
499
343
 
500
344
  return_value = types.File._from_response(
501
345
  response=response_dict, kwargs=parameter_model.model_dump()
@@ -561,7 +405,7 @@ class Files(_api_module.BaseModule):
561
405
  'delete', path, request_dict, http_options
562
406
  )
563
407
 
564
- response_dict = '' if not response.body else json.loads(response.body)
408
+ response_dict = {} if not response.body else json.loads(response.body)
565
409
 
566
410
  if not self._api_client.vertexai:
567
411
  response_dict = _DeleteFileResponse_from_mldev(response_dict)
@@ -611,54 +455,13 @@ class Files(_api_module.BaseModule):
611
455
  if file_obj.name is not None and not file_obj.name.startswith('files/'):
612
456
  file_obj.name = f'files/{file_obj.name}'
613
457
 
614
- if isinstance(file, io.IOBase):
615
- if file_obj.mime_type is None:
616
- raise ValueError(
617
- 'Unknown mime type: Could not determine the mimetype for your'
618
- ' file\n please set the `mime_type` argument'
619
- )
620
- if hasattr(file, 'mode'):
621
- if 'b' not in file.mode:
622
- raise ValueError('The file must be opened in binary mode.')
623
- offset = file.tell()
624
- file.seek(0, os.SEEK_END)
625
- file_obj.size_bytes = file.tell() - offset
626
- file.seek(offset, os.SEEK_SET)
627
- else:
628
- fs_path = os.fspath(file)
629
- if not fs_path or not os.path.isfile(fs_path):
630
- raise FileNotFoundError(f'{file} is not a valid file path.')
631
- file_obj.size_bytes = os.path.getsize(fs_path)
632
- if file_obj.mime_type is None:
633
- file_obj.mime_type, _ = mimetypes.guess_type(fs_path)
634
- if file_obj.mime_type is None:
635
- raise ValueError(
636
- 'Unknown mime type: Could not determine the mimetype for your'
637
- ' file\n please set the `mime_type` argument'
638
- )
639
-
640
- http_options: types.HttpOptions
641
- if config_model and config_model.http_options:
642
- http_options = config_model.http_options
643
- http_options.api_version = ''
644
- http_options.headers = {
645
- 'Content-Type': 'application/json',
646
- 'X-Goog-Upload-Protocol': 'resumable',
647
- 'X-Goog-Upload-Command': 'start',
648
- 'X-Goog-Upload-Header-Content-Length': f'{file_obj.size_bytes}',
649
- 'X-Goog-Upload-Header-Content-Type': f'{file_obj.mime_type}',
650
- }
651
- else:
652
- http_options = types.HttpOptions(
653
- api_version='',
654
- headers={
655
- 'Content-Type': 'application/json',
656
- 'X-Goog-Upload-Protocol': 'resumable',
657
- 'X-Goog-Upload-Command': 'start',
658
- 'X-Goog-Upload-Header-Content-Length': f'{file_obj.size_bytes}',
659
- 'X-Goog-Upload-Header-Content-Type': f'{file_obj.mime_type}',
660
- },
661
- )
458
+ http_options, size_bytes, mime_type = _extra_utils.prepare_resumable_upload(
459
+ file,
460
+ user_http_options=config_model.http_options,
461
+ user_mime_type=config_model.mime_type,
462
+ )
463
+ file_obj.size_bytes = size_bytes
464
+ file_obj.mime_type = mime_type
662
465
  response = self._create(
663
466
  file=file_obj,
664
467
  config=types.CreateFileConfig(
@@ -682,12 +485,13 @@ class Files(_api_module.BaseModule):
682
485
  file, upload_url, file_obj.size_bytes, http_options=http_options
683
486
  )
684
487
  else:
488
+ fs_path = os.fspath(file)
685
489
  return_file = self._api_client.upload_file(
686
490
  fs_path, upload_url, file_obj.size_bytes, http_options=http_options
687
491
  )
688
492
 
689
493
  return types.File._from_response(
690
- response=_File_from_mldev(return_file.json['file']),
494
+ response=return_file.json['file'],
691
495
  kwargs=config_model.model_dump() if config else {},
692
496
  )
693
497
 
@@ -842,7 +646,7 @@ class AsyncFiles(_api_module.BaseModule):
842
646
  'get', path, request_dict, http_options
843
647
  )
844
648
 
845
- response_dict = '' if not response.body else json.loads(response.body)
649
+ response_dict = {} if not response.body else json.loads(response.body)
846
650
 
847
651
  if not self._api_client.vertexai:
848
652
  response_dict = _ListFilesResponse_from_mldev(response_dict)
@@ -907,7 +711,7 @@ class AsyncFiles(_api_module.BaseModule):
907
711
  self._api_client._verify_response(return_value)
908
712
  return return_value
909
713
 
910
- response_dict = '' if not response.body else json.loads(response.body)
714
+ response_dict = {} if not response.body else json.loads(response.body)
911
715
 
912
716
  if not self._api_client.vertexai:
913
717
  response_dict = _CreateFileResponse_from_mldev(response_dict)
@@ -977,10 +781,7 @@ class AsyncFiles(_api_module.BaseModule):
977
781
  'get', path, request_dict, http_options
978
782
  )
979
783
 
980
- response_dict = '' if not response.body else json.loads(response.body)
981
-
982
- if not self._api_client.vertexai:
983
- response_dict = _File_from_mldev(response_dict)
784
+ response_dict = {} if not response.body else json.loads(response.body)
984
785
 
985
786
  return_value = types.File._from_response(
986
787
  response=response_dict, kwargs=parameter_model.model_dump()
@@ -1046,7 +847,7 @@ class AsyncFiles(_api_module.BaseModule):
1046
847
  'delete', path, request_dict, http_options
1047
848
  )
1048
849
 
1049
- response_dict = '' if not response.body else json.loads(response.body)
850
+ response_dict = {} if not response.body else json.loads(response.body)
1050
851
 
1051
852
  if not self._api_client.vertexai:
1052
853
  response_dict = _DeleteFileResponse_from_mldev(response_dict)
@@ -1096,54 +897,13 @@ class AsyncFiles(_api_module.BaseModule):
1096
897
  if file_obj.name is not None and not file_obj.name.startswith('files/'):
1097
898
  file_obj.name = f'files/{file_obj.name}'
1098
899
 
1099
- if isinstance(file, io.IOBase):
1100
- if file_obj.mime_type is None:
1101
- raise ValueError(
1102
- 'Unknown mime type: Could not determine the mimetype for your'
1103
- ' file\n please set the `mime_type` argument'
1104
- )
1105
- if hasattr(file, 'mode'):
1106
- if 'b' not in file.mode:
1107
- raise ValueError('The file must be opened in binary mode.')
1108
- offset = file.tell()
1109
- file.seek(0, os.SEEK_END)
1110
- file_obj.size_bytes = file.tell() - offset
1111
- file.seek(offset, os.SEEK_SET)
1112
- else:
1113
- fs_path = os.fspath(file)
1114
- if not fs_path or not os.path.isfile(fs_path):
1115
- raise FileNotFoundError(f'{file} is not a valid file path.')
1116
- file_obj.size_bytes = os.path.getsize(fs_path)
1117
- if file_obj.mime_type is None:
1118
- file_obj.mime_type, _ = mimetypes.guess_type(fs_path)
1119
- if file_obj.mime_type is None:
1120
- raise ValueError(
1121
- 'Unknown mime type: Could not determine the mimetype for your'
1122
- ' file\n please set the `mime_type` argument'
1123
- )
1124
-
1125
- http_options: types.HttpOptions
1126
- if config_model and config_model.http_options:
1127
- http_options = config_model.http_options
1128
- http_options.api_version = ''
1129
- http_options.headers = {
1130
- 'Content-Type': 'application/json',
1131
- 'X-Goog-Upload-Protocol': 'resumable',
1132
- 'X-Goog-Upload-Command': 'start',
1133
- 'X-Goog-Upload-Header-Content-Length': f'{file_obj.size_bytes}',
1134
- 'X-Goog-Upload-Header-Content-Type': f'{file_obj.mime_type}',
1135
- }
1136
- else:
1137
- http_options = types.HttpOptions(
1138
- api_version='',
1139
- headers={
1140
- 'Content-Type': 'application/json',
1141
- 'X-Goog-Upload-Protocol': 'resumable',
1142
- 'X-Goog-Upload-Command': 'start',
1143
- 'X-Goog-Upload-Header-Content-Length': f'{file_obj.size_bytes}',
1144
- 'X-Goog-Upload-Header-Content-Type': f'{file_obj.mime_type}',
1145
- },
1146
- )
900
+ http_options, size_bytes, mime_type = _extra_utils.prepare_resumable_upload(
901
+ file,
902
+ user_http_options=config_model.http_options,
903
+ user_mime_type=config_model.mime_type,
904
+ )
905
+ file_obj.size_bytes = size_bytes
906
+ file_obj.mime_type = mime_type
1147
907
  response = await self._create(
1148
908
  file=file_obj,
1149
909
  config=types.CreateFileConfig(
@@ -1172,12 +932,13 @@ class AsyncFiles(_api_module.BaseModule):
1172
932
  file, upload_url, file_obj.size_bytes, http_options=http_options
1173
933
  )
1174
934
  else:
935
+ fs_path = os.fspath(file)
1175
936
  return_file = await self._api_client.async_upload_file(
1176
937
  fs_path, upload_url, file_obj.size_bytes, http_options=http_options
1177
938
  )
1178
939
 
1179
940
  return types.File._from_response(
1180
- response=_File_from_mldev(return_file.json['file']),
941
+ response=return_file.json['file'],
1181
942
  kwargs=config_model.model_dump() if config else {},
1182
943
  )
1183
944
 
google/genai/live.py CHANGED
@@ -21,7 +21,7 @@ import contextlib
21
21
  import json
22
22
  import logging
23
23
  import typing
24
- from typing import Any, AsyncIterator, Dict, Optional, Sequence, Union, get_args
24
+ from typing import Any, AsyncIterator, Optional, Sequence, Union, get_args
25
25
  import warnings
26
26
 
27
27
  import google.auth
@@ -40,7 +40,6 @@ from ._common import get_value_by_path as getv
40
40
  from ._common import set_value_by_path as setv
41
41
  from .live_music import AsyncLiveMusic
42
42
  from .models import _Content_to_mldev
43
- from .models import _Content_to_vertex
44
43
 
45
44
 
46
45
  try:
@@ -223,8 +222,8 @@ class AsyncSession:
223
222
  )
224
223
 
225
224
  if self._api_client.vertexai:
226
- client_content_dict = live_converters._LiveClientContent_to_vertex(
227
- from_object=client_content
225
+ client_content_dict = _common.convert_to_dict(
226
+ client_content, convert_keys=True
228
227
  )
229
228
  else:
230
229
  client_content_dict = live_converters._LiveClientContent_to_mldev(
@@ -410,12 +409,12 @@ class AsyncSession:
410
409
  """
411
410
  tool_response = t.t_tool_response(function_responses)
412
411
  if self._api_client.vertexai:
413
- tool_response_dict = live_converters._LiveClientToolResponse_to_vertex(
414
- from_object=tool_response
412
+ tool_response_dict = _common.convert_to_dict(
413
+ tool_response, convert_keys=True
415
414
  )
416
415
  else:
417
- tool_response_dict = live_converters._LiveClientToolResponse_to_mldev(
418
- from_object=tool_response
416
+ tool_response_dict = _common.convert_to_dict(
417
+ tool_response, convert_keys=True
419
418
  )
420
419
  for response in tool_response_dict.get('functionResponses', []):
421
420
  if response.get('id') is None:
@@ -541,7 +540,7 @@ class AsyncSession:
541
540
  if self._api_client.vertexai:
542
541
  response_dict = live_converters._LiveServerMessage_from_vertex(response)
543
542
  else:
544
- response_dict = live_converters._LiveServerMessage_from_mldev(response)
543
+ response_dict = response
545
544
 
546
545
  return types.LiveServerMessage._from_response(
547
546
  response=response_dict, kwargs=parameter_model.model_dump()
@@ -655,7 +654,7 @@ class AsyncSession:
655
654
  content_input_parts.append(item)
656
655
  if self._api_client.vertexai:
657
656
  contents = [
658
- _Content_to_vertex(item, to_object)
657
+ _common.convert_to_dict(item, convert_keys=True)
659
658
  for item in t.t_contents(content_input_parts)
660
659
  ]
661
660
  else:
@@ -1074,7 +1073,7 @@ class AsyncLive(_api_module.BaseModule):
1074
1073
  if self._api_client.vertexai:
1075
1074
  response_dict = live_converters._LiveServerMessage_from_vertex(response)
1076
1075
  else:
1077
- response_dict = live_converters._LiveServerMessage_from_mldev(response)
1076
+ response_dict = response
1078
1077
 
1079
1078
  setup_response = types.LiveServerMessage._from_response(
1080
1079
  response=response_dict, kwargs=parameter_model.model_dump()
@@ -22,13 +22,11 @@ from typing import AsyncIterator
22
22
 
23
23
  from . import _api_module
24
24
  from . import _common
25
+ from . import _live_converters as live_converters
25
26
  from . import _transformers as t
26
27
  from . import types
27
28
  from ._api_client import BaseApiClient
28
29
  from ._common import set_value_by_path as setv
29
- from . import _live_converters as live_converters
30
- from .models import _Content_to_mldev
31
- from .models import _Content_to_vertex
32
30
 
33
31
 
34
32
  try:
@@ -44,46 +42,47 @@ logger = logging.getLogger('google_genai.live_music')
44
42
  class AsyncMusicSession:
45
43
  """[Experimental] AsyncMusicSession."""
46
44
 
47
- def __init__(
48
- self, api_client: BaseApiClient, websocket: ClientConnection
49
- ):
45
+ def __init__(self, api_client: BaseApiClient, websocket: ClientConnection):
50
46
  self._api_client = api_client
51
47
  self._ws = websocket
52
48
 
53
49
  async def set_weighted_prompts(
54
- self,
55
- prompts: list[types.WeightedPrompt]
50
+ self, prompts: list[types.WeightedPrompt]
56
51
  ) -> None:
57
52
  if self._api_client.vertexai:
58
- raise NotImplementedError('Live music generation is not supported in Vertex AI.')
59
- else:
60
- client_content_dict = live_converters._LiveMusicClientContent_to_mldev(
61
- from_object={'weighted_prompts': prompts}
53
+ raise NotImplementedError(
54
+ 'Live music generation is not supported in Vertex AI.'
62
55
  )
56
+ else:
57
+ client_content_dict = {
58
+ 'weightedPrompts': [
59
+ _common.convert_to_dict(prompt, convert_keys=True)
60
+ for prompt in prompts
61
+ ]
62
+ }
63
+
63
64
  await self._ws.send(json.dumps({'clientContent': client_content_dict}))
64
65
 
65
66
  async def set_music_generation_config(
66
- self,
67
- config: types.LiveMusicGenerationConfig
67
+ self, config: types.LiveMusicGenerationConfig
68
68
  ) -> None:
69
69
  if self._api_client.vertexai:
70
- raise NotImplementedError('Live music generation is not supported in Vertex AI.')
71
- else:
72
- config_dict = live_converters._LiveMusicGenerationConfig_to_mldev(
73
- from_object=config
70
+ raise NotImplementedError(
71
+ 'Live music generation is not supported in Vertex AI.'
74
72
  )
73
+ else:
74
+ config_dict = _common.convert_to_dict(config, convert_keys=True)
75
75
  await self._ws.send(json.dumps({'musicGenerationConfig': config_dict}))
76
76
 
77
77
  async def _send_control_signal(
78
- self,
79
- playback_control: types.LiveMusicPlaybackControl
78
+ self, playback_control: types.LiveMusicPlaybackControl
80
79
  ) -> None:
81
80
  if self._api_client.vertexai:
82
- raise NotImplementedError('Live music generation is not supported in Vertex AI.')
83
- else:
84
- playback_control_dict = live_converters._LiveMusicClientMessage_to_mldev(
85
- from_object={'playback_control': playback_control}
81
+ raise NotImplementedError(
82
+ 'Live music generation is not supported in Vertex AI.'
86
83
  )
84
+ else:
85
+ playback_control_dict = {'playbackControl': playback_control.value}
87
86
  await self._ws.send(json.dumps(playback_control_dict))
88
87
 
89
88
  async def play(self) -> None:
@@ -134,9 +133,7 @@ class AsyncMusicSession:
134
133
  if self._api_client.vertexai:
135
134
  raise NotImplementedError('Live music generation is not supported in Vertex AI.')
136
135
  else:
137
- response_dict = live_converters._LiveMusicServerMessage_from_mldev(
138
- response
139
- )
136
+ response_dict = response
140
137
 
141
138
  return types.LiveMusicServerMessage._from_response(
142
139
  response=response_dict, kwargs=parameter_model.model_dump()