google-genai 1.56.0__py3-none-any.whl → 1.58.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (246) hide show
  1. google/genai/_api_client.py +49 -26
  2. google/genai/_interactions/__init__.py +3 -0
  3. google/genai/_interactions/_base_client.py +1 -1
  4. google/genai/_interactions/_client.py +57 -3
  5. google/genai/_interactions/_client_adapter.py +48 -0
  6. google/genai/_interactions/types/__init__.py +6 -0
  7. google/genai/_interactions/types/audio_content.py +2 -0
  8. google/genai/_interactions/types/audio_content_param.py +2 -0
  9. google/genai/_interactions/types/content.py +65 -0
  10. google/genai/_interactions/types/content_delta.py +10 -2
  11. google/genai/_interactions/types/content_param.py +63 -0
  12. google/genai/_interactions/types/content_start.py +5 -46
  13. google/genai/_interactions/types/content_stop.py +1 -2
  14. google/genai/_interactions/types/document_content.py +2 -0
  15. google/genai/_interactions/types/document_content_param.py +2 -0
  16. google/genai/_interactions/types/error_event.py +1 -2
  17. google/genai/_interactions/types/file_search_call_content.py +32 -0
  18. google/genai/_interactions/types/file_search_call_content_param.py +31 -0
  19. google/genai/_interactions/types/generation_config.py +4 -0
  20. google/genai/_interactions/types/generation_config_param.py +4 -0
  21. google/genai/_interactions/types/image_config.py +31 -0
  22. google/genai/_interactions/types/image_config_param.py +30 -0
  23. google/genai/_interactions/types/image_content.py +2 -0
  24. google/genai/_interactions/types/image_content_param.py +2 -0
  25. google/genai/_interactions/types/interaction.py +6 -52
  26. google/genai/_interactions/types/interaction_create_params.py +4 -22
  27. google/genai/_interactions/types/interaction_event.py +1 -2
  28. google/genai/_interactions/types/interaction_sse_event.py +5 -3
  29. google/genai/_interactions/types/interaction_status_update.py +1 -2
  30. google/genai/_interactions/types/model.py +1 -0
  31. google/genai/_interactions/types/model_param.py +1 -0
  32. google/genai/_interactions/types/turn.py +3 -44
  33. google/genai/_interactions/types/turn_param.py +4 -40
  34. google/genai/_interactions/types/usage.py +1 -1
  35. google/genai/_interactions/types/usage_param.py +1 -1
  36. google/genai/_interactions/types/video_content.py +2 -0
  37. google/genai/_interactions/types/video_content_param.py +2 -0
  38. google/genai/_live_converters.py +118 -34
  39. google/genai/_local_tokenizer_loader.py +1 -0
  40. google/genai/_tokens_converters.py +14 -14
  41. google/genai/_transformers.py +15 -21
  42. google/genai/batches.py +27 -22
  43. google/genai/caches.py +42 -42
  44. google/genai/chats.py +0 -2
  45. google/genai/client.py +61 -55
  46. google/genai/files.py +224 -0
  47. google/genai/live.py +1 -1
  48. google/genai/models.py +56 -44
  49. google/genai/tests/__init__.py +21 -0
  50. google/genai/tests/afc/__init__.py +21 -0
  51. google/genai/tests/afc/test_convert_if_exist_pydantic_model.py +309 -0
  52. google/genai/tests/afc/test_convert_number_values_for_function_call_args.py +63 -0
  53. google/genai/tests/afc/test_find_afc_incompatible_tool_indexes.py +240 -0
  54. google/genai/tests/afc/test_generate_content_stream_afc.py +530 -0
  55. google/genai/tests/afc/test_generate_content_stream_afc_thoughts.py +77 -0
  56. google/genai/tests/afc/test_get_function_map.py +176 -0
  57. google/genai/tests/afc/test_get_function_response_parts.py +277 -0
  58. google/genai/tests/afc/test_get_max_remote_calls_for_afc.py +130 -0
  59. google/genai/tests/afc/test_invoke_function_from_dict_args.py +241 -0
  60. google/genai/tests/afc/test_raise_error_for_afc_incompatible_config.py +159 -0
  61. google/genai/tests/afc/test_should_append_afc_history.py +53 -0
  62. google/genai/tests/afc/test_should_disable_afc.py +214 -0
  63. google/genai/tests/batches/__init__.py +17 -0
  64. google/genai/tests/batches/test_cancel.py +77 -0
  65. google/genai/tests/batches/test_create.py +78 -0
  66. google/genai/tests/batches/test_create_with_bigquery.py +113 -0
  67. google/genai/tests/batches/test_create_with_file.py +82 -0
  68. google/genai/tests/batches/test_create_with_gcs.py +125 -0
  69. google/genai/tests/batches/test_create_with_inlined_requests.py +255 -0
  70. google/genai/tests/batches/test_delete.py +86 -0
  71. google/genai/tests/batches/test_embedding.py +157 -0
  72. google/genai/tests/batches/test_get.py +78 -0
  73. google/genai/tests/batches/test_list.py +79 -0
  74. google/genai/tests/caches/__init__.py +17 -0
  75. google/genai/tests/caches/constants.py +29 -0
  76. google/genai/tests/caches/test_create.py +210 -0
  77. google/genai/tests/caches/test_create_custom_url.py +105 -0
  78. google/genai/tests/caches/test_delete.py +54 -0
  79. google/genai/tests/caches/test_delete_custom_url.py +52 -0
  80. google/genai/tests/caches/test_get.py +94 -0
  81. google/genai/tests/caches/test_get_custom_url.py +52 -0
  82. google/genai/tests/caches/test_list.py +68 -0
  83. google/genai/tests/caches/test_update.py +70 -0
  84. google/genai/tests/caches/test_update_custom_url.py +58 -0
  85. google/genai/tests/chats/__init__.py +1 -0
  86. google/genai/tests/chats/test_get_history.py +598 -0
  87. google/genai/tests/chats/test_send_message.py +844 -0
  88. google/genai/tests/chats/test_validate_response.py +90 -0
  89. google/genai/tests/client/__init__.py +17 -0
  90. google/genai/tests/client/test_async_stream.py +427 -0
  91. google/genai/tests/client/test_client_close.py +197 -0
  92. google/genai/tests/client/test_client_initialization.py +1687 -0
  93. google/genai/tests/client/test_client_requests.py +221 -0
  94. google/genai/tests/client/test_custom_client.py +104 -0
  95. google/genai/tests/client/test_http_options.py +178 -0
  96. google/genai/tests/client/test_replay_client_equality.py +168 -0
  97. google/genai/tests/client/test_retries.py +846 -0
  98. google/genai/tests/client/test_upload_errors.py +136 -0
  99. google/genai/tests/common/__init__.py +17 -0
  100. google/genai/tests/common/test_common.py +954 -0
  101. google/genai/tests/conftest.py +162 -0
  102. google/genai/tests/documents/__init__.py +17 -0
  103. google/genai/tests/documents/test_delete.py +51 -0
  104. google/genai/tests/documents/test_get.py +85 -0
  105. google/genai/tests/documents/test_list.py +72 -0
  106. google/genai/tests/errors/__init__.py +1 -0
  107. google/genai/tests/errors/test_api_error.py +417 -0
  108. google/genai/tests/file_search_stores/__init__.py +17 -0
  109. google/genai/tests/file_search_stores/test_create.py +66 -0
  110. google/genai/tests/file_search_stores/test_delete.py +64 -0
  111. google/genai/tests/file_search_stores/test_get.py +94 -0
  112. google/genai/tests/file_search_stores/test_import_file.py +112 -0
  113. google/genai/tests/file_search_stores/test_list.py +57 -0
  114. google/genai/tests/file_search_stores/test_upload_to_file_search_store.py +141 -0
  115. google/genai/tests/files/__init__.py +17 -0
  116. google/genai/tests/files/test_delete.py +46 -0
  117. google/genai/tests/files/test_download.py +85 -0
  118. google/genai/tests/files/test_get.py +46 -0
  119. google/genai/tests/files/test_list.py +72 -0
  120. google/genai/tests/files/test_register.py +272 -0
  121. google/genai/tests/files/test_register_table.py +70 -0
  122. google/genai/tests/files/test_upload.py +255 -0
  123. google/genai/tests/imports/test_no_optional_imports.py +28 -0
  124. google/genai/tests/interactions/test_auth.py +476 -0
  125. google/genai/tests/interactions/test_integration.py +84 -0
  126. google/genai/tests/interactions/test_paths.py +105 -0
  127. google/genai/tests/live/__init__.py +16 -0
  128. google/genai/tests/live/test_live.py +2143 -0
  129. google/genai/tests/live/test_live_music.py +362 -0
  130. google/genai/tests/live/test_live_response.py +163 -0
  131. google/genai/tests/live/test_send_client_content.py +147 -0
  132. google/genai/tests/live/test_send_realtime_input.py +268 -0
  133. google/genai/tests/live/test_send_tool_response.py +222 -0
  134. google/genai/tests/local_tokenizer/__init__.py +17 -0
  135. google/genai/tests/local_tokenizer/test_local_tokenizer.py +343 -0
  136. google/genai/tests/local_tokenizer/test_local_tokenizer_loader.py +235 -0
  137. google/genai/tests/mcp/__init__.py +17 -0
  138. google/genai/tests/mcp/test_has_mcp_tool_usage.py +89 -0
  139. google/genai/tests/mcp/test_mcp_to_gemini_tools.py +191 -0
  140. google/genai/tests/mcp/test_parse_config_for_mcp_sessions.py +201 -0
  141. google/genai/tests/mcp/test_parse_config_for_mcp_usage.py +130 -0
  142. google/genai/tests/mcp/test_set_mcp_usage_header.py +72 -0
  143. google/genai/tests/models/__init__.py +17 -0
  144. google/genai/tests/models/constants.py +8 -0
  145. google/genai/tests/models/test_compute_tokens.py +120 -0
  146. google/genai/tests/models/test_count_tokens.py +159 -0
  147. google/genai/tests/models/test_delete.py +107 -0
  148. google/genai/tests/models/test_edit_image.py +264 -0
  149. google/genai/tests/models/test_embed_content.py +94 -0
  150. google/genai/tests/models/test_function_call_streaming.py +442 -0
  151. google/genai/tests/models/test_generate_content.py +2501 -0
  152. google/genai/tests/models/test_generate_content_cached_content.py +132 -0
  153. google/genai/tests/models/test_generate_content_config_zero_value.py +103 -0
  154. google/genai/tests/models/test_generate_content_from_apikey.py +44 -0
  155. google/genai/tests/models/test_generate_content_http_options.py +40 -0
  156. google/genai/tests/models/test_generate_content_image_generation.py +143 -0
  157. google/genai/tests/models/test_generate_content_mcp.py +343 -0
  158. google/genai/tests/models/test_generate_content_media_resolution.py +97 -0
  159. google/genai/tests/models/test_generate_content_model.py +139 -0
  160. google/genai/tests/models/test_generate_content_part.py +821 -0
  161. google/genai/tests/models/test_generate_content_thought.py +76 -0
  162. google/genai/tests/models/test_generate_content_tools.py +1761 -0
  163. google/genai/tests/models/test_generate_images.py +191 -0
  164. google/genai/tests/models/test_generate_videos.py +759 -0
  165. google/genai/tests/models/test_get.py +104 -0
  166. google/genai/tests/models/test_list.py +233 -0
  167. google/genai/tests/models/test_recontext_image.py +189 -0
  168. google/genai/tests/models/test_segment_image.py +148 -0
  169. google/genai/tests/models/test_update.py +95 -0
  170. google/genai/tests/models/test_upscale_image.py +157 -0
  171. google/genai/tests/operations/__init__.py +17 -0
  172. google/genai/tests/operations/test_get.py +38 -0
  173. google/genai/tests/public_samples/__init__.py +17 -0
  174. google/genai/tests/public_samples/test_gemini_text_only.py +34 -0
  175. google/genai/tests/pytest_helper.py +246 -0
  176. google/genai/tests/shared/__init__.py +16 -0
  177. google/genai/tests/shared/batches/__init__.py +14 -0
  178. google/genai/tests/shared/batches/test_create_delete.py +57 -0
  179. google/genai/tests/shared/batches/test_create_get_cancel.py +56 -0
  180. google/genai/tests/shared/batches/test_list.py +40 -0
  181. google/genai/tests/shared/caches/__init__.py +14 -0
  182. google/genai/tests/shared/caches/test_create_get_delete.py +67 -0
  183. google/genai/tests/shared/caches/test_create_update_get.py +71 -0
  184. google/genai/tests/shared/caches/test_list.py +40 -0
  185. google/genai/tests/shared/chats/__init__.py +14 -0
  186. google/genai/tests/shared/chats/test_send_message.py +48 -0
  187. google/genai/tests/shared/chats/test_send_message_stream.py +50 -0
  188. google/genai/tests/shared/files/__init__.py +14 -0
  189. google/genai/tests/shared/files/test_list.py +41 -0
  190. google/genai/tests/shared/files/test_upload_get_delete.py +54 -0
  191. google/genai/tests/shared/models/__init__.py +14 -0
  192. google/genai/tests/shared/models/test_compute_tokens.py +41 -0
  193. google/genai/tests/shared/models/test_count_tokens.py +40 -0
  194. google/genai/tests/shared/models/test_edit_image.py +67 -0
  195. google/genai/tests/shared/models/test_embed.py +40 -0
  196. google/genai/tests/shared/models/test_generate_content.py +39 -0
  197. google/genai/tests/shared/models/test_generate_content_stream.py +54 -0
  198. google/genai/tests/shared/models/test_generate_images.py +40 -0
  199. google/genai/tests/shared/models/test_generate_videos.py +38 -0
  200. google/genai/tests/shared/models/test_list.py +37 -0
  201. google/genai/tests/shared/models/test_recontext_image.py +55 -0
  202. google/genai/tests/shared/models/test_segment_image.py +52 -0
  203. google/genai/tests/shared/models/test_upscale_image.py +52 -0
  204. google/genai/tests/shared/tunings/__init__.py +16 -0
  205. google/genai/tests/shared/tunings/test_create.py +46 -0
  206. google/genai/tests/shared/tunings/test_create_get_cancel.py +56 -0
  207. google/genai/tests/shared/tunings/test_list.py +39 -0
  208. google/genai/tests/tokens/__init__.py +16 -0
  209. google/genai/tests/tokens/test_create.py +154 -0
  210. google/genai/tests/transformers/__init__.py +17 -0
  211. google/genai/tests/transformers/test_blobs.py +84 -0
  212. google/genai/tests/transformers/test_bytes.py +15 -0
  213. google/genai/tests/transformers/test_duck_type.py +96 -0
  214. google/genai/tests/transformers/test_function_responses.py +72 -0
  215. google/genai/tests/transformers/test_schema.py +653 -0
  216. google/genai/tests/transformers/test_t_batch.py +286 -0
  217. google/genai/tests/transformers/test_t_content.py +160 -0
  218. google/genai/tests/transformers/test_t_contents.py +398 -0
  219. google/genai/tests/transformers/test_t_part.py +85 -0
  220. google/genai/tests/transformers/test_t_parts.py +87 -0
  221. google/genai/tests/transformers/test_t_tool.py +157 -0
  222. google/genai/tests/transformers/test_t_tools.py +195 -0
  223. google/genai/tests/tunings/__init__.py +16 -0
  224. google/genai/tests/tunings/test_cancel.py +39 -0
  225. google/genai/tests/tunings/test_end_to_end.py +106 -0
  226. google/genai/tests/tunings/test_get.py +67 -0
  227. google/genai/tests/tunings/test_list.py +75 -0
  228. google/genai/tests/tunings/test_tune.py +268 -0
  229. google/genai/tests/types/__init__.py +16 -0
  230. google/genai/tests/types/test_bytes_internal.py +271 -0
  231. google/genai/tests/types/test_bytes_type.py +152 -0
  232. google/genai/tests/types/test_future.py +101 -0
  233. google/genai/tests/types/test_optional_types.py +36 -0
  234. google/genai/tests/types/test_part_type.py +616 -0
  235. google/genai/tests/types/test_schema_from_json_schema.py +417 -0
  236. google/genai/tests/types/test_schema_json_schema.py +468 -0
  237. google/genai/tests/types/test_types.py +2903 -0
  238. google/genai/types.py +631 -488
  239. google/genai/version.py +1 -1
  240. {google_genai-1.56.0.dist-info → google_genai-1.58.0.dist-info}/METADATA +6 -11
  241. google_genai-1.58.0.dist-info/RECORD +358 -0
  242. google_genai-1.56.0.dist-info/RECORD +0 -162
  243. /google/genai/{_interactions/py.typed → tests/interactions/__init__.py} +0 -0
  244. {google_genai-1.56.0.dist-info → google_genai-1.58.0.dist-info}/WHEEL +0 -0
  245. {google_genai-1.56.0.dist-info → google_genai-1.58.0.dist-info}/licenses/LICENSE +0 -0
  246. {google_genai-1.56.0.dist-info → google_genai-1.58.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,2143 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+
16
+
17
+ """Tests for live.py."""
18
+
19
+ import contextlib
20
+ import json
21
+ import os
22
+ import ssl
23
+ import typing
24
+ from typing import Any, AsyncIterator
25
+ from unittest import mock
26
+ from unittest.mock import AsyncMock
27
+ from unittest.mock import Mock
28
+ from unittest.mock import patch
29
+ import warnings
30
+
31
+ import certifi
32
+ from google.oauth2.credentials import Credentials
33
+ import pytest
34
+ from websockets import client
35
+
36
+ from .. import pytest_helper
37
+ from ... import _api_client as api_client
38
+ from ... import _common
39
+ from ... import Client
40
+ from ... import client as gl_client
41
+ from ... import live
42
+ from ... import types
43
+ try:
44
+ import aiohttp
45
+ AIOHTTP_NOT_INSTALLED = False
46
+ except ImportError:
47
+ AIOHTTP_NOT_INSTALLED = True
48
+ aiohttp = mock.MagicMock()
49
+
50
+
51
+ if typing.TYPE_CHECKING:
52
+ from mcp import types as mcp_types
53
+ from mcp import ClientSession as McpClientSession
54
+ else:
55
+ mcp_types: typing.Type = Any
56
+ McpClientSession: typing.Type = Any
57
+ try:
58
+ from mcp import types as mcp_types
59
+ from mcp import ClientSession as McpClientSession
60
+ except ImportError:
61
+ mcp_types = None
62
+ McpClientSession = None
63
+
64
+
65
+ requires_aiohttp = pytest.mark.skipif(
66
+ AIOHTTP_NOT_INSTALLED, reason="aiohttp is not installed, skipping test."
67
+ )
68
+
69
+ function_declarations = [{
70
+ 'name': 'get_current_weather',
71
+ 'description': 'Get the current weather in a city',
72
+ 'parameters': {
73
+ 'type': 'OBJECT',
74
+ 'properties': {
75
+ 'location': {
76
+ 'type': 'STRING',
77
+ 'description': 'The location to get the weather for',
78
+ },
79
+ 'unit': {
80
+ 'type': 'STRING',
81
+ 'enum': ['C', 'F'],
82
+ },
83
+ },
84
+ },
85
+ }]
86
+
87
+
88
+ def get_current_weather(location: str, unit: str):
89
+ """Get the current weather in a city."""
90
+ return 15 if unit == 'C' else 59
91
+
92
+
93
+ def mock_api_client(vertexai=False, credentials=None, http_options=None):
94
+ api_client = mock.MagicMock(spec=gl_client.BaseApiClient)
95
+ if not vertexai:
96
+ api_client.api_key = 'TEST_API_KEY'
97
+ api_client.location = None
98
+ api_client.project = None
99
+ api_client.custom_base_url = None
100
+ else:
101
+ api_client.api_key = None
102
+ if http_options:
103
+ http_options = (
104
+ types.HttpOptions(**http_options)
105
+ if isinstance(http_options, dict)
106
+ else http_options
107
+ )
108
+ api_client.custom_base_url = http_options.base_url
109
+ api_client.location = None
110
+ api_client.project = None
111
+ else:
112
+ api_client.location = 'us-central1'
113
+ api_client.project = 'test_project'
114
+ api_client.custom_base_url = None
115
+
116
+ api_client._host = lambda: 'test_host'
117
+ api_client._credentials = credentials
118
+ api_client._http_options = types.HttpOptions.model_validate(
119
+ {'headers': {}}
120
+ ) # Ensure headers exist
121
+ api_client.vertexai = vertexai
122
+ api_client._api_client = api_client
123
+ ctx = ssl.create_default_context(
124
+ cafile=os.environ.get("SSL_CERT_FILE", certifi.where()),
125
+ capath=os.environ.get("SSL_CERT_DIR"),
126
+ )
127
+ api_client._websocket_ssl_ctx = {'ssl': ctx}
128
+ return api_client
129
+
130
+
131
+ @pytest.fixture
132
+ def mock_websocket():
133
+ websocket = AsyncMock(spec=client.ClientConnection)
134
+ websocket.send = AsyncMock()
135
+ websocket.recv = AsyncMock(
136
+ return_value='{"serverContent": {"turnComplete": true}}'
137
+ ) # Default response
138
+ websocket.close = AsyncMock()
139
+ return websocket
140
+
141
+
142
+ async def get_connect_message(api_client, model, config=None):
143
+ if config is None:
144
+ config = {}
145
+ mock_ws = AsyncMock()
146
+ mock_ws.send = AsyncMock()
147
+ mock_ws.recv = AsyncMock(
148
+ return_value=(
149
+ b'{\n "setupComplete": {"sessionId": "test_session_id"}\n}\n'
150
+ )
151
+ )
152
+
153
+ mock_google_auth_default = Mock(return_value=(None, None))
154
+ mock_creds = Mock(token='test_token')
155
+ mock_google_auth_default.return_value = (mock_creds, None)
156
+
157
+ @contextlib.asynccontextmanager
158
+ async def mock_connect(uri, additional_headers=None, **kwargs):
159
+ yield mock_ws
160
+
161
+ @patch('google.auth.default', new=mock_google_auth_default)
162
+ @patch.object(live, 'ws_connect', new=mock_connect)
163
+ async def _test_connect():
164
+ live_module = live.AsyncLive(api_client)
165
+ async with live_module.connect(
166
+ model=model,
167
+ config=config,
168
+ ):
169
+ pass
170
+
171
+ mock_ws.send.assert_called_once()
172
+ return json.loads(mock_ws.send.call_args[0][0])
173
+
174
+ return await _test_connect()
175
+
176
+
177
+ async def _async_iterator_to_list(async_iter):
178
+ return [value async for value in async_iter]
179
+
180
+
181
+ def test_mldev_from_env(monkeypatch):
182
+ api_key = 'google_api_key'
183
+ monkeypatch.setenv('GOOGLE_API_KEY', api_key)
184
+
185
+ client = Client()
186
+
187
+ assert not client.aio.live._api_client.vertexai
188
+ assert client.aio.live._api_client.api_key == api_key
189
+ assert isinstance(client.aio.live._api_client, api_client.BaseApiClient)
190
+ assert client.aio.live._api_client._http_options.headers['x-goog-api-key'] == api_key
191
+
192
+
193
+ @requires_aiohttp
194
+ def test_vertex_from_env(monkeypatch):
195
+ project_id = 'fake_project_id'
196
+ location = 'fake-location'
197
+ monkeypatch.setenv('GOOGLE_GENAI_USE_VERTEXAI', 'true')
198
+ monkeypatch.setenv('GOOGLE_CLOUD_PROJECT', project_id)
199
+ monkeypatch.setenv('GOOGLE_CLOUD_LOCATION', location)
200
+
201
+ client = Client()
202
+
203
+ assert client.aio.live._api_client.custom_base_url is None
204
+ assert client.aio.live._api_client.vertexai
205
+ assert client.aio.live._api_client.project == project_id
206
+ assert isinstance(client.aio.live._api_client, api_client.BaseApiClient)
207
+ assert 'x-goog-api-key' not in client.aio.live._api_client._http_options.headers
208
+
209
+
210
+ def test_vertex_api_key_from_env(monkeypatch):
211
+ api_key = 'google_api_key'
212
+ monkeypatch.setenv('GOOGLE_GENAI_USE_VERTEXAI', 'true')
213
+ monkeypatch.setenv('GOOGLE_API_KEY', api_key)
214
+
215
+ # Due to proj/location taking precedence, need to clear proj/location env
216
+ # variables. Tests in client/test_client_initialization.py provide
217
+ # comprehensive coverage for proj/location and api key precedence.
218
+ monkeypatch.setenv("GOOGLE_CLOUD_LOCATION", "")
219
+ monkeypatch.setenv("GOOGLE_CLOUD_PROJECT", "")
220
+
221
+ client = Client()
222
+
223
+ assert client.aio.live._api_client.vertexai
224
+ assert client.aio.live._api_client.api_key == api_key
225
+ assert isinstance(client.aio.live._api_client, api_client.BaseApiClient)
226
+ assert client.aio.live._api_client._http_options.headers['x-goog-api-key'] == api_key
227
+
228
+
229
+ def test_websocket_base_url():
230
+ base_url = 'https://test.com'
231
+ api_client = gl_client.BaseApiClient(
232
+ api_key='google_api_key',
233
+ http_options={'base_url': base_url},
234
+ )
235
+ assert api_client._websocket_base_url() == 'wss://test.com'
236
+
237
+
238
+ def test_websocket_base_url_no_auth_with_custom_base_url():
239
+ base_url = 'https://test-api-gateway-proxy.com'
240
+ api_client = gl_client.BaseApiClient(
241
+ vertexai=True,
242
+ http_options={
243
+ 'base_url': base_url,
244
+ 'headers': {'Authorization': 'Bearer test_token'},
245
+ },
246
+ )
247
+ # Note that our test environment does have project/location set. So we
248
+ # need to explicitly set them to None here.
249
+ api_client.project = None
250
+ api_client.location = None
251
+
252
+ # Fully pass the custom base url if no API key or project/location.
253
+ assert api_client._websocket_base_url() == base_url
254
+
255
+
256
+ @pytest.mark.parametrize('vertexai', [True, False])
257
+ @pytest.mark.asyncio
258
+ async def test_async_session_send_text(mock_websocket, vertexai):
259
+ session = live.AsyncSession(
260
+ api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
261
+ )
262
+ await session.send(input='test')
263
+ mock_websocket.send.assert_called_once()
264
+ sent_data = json.loads(mock_websocket.send.call_args[0][0])
265
+ assert 'client_content' in sent_data
266
+
267
+
268
+ @pytest.mark.parametrize('vertexai', [True, False])
269
+ @pytest.mark.asyncio
270
+ async def test_async_session_send_content_dict(
271
+ mock_websocket, vertexai
272
+ ):
273
+ session = live.AsyncSession(
274
+ api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
275
+ )
276
+ client_content = {
277
+ 'content': [{'parts': [{'text': 'test'}]}],
278
+ 'turn_complete': True,
279
+ }
280
+ await session.send(input=client_content)
281
+ mock_websocket.send.assert_called_once()
282
+ sent_data = json.loads(mock_websocket.send.call_args[0][0])
283
+ assert 'client_content' in sent_data
284
+
285
+
286
+ @pytest.mark.parametrize('vertexai', [True, False])
287
+ @pytest.mark.asyncio
288
+ async def test_async_session_send_content(
289
+ mock_websocket, vertexai
290
+ ):
291
+ session = live.AsyncSession(
292
+ api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
293
+ )
294
+ client_content = types.LiveClientContent(
295
+ turns=[types.Content(parts=[types.Part(text='test')])], turn_complete=True
296
+ )
297
+ await session.send(input=client_content)
298
+ mock_websocket.send.assert_called_once()
299
+ sent_data = json.loads(mock_websocket.send.call_args[0][0])
300
+ assert 'client_content' in sent_data
301
+
302
+
303
+ @pytest.mark.parametrize('vertexai', [True, False])
304
+ @pytest.mark.asyncio
305
+ async def test_async_session_send_bytes(
306
+ mock_websocket, vertexai
307
+ ):
308
+ session = live.AsyncSession(
309
+ api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
310
+ )
311
+ realtime_input = {'data': b'000000', 'mime_type': 'audio/pcm'}
312
+
313
+ await session.send(input=realtime_input)
314
+ mock_websocket.send.assert_called_once()
315
+ sent_data = json.loads(mock_websocket.send.call_args[0][0])
316
+ assert 'realtime_input' in sent_data
317
+
318
+
319
+ @pytest.mark.parametrize('vertexai', [True, False])
320
+ @pytest.mark.asyncio
321
+ async def test_async_session_send_blob(
322
+ mock_websocket, vertexai
323
+ ):
324
+ session = live.AsyncSession(
325
+ api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
326
+ )
327
+ realtime_input = types.Blob(data=b'000000', mime_type='audio/pcm')
328
+
329
+ await session.send(input=realtime_input)
330
+ mock_websocket.send.assert_called_once()
331
+ sent_data = json.loads(mock_websocket.send.call_args[0][0])
332
+ assert 'realtime_input' in sent_data
333
+
334
+
335
+ @pytest.mark.parametrize('vertexai', [True, False])
336
+ @pytest.mark.asyncio
337
+ async def test_async_session_send_realtime_input(
338
+ mock_websocket, vertexai
339
+ ):
340
+ session = live.AsyncSession(
341
+ api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
342
+ )
343
+ realtime_input = types.LiveClientRealtimeInput(
344
+ media_chunks=[types.Blob(data='MDAwMDAw', mime_type='audio/pcm')]
345
+ )
346
+ await session.send(input=realtime_input)
347
+ mock_websocket.send.assert_called_once()
348
+ sent_data = json.loads(mock_websocket.send.call_args[0][0])
349
+ assert 'realtime_input' in sent_data
350
+
351
+
352
+ @pytest.mark.parametrize('vertexai', [True, False])
353
+ @pytest.mark.asyncio
354
+ async def test_async_session_send_tool_response(
355
+ mock_websocket, vertexai
356
+ ):
357
+ session = live.AsyncSession(
358
+ api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
359
+ )
360
+
361
+ if vertexai:
362
+ tool_response = types.LiveClientToolResponse(
363
+ function_responses=[
364
+ types.FunctionResponse(
365
+ name='get_current_weather',
366
+ response={'temperature': 14.5, 'unit': 'C'},
367
+ )
368
+ ]
369
+ )
370
+ else:
371
+ tool_response = types.LiveClientToolResponse(
372
+ function_responses=[
373
+ types.FunctionResponse(
374
+ name='get_current_weather',
375
+ response={'temperature': 14.5, 'unit': 'C'},
376
+ id='some-id',
377
+ )
378
+ ]
379
+ )
380
+ await session.send(input=tool_response)
381
+ mock_websocket.send.assert_called_once()
382
+ sent_data = json.loads(mock_websocket.send.call_args[0][0])
383
+ assert 'tool_response' in sent_data
384
+
385
+
386
+ @pytest.mark.parametrize('vertexai', [True, False])
387
+ @pytest.mark.asyncio
388
+ async def test_async_session_send_input_none(
389
+ mock_websocket, vertexai
390
+ ):
391
+ session = live.AsyncSession(
392
+ api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
393
+ )
394
+ await session.send(input=None)
395
+ mock_websocket.send.assert_called_once()
396
+ sent_data = json.loads(mock_websocket.send.call_args[0][0])
397
+ assert 'client_content' in sent_data
398
+ assert sent_data['client_content']['turn_complete']
399
+
400
+
401
+ @pytest.mark.parametrize('vertexai', [True, False])
402
+ @pytest.mark.asyncio
403
+ async def test_async_session_send_error(
404
+ mock_websocket, vertexai
405
+ ):
406
+ session = live.AsyncSession(
407
+ api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
408
+ )
409
+ with pytest.raises(ValueError):
410
+ await session.send(input=[{'invalid_key': 'invalid_value'}])
411
+
412
+ with pytest.raises(ValueError):
413
+ await session.send(input={'invalid_key': 'invalid_value'})
414
+
415
+
416
+ @pytest.mark.parametrize('vertexai', [True, False])
417
+ @pytest.mark.asyncio
418
+ async def test_async_session_receive( mock_websocket, vertexai):
419
+ session = live.AsyncSession(
420
+ api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
421
+ )
422
+ responses = session.receive()
423
+ responses = await _async_iterator_to_list(responses)
424
+ assert isinstance(responses[0], types.LiveServerMessage)
425
+
426
+
427
+ @pytest.mark.parametrize('vertexai', [True, False])
428
+ @pytest.mark.asyncio
429
+ async def test_async_session_receive_error(
430
+ mock_websocket, vertexai
431
+ ):
432
+ mock_websocket.recv = AsyncMock(return_value='invalid json')
433
+ session = live.AsyncSession(
434
+ api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
435
+ )
436
+ with pytest.raises(ValueError):
437
+ await session.receive().__anext__()
438
+
439
+
440
+ @pytest.mark.parametrize('vertexai', [True, False])
441
+ @pytest.mark.asyncio
442
+ async def test_async_session_receive_text(
443
+ mock_websocket, vertexai
444
+ ):
445
+ mock_websocket.recv = AsyncMock(
446
+ side_effect=[
447
+ '{"serverContent": {"modelTurn": {"parts":[{"text": "test"}]}}}',
448
+ '{"serverContent": {"turnComplete": true}}',
449
+ ]
450
+ )
451
+ session = live.AsyncSession(
452
+ api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
453
+ )
454
+ messages = session.receive()
455
+ messages = await _async_iterator_to_list(messages)
456
+ assert isinstance(messages[0], types.LiveServerMessage)
457
+ assert messages[0].server_content.model_turn.parts[0].text == 'test'
458
+ assert messages[1].server_content.turn_complete == True
459
+
460
+
461
+ @pytest.mark.parametrize('vertexai', [True, False])
462
+ @pytest.mark.asyncio
463
+ async def test_async_session_receive_audio(
464
+ mock_websocket, vertexai
465
+ ):
466
+ mock_websocket.recv = AsyncMock(
467
+ side_effect=[
468
+ (
469
+ '{"serverContent": {"modelTurn": {"parts":[{"inlineData":'
470
+ ' {"data": "MDAwMDAw", "mimeType": "audio/pcm" }}]}}}'
471
+ ),
472
+ '{"serverContent": {"turnComplete": true}}',
473
+ ]
474
+ )
475
+ session = live.AsyncSession(
476
+ api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
477
+ )
478
+ messages = session.receive()
479
+ messages = await _async_iterator_to_list(messages)
480
+ assert isinstance(messages[0], types.LiveServerMessage)
481
+ assert (
482
+ messages[0].server_content.model_turn.parts[0].inline_data.mime_type
483
+ == 'audio/pcm'
484
+ )
485
+ assert (
486
+ messages[0].server_content.model_turn.parts[0].inline_data.data
487
+ == b'000000'
488
+ )
489
+
490
+ with pytest.raises(RuntimeError):
491
+ await _async_iterator_to_list(session.receive())
492
+
493
+
494
+ @pytest.mark.parametrize('vertexai', [True, False])
495
+ @pytest.mark.asyncio
496
+ async def test_async_session_receive_tool_call(
497
+ mock_websocket, vertexai
498
+ ):
499
+ mock_websocket.recv = AsyncMock(
500
+ side_effect=[
501
+ (
502
+ '{"toolCall": {"functionCalls": [{"name":'
503
+ ' "get_current_weather", "args": {"location": "San Francisco",'
504
+ ' "unit": "C"}}]}}'
505
+ ),
506
+ '{"serverContent": {"turnComplete": true}}',
507
+ ]
508
+ )
509
+ session = live.AsyncSession(
510
+ api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
511
+ )
512
+ messages = session.receive()
513
+ messages = await _async_iterator_to_list(messages)
514
+ assert isinstance(messages[0], types.LiveServerMessage)
515
+ assert messages[0].tool_call.function_calls[0].name == 'get_current_weather'
516
+ assert (
517
+ messages[0].tool_call.function_calls[0].args['location']
518
+ == 'San Francisco'
519
+ )
520
+ assert messages[0].tool_call.function_calls[0].args['unit'] == 'C'
521
+
522
+ with pytest.raises(RuntimeError):
523
+ await _async_iterator_to_list(session.receive())
524
+
525
+
526
+ @pytest.mark.parametrize('vertexai', [True, False])
527
+ @pytest.mark.asyncio
528
+ async def test_async_session_receive_transcription(
529
+ mock_websocket, vertexai
530
+ ):
531
+ mock_websocket.recv = AsyncMock(
532
+ side_effect=[
533
+ '{"serverContent": {"inputTranscription": {"text": "test_input", "finished": true}}}',
534
+ '{"serverContent": {"outputTranscription": {"text": "test_output", "finished": false}}}',
535
+ '{"serverContent": {"turnComplete": true}}',
536
+ ]
537
+ )
538
+ session = live.AsyncSession(
539
+ api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
540
+ )
541
+ messages = session.receive()
542
+ messages = await _async_iterator_to_list(messages)
543
+ assert isinstance(messages[0], types.LiveServerMessage)
544
+ assert messages[0].server_content.input_transcription.text == 'test_input'
545
+ assert messages[0].server_content.input_transcription.finished == True
546
+
547
+ assert isinstance(messages[1], types.LiveServerMessage)
548
+ assert messages[1].server_content.output_transcription.text == 'test_output'
549
+ assert messages[1].server_content.output_transcription.finished == False
550
+
551
+
552
+ @pytest.mark.parametrize('vertexai', [True, False])
553
+ @pytest.mark.asyncio
554
+ async def test_async_go_away(
555
+ mock_websocket, vertexai
556
+ ):
557
+ mock_websocket.recv = AsyncMock(
558
+ side_effect=[
559
+ '{"goAway": {"timeLeft": "10s"}}',
560
+ '{"serverContent": {"turnComplete": true}}',
561
+ ]
562
+ )
563
+ expected_result = types.LiveServerMessage(
564
+ go_away=types.LiveServerGoAway(time_left='10s'),
565
+ )
566
+ session = live.AsyncSession(
567
+ api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
568
+ )
569
+ messages = session.receive()
570
+ messages = await _async_iterator_to_list(messages)
571
+ message = messages[0]
572
+
573
+ assert isinstance(message, types.LiveServerMessage)
574
+ assert message == expected_result
575
+
576
+
577
+ @pytest.mark.parametrize('vertexai', [True, False])
578
+ @pytest.mark.asyncio
579
+ async def test_async_session_resumption_update(
580
+ mock_websocket, vertexai
581
+ ):
582
+ mock_websocket.recv = AsyncMock(
583
+ side_effect=[
584
+ """{
585
+ "sessionResumptionUpdate": {
586
+ "newHandle": "test_handle",
587
+ "resumable": "true",
588
+ "lastConsumedClientMessageIndex": "123456789"
589
+ }
590
+ }""",
591
+ '{"serverContent": {"turnComplete": true}}',
592
+ ]
593
+ )
594
+
595
+ expected_result = types.LiveServerMessage(
596
+ session_resumption_update=types.LiveServerSessionResumptionUpdate(
597
+ new_handle='test_handle',
598
+ resumable=True,
599
+ last_consumed_client_message_index=123456789
600
+ ),
601
+ )
602
+
603
+ session = live.AsyncSession(
604
+ api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
605
+ )
606
+ messages = session.receive()
607
+ messages = await _async_iterator_to_list(messages)
608
+ message = messages[0]
609
+
610
+ assert isinstance(message, types.LiveServerMessage)
611
+ assert message == expected_result
612
+
613
+
614
+ @pytest.mark.parametrize('vertexai', [True, False])
615
+ @pytest.mark.asyncio
616
+ async def test_async_session_start_stream(
617
+ mock_websocket, vertexai
618
+ ):
619
+
620
+ session = live.AsyncSession(
621
+ mock_api_client(vertexai=vertexai), mock_websocket
622
+ )
623
+
624
+ async def mock_stream():
625
+ yield b'data1'
626
+ yield b'data2'
627
+
628
+ async for message in session.start_stream(
629
+ stream=mock_stream(), mime_type='audio/pcm'
630
+ ):
631
+ assert isinstance(message, types.LiveServerMessage)
632
+
633
+
634
+ @pytest.mark.parametrize('vertexai', [True, False])
635
+ @pytest.mark.asyncio
636
+ async def test_async_session_receive_vad_signal(mock_websocket, vertexai):
637
+ # Simulate the server sending a VAD signal message
638
+ mock_websocket.recv = mock.AsyncMock(
639
+ side_effect=[
640
+ '{"voiceActivityDetectionSignal": {"vadSignalType": "VAD_SIGNAL_TYPE_SOS"}}',
641
+ '{"serverContent": {"turnComplete": true}}', # To close the receiver loop
642
+ ]
643
+ )
644
+
645
+ session = live.AsyncSession(
646
+ api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
647
+ )
648
+
649
+ messages = await _async_iterator_to_list(session.receive())
650
+
651
+ # Check if the first message contains the VAD signal
652
+ assert len(messages) > 0
653
+ vad_message = messages[0]
654
+ assert isinstance(vad_message, types.LiveServerMessage)
655
+ assert vad_message.voice_activity_detection_signal is not None
656
+ assert (
657
+ vad_message.voice_activity_detection_signal.vad_signal_type
658
+ == types.VadSignalType.VAD_SIGNAL_TYPE_SOS
659
+ )
660
+
661
+ # Check that the session can close cleanly
662
+ assert messages[-1].server_content.turn_complete is True
663
+
664
+
665
+ @pytest.mark.parametrize('vertexai', [True, False])
666
+ @pytest.mark.asyncio
667
+ async def test_async_session_close( mock_websocket, vertexai):
668
+ session = live.AsyncSession(
669
+ mock_api_client(vertexai=vertexai), mock_websocket
670
+ )
671
+ await session.close()
672
+ mock_websocket.close.assert_called_once()
673
+
674
+
675
+ @pytest.mark.parametrize('vertexai', [True, False])
676
+ @pytest.mark.asyncio
677
+ async def test_bidi_setup_to_api_no_config(vertexai):
678
+ with warnings.catch_warnings():
679
+ # Make sure there are no warnings cause by default values.
680
+ warnings.simplefilter('error')
681
+ result = await get_connect_message(
682
+ mock_api_client(vertexai=vertexai),
683
+ model='test_model'
684
+ )
685
+ expected_result = {'setup': {}}
686
+ if vertexai:
687
+ expected_result['setup']['model'] = 'projects/test_project/locations/us-central1/publishers/google/models/test_model'
688
+ expected_result['setup']['generationConfig'] = {}
689
+ expected_result['setup']['generationConfig']['responseModalities'] = ["AUDIO"]
690
+ else:
691
+ expected_result['setup']['model'] = 'models/test_model'
692
+ assert result == expected_result
693
+
694
+
695
+ @pytest.mark.parametrize('vertexai', [True, False])
696
+ @pytest.mark.asyncio
697
+ async def test_bidi_setup_to_api_speech_config(vertexai):
698
+
699
+ expected_result = {
700
+ 'setup': {
701
+ 'model': 'models/test_model',
702
+ 'generationConfig': {
703
+ 'speechConfig': {
704
+ 'voice_config': {
705
+ 'prebuilt_voice_config': {'voice_name': 'en-default'}
706
+ },
707
+ 'language_code': 'en-US',
708
+ },
709
+ 'enableAffectiveDialog': True,
710
+ 'temperature': 0.7,
711
+ 'topP': 0.8,
712
+ 'topK': 9.0,
713
+ 'maxOutputTokens': 10,
714
+ 'mediaResolution': 'MEDIA_RESOLUTION_MEDIUM',
715
+ 'seed': 13,
716
+ },
717
+ 'proactivity': {'proactive_audio': True},
718
+ 'systemInstruction': {
719
+ 'parts': [
720
+ {
721
+ 'text': 'test instruction',
722
+ },
723
+ ],
724
+ 'role': 'user',
725
+ },
726
+ }
727
+ }
728
+ if vertexai:
729
+ expected_result['setup']['model'] = (
730
+ 'projects/test_project/locations/us-central1/'
731
+ 'publishers/google/models/test_model'
732
+ )
733
+ expected_result['setup']['generationConfig']['responseModalities'] = [
734
+ 'AUDIO'
735
+ ]
736
+ expected_result['setup']['generationConfig']['speechConfig'] = {
737
+ 'voiceConfig': {
738
+ 'prebuilt_voice_config': {'voice_name': 'en-default'}
739
+ },
740
+ 'languageCode': 'en-US',
741
+ }
742
+ else:
743
+ expected_result['setup']['model'] = 'models/test_model'
744
+
745
+ # Config is a dict
746
+ config_dict = {
747
+ 'speech_config': {
748
+ 'voice_config': {
749
+ 'prebuilt_voice_config': {'voice_name': 'en-default'}
750
+ },
751
+ 'language_code': 'en-US',
752
+ },
753
+ 'enable_affective_dialog': True,
754
+ 'proactivity': {'proactive_audio': True},
755
+ 'temperature': 0.7,
756
+ 'top_p': 0.8,
757
+ 'top_k': 9,
758
+ 'max_output_tokens': 10,
759
+ 'seed': 13,
760
+ 'system_instruction': 'test instruction',
761
+ 'media_resolution': 'MEDIA_RESOLUTION_MEDIUM',
762
+ }
763
+ result = await get_connect_message(
764
+ mock_api_client(vertexai=vertexai), model='test_model', config=config_dict
765
+ )
766
+ assert types.LiveClientMessage._from_response(
767
+ response=result, kwargs=None
768
+ ) == types.LiveClientMessage._from_response(
769
+ response=expected_result, kwargs=None
770
+ )
771
+ # Config is a LiveConnectConfig
772
+ config = types.LiveConnectConfig(
773
+ speech_config=types.SpeechConfig(
774
+ voice_config=types.VoiceConfig(
775
+ prebuilt_voice_config=types.PrebuiltVoiceConfig(
776
+ voice_name='en-default'
777
+ )
778
+ ),
779
+ language_code='en-US',
780
+ ),
781
+ enable_affective_dialog=True,
782
+ proactivity=types.ProactivityConfig(proactive_audio=True),
783
+ temperature=0.7,
784
+ top_p=0.8,
785
+ top_k=9,
786
+ max_output_tokens=10,
787
+ media_resolution=types.MediaResolution.MEDIA_RESOLUTION_MEDIUM,
788
+ seed=13,
789
+ system_instruction='test instruction',
790
+ )
791
+ result = await get_connect_message(
792
+ mock_api_client(vertexai=vertexai),
793
+ model='test_model', config=config
794
+ )
795
+ assert types.LiveClientMessage._from_response(
796
+ response=result, kwargs=None
797
+ ) == types.LiveClientMessage._from_response(
798
+ response=expected_result, kwargs=None
799
+ )
800
+
801
+
802
+ @pytest.mark.parametrize('vertexai', [True, False])
803
+ @pytest.mark.asyncio
804
+ async def test_bidi_setup_error_if_multispeaker_voice_config(vertexai):
805
+
806
+ # Config is a dict
807
+ config_dict = {
808
+ 'speech_config': {
809
+ 'multi_speaker_voice_config': {
810
+ 'speaker_voice_configs': [
811
+ {
812
+ 'speaker': 'Alice',
813
+ 'voice_config': {
814
+ 'prebuilt_voice_config': {'voice_name': 'leda'}
815
+ },
816
+ },
817
+ {
818
+ 'speaker': 'Bob',
819
+ 'voice_config': {
820
+ 'prebuilt_voice_config': {'voice_name': 'kore'}
821
+ },
822
+ },
823
+ ],
824
+ },
825
+ },
826
+ 'temperature': 0.7,
827
+ 'top_p': 0.8,
828
+ 'top_k': 9,
829
+ 'max_output_tokens': 10,
830
+ 'seed': 13,
831
+ 'system_instruction': 'test instruction',
832
+ 'media_resolution': 'MEDIA_RESOLUTION_MEDIUM',
833
+ }
834
+ with pytest.raises(ValueError, match='.*multi_speaker_voice_config.*'):
835
+ result = await get_connect_message(
836
+ mock_api_client(vertexai=vertexai),
837
+ model='test_model',
838
+ config=config_dict,
839
+ )
840
+
841
+
842
+ @pytest.mark.parametrize('vertexai', [True, False])
843
+ @pytest.mark.asyncio
844
+ async def test_explicit_vad(vertexai):
845
+ # Config is a dict
846
+ config_dict = {'explicit_vad_signal': True}
847
+ with pytest_helper.exception_if_mldev(
848
+ mock_api_client(vertexai=vertexai), ValueError
849
+ ):
850
+ result = await get_connect_message(
851
+ mock_api_client(vertexai=vertexai), model='test_model', config=config_dict
852
+ )
853
+ if not vertexai:
854
+ return
855
+ assert result['setup']['explicitVadSignal'] == True
856
+
857
+
858
+ @pytest.mark.parametrize('vertexai', [True, False])
859
+ @pytest.mark.asyncio
860
+ async def test_explicit_vad_config(vertexai):
861
+ api_client = mock_api_client(vertexai=vertexai)
862
+
863
+ # Config is a dict
864
+ config_dict = {'explicit_vad_signal': True}
865
+ with pytest_helper.exception_if_mldev(api_client, ValueError):
866
+ result = await get_connect_message(
867
+ mock_api_client(vertexai=vertexai),
868
+ model='test_model',
869
+ config=config_dict,
870
+ )
871
+ if not vertexai:
872
+ return
873
+ assert result['setup']['explicitVadSignal'] == True
874
+
875
+
876
+ @pytest.mark.parametrize('vertexai', [True, False])
877
+ @pytest.mark.asyncio
878
+ async def test_bidi_setup_to_api_with_system_instruction_as_content_type(
879
+ vertexai,
880
+ ):
881
+ config_dict = {
882
+ 'system_instruction': {
883
+ 'parts': [{'text': 'test instruction'}],
884
+ 'role': 'user',
885
+ },
886
+ }
887
+ config = types.LiveConnectConfig(**config_dict)
888
+ expected_result = {
889
+ 'setup': {
890
+ 'model': 'test_model',
891
+ 'systemInstruction': {
892
+ 'parts': [{'text': 'test instruction'}],
893
+ 'role': 'user',
894
+ },
895
+ }
896
+ }
897
+ if vertexai:
898
+ expected_result['setup'][
899
+ 'model'
900
+ ] = 'projects/test_project/locations/us-central1/publishers/google/models/test_model'
901
+ expected_result['setup']['generationConfig'] = {}
902
+ expected_result['setup']['generationConfig']['responseModalities'] = [
903
+ 'AUDIO'
904
+ ]
905
+ else:
906
+ expected_result['setup']['model'] = 'models/test_model'
907
+
908
+ result = await get_connect_message(
909
+ mock_api_client(vertexai=vertexai),
910
+ model='test_model', config=config
911
+ )
912
+ assert result == expected_result
913
+
914
+
915
+ @pytest.mark.parametrize('vertexai', [True, False])
916
+ @pytest.mark.asyncio
917
+ async def test_bidi_setup_to_api_with_config_tools_google_search(vertexai):
918
+ config_dict = {
919
+ 'response_modalities': ['TEXT'],
920
+ 'system_instruction': 'test instruction',
921
+ 'generation_config': {'temperature': 0.7},
922
+ 'tools': [{'google_search': {}}],
923
+ }
924
+
925
+ config = types.LiveConnectConfig(**config_dict)
926
+ expected_result = {
927
+ 'setup': {
928
+ 'generationConfig': {
929
+ 'temperature': 0.7,
930
+ 'responseModalities': ['TEXT'],
931
+ },
932
+ 'systemInstruction': {
933
+ 'parts': [{'text': 'test instruction'}],
934
+ 'role': 'user',
935
+ },
936
+ 'tools': [{'googleSearch': {}}],
937
+ }
938
+ }
939
+ if vertexai:
940
+ expected_result['setup']['model'] = 'projects/test_project/locations/us-central1/publishers/google/models/test_model'
941
+ else:
942
+ expected_result['setup']['model'] = 'models/test_model'
943
+
944
+ result = await get_connect_message(
945
+ mock_api_client(vertexai=vertexai),
946
+ model='test_model', config=config_dict
947
+ )
948
+
949
+ assert result == expected_result
950
+
951
+ # Test config is a LiveConnectConfig
952
+ result = await get_connect_message(
953
+ mock_api_client(vertexai=vertexai),
954
+ model='test_model', config=config
955
+ )
956
+
957
+ assert result == expected_result
958
+
959
+
960
+ @pytest.mark.parametrize('vertexai', [True, False])
961
+ @pytest.mark.asyncio
962
+ async def test_bidi_setup_to_api_with_config_tools_with_no_mcp(vertexai):
963
+ config_dict = {
964
+ 'response_modalities': ['TEXT'],
965
+ 'system_instruction': 'test instruction',
966
+ 'generation_config': {'temperature': 0.7},
967
+ 'tools': [{'google_search': {}}],
968
+ }
969
+
970
+ config = types.LiveConnectConfig(**config_dict)
971
+ expected_result = {
972
+ 'setup': {
973
+ 'generationConfig': {
974
+ 'temperature': 0.7,
975
+ 'responseModalities': ['TEXT'],
976
+ },
977
+ 'systemInstruction': {
978
+ 'parts': [{'text': 'test instruction'}],
979
+ 'role': 'user',
980
+ },
981
+ 'tools': [{'googleSearch': {}}],
982
+ }
983
+ }
984
+ if vertexai:
985
+ expected_result['setup']['model'] = 'projects/test_project/locations/us-central1/publishers/google/models/test_model'
986
+ else:
987
+ expected_result['setup']['model'] = 'models/test_model'
988
+
989
+ @patch.object(live, "McpClientSession", new=None)
990
+ @patch.object(live, "McpTool", new=None)
991
+ async def get_connect_message_no_mcp(config):
992
+ return await get_connect_message(
993
+ mock_api_client(vertexai=vertexai),
994
+ model='test_model', config=config
995
+ )
996
+
997
+ result = await get_connect_message_no_mcp(config_dict)
998
+ assert result == expected_result
999
+
1000
+ result = await get_connect_message_no_mcp(config_dict)
1001
+ assert result == expected_result
1002
+
1003
+
1004
+ @pytest.mark.parametrize('vertexai', [True, False])
1005
+ @pytest.mark.asyncio
1006
+ async def test_bidi_setup_to_api_with_context_window_compression(
1007
+ vertexai
1008
+ ):
1009
+ config = types.LiveConnectConfig(
1010
+ generation_config=types.GenerationConfig(temperature=0.7),
1011
+ response_modalities=['TEXT'],
1012
+ system_instruction=types.Content(
1013
+ parts=[types.Part(text='test instruction')], role='user'
1014
+ ),
1015
+ context_window_compression=types.ContextWindowCompressionConfig(
1016
+ trigger_tokens=1000,
1017
+ sliding_window=types.SlidingWindow(target_tokens=10),
1018
+ ),
1019
+ )
1020
+ expected_result = {
1021
+ 'setup': {
1022
+ 'generationConfig': {
1023
+ 'temperature': 0.7,
1024
+ 'responseModalities': ['TEXT'],
1025
+ },
1026
+ 'systemInstruction': {
1027
+ 'parts': [{'text': 'test instruction'}],
1028
+ 'role': 'user',
1029
+ },
1030
+ 'contextWindowCompression': {
1031
+ 'trigger_tokens': 1000,
1032
+ 'sliding_window': {'target_tokens': 10},
1033
+ }
1034
+ }
1035
+ }
1036
+ if vertexai:
1037
+ expected_result['setup']['model'] = 'projects/test_project/locations/us-central1/publishers/google/models/test_model'
1038
+ else:
1039
+ expected_result['setup']['model'] = 'models/test_model'
1040
+
1041
+ result = await get_connect_message(
1042
+ mock_api_client(vertexai=vertexai),
1043
+ model='test_model', config=config
1044
+ )
1045
+ assert result == expected_result
1046
+
1047
+
1048
+ @pytest.mark.parametrize('vertexai', [True, False])
1049
+ @pytest.mark.asyncio
1050
+ async def test_bidi_setup_to_api_with_config_tools_function_declaration(
1051
+ vertexai
1052
+ ):
1053
+ config_dict = {
1054
+ 'generation_config': {'temperature': 0.7},
1055
+ 'tools': [{'function_declarations': function_declarations}],
1056
+ }
1057
+ config = types.LiveConnectConfig(**config_dict)
1058
+ expected_result = {
1059
+ 'setup': {
1060
+ 'model': 'test_model',
1061
+ 'tools': [{
1062
+ 'functionDeclarations': [{
1063
+ 'parameters': {
1064
+ 'type': 'OBJECT',
1065
+ 'properties': {
1066
+ 'location': {
1067
+ 'type': 'STRING',
1068
+ 'description': (
1069
+ 'The location to get the weather for'
1070
+ ),
1071
+ },
1072
+ 'unit': {'type': 'STRING', 'enum': ['C', 'F']},
1073
+ },
1074
+ },
1075
+ 'name': 'get_current_weather',
1076
+ 'description': 'Get the current weather in a city',
1077
+ }],
1078
+ }],
1079
+ }
1080
+ }
1081
+ result = await get_connect_message(
1082
+ mock_api_client(vertexai=vertexai),
1083
+ model='test_model', config=config
1084
+ )
1085
+
1086
+ assert result['setup']['tools'][0]['functionDeclarations'][0][
1087
+ 'description'
1088
+ ] == (
1089
+ expected_result['setup']['tools'][0]['functionDeclarations'][0][
1090
+ 'description'
1091
+ ]
1092
+ )
1093
+
1094
+ result = await get_connect_message(
1095
+ mock_api_client(vertexai=vertexai),
1096
+ model='test_model', config=config
1097
+ )
1098
+
1099
+ assert result['setup']['tools'][0]['functionDeclarations'][0][
1100
+ 'description'
1101
+ ] == (
1102
+ expected_result['setup']['tools'][0]['functionDeclarations'][0][
1103
+ 'description'
1104
+ ]
1105
+ )
1106
+
1107
+
1108
+ @pytest.mark.parametrize('vertexai', [True, False])
1109
+ @pytest.mark.asyncio
1110
+ async def test_bidi_setup_to_api_with_config_tools_function_directly(
1111
+ vertexai
1112
+ ):
1113
+ config_dict = {
1114
+ 'generation_config': {'temperature': 0.7},
1115
+ 'tools': [get_current_weather],
1116
+ }
1117
+ config = types.LiveConnectConfig(**config_dict)
1118
+ expected_result = {
1119
+ 'setup': {
1120
+ 'model': 'test_model',
1121
+ 'tools': [{
1122
+ 'functionDeclarations': [{
1123
+ 'parameters': {
1124
+ 'type': 'OBJECT',
1125
+ 'properties': {
1126
+ 'location': {
1127
+ 'type': 'STRING',
1128
+ 'description': (
1129
+ 'The location to get the weather for'
1130
+ ),
1131
+ },
1132
+ 'unit': {'type': 'STRING', 'enum': ['C', 'F']},
1133
+ },
1134
+ },
1135
+ 'name': 'get_current_weather',
1136
+ 'description': 'Get the current weather in a city.',
1137
+ }],
1138
+ }],
1139
+ }
1140
+ }
1141
+ result = await get_connect_message(
1142
+ mock_api_client(vertexai=vertexai),
1143
+ model='test_model', config=config
1144
+ )
1145
+
1146
+ assert result['setup']['tools'][0]['functionDeclarations'][0][
1147
+ 'description'
1148
+ ] == (
1149
+ expected_result['setup']['tools'][0]['functionDeclarations'][0][
1150
+ 'description'
1151
+ ]
1152
+ )
1153
+
1154
+ result = await get_connect_message(
1155
+ mock_api_client(vertexai=vertexai),
1156
+ model='test_model', config=config
1157
+ )
1158
+
1159
+ assert result['setup']['tools'][0]['functionDeclarations'][0][
1160
+ 'description'
1161
+ ] == (
1162
+ expected_result['setup']['tools'][0]['functionDeclarations'][0][
1163
+ 'description'
1164
+ ]
1165
+ )
1166
+
1167
+
1168
+ @pytest.mark.parametrize('vertexai', [True, False])
1169
+ @pytest.mark.asyncio
1170
+ async def test_bidi_setup_to_api_with_tools_function_behavior(vertexai):
1171
+ api_client = mock_api_client(vertexai=vertexai)
1172
+
1173
+ declaration = types.FunctionDeclaration.from_callable(
1174
+ client=api_client, callable=get_current_weather
1175
+ )
1176
+ declaration.behavior = types.Behavior.NON_BLOCKING
1177
+ config_dict = {
1178
+ 'generation_config': {'temperature': 0.7},
1179
+ 'tools': [{'function_declarations': [declaration]}],
1180
+ }
1181
+ config = types.LiveConnectConfig(**config_dict)
1182
+
1183
+ with pytest_helper.exception_if_vertex(api_client, ValueError):
1184
+ result = await get_connect_message(
1185
+ mock_api_client(vertexai=vertexai), model='test_model', config=config
1186
+ )
1187
+ if vertexai:
1188
+ return
1189
+
1190
+ assert (
1191
+ result['setup']['tools'][0]['functionDeclarations'][0]['behavior']
1192
+ == 'NON_BLOCKING'
1193
+ )
1194
+
1195
+
1196
+ @pytest.mark.parametrize('vertexai', [True, False])
1197
+ @pytest.mark.asyncio
1198
+ async def test_bidi_setup_to_api_with_config_mcp_tools(
1199
+ vertexai,
1200
+ ):
1201
+ if mcp_types is None:
1202
+ return
1203
+
1204
+ expected_result_googleai = {
1205
+ 'setup': {
1206
+ 'model': 'models/test_model',
1207
+ 'tools': [{
1208
+ 'functionDeclarations': [{
1209
+ 'parameters': {
1210
+ 'type': 'OBJECT',
1211
+ 'properties': {
1212
+ 'location': {
1213
+ 'type': 'STRING',
1214
+ },
1215
+ },
1216
+ },
1217
+ 'name': 'get_weather',
1218
+ 'description': 'Get the weather in a city.',
1219
+ }],
1220
+ }],
1221
+ }
1222
+ }
1223
+ expected_result_vertexai = {
1224
+ 'setup': {
1225
+ 'generationConfig': {
1226
+ 'responseModalities': [
1227
+ 'AUDIO',
1228
+ ],
1229
+ },
1230
+ 'model': (
1231
+ 'projects/test_project/locations/us-central1/publishers/google/models/test_model'
1232
+ ),
1233
+ 'tools': [{
1234
+ 'functionDeclarations': [{
1235
+ 'parameters': {
1236
+ 'type': 'OBJECT',
1237
+ 'properties': {
1238
+ 'location': {
1239
+ 'type': 'STRING',
1240
+ },
1241
+ },
1242
+ },
1243
+ 'name': 'get_weather',
1244
+ 'description': 'Get the weather in a city.',
1245
+ }],
1246
+ }],
1247
+ }
1248
+ }
1249
+ result = await get_connect_message(
1250
+ mock_api_client(vertexai=vertexai),
1251
+ model='test_model',
1252
+ config={
1253
+ 'tools': [
1254
+ mcp_types.Tool(
1255
+ name='get_weather',
1256
+ description='Get the weather in a city.',
1257
+ inputSchema={
1258
+ 'type': 'object',
1259
+ 'properties': {'location': {'type': 'string'}},
1260
+ },
1261
+ )
1262
+ ],
1263
+ },
1264
+ )
1265
+
1266
+ assert (
1267
+ result == expected_result_vertexai
1268
+ if vertexai
1269
+ else expected_result_googleai
1270
+ )
1271
+
1272
+
1273
+ @pytest.mark.parametrize('vertexai', [True, False])
1274
+ @pytest.mark.asyncio
1275
+ async def test_bidi_setup_to_api_with_config_mcp_session(
1276
+ vertexai,
1277
+ ):
1278
+ if mcp_types is None:
1279
+ return
1280
+
1281
+ class MockMcpClientSession(McpClientSession):
1282
+
1283
+ def __init__(self):
1284
+ self._read_stream = None
1285
+ self._write_stream = None
1286
+
1287
+ async def list_tools(self):
1288
+ return mcp_types.ListToolsResult(
1289
+ tools=[
1290
+ mcp_types.Tool(
1291
+ name='get_weather',
1292
+ description='Get the weather in a city.',
1293
+ inputSchema={
1294
+ 'type': 'object',
1295
+ 'properties': {'location': {'type': 'string'}},
1296
+ },
1297
+ ),
1298
+ ]
1299
+ )
1300
+
1301
+ expected_result_googleai = {
1302
+ 'setup': {
1303
+ 'model': 'models/test_model',
1304
+ 'tools': [{
1305
+ 'functionDeclarations': [{
1306
+ 'parameters': {
1307
+ 'type': 'OBJECT',
1308
+ 'properties': {
1309
+ 'location': {
1310
+ 'type': 'STRING',
1311
+ },
1312
+ },
1313
+ },
1314
+ 'name': 'get_weather',
1315
+ 'description': 'Get the weather in a city.',
1316
+ }],
1317
+ }],
1318
+ }
1319
+ }
1320
+ expected_result_vertexai = {
1321
+ 'setup': {
1322
+ 'generationConfig': {
1323
+ 'responseModalities': [
1324
+ 'AUDIO',
1325
+ ],
1326
+ },
1327
+ 'model': (
1328
+ 'projects/test_project/locations/us-central1/publishers/google/models/test_model'
1329
+ ),
1330
+ 'tools': [{
1331
+ 'functionDeclarations': [{
1332
+ 'parameters': {
1333
+ 'type': 'OBJECT',
1334
+ 'properties': {
1335
+ 'location': {
1336
+ 'type': 'STRING',
1337
+ },
1338
+ },
1339
+ },
1340
+ 'name': 'get_weather',
1341
+ 'description': 'Get the weather in a city.',
1342
+ }],
1343
+ }],
1344
+ }
1345
+ }
1346
+ result = await get_connect_message(
1347
+ mock_api_client(vertexai=vertexai),
1348
+ model='test_model',
1349
+ config={
1350
+ 'tools': [MockMcpClientSession()],
1351
+ },
1352
+ )
1353
+
1354
+ assert (
1355
+ result == expected_result_vertexai
1356
+ if vertexai
1357
+ else expected_result_googleai
1358
+ )
1359
+
1360
+
1361
+ @pytest.mark.parametrize('vertexai', [True, False])
1362
+ @pytest.mark.asyncio
1363
+ async def test_bidi_setup_to_api_with_config_tools_code_execution(
1364
+ vertexai
1365
+ ):
1366
+ config_dict = {
1367
+ 'tools': [{'code_execution': {}}],
1368
+ }
1369
+ config = types.LiveConnectConfig(**config_dict)
1370
+ expected_result = {
1371
+ 'setup': {
1372
+ 'model': 'test_model',
1373
+ 'tools': [{
1374
+ 'codeExecution': {},
1375
+ }],
1376
+ }
1377
+ }
1378
+ result = await get_connect_message(
1379
+ mock_api_client(vertexai=vertexai),
1380
+ model='test_model', config=config
1381
+ )
1382
+
1383
+ assert result['setup']['tools'][0] == expected_result['setup']['tools'][0]
1384
+
1385
+
1386
+ @pytest.mark.parametrize('vertexai', [True, False])
1387
+ @pytest.mark.asyncio
1388
+ async def test_bidi_setup_to_api_with_realtime_input_config(vertexai):
1389
+ config_dict = {
1390
+ 'realtime_input_config': {
1391
+ 'automatic_activity_detection': {
1392
+ 'disabled': True,
1393
+ 'start_of_speech_sensitivity': 'START_SENSITIVITY_HIGH',
1394
+ 'end_of_speech_sensitivity': 'END_SENSITIVITY_HIGH',
1395
+ 'prefix_padding_ms': 20,
1396
+ 'silence_duration_ms': 100,
1397
+ },
1398
+ 'activity_handling': 'NO_INTERRUPTION',
1399
+ 'turn_coverage': 'TURN_INCLUDES_ALL_INPUT',
1400
+ }
1401
+ }
1402
+
1403
+ config = types.LiveConnectConfig(**config_dict)
1404
+ expected_result = {
1405
+ 'setup': {
1406
+ 'model': 'test_model',
1407
+ 'realtimeInputConfig': config_dict['realtime_input_config'],
1408
+ }
1409
+ }
1410
+
1411
+ result = await get_connect_message(
1412
+ mock_api_client(vertexai=vertexai),
1413
+ model='test_model', config=config
1414
+ )
1415
+
1416
+ assert (
1417
+ result['setup']['realtimeInputConfig']
1418
+ == expected_result['setup']['realtimeInputConfig']
1419
+ )
1420
+
1421
+
1422
+ @pytest.mark.parametrize('vertexai', [True, False])
1423
+ @pytest.mark.asyncio
1424
+ async def test_bidi_setup_to_api_with_input_transcription(vertexai):
1425
+ config_dict = {
1426
+ 'input_audio_transcription': {},
1427
+ }
1428
+ config = types.LiveConnectConfig(**config_dict)
1429
+ expected_result = {
1430
+ 'setup': {
1431
+ 'model': 'test_model',
1432
+ 'inputAudioTranscription': {},
1433
+ }
1434
+ }
1435
+
1436
+ result = await get_connect_message(
1437
+ mock_api_client(vertexai=vertexai), model='test_model', config=config
1438
+ )
1439
+
1440
+ assert (
1441
+ result['setup']['inputAudioTranscription']
1442
+ == expected_result['setup']['inputAudioTranscription']
1443
+ )
1444
+
1445
+
1446
+ @pytest.mark.parametrize('vertexai', [True, False])
1447
+ @pytest.mark.asyncio
1448
+ async def test_bidi_setup_to_api_with_output_transcription(vertexai):
1449
+ config_dict = {
1450
+ 'output_audio_transcription': {},
1451
+ }
1452
+ config = types.LiveConnectConfig(**config_dict)
1453
+ expected_result = {
1454
+ 'setup': {
1455
+ 'model': 'test_model',
1456
+ 'outputAudioTranscription': {},
1457
+ }
1458
+ }
1459
+
1460
+ result = await get_connect_message(
1461
+ mock_api_client(vertexai=vertexai),
1462
+ model='test_model', config=config
1463
+ )
1464
+
1465
+ assert (
1466
+ result['setup']['outputAudioTranscription']
1467
+ == expected_result['setup']['outputAudioTranscription']
1468
+ )
1469
+
1470
+ @pytest.mark.parametrize('vertexai', [True, False])
1471
+ @pytest.mark.asyncio
1472
+ async def test_bidi_setup_to_api_with_media_resolution(vertexai):
1473
+ config_dict = {
1474
+ 'media_resolution': 'MEDIA_RESOLUTION_LOW',
1475
+ }
1476
+ config = types.LiveConnectConfig(**config_dict)
1477
+ expected_result = {
1478
+ 'setup': {
1479
+ 'model': 'test_model',
1480
+ 'generationConfig': {'mediaResolution':'MEDIA_RESOLUTION_LOW'},
1481
+ }
1482
+ }
1483
+
1484
+ result = await get_connect_message(
1485
+ mock_api_client(vertexai=vertexai),
1486
+ model='test_model', config=config
1487
+ )
1488
+
1489
+ assert (
1490
+ result['setup']['generationConfig']['mediaResolution']
1491
+ == expected_result['setup']['generationConfig']['mediaResolution']
1492
+ )
1493
+
1494
+
1495
+ @pytest.mark.parametrize('vertexai', [True])
1496
+ @pytest.mark.asyncio
1497
+ async def test_bidi_setup_publishers(
1498
+ vertexai
1499
+ ):
1500
+ expected_result = {
1501
+ 'setup': {
1502
+ 'generationConfig': {
1503
+ 'responseModalities': [
1504
+ 'AUDIO',
1505
+ ],
1506
+ },
1507
+ 'model': 'projects/test_project/locations/us-central1/publishers/google/models/test_model',
1508
+ }
1509
+ }
1510
+ result = await get_connect_message(
1511
+ mock_api_client(vertexai=vertexai),
1512
+ model='publishers/google/models/test_model')
1513
+
1514
+ assert result == expected_result
1515
+
1516
+
1517
+ @pytest.mark.parametrize('vertexai', [True, False])
1518
+ @pytest.mark.asyncio
1519
+ async def test_bidi_setup_generation_config_warning(
1520
+ vertexai
1521
+ ):
1522
+ with pytest.warns(
1523
+ DeprecationWarning,
1524
+ match='Setting `LiveConnectConfig.generation_config` is deprecated'
1525
+ ):
1526
+ result = await get_connect_message(
1527
+ mock_api_client(vertexai=vertexai),
1528
+ model='models/test_model',
1529
+ config={'generation_config': {'temperature': 0.7}})
1530
+
1531
+ assert result['setup']['generationConfig']['temperature'] == 0.7
1532
+
1533
+
1534
+ @pytest.mark.parametrize('vertexai', [True, False])
1535
+ @pytest.mark.asyncio
1536
+ async def test_bidi_setup_to_api_with_session_resumption(vertexai):
1537
+ config_dict = {
1538
+ 'session_resumption': {'handle': 'test_handle'},
1539
+ }
1540
+ config = types.LiveConnectConfig(**config_dict)
1541
+
1542
+ result = await get_connect_message(
1543
+ mock_api_client(vertexai=vertexai),
1544
+ model='test_model',
1545
+ config=config
1546
+ )
1547
+ expected_result = {
1548
+ 'setup': {
1549
+ 'sessionResumption': {
1550
+ 'handle': 'test_handle',
1551
+ },
1552
+ }
1553
+ }
1554
+ if vertexai:
1555
+ expected_result['setup']['generationConfig'] = {
1556
+ 'responseModalities': [
1557
+ 'AUDIO',
1558
+ ],
1559
+ }
1560
+ expected_result['setup']['model'] = 'projects/test_project/locations/us-central1/publishers/google/models/test_model'
1561
+ else:
1562
+ expected_result['setup']['model'] = 'models/test_model'
1563
+ assert result == expected_result
1564
+
1565
+
1566
+ @pytest.mark.parametrize('vertexai', [True, False])
1567
+ @pytest.mark.asyncio
1568
+ async def test_bidi_setup_to_api_with_transparent_session_resumption(vertexai):
1569
+ api_client = mock_api_client(vertexai=vertexai)
1570
+ config_dict = {
1571
+ 'session_resumption': {'handle': 'test_handle', 'transparent': True},
1572
+ }
1573
+ config = types.LiveConnectConfig(**config_dict)
1574
+
1575
+ with pytest_helper.exception_if_mldev(api_client, ValueError):
1576
+ result = await get_connect_message(
1577
+ mock_api_client(vertexai=vertexai),
1578
+ model='test_model',
1579
+ config=config
1580
+ )
1581
+
1582
+ expected_result = {
1583
+ 'setup': {
1584
+ 'sessionResumption': {
1585
+ 'handle': 'test_handle',
1586
+ 'transparent': True,
1587
+ },
1588
+ }
1589
+ }
1590
+ if vertexai:
1591
+ expected_result['setup']['generationConfig'] = {
1592
+ 'responseModalities': [
1593
+ 'AUDIO',
1594
+ ],
1595
+ }
1596
+ expected_result['setup']['model'] = 'projects/test_project/locations/us-central1/publishers/google/models/test_model'
1597
+ else:
1598
+ return
1599
+
1600
+ assert result == expected_result
1601
+
1602
+
1603
+ @pytest.mark.parametrize('vertexai', [True, False])
1604
+ def test_parse_client_message_str( mock_websocket, vertexai):
1605
+ session = live.AsyncSession(
1606
+ api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
1607
+ )
1608
+ result = session._parse_client_message('test')
1609
+ assert 'client_content' in result
1610
+ assert result == {
1611
+ 'client_content': {
1612
+ 'turn_complete': False,
1613
+ 'turns': [{'role': 'user', 'parts': [{'text': 'test'}]}],
1614
+ }
1615
+ }
1616
+ # _parse_client_message returns a TypedDict, so we should be able to
1617
+ # construct a LiveClientMessage from it
1618
+ assert types.LiveClientMessage(**result)
1619
+
1620
+
1621
+ @pytest.mark.parametrize('vertexai', [True, False])
1622
+ @pytest.mark.asyncio
1623
+ async def test_bidi_setup_to_api_with_thinking_config(vertexai):
1624
+ config_dict = {
1625
+ 'thinking_config': {
1626
+ 'include_thoughts': True,
1627
+ 'thinking_budget': 1024,
1628
+ }
1629
+ }
1630
+
1631
+ expected_gen_config = {
1632
+ 'thinkingConfig': {
1633
+ 'include_thoughts': True,
1634
+ 'thinking_budget': 1024,
1635
+ }
1636
+ }
1637
+
1638
+ if vertexai:
1639
+ expected_gen_config['responseModalities'] = ['AUDIO']
1640
+
1641
+ expected_result = {
1642
+ 'setup': {
1643
+ 'generationConfig': expected_gen_config,
1644
+ }
1645
+ }
1646
+
1647
+ if vertexai:
1648
+ expected_result['setup'][
1649
+ 'model'
1650
+ ] = 'projects/test_project/locations/us-central1/publishers/google/models/test_model'
1651
+ else:
1652
+ expected_result['setup']['model'] = 'models/test_model'
1653
+
1654
+ result = await get_connect_message(
1655
+ mock_api_client(vertexai=vertexai), model='test_model', config=config_dict
1656
+ )
1657
+ result = pytest_helper.camel_to_snake_all_keys(result)
1658
+ expected_result = pytest_helper.camel_to_snake_all_keys(expected_result)
1659
+ assert result == expected_result
1660
+
1661
+
1662
+ @pytest.mark.parametrize('vertexai', [True, False])
1663
+ def test_parse_client_message_blob( mock_websocket, vertexai):
1664
+ session = live.AsyncSession(
1665
+ api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
1666
+ )
1667
+ result = session._parse_client_message(
1668
+ types.Blob(data=bytes([0, 0, 0]), mime_type='text/plain')
1669
+ )
1670
+ assert 'realtime_input' in result
1671
+ assert result == {
1672
+ 'realtime_input': {
1673
+ 'media_chunks': [{'mime_type': 'text/plain', 'data': 'AAAA'}],
1674
+ }
1675
+ }
1676
+
1677
+
1678
+ @pytest.mark.parametrize('vertexai', [True, False])
1679
+ def test_parse_client_message_blob_dict(
1680
+ mock_websocket, vertexai
1681
+ ):
1682
+ session = live.AsyncSession(
1683
+ api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
1684
+ )
1685
+
1686
+ blob = types.Blob(data=bytes([0, 0, 0]), mime_type='text/plain')
1687
+ blob_dict = blob.model_dump()
1688
+ result = session._parse_client_message(blob_dict)
1689
+ assert 'realtime_input' in result
1690
+ assert result == {
1691
+ 'realtime_input': {
1692
+ 'media_chunks': [{'mime_type': 'text/plain', 'data': 'AAAA'}],
1693
+ }
1694
+ }
1695
+
1696
+
1697
+ @pytest.mark.parametrize('vertexai', [True, False])
1698
+ def test_parse_client_message_client_content(
1699
+ mock_websocket, vertexai
1700
+ ):
1701
+ session = live.AsyncSession(
1702
+ api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
1703
+ )
1704
+ result = session._parse_client_message(
1705
+ types.LiveClientContent(
1706
+ turn_complete=False,
1707
+ turns=[types.Content(parts=[types.Part(text='test')], role='user')],
1708
+ )
1709
+ )
1710
+ assert 'client_content' in result
1711
+ assert result == {
1712
+ 'client_content': {
1713
+ 'turn_complete': False,
1714
+ 'turns': [{'role': 'user', 'parts': [{'text': 'test'}]}],
1715
+ }
1716
+ }
1717
+
1718
+
1719
+ @pytest.mark.parametrize('vertexai', [True, False])
1720
+ def test_parse_client_message_client_content_blob(
1721
+ mock_websocket, vertexai
1722
+ ):
1723
+ session = live.AsyncSession(
1724
+ api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
1725
+ )
1726
+ client_content = types.LiveClientContent(
1727
+ turn_complete=False,
1728
+ turns=[
1729
+ types.Content(
1730
+ parts=[
1731
+ types.Part(
1732
+ inline_data=types.Blob(
1733
+ data=bytes([0, 0, 0]), mime_type='text/plain'
1734
+ )
1735
+ )
1736
+ ],
1737
+ role='user',
1738
+ )
1739
+ ],
1740
+ )
1741
+ result = session._parse_client_message(client_content)
1742
+ assert 'client_content' in result
1743
+ assert (
1744
+ type(
1745
+ result['client_content']['turns'][0]['parts'][0]['inline_data'][
1746
+ 'data'
1747
+ ]
1748
+ )
1749
+ == str
1750
+ )
1751
+ assert result == {
1752
+ 'client_content': {
1753
+ 'turn_complete': False,
1754
+ 'turns': [{
1755
+ 'role': 'user',
1756
+ 'parts': [
1757
+ {'inline_data': {'mime_type': 'text/plain', 'data': 'AAAA'}}
1758
+ ],
1759
+ }],
1760
+ }
1761
+ }
1762
+
1763
+
1764
+ @pytest.mark.parametrize('vertexai', [True, False])
1765
+ def test_parse_client_message_client_content_dict(
1766
+ mock_websocket, vertexai
1767
+ ):
1768
+ session = live.AsyncSession(
1769
+ api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
1770
+ )
1771
+ client_content = types.LiveClientContent(
1772
+ turn_complete=False,
1773
+ turns=[
1774
+ types.Content(
1775
+ parts=[
1776
+ types.Part(
1777
+ inline_data=types.Blob(
1778
+ data=bytes([0, 0, 0]), mime_type='text/plain'
1779
+ )
1780
+ )
1781
+ ],
1782
+ role='user',
1783
+ )
1784
+ ],
1785
+ )
1786
+ result = session._parse_client_message(
1787
+ client_content.model_dump(mode='json', exclude_none=True)
1788
+ )
1789
+ assert 'client_content' in result
1790
+ assert (
1791
+ type(
1792
+ result['client_content']['turns'][0]['parts'][0]['inline_data'][
1793
+ 'data'
1794
+ ]
1795
+ )
1796
+ == str
1797
+ )
1798
+ assert result == {
1799
+ 'client_content': {
1800
+ 'turn_complete': False,
1801
+ 'turns': [{
1802
+ 'role': 'user',
1803
+ 'parts': [
1804
+ {'inline_data': {'mime_type': 'text/plain', 'data': 'AAAA'}}
1805
+ ],
1806
+ }],
1807
+ }
1808
+ }
1809
+
1810
+
1811
+ @pytest.mark.parametrize('vertexai', [True, False])
1812
+ def test_parse_client_message_realtime_input(
1813
+ mock_websocket, vertexai
1814
+ ):
1815
+ session = live.AsyncSession(
1816
+ api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
1817
+ )
1818
+ input = types.LiveClientRealtimeInput(
1819
+ media_chunks=[types.Blob(data=bytes([0, 0, 0]), mime_type='text/plain')]
1820
+ )
1821
+ result = session._parse_client_message(input)
1822
+ assert 'realtime_input' in result
1823
+ assert result == {
1824
+ 'realtime_input': {
1825
+ 'media_chunks': [{'mime_type': 'text/plain', 'data': 'AAAA'}],
1826
+ }
1827
+ }
1828
+
1829
+
1830
+ @pytest.mark.parametrize('vertexai', [True, False])
1831
+ def test_parse_client_message_realtime_input_dict(
1832
+ mock_websocket, vertexai
1833
+ ):
1834
+ session = live.AsyncSession(
1835
+ api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
1836
+ )
1837
+ input = types.LiveClientRealtimeInput(
1838
+ media_chunks=[types.Blob(data=bytes([0, 0, 0]), mime_type='text/plain')]
1839
+ )
1840
+ result = session._parse_client_message(
1841
+ input.model_dump(mode='json', exclude_none=True)
1842
+ )
1843
+ assert 'realtime_input' in result
1844
+ assert result == {
1845
+ 'realtime_input': {
1846
+ 'media_chunks': [{'mime_type': 'text/plain', 'data': 'AAAA'}],
1847
+ }
1848
+ }
1849
+
1850
+
1851
+ @pytest.mark.parametrize('vertexai', [True, False])
1852
+ def test_parse_client_message_tool_response(
1853
+ mock_websocket, vertexai
1854
+ ):
1855
+ session = live.AsyncSession(
1856
+ api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
1857
+ )
1858
+ input = types.LiveClientToolResponse(
1859
+ function_responses=[
1860
+ types.FunctionResponse(
1861
+ id='test_id',
1862
+ name='test_name',
1863
+ response={'result': 'test_response'},
1864
+ )
1865
+ ]
1866
+ )
1867
+ result = session._parse_client_message(input)
1868
+ assert 'tool_response' in result
1869
+ assert result == {
1870
+ 'tool_response': {
1871
+ 'function_responses': [
1872
+ {
1873
+ 'id': 'test_id',
1874
+ 'name': 'test_name',
1875
+ 'response': {
1876
+ 'result': 'test_response',
1877
+ },
1878
+ },
1879
+ ],
1880
+ }
1881
+ }
1882
+
1883
+
1884
+ @pytest.mark.parametrize('vertexai', [True, False])
1885
+ def test_parse_client_message_function_response(
1886
+ mock_websocket, vertexai
1887
+ ):
1888
+ session = live.AsyncSession(
1889
+ api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
1890
+ )
1891
+ input = types.FunctionResponse(
1892
+ id='test_id',
1893
+ name='test_name',
1894
+ response={
1895
+ 'result': 'test_response',
1896
+ 'user_name': 'test_user_name',
1897
+ 'userEmail': 'test_user_email',
1898
+ },
1899
+ )
1900
+ result = session._parse_client_message(input)
1901
+ assert 'tool_response' in result
1902
+ assert result == {
1903
+ 'tool_response': {
1904
+ 'function_responses': [
1905
+ {
1906
+ 'id': 'test_id',
1907
+ 'name': 'test_name',
1908
+ 'response': {
1909
+ 'result': 'test_response',
1910
+ 'user_name': 'test_user_name',
1911
+ 'userEmail': 'test_user_email',
1912
+ },
1913
+ },
1914
+ ],
1915
+ }
1916
+ }
1917
+
1918
+
1919
+ @pytest.mark.parametrize('vertexai', [True, False])
1920
+ def test_parse_client_message_tool_response_dict_with_only_response(
1921
+ mock_websocket, vertexai
1922
+ ):
1923
+ session = live.AsyncSession(
1924
+ api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
1925
+ )
1926
+ input = {
1927
+ 'id': 'test_id',
1928
+ 'name': 'test_name',
1929
+ 'response': {
1930
+ 'result': 'test_response',
1931
+ }
1932
+ }
1933
+ result = session._parse_client_message(input)
1934
+ assert 'tool_response' in result
1935
+ assert result == {
1936
+ 'tool_response': {
1937
+ 'function_responses': [
1938
+ {
1939
+ 'id': 'test_id',
1940
+ 'name': 'test_name',
1941
+ 'response': {
1942
+ 'result': 'test_response',
1943
+ },
1944
+ },
1945
+ ],
1946
+ }
1947
+ }
1948
+
1949
+
1950
+ @pytest.mark.parametrize('vertexai', [True, False])
1951
+ def test_parse_client_message_realtime_tool_response(
1952
+ mock_websocket, vertexai
1953
+ ):
1954
+ session = live.AsyncSession(
1955
+ api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
1956
+ )
1957
+ input = types.LiveClientToolResponse(
1958
+ function_responses=[
1959
+ types.FunctionResponse(
1960
+ id='test_id',
1961
+ name='test_name',
1962
+ response={'result': 'test_response'},
1963
+ )
1964
+ ]
1965
+ )
1966
+
1967
+ result = session._parse_client_message(
1968
+ input.model_dump(mode='json', exclude_none=True)
1969
+ )
1970
+ assert 'tool_response' in result
1971
+ assert result == {
1972
+ 'tool_response': {
1973
+ 'function_responses': [
1974
+ {
1975
+ 'id': 'test_id',
1976
+ 'name': 'test_name',
1977
+ 'response': {
1978
+ 'result': 'test_response',
1979
+ },
1980
+ },
1981
+ ],
1982
+ }
1983
+ }
1984
+
1985
+
1986
+ @pytest.mark.asyncio
1987
+ async def test_connect_with_provided_credentials(mock_websocket):
1988
+ # custom oauth2 credentials
1989
+ credentials = Credentials(token='provided_fake_token')
1990
+ # mock api client
1991
+ client = mock_api_client(vertexai=True, credentials=credentials)
1992
+ capture = {}
1993
+
1994
+ @contextlib.asynccontextmanager
1995
+ async def mock_connect(uri, additional_headers=None, **kwargs):
1996
+ capture['headers'] = additional_headers
1997
+ yield mock_websocket
1998
+
1999
+ @patch.object(live, 'ws_connect', new=mock_connect)
2000
+ async def _test_connect():
2001
+ live_module = live.AsyncLive(client)
2002
+ async with live_module.connect(model='test-model'):
2003
+ pass
2004
+
2005
+ assert 'Authorization' in capture['headers']
2006
+ assert capture['headers']['Authorization'] == 'Bearer provided_fake_token'
2007
+
2008
+ await _test_connect()
2009
+
2010
+
2011
+ @pytest.mark.asyncio
2012
+ async def test_connect_with_default_credentials(mock_websocket):
2013
+ # mock api client
2014
+ client = mock_api_client(vertexai=True, credentials=None)
2015
+ # mock google auth cred
2016
+ mock_google_auth_default = Mock(return_value=(None, None))
2017
+ mock_creds = Mock(token='default_test_token')
2018
+ mock_google_auth_default.return_value = (mock_creds, None)
2019
+ capture = {}
2020
+
2021
+ @contextlib.asynccontextmanager
2022
+ async def mock_connect(uri, additional_headers=None, **kwargs):
2023
+ capture['headers'] = additional_headers
2024
+ yield mock_websocket
2025
+
2026
+ @patch('google.auth.default', new=mock_google_auth_default)
2027
+ @patch.object(live, 'ws_connect', new=mock_connect)
2028
+ async def _test_connect():
2029
+ live_module = live.AsyncLive(client)
2030
+ async with live_module.connect(model='test-model'):
2031
+ pass
2032
+
2033
+ assert 'Authorization' in capture['headers']
2034
+ assert capture['headers']['Authorization'] == 'Bearer default_test_token'
2035
+
2036
+ await _test_connect()
2037
+
2038
+
2039
+ @pytest.mark.asyncio
2040
+ async def test_connect_with_custom_base_url(mock_websocket):
2041
+ # mock api client
2042
+ client = gl_client.BaseApiClient(
2043
+ vertexai=True,
2044
+ http_options={
2045
+ 'base_url': 'https://custom-base-url.com',
2046
+ 'headers': {'Authorization': 'Bearer custom_test_token'},
2047
+ }
2048
+ )
2049
+ # No ADC credentials.
2050
+ capture = {}
2051
+
2052
+ @contextlib.asynccontextmanager
2053
+ async def mock_connect(uri, additional_headers=None, **kwargs):
2054
+ capture['uri'] = uri
2055
+ capture['headers'] = additional_headers
2056
+ yield mock_websocket
2057
+
2058
+ @patch.object(live, 'ws_connect', new=mock_connect)
2059
+ async def _test_connect():
2060
+ live_module = live.AsyncLive(client)
2061
+ async with live_module.connect(model='test-model'):
2062
+ pass
2063
+
2064
+ assert 'Authorization' in capture['headers']
2065
+ assert capture['headers']['Authorization'] == 'Bearer custom_test_token'
2066
+ assert capture['uri'] == 'https://custom-base-url.com'
2067
+
2068
+ await _test_connect()
2069
+
2070
+
2071
+ @pytest.mark.parametrize('vertexai', [False])
2072
+ @pytest.mark.asyncio
2073
+ async def test_bidi_setup_to_api_with_auth_tokens(mock_websocket, vertexai):
2074
+ api_client_mock = mock_api_client(vertexai=vertexai)
2075
+ api_client_mock.api_key = 'auth_tokens/TEST_AUTH_TOKEN'
2076
+ result = await get_connect_message(api_client_mock, model='test_model')
2077
+
2078
+ mock_ws = AsyncMock()
2079
+ mock_ws.send = AsyncMock()
2080
+ mock_ws.recv = AsyncMock(
2081
+ return_value=(
2082
+ b'{\n "setupComplete": {"sessionId": "test_session_id"}\n}\n'
2083
+ )
2084
+ )
2085
+ capture = {}
2086
+
2087
+ @contextlib.asynccontextmanager
2088
+ async def mock_connect(uri, additional_headers=None, **kwargs):
2089
+ capture['uri'] = uri
2090
+ capture['headers'] = additional_headers
2091
+ yield mock_ws
2092
+
2093
+ with patch.object(live, 'ws_connect', new=mock_connect):
2094
+ live_module = live.AsyncLive(api_client_mock)
2095
+ async with live_module.connect(
2096
+ model='test_model',
2097
+ ):
2098
+ pass
2099
+
2100
+ assert (
2101
+ 'Authorization' in capture['headers']
2102
+ ), 'Authorization key is missing from headers'
2103
+ assert (
2104
+ capture['headers']['Authorization'] == 'Token auth_tokens/TEST_AUTH_TOKEN'
2105
+ )
2106
+ assert 'BidiGenerateContentConstrained' in capture['uri']
2107
+
2108
+
2109
+ @pytest.mark.parametrize('vertexai', [False])
2110
+ @pytest.mark.asyncio
2111
+ async def test_bidi_setup_to_api_with_api_key(mock_websocket, vertexai):
2112
+ api_client_mock = mock_api_client(vertexai=vertexai)
2113
+ api_client_mock._http_options = types.HttpOptions.model_validate(
2114
+ {'headers': {'x-goog-api-key': 'TEST_API_KEY'}}
2115
+ )
2116
+ result = await get_connect_message(api_client_mock, model='test_model')
2117
+
2118
+ mock_ws = AsyncMock()
2119
+ mock_ws.send = AsyncMock()
2120
+ mock_ws.recv = AsyncMock(
2121
+ return_value=(
2122
+ b'{\n "setupComplete": {"sessionId": "test_session_id"}\n}\n'
2123
+ )
2124
+ )
2125
+ capture = {}
2126
+
2127
+ @contextlib.asynccontextmanager
2128
+ async def mock_connect(uri, additional_headers=None, **kwargs):
2129
+ capture['uri'] = uri
2130
+ capture['headers'] = additional_headers
2131
+ yield mock_ws
2132
+
2133
+ with patch.object(live, 'ws_connect', new=mock_connect):
2134
+ live_module = live.AsyncLive(api_client_mock)
2135
+ async with live_module.connect(
2136
+ model='test_model',
2137
+ ):
2138
+ pass
2139
+
2140
+ assert 'x-goog-api-key' in capture['headers'], "x-goog-api-key is missing from headers"
2141
+ assert capture['headers']['x-goog-api-key'] == 'TEST_API_KEY'
2142
+ assert 'BidiGenerateContent' in capture['uri']
2143
+