google-genai 1.53.0__py3-none-any.whl → 1.55.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google/genai/__init__.py +1 -0
- google/genai/_api_client.py +6 -6
- google/genai/_interactions/__init__.py +117 -0
- google/genai/_interactions/_base_client.py +2019 -0
- google/genai/_interactions/_client.py +511 -0
- google/genai/_interactions/_compat.py +234 -0
- google/genai/_interactions/_constants.py +29 -0
- google/genai/_interactions/_exceptions.py +122 -0
- google/genai/_interactions/_files.py +139 -0
- google/genai/_interactions/_models.py +873 -0
- google/genai/_interactions/_qs.py +165 -0
- google/genai/_interactions/_resource.py +58 -0
- google/genai/_interactions/_response.py +847 -0
- google/genai/_interactions/_streaming.py +354 -0
- google/genai/_interactions/_types.py +276 -0
- google/genai/_interactions/_utils/__init__.py +79 -0
- google/genai/_interactions/_utils/_compat.py +61 -0
- google/genai/_interactions/_utils/_datetime_parse.py +151 -0
- google/genai/_interactions/_utils/_logs.py +40 -0
- google/genai/_interactions/_utils/_proxy.py +80 -0
- google/genai/_interactions/_utils/_reflection.py +57 -0
- google/genai/_interactions/_utils/_resources_proxy.py +39 -0
- google/genai/_interactions/_utils/_streams.py +27 -0
- google/genai/_interactions/_utils/_sync.py +73 -0
- google/genai/_interactions/_utils/_transform.py +472 -0
- google/genai/_interactions/_utils/_typing.py +172 -0
- google/genai/_interactions/_utils/_utils.py +437 -0
- google/genai/_interactions/_version.py +18 -0
- google/genai/_interactions/resources/__init__.py +34 -0
- google/genai/_interactions/resources/interactions.py +1350 -0
- google/genai/_interactions/types/__init__.py +107 -0
- google/genai/_interactions/types/allowed_tools.py +33 -0
- google/genai/_interactions/types/allowed_tools_param.py +35 -0
- google/genai/_interactions/types/annotation.py +42 -0
- google/genai/_interactions/types/annotation_param.py +42 -0
- google/genai/_interactions/types/audio_content.py +38 -0
- google/genai/_interactions/types/audio_content_param.py +45 -0
- google/genai/_interactions/types/audio_mime_type.py +25 -0
- google/genai/_interactions/types/audio_mime_type_param.py +27 -0
- google/genai/_interactions/types/code_execution_call_arguments.py +33 -0
- google/genai/_interactions/types/code_execution_call_arguments_param.py +32 -0
- google/genai/_interactions/types/code_execution_call_content.py +37 -0
- google/genai/_interactions/types/code_execution_call_content_param.py +37 -0
- google/genai/_interactions/types/code_execution_result_content.py +42 -0
- google/genai/_interactions/types/code_execution_result_content_param.py +41 -0
- google/genai/_interactions/types/content_delta.py +358 -0
- google/genai/_interactions/types/content_start.py +79 -0
- google/genai/_interactions/types/content_stop.py +35 -0
- google/genai/_interactions/types/deep_research_agent_config.py +33 -0
- google/genai/_interactions/types/deep_research_agent_config_param.py +32 -0
- google/genai/_interactions/types/document_content.py +36 -0
- google/genai/_interactions/types/document_content_param.py +43 -0
- google/genai/_interactions/types/dynamic_agent_config.py +44 -0
- google/genai/_interactions/types/dynamic_agent_config_param.py +33 -0
- google/genai/_interactions/types/error_event.py +46 -0
- google/genai/_interactions/types/file_search_result_content.py +46 -0
- google/genai/_interactions/types/file_search_result_content_param.py +46 -0
- google/genai/_interactions/types/function.py +38 -0
- google/genai/_interactions/types/function_call_content.py +39 -0
- google/genai/_interactions/types/function_call_content_param.py +39 -0
- google/genai/_interactions/types/function_param.py +37 -0
- google/genai/_interactions/types/function_result_content.py +52 -0
- google/genai/_interactions/types/function_result_content_param.py +54 -0
- google/genai/_interactions/types/generation_config.py +57 -0
- google/genai/_interactions/types/generation_config_param.py +59 -0
- google/genai/_interactions/types/google_search_call_arguments.py +29 -0
- google/genai/_interactions/types/google_search_call_arguments_param.py +31 -0
- google/genai/_interactions/types/google_search_call_content.py +37 -0
- google/genai/_interactions/types/google_search_call_content_param.py +37 -0
- google/genai/_interactions/types/google_search_result.py +35 -0
- google/genai/_interactions/types/google_search_result_content.py +43 -0
- google/genai/_interactions/types/google_search_result_content_param.py +44 -0
- google/genai/_interactions/types/google_search_result_param.py +35 -0
- google/genai/_interactions/types/image_content.py +41 -0
- google/genai/_interactions/types/image_content_param.py +48 -0
- google/genai/_interactions/types/image_mime_type.py +23 -0
- google/genai/_interactions/types/image_mime_type_param.py +25 -0
- google/genai/_interactions/types/interaction.py +165 -0
- google/genai/_interactions/types/interaction_create_params.py +212 -0
- google/genai/_interactions/types/interaction_event.py +37 -0
- google/genai/_interactions/types/interaction_get_params.py +46 -0
- google/genai/_interactions/types/interaction_sse_event.py +32 -0
- google/genai/_interactions/types/interaction_status_update.py +37 -0
- google/genai/_interactions/types/mcp_server_tool_call_content.py +42 -0
- google/genai/_interactions/types/mcp_server_tool_call_content_param.py +42 -0
- google/genai/_interactions/types/mcp_server_tool_result_content.py +52 -0
- google/genai/_interactions/types/mcp_server_tool_result_content_param.py +54 -0
- google/genai/_interactions/types/model.py +36 -0
- google/genai/_interactions/types/model_param.py +38 -0
- google/genai/_interactions/types/speech_config.py +35 -0
- google/genai/_interactions/types/speech_config_param.py +35 -0
- google/genai/_interactions/types/text_content.py +37 -0
- google/genai/_interactions/types/text_content_param.py +38 -0
- google/genai/_interactions/types/thinking_level.py +22 -0
- google/genai/_interactions/types/thought_content.py +41 -0
- google/genai/_interactions/types/thought_content_param.py +47 -0
- google/genai/_interactions/types/tool.py +100 -0
- google/genai/_interactions/types/tool_choice.py +26 -0
- google/genai/_interactions/types/tool_choice_config.py +28 -0
- google/genai/_interactions/types/tool_choice_config_param.py +29 -0
- google/genai/_interactions/types/tool_choice_param.py +28 -0
- google/genai/_interactions/types/tool_choice_type.py +22 -0
- google/genai/_interactions/types/tool_param.py +97 -0
- google/genai/_interactions/types/turn.py +76 -0
- google/genai/_interactions/types/turn_param.py +73 -0
- google/genai/_interactions/types/url_context_call_arguments.py +29 -0
- google/genai/_interactions/types/url_context_call_arguments_param.py +31 -0
- google/genai/_interactions/types/url_context_call_content.py +37 -0
- google/genai/_interactions/types/url_context_call_content_param.py +37 -0
- google/genai/_interactions/types/url_context_result.py +33 -0
- google/genai/_interactions/types/url_context_result_content.py +43 -0
- google/genai/_interactions/types/url_context_result_content_param.py +44 -0
- google/genai/_interactions/types/url_context_result_param.py +32 -0
- google/genai/_interactions/types/usage.py +106 -0
- google/genai/_interactions/types/usage_param.py +106 -0
- google/genai/_interactions/types/video_content.py +41 -0
- google/genai/_interactions/types/video_content_param.py +48 -0
- google/genai/_interactions/types/video_mime_type.py +36 -0
- google/genai/_interactions/types/video_mime_type_param.py +38 -0
- google/genai/_live_converters.py +34 -3
- google/genai/_tokens_converters.py +5 -0
- google/genai/batches.py +62 -55
- google/genai/client.py +223 -0
- google/genai/errors.py +16 -1
- google/genai/file_search_stores.py +60 -60
- google/genai/files.py +56 -56
- google/genai/interactions.py +17 -0
- google/genai/live.py +4 -3
- google/genai/models.py +15 -3
- google/genai/tests/__init__.py +21 -0
- google/genai/tests/afc/__init__.py +21 -0
- google/genai/tests/afc/test_convert_if_exist_pydantic_model.py +309 -0
- google/genai/tests/afc/test_convert_number_values_for_function_call_args.py +63 -0
- google/genai/tests/afc/test_find_afc_incompatible_tool_indexes.py +240 -0
- google/genai/tests/afc/test_generate_content_stream_afc.py +530 -0
- google/genai/tests/afc/test_generate_content_stream_afc_thoughts.py +77 -0
- google/genai/tests/afc/test_get_function_map.py +176 -0
- google/genai/tests/afc/test_get_function_response_parts.py +277 -0
- google/genai/tests/afc/test_get_max_remote_calls_for_afc.py +130 -0
- google/genai/tests/afc/test_invoke_function_from_dict_args.py +241 -0
- google/genai/tests/afc/test_raise_error_for_afc_incompatible_config.py +159 -0
- google/genai/tests/afc/test_should_append_afc_history.py +53 -0
- google/genai/tests/afc/test_should_disable_afc.py +214 -0
- google/genai/tests/batches/__init__.py +17 -0
- google/genai/tests/batches/test_cancel.py +77 -0
- google/genai/tests/batches/test_create.py +78 -0
- google/genai/tests/batches/test_create_with_bigquery.py +113 -0
- google/genai/tests/batches/test_create_with_file.py +82 -0
- google/genai/tests/batches/test_create_with_gcs.py +125 -0
- google/genai/tests/batches/test_create_with_inlined_requests.py +255 -0
- google/genai/tests/batches/test_delete.py +86 -0
- google/genai/tests/batches/test_embedding.py +157 -0
- google/genai/tests/batches/test_get.py +78 -0
- google/genai/tests/batches/test_list.py +79 -0
- google/genai/tests/caches/__init__.py +17 -0
- google/genai/tests/caches/constants.py +29 -0
- google/genai/tests/caches/test_create.py +210 -0
- google/genai/tests/caches/test_create_custom_url.py +105 -0
- google/genai/tests/caches/test_delete.py +54 -0
- google/genai/tests/caches/test_delete_custom_url.py +52 -0
- google/genai/tests/caches/test_get.py +94 -0
- google/genai/tests/caches/test_get_custom_url.py +52 -0
- google/genai/tests/caches/test_list.py +68 -0
- google/genai/tests/caches/test_update.py +70 -0
- google/genai/tests/caches/test_update_custom_url.py +58 -0
- google/genai/tests/chats/__init__.py +1 -0
- google/genai/tests/chats/test_get_history.py +597 -0
- google/genai/tests/chats/test_send_message.py +844 -0
- google/genai/tests/chats/test_validate_response.py +90 -0
- google/genai/tests/client/__init__.py +17 -0
- google/genai/tests/client/test_async_stream.py +427 -0
- google/genai/tests/client/test_client_close.py +197 -0
- google/genai/tests/client/test_client_initialization.py +1687 -0
- google/genai/tests/client/test_client_requests.py +355 -0
- google/genai/tests/client/test_custom_client.py +77 -0
- google/genai/tests/client/test_http_options.py +178 -0
- google/genai/tests/client/test_replay_client_equality.py +168 -0
- google/genai/tests/client/test_retries.py +846 -0
- google/genai/tests/client/test_upload_errors.py +136 -0
- google/genai/tests/common/__init__.py +17 -0
- google/genai/tests/common/test_common.py +954 -0
- google/genai/tests/conftest.py +162 -0
- google/genai/tests/documents/__init__.py +17 -0
- google/genai/tests/documents/test_delete.py +51 -0
- google/genai/tests/documents/test_get.py +85 -0
- google/genai/tests/documents/test_list.py +72 -0
- google/genai/tests/errors/__init__.py +1 -0
- google/genai/tests/errors/test_api_error.py +417 -0
- google/genai/tests/file_search_stores/__init__.py +17 -0
- google/genai/tests/file_search_stores/test_create.py +66 -0
- google/genai/tests/file_search_stores/test_delete.py +64 -0
- google/genai/tests/file_search_stores/test_get.py +94 -0
- google/genai/tests/file_search_stores/test_import_file.py +112 -0
- google/genai/tests/file_search_stores/test_list.py +57 -0
- google/genai/tests/file_search_stores/test_upload_to_file_search_store.py +141 -0
- google/genai/tests/files/__init__.py +17 -0
- google/genai/tests/files/test_delete.py +46 -0
- google/genai/tests/files/test_download.py +85 -0
- google/genai/tests/files/test_get.py +46 -0
- google/genai/tests/files/test_list.py +72 -0
- google/genai/tests/files/test_upload.py +255 -0
- google/genai/tests/imports/test_no_optional_imports.py +28 -0
- google/genai/tests/interactions/__init__.py +0 -0
- google/genai/tests/interactions/test_integration.py +80 -0
- google/genai/tests/live/__init__.py +16 -0
- google/genai/tests/live/test_live.py +2177 -0
- google/genai/tests/live/test_live_music.py +362 -0
- google/genai/tests/live/test_live_response.py +163 -0
- google/genai/tests/live/test_send_client_content.py +147 -0
- google/genai/tests/live/test_send_realtime_input.py +268 -0
- google/genai/tests/live/test_send_tool_response.py +222 -0
- google/genai/tests/local_tokenizer/__init__.py +17 -0
- google/genai/tests/local_tokenizer/test_local_tokenizer.py +343 -0
- google/genai/tests/local_tokenizer/test_local_tokenizer_loader.py +235 -0
- google/genai/tests/mcp/__init__.py +17 -0
- google/genai/tests/mcp/test_has_mcp_tool_usage.py +89 -0
- google/genai/tests/mcp/test_mcp_to_gemini_tools.py +191 -0
- google/genai/tests/mcp/test_parse_config_for_mcp_sessions.py +201 -0
- google/genai/tests/mcp/test_parse_config_for_mcp_usage.py +130 -0
- google/genai/tests/mcp/test_set_mcp_usage_header.py +72 -0
- google/genai/tests/models/__init__.py +17 -0
- google/genai/tests/models/constants.py +8 -0
- google/genai/tests/models/test_compute_tokens.py +120 -0
- google/genai/tests/models/test_count_tokens.py +159 -0
- google/genai/tests/models/test_delete.py +107 -0
- google/genai/tests/models/test_edit_image.py +264 -0
- google/genai/tests/models/test_embed_content.py +94 -0
- google/genai/tests/models/test_function_call_streaming.py +442 -0
- google/genai/tests/models/test_generate_content.py +2502 -0
- google/genai/tests/models/test_generate_content_cached_content.py +132 -0
- google/genai/tests/models/test_generate_content_config_zero_value.py +103 -0
- google/genai/tests/models/test_generate_content_from_apikey.py +44 -0
- google/genai/tests/models/test_generate_content_http_options.py +40 -0
- google/genai/tests/models/test_generate_content_image_generation.py +143 -0
- google/genai/tests/models/test_generate_content_mcp.py +343 -0
- google/genai/tests/models/test_generate_content_media_resolution.py +97 -0
- google/genai/tests/models/test_generate_content_model.py +139 -0
- google/genai/tests/models/test_generate_content_part.py +821 -0
- google/genai/tests/models/test_generate_content_thought.py +76 -0
- google/genai/tests/models/test_generate_content_tools.py +1761 -0
- google/genai/tests/models/test_generate_images.py +191 -0
- google/genai/tests/models/test_generate_videos.py +759 -0
- google/genai/tests/models/test_get.py +104 -0
- google/genai/tests/models/test_list.py +233 -0
- google/genai/tests/models/test_recontext_image.py +189 -0
- google/genai/tests/models/test_segment_image.py +148 -0
- google/genai/tests/models/test_update.py +95 -0
- google/genai/tests/models/test_upscale_image.py +157 -0
- google/genai/tests/operations/__init__.py +17 -0
- google/genai/tests/operations/test_get.py +38 -0
- google/genai/tests/public_samples/__init__.py +17 -0
- google/genai/tests/public_samples/test_gemini_text_only.py +34 -0
- google/genai/tests/pytest_helper.py +229 -0
- google/genai/tests/shared/__init__.py +16 -0
- google/genai/tests/shared/batches/__init__.py +14 -0
- google/genai/tests/shared/batches/test_create_delete.py +57 -0
- google/genai/tests/shared/batches/test_create_get_cancel.py +56 -0
- google/genai/tests/shared/batches/test_list.py +40 -0
- google/genai/tests/shared/caches/__init__.py +14 -0
- google/genai/tests/shared/caches/test_create_get_delete.py +67 -0
- google/genai/tests/shared/caches/test_create_update_get.py +71 -0
- google/genai/tests/shared/caches/test_list.py +40 -0
- google/genai/tests/shared/chats/__init__.py +14 -0
- google/genai/tests/shared/chats/test_send_message.py +48 -0
- google/genai/tests/shared/chats/test_send_message_stream.py +50 -0
- google/genai/tests/shared/files/__init__.py +14 -0
- google/genai/tests/shared/files/test_list.py +41 -0
- google/genai/tests/shared/files/test_upload_get_delete.py +54 -0
- google/genai/tests/shared/models/__init__.py +14 -0
- google/genai/tests/shared/models/test_compute_tokens.py +41 -0
- google/genai/tests/shared/models/test_count_tokens.py +40 -0
- google/genai/tests/shared/models/test_edit_image.py +67 -0
- google/genai/tests/shared/models/test_embed.py +40 -0
- google/genai/tests/shared/models/test_generate_content.py +39 -0
- google/genai/tests/shared/models/test_generate_content_stream.py +54 -0
- google/genai/tests/shared/models/test_generate_images.py +40 -0
- google/genai/tests/shared/models/test_generate_videos.py +38 -0
- google/genai/tests/shared/models/test_list.py +37 -0
- google/genai/tests/shared/models/test_recontext_image.py +55 -0
- google/genai/tests/shared/models/test_segment_image.py +52 -0
- google/genai/tests/shared/models/test_upscale_image.py +52 -0
- google/genai/tests/shared/tunings/__init__.py +16 -0
- google/genai/tests/shared/tunings/test_create.py +46 -0
- google/genai/tests/shared/tunings/test_create_get_cancel.py +56 -0
- google/genai/tests/shared/tunings/test_list.py +39 -0
- google/genai/tests/tokens/__init__.py +16 -0
- google/genai/tests/tokens/test_create.py +154 -0
- google/genai/tests/transformers/__init__.py +17 -0
- google/genai/tests/transformers/test_blobs.py +71 -0
- google/genai/tests/transformers/test_bytes.py +15 -0
- google/genai/tests/transformers/test_duck_type.py +96 -0
- google/genai/tests/transformers/test_function_responses.py +72 -0
- google/genai/tests/transformers/test_schema.py +653 -0
- google/genai/tests/transformers/test_t_batch.py +286 -0
- google/genai/tests/transformers/test_t_content.py +160 -0
- google/genai/tests/transformers/test_t_contents.py +398 -0
- google/genai/tests/transformers/test_t_part.py +85 -0
- google/genai/tests/transformers/test_t_parts.py +87 -0
- google/genai/tests/transformers/test_t_tool.py +157 -0
- google/genai/tests/transformers/test_t_tools.py +195 -0
- google/genai/tests/tunings/__init__.py +16 -0
- google/genai/tests/tunings/test_cancel.py +39 -0
- google/genai/tests/tunings/test_end_to_end.py +106 -0
- google/genai/tests/tunings/test_get.py +67 -0
- google/genai/tests/tunings/test_list.py +75 -0
- google/genai/tests/tunings/test_tune.py +268 -0
- google/genai/tests/types/__init__.py +16 -0
- google/genai/tests/types/test_bytes_internal.py +271 -0
- google/genai/tests/types/test_bytes_type.py +152 -0
- google/genai/tests/types/test_future.py +101 -0
- google/genai/tests/types/test_optional_types.py +36 -0
- google/genai/tests/types/test_part_type.py +616 -0
- google/genai/tests/types/test_schema_from_json_schema.py +417 -0
- google/genai/tests/types/test_schema_json_schema.py +468 -0
- google/genai/tests/types/test_types.py +2903 -0
- google/genai/tunings.py +57 -57
- google/genai/types.py +229 -121
- google/genai/version.py +1 -1
- {google_genai-1.53.0.dist-info → google_genai-1.55.0.dist-info}/METADATA +4 -2
- google_genai-1.55.0.dist-info/RECORD +345 -0
- google_genai-1.53.0.dist-info/RECORD +0 -41
- {google_genai-1.53.0.dist-info → google_genai-1.55.0.dist-info}/WHEEL +0 -0
- {google_genai-1.53.0.dist-info → google_genai-1.55.0.dist-info}/licenses/LICENSE +0 -0
- {google_genai-1.53.0.dist-info → google_genai-1.55.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,2177 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
#
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
"""Tests for live.py."""
|
|
18
|
+
|
|
19
|
+
import contextlib
|
|
20
|
+
import json
|
|
21
|
+
import os
|
|
22
|
+
import ssl
|
|
23
|
+
import typing
|
|
24
|
+
from typing import Any, AsyncIterator
|
|
25
|
+
from unittest import mock
|
|
26
|
+
from unittest.mock import AsyncMock
|
|
27
|
+
from unittest.mock import Mock
|
|
28
|
+
from unittest.mock import patch
|
|
29
|
+
import warnings
|
|
30
|
+
|
|
31
|
+
import certifi
|
|
32
|
+
from google.oauth2.credentials import Credentials
|
|
33
|
+
import pytest
|
|
34
|
+
from websockets import client
|
|
35
|
+
|
|
36
|
+
from .. import pytest_helper
|
|
37
|
+
from ... import _api_client as api_client
|
|
38
|
+
from ... import _common
|
|
39
|
+
from ... import Client
|
|
40
|
+
from ... import client as gl_client
|
|
41
|
+
from ... import live
|
|
42
|
+
from ... import types
|
|
43
|
+
try:
|
|
44
|
+
import aiohttp
|
|
45
|
+
AIOHTTP_NOT_INSTALLED = False
|
|
46
|
+
except ImportError:
|
|
47
|
+
AIOHTTP_NOT_INSTALLED = True
|
|
48
|
+
aiohttp = mock.MagicMock()
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
if typing.TYPE_CHECKING:
|
|
52
|
+
from mcp import types as mcp_types
|
|
53
|
+
from mcp import ClientSession as McpClientSession
|
|
54
|
+
else:
|
|
55
|
+
mcp_types: typing.Type = Any
|
|
56
|
+
McpClientSession: typing.Type = Any
|
|
57
|
+
try:
|
|
58
|
+
from mcp import types as mcp_types
|
|
59
|
+
from mcp import ClientSession as McpClientSession
|
|
60
|
+
except ImportError:
|
|
61
|
+
mcp_types = None
|
|
62
|
+
McpClientSession = None
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
requires_aiohttp = pytest.mark.skipif(
|
|
66
|
+
AIOHTTP_NOT_INSTALLED, reason="aiohttp is not installed, skipping test."
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
function_declarations = [{
|
|
70
|
+
'name': 'get_current_weather',
|
|
71
|
+
'description': 'Get the current weather in a city',
|
|
72
|
+
'parameters': {
|
|
73
|
+
'type': 'OBJECT',
|
|
74
|
+
'properties': {
|
|
75
|
+
'location': {
|
|
76
|
+
'type': 'STRING',
|
|
77
|
+
'description': 'The location to get the weather for',
|
|
78
|
+
},
|
|
79
|
+
'unit': {
|
|
80
|
+
'type': 'STRING',
|
|
81
|
+
'enum': ['C', 'F'],
|
|
82
|
+
},
|
|
83
|
+
},
|
|
84
|
+
},
|
|
85
|
+
}]
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def get_current_weather(location: str, unit: str):
|
|
89
|
+
"""Get the current weather in a city."""
|
|
90
|
+
return 15 if unit == 'C' else 59
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def mock_api_client(vertexai=False, credentials=None, http_options=None):
|
|
94
|
+
api_client = mock.MagicMock(spec=gl_client.BaseApiClient)
|
|
95
|
+
if not vertexai:
|
|
96
|
+
api_client.api_key = 'TEST_API_KEY'
|
|
97
|
+
api_client.location = None
|
|
98
|
+
api_client.project = None
|
|
99
|
+
api_client.custom_base_url = None
|
|
100
|
+
else:
|
|
101
|
+
api_client.api_key = None
|
|
102
|
+
if http_options:
|
|
103
|
+
http_options = (
|
|
104
|
+
types.HttpOptions(**http_options)
|
|
105
|
+
if isinstance(http_options, dict)
|
|
106
|
+
else http_options
|
|
107
|
+
)
|
|
108
|
+
api_client.custom_base_url = http_options.base_url
|
|
109
|
+
api_client.location = None
|
|
110
|
+
api_client.project = None
|
|
111
|
+
else:
|
|
112
|
+
api_client.location = 'us-central1'
|
|
113
|
+
api_client.project = 'test_project'
|
|
114
|
+
api_client.custom_base_url = None
|
|
115
|
+
|
|
116
|
+
api_client._host = lambda: 'test_host'
|
|
117
|
+
api_client._credentials = credentials
|
|
118
|
+
api_client._http_options = types.HttpOptions.model_validate(
|
|
119
|
+
{'headers': {}}
|
|
120
|
+
) # Ensure headers exist
|
|
121
|
+
api_client.vertexai = vertexai
|
|
122
|
+
api_client._api_client = api_client
|
|
123
|
+
ctx = ssl.create_default_context(
|
|
124
|
+
cafile=os.environ.get("SSL_CERT_FILE", certifi.where()),
|
|
125
|
+
capath=os.environ.get("SSL_CERT_DIR"),
|
|
126
|
+
)
|
|
127
|
+
api_client._websocket_ssl_ctx = {'ssl': ctx}
|
|
128
|
+
return api_client
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
@pytest.fixture
|
|
132
|
+
def mock_websocket():
|
|
133
|
+
websocket = AsyncMock(spec=client.ClientConnection)
|
|
134
|
+
websocket.send = AsyncMock()
|
|
135
|
+
websocket.recv = AsyncMock(
|
|
136
|
+
return_value='{"serverContent": {"turnComplete": true}}'
|
|
137
|
+
) # Default response
|
|
138
|
+
websocket.close = AsyncMock()
|
|
139
|
+
return websocket
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
async def get_connect_message(api_client, model, config=None):
|
|
143
|
+
if config is None:
|
|
144
|
+
config = {}
|
|
145
|
+
mock_ws = AsyncMock()
|
|
146
|
+
mock_ws.send = AsyncMock()
|
|
147
|
+
mock_ws.recv = AsyncMock(
|
|
148
|
+
return_value=(
|
|
149
|
+
b'{\n "setupComplete": {"sessionId": "test_session_id"}\n}\n'
|
|
150
|
+
)
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
mock_google_auth_default = Mock(return_value=(None, None))
|
|
154
|
+
mock_creds = Mock(token='test_token')
|
|
155
|
+
mock_google_auth_default.return_value = (mock_creds, None)
|
|
156
|
+
|
|
157
|
+
@contextlib.asynccontextmanager
|
|
158
|
+
async def mock_connect(uri, additional_headers=None, **kwargs):
|
|
159
|
+
yield mock_ws
|
|
160
|
+
|
|
161
|
+
@patch('google.auth.default', new=mock_google_auth_default)
|
|
162
|
+
@patch.object(live, 'ws_connect', new=mock_connect)
|
|
163
|
+
async def _test_connect():
|
|
164
|
+
live_module = live.AsyncLive(api_client)
|
|
165
|
+
async with live_module.connect(
|
|
166
|
+
model=model,
|
|
167
|
+
config=config,
|
|
168
|
+
):
|
|
169
|
+
pass
|
|
170
|
+
|
|
171
|
+
mock_ws.send.assert_called_once()
|
|
172
|
+
return json.loads(mock_ws.send.call_args[0][0])
|
|
173
|
+
|
|
174
|
+
return await _test_connect()
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
async def _async_iterator_to_list(async_iter):
|
|
178
|
+
return [value async for value in async_iter]
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
def test_mldev_from_env(monkeypatch):
|
|
182
|
+
api_key = 'google_api_key'
|
|
183
|
+
monkeypatch.setenv('GOOGLE_API_KEY', api_key)
|
|
184
|
+
|
|
185
|
+
client = Client()
|
|
186
|
+
|
|
187
|
+
assert not client.aio.live._api_client.vertexai
|
|
188
|
+
assert client.aio.live._api_client.api_key == api_key
|
|
189
|
+
assert isinstance(client.aio.live._api_client, api_client.BaseApiClient)
|
|
190
|
+
assert client.aio.live._api_client._http_options.headers['x-goog-api-key'] == api_key
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
@requires_aiohttp
|
|
194
|
+
def test_vertex_from_env(monkeypatch):
|
|
195
|
+
project_id = 'fake_project_id'
|
|
196
|
+
location = 'fake-location'
|
|
197
|
+
monkeypatch.setenv('GOOGLE_GENAI_USE_VERTEXAI', 'true')
|
|
198
|
+
monkeypatch.setenv('GOOGLE_CLOUD_PROJECT', project_id)
|
|
199
|
+
monkeypatch.setenv('GOOGLE_CLOUD_LOCATION', location)
|
|
200
|
+
|
|
201
|
+
client = Client()
|
|
202
|
+
|
|
203
|
+
assert client.aio.live._api_client.custom_base_url is None
|
|
204
|
+
assert client.aio.live._api_client.vertexai
|
|
205
|
+
assert client.aio.live._api_client.project == project_id
|
|
206
|
+
assert isinstance(client.aio.live._api_client, api_client.BaseApiClient)
|
|
207
|
+
assert 'x-goog-api-key' not in client.aio.live._api_client._http_options.headers
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
def test_vertex_api_key_from_env(monkeypatch):
|
|
211
|
+
api_key = 'google_api_key'
|
|
212
|
+
monkeypatch.setenv('GOOGLE_GENAI_USE_VERTEXAI', 'true')
|
|
213
|
+
monkeypatch.setenv('GOOGLE_API_KEY', api_key)
|
|
214
|
+
|
|
215
|
+
# Due to proj/location taking precedence, need to clear proj/location env
|
|
216
|
+
# variables. Tests in client/test_client_initialization.py provide
|
|
217
|
+
# comprehensive coverage for proj/location and api key precedence.
|
|
218
|
+
monkeypatch.setenv("GOOGLE_CLOUD_LOCATION", "")
|
|
219
|
+
monkeypatch.setenv("GOOGLE_CLOUD_PROJECT", "")
|
|
220
|
+
|
|
221
|
+
client = Client()
|
|
222
|
+
|
|
223
|
+
assert client.aio.live._api_client.vertexai
|
|
224
|
+
assert client.aio.live._api_client.api_key == api_key
|
|
225
|
+
assert isinstance(client.aio.live._api_client, api_client.BaseApiClient)
|
|
226
|
+
assert client.aio.live._api_client._http_options.headers['x-goog-api-key'] == api_key
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
def test_websocket_base_url():
|
|
230
|
+
base_url = 'https://test.com'
|
|
231
|
+
api_client = gl_client.BaseApiClient(
|
|
232
|
+
api_key='google_api_key',
|
|
233
|
+
http_options={'base_url': base_url},
|
|
234
|
+
)
|
|
235
|
+
assert api_client._websocket_base_url() == 'wss://test.com'
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
def test_websocket_base_url_no_auth_with_custom_base_url():
|
|
239
|
+
base_url = 'https://test-api-gateway-proxy.com'
|
|
240
|
+
api_client = gl_client.BaseApiClient(
|
|
241
|
+
vertexai=True,
|
|
242
|
+
http_options={
|
|
243
|
+
'base_url': base_url,
|
|
244
|
+
'headers': {'Authorization': 'Bearer test_token'},
|
|
245
|
+
},
|
|
246
|
+
)
|
|
247
|
+
# Note that our test environment does have project/location set. So we
|
|
248
|
+
# need to explicitly set them to None here.
|
|
249
|
+
api_client.project = None
|
|
250
|
+
api_client.location = None
|
|
251
|
+
|
|
252
|
+
# Fully pass the custom base url if no API key or project/location.
|
|
253
|
+
assert api_client._websocket_base_url() == base_url
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
@pytest.mark.parametrize('vertexai', [True, False])
|
|
257
|
+
@pytest.mark.asyncio
|
|
258
|
+
async def test_async_session_send_text(mock_websocket, vertexai):
|
|
259
|
+
session = live.AsyncSession(
|
|
260
|
+
api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
|
|
261
|
+
)
|
|
262
|
+
await session.send(input='test')
|
|
263
|
+
mock_websocket.send.assert_called_once()
|
|
264
|
+
sent_data = json.loads(mock_websocket.send.call_args[0][0])
|
|
265
|
+
assert 'client_content' in sent_data
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
@pytest.mark.parametrize('vertexai', [True, False])
|
|
269
|
+
@pytest.mark.asyncio
|
|
270
|
+
async def test_async_session_send_content_dict(
|
|
271
|
+
mock_websocket, vertexai
|
|
272
|
+
):
|
|
273
|
+
session = live.AsyncSession(
|
|
274
|
+
api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
|
|
275
|
+
)
|
|
276
|
+
client_content = {
|
|
277
|
+
'content': [{'parts': [{'text': 'test'}]}],
|
|
278
|
+
'turn_complete': True,
|
|
279
|
+
}
|
|
280
|
+
await session.send(input=client_content)
|
|
281
|
+
mock_websocket.send.assert_called_once()
|
|
282
|
+
sent_data = json.loads(mock_websocket.send.call_args[0][0])
|
|
283
|
+
assert 'client_content' in sent_data
|
|
284
|
+
|
|
285
|
+
|
|
286
|
+
@pytest.mark.parametrize('vertexai', [True, False])
|
|
287
|
+
@pytest.mark.asyncio
|
|
288
|
+
async def test_async_session_send_content(
|
|
289
|
+
mock_websocket, vertexai
|
|
290
|
+
):
|
|
291
|
+
session = live.AsyncSession(
|
|
292
|
+
api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
|
|
293
|
+
)
|
|
294
|
+
client_content = types.LiveClientContent(
|
|
295
|
+
turns=[types.Content(parts=[types.Part(text='test')])], turn_complete=True
|
|
296
|
+
)
|
|
297
|
+
await session.send(input=client_content)
|
|
298
|
+
mock_websocket.send.assert_called_once()
|
|
299
|
+
sent_data = json.loads(mock_websocket.send.call_args[0][0])
|
|
300
|
+
assert 'client_content' in sent_data
|
|
301
|
+
|
|
302
|
+
|
|
303
|
+
@pytest.mark.parametrize('vertexai', [True, False])
|
|
304
|
+
@pytest.mark.asyncio
|
|
305
|
+
async def test_async_session_send_bytes(
|
|
306
|
+
mock_websocket, vertexai
|
|
307
|
+
):
|
|
308
|
+
session = live.AsyncSession(
|
|
309
|
+
api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
|
|
310
|
+
)
|
|
311
|
+
realtime_input = {'data': b'000000', 'mime_type': 'audio/pcm'}
|
|
312
|
+
|
|
313
|
+
await session.send(input=realtime_input)
|
|
314
|
+
mock_websocket.send.assert_called_once()
|
|
315
|
+
sent_data = json.loads(mock_websocket.send.call_args[0][0])
|
|
316
|
+
assert 'realtime_input' in sent_data
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
@pytest.mark.parametrize('vertexai', [True, False])
|
|
320
|
+
@pytest.mark.asyncio
|
|
321
|
+
async def test_async_session_send_blob(
|
|
322
|
+
mock_websocket, vertexai
|
|
323
|
+
):
|
|
324
|
+
session = live.AsyncSession(
|
|
325
|
+
api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
|
|
326
|
+
)
|
|
327
|
+
realtime_input = types.Blob(data=b'000000', mime_type='audio/pcm')
|
|
328
|
+
|
|
329
|
+
await session.send(input=realtime_input)
|
|
330
|
+
mock_websocket.send.assert_called_once()
|
|
331
|
+
sent_data = json.loads(mock_websocket.send.call_args[0][0])
|
|
332
|
+
assert 'realtime_input' in sent_data
|
|
333
|
+
|
|
334
|
+
|
|
335
|
+
@pytest.mark.parametrize('vertexai', [True, False])
|
|
336
|
+
@pytest.mark.asyncio
|
|
337
|
+
async def test_async_session_send_realtime_input(
|
|
338
|
+
mock_websocket, vertexai
|
|
339
|
+
):
|
|
340
|
+
session = live.AsyncSession(
|
|
341
|
+
api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
|
|
342
|
+
)
|
|
343
|
+
realtime_input = types.LiveClientRealtimeInput(
|
|
344
|
+
media_chunks=[types.Blob(data='MDAwMDAw', mime_type='audio/pcm')]
|
|
345
|
+
)
|
|
346
|
+
await session.send(input=realtime_input)
|
|
347
|
+
mock_websocket.send.assert_called_once()
|
|
348
|
+
sent_data = json.loads(mock_websocket.send.call_args[0][0])
|
|
349
|
+
assert 'realtime_input' in sent_data
|
|
350
|
+
|
|
351
|
+
|
|
352
|
+
@pytest.mark.parametrize('vertexai', [True, False])
|
|
353
|
+
@pytest.mark.asyncio
|
|
354
|
+
async def test_async_session_send_tool_response(
|
|
355
|
+
mock_websocket, vertexai
|
|
356
|
+
):
|
|
357
|
+
session = live.AsyncSession(
|
|
358
|
+
api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
|
|
359
|
+
)
|
|
360
|
+
|
|
361
|
+
if vertexai:
|
|
362
|
+
tool_response = types.LiveClientToolResponse(
|
|
363
|
+
function_responses=[
|
|
364
|
+
types.FunctionResponse(
|
|
365
|
+
name='get_current_weather',
|
|
366
|
+
response={'temperature': 14.5, 'unit': 'C'},
|
|
367
|
+
)
|
|
368
|
+
]
|
|
369
|
+
)
|
|
370
|
+
else:
|
|
371
|
+
tool_response = types.LiveClientToolResponse(
|
|
372
|
+
function_responses=[
|
|
373
|
+
types.FunctionResponse(
|
|
374
|
+
name='get_current_weather',
|
|
375
|
+
response={'temperature': 14.5, 'unit': 'C'},
|
|
376
|
+
id='some-id',
|
|
377
|
+
)
|
|
378
|
+
]
|
|
379
|
+
)
|
|
380
|
+
await session.send(input=tool_response)
|
|
381
|
+
mock_websocket.send.assert_called_once()
|
|
382
|
+
sent_data = json.loads(mock_websocket.send.call_args[0][0])
|
|
383
|
+
assert 'tool_response' in sent_data
|
|
384
|
+
|
|
385
|
+
|
|
386
|
+
@pytest.mark.parametrize('vertexai', [True, False])
|
|
387
|
+
@pytest.mark.asyncio
|
|
388
|
+
async def test_async_session_send_input_none(
|
|
389
|
+
mock_websocket, vertexai
|
|
390
|
+
):
|
|
391
|
+
session = live.AsyncSession(
|
|
392
|
+
api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
|
|
393
|
+
)
|
|
394
|
+
await session.send(input=None)
|
|
395
|
+
mock_websocket.send.assert_called_once()
|
|
396
|
+
sent_data = json.loads(mock_websocket.send.call_args[0][0])
|
|
397
|
+
assert 'client_content' in sent_data
|
|
398
|
+
assert sent_data['client_content']['turn_complete']
|
|
399
|
+
|
|
400
|
+
|
|
401
|
+
@pytest.mark.parametrize('vertexai', [True, False])
|
|
402
|
+
@pytest.mark.asyncio
|
|
403
|
+
async def test_async_session_send_error(
|
|
404
|
+
mock_websocket, vertexai
|
|
405
|
+
):
|
|
406
|
+
session = live.AsyncSession(
|
|
407
|
+
api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
|
|
408
|
+
)
|
|
409
|
+
with pytest.raises(ValueError):
|
|
410
|
+
await session.send(input=[{'invalid_key': 'invalid_value'}])
|
|
411
|
+
|
|
412
|
+
with pytest.raises(ValueError):
|
|
413
|
+
await session.send(input={'invalid_key': 'invalid_value'})
|
|
414
|
+
|
|
415
|
+
|
|
416
|
+
@pytest.mark.parametrize('vertexai', [True, False])
|
|
417
|
+
@pytest.mark.asyncio
|
|
418
|
+
async def test_async_session_receive( mock_websocket, vertexai):
|
|
419
|
+
session = live.AsyncSession(
|
|
420
|
+
api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
|
|
421
|
+
)
|
|
422
|
+
responses = session.receive()
|
|
423
|
+
responses = await _async_iterator_to_list(responses)
|
|
424
|
+
assert isinstance(responses[0], types.LiveServerMessage)
|
|
425
|
+
|
|
426
|
+
|
|
427
|
+
@pytest.mark.parametrize('vertexai', [True, False])
|
|
428
|
+
@pytest.mark.asyncio
|
|
429
|
+
async def test_async_session_receive_error(
|
|
430
|
+
mock_websocket, vertexai
|
|
431
|
+
):
|
|
432
|
+
mock_websocket.recv = AsyncMock(return_value='invalid json')
|
|
433
|
+
session = live.AsyncSession(
|
|
434
|
+
api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
|
|
435
|
+
)
|
|
436
|
+
with pytest.raises(ValueError):
|
|
437
|
+
await session.receive().__anext__()
|
|
438
|
+
|
|
439
|
+
|
|
440
|
+
@pytest.mark.parametrize('vertexai', [True, False])
|
|
441
|
+
@pytest.mark.asyncio
|
|
442
|
+
async def test_async_session_receive_text(
|
|
443
|
+
mock_websocket, vertexai
|
|
444
|
+
):
|
|
445
|
+
mock_websocket.recv = AsyncMock(
|
|
446
|
+
side_effect=[
|
|
447
|
+
'{"serverContent": {"modelTurn": {"parts":[{"text": "test"}]}}}',
|
|
448
|
+
'{"serverContent": {"turnComplete": true}}',
|
|
449
|
+
]
|
|
450
|
+
)
|
|
451
|
+
session = live.AsyncSession(
|
|
452
|
+
api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
|
|
453
|
+
)
|
|
454
|
+
messages = session.receive()
|
|
455
|
+
messages = await _async_iterator_to_list(messages)
|
|
456
|
+
assert isinstance(messages[0], types.LiveServerMessage)
|
|
457
|
+
assert messages[0].server_content.model_turn.parts[0].text == 'test'
|
|
458
|
+
assert messages[1].server_content.turn_complete == True
|
|
459
|
+
|
|
460
|
+
|
|
461
|
+
@pytest.mark.parametrize('vertexai', [True, False])
|
|
462
|
+
@pytest.mark.asyncio
|
|
463
|
+
async def test_async_session_receive_audio(
|
|
464
|
+
mock_websocket, vertexai
|
|
465
|
+
):
|
|
466
|
+
mock_websocket.recv = AsyncMock(
|
|
467
|
+
side_effect=[
|
|
468
|
+
(
|
|
469
|
+
'{"serverContent": {"modelTurn": {"parts":[{"inlineData":'
|
|
470
|
+
' {"data": "MDAwMDAw", "mimeType": "audio/pcm" }}]}}}'
|
|
471
|
+
),
|
|
472
|
+
'{"serverContent": {"turnComplete": true}}',
|
|
473
|
+
]
|
|
474
|
+
)
|
|
475
|
+
session = live.AsyncSession(
|
|
476
|
+
api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
|
|
477
|
+
)
|
|
478
|
+
messages = session.receive()
|
|
479
|
+
messages = await _async_iterator_to_list(messages)
|
|
480
|
+
assert isinstance(messages[0], types.LiveServerMessage)
|
|
481
|
+
assert (
|
|
482
|
+
messages[0].server_content.model_turn.parts[0].inline_data.mime_type
|
|
483
|
+
== 'audio/pcm'
|
|
484
|
+
)
|
|
485
|
+
assert (
|
|
486
|
+
messages[0].server_content.model_turn.parts[0].inline_data.data
|
|
487
|
+
== b'000000'
|
|
488
|
+
)
|
|
489
|
+
|
|
490
|
+
with pytest.raises(RuntimeError):
|
|
491
|
+
await _async_iterator_to_list(session.receive())
|
|
492
|
+
|
|
493
|
+
|
|
494
|
+
@pytest.mark.parametrize('vertexai', [True, False])
|
|
495
|
+
@pytest.mark.asyncio
|
|
496
|
+
async def test_async_session_receive_tool_call(
|
|
497
|
+
mock_websocket, vertexai
|
|
498
|
+
):
|
|
499
|
+
mock_websocket.recv = AsyncMock(
|
|
500
|
+
side_effect=[
|
|
501
|
+
(
|
|
502
|
+
'{"toolCall": {"functionCalls": [{"name":'
|
|
503
|
+
' "get_current_weather", "args": {"location": "San Francisco",'
|
|
504
|
+
' "unit": "C"}}]}}'
|
|
505
|
+
),
|
|
506
|
+
'{"serverContent": {"turnComplete": true}}',
|
|
507
|
+
]
|
|
508
|
+
)
|
|
509
|
+
session = live.AsyncSession(
|
|
510
|
+
api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
|
|
511
|
+
)
|
|
512
|
+
messages = session.receive()
|
|
513
|
+
messages = await _async_iterator_to_list(messages)
|
|
514
|
+
assert isinstance(messages[0], types.LiveServerMessage)
|
|
515
|
+
assert messages[0].tool_call.function_calls[0].name == 'get_current_weather'
|
|
516
|
+
assert (
|
|
517
|
+
messages[0].tool_call.function_calls[0].args['location']
|
|
518
|
+
== 'San Francisco'
|
|
519
|
+
)
|
|
520
|
+
assert messages[0].tool_call.function_calls[0].args['unit'] == 'C'
|
|
521
|
+
|
|
522
|
+
with pytest.raises(RuntimeError):
|
|
523
|
+
await _async_iterator_to_list(session.receive())
|
|
524
|
+
|
|
525
|
+
|
|
526
|
+
@pytest.mark.parametrize('vertexai', [True, False])
|
|
527
|
+
@pytest.mark.asyncio
|
|
528
|
+
async def test_async_session_receive_transcription(
|
|
529
|
+
mock_websocket, vertexai
|
|
530
|
+
):
|
|
531
|
+
mock_websocket.recv = AsyncMock(
|
|
532
|
+
side_effect=[
|
|
533
|
+
'{"serverContent": {"inputTranscription": {"text": "test_input", "finished": true}}}',
|
|
534
|
+
'{"serverContent": {"outputTranscription": {"text": "test_output", "finished": false}}}',
|
|
535
|
+
'{"serverContent": {"turnComplete": true}}',
|
|
536
|
+
]
|
|
537
|
+
)
|
|
538
|
+
session = live.AsyncSession(
|
|
539
|
+
api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
|
|
540
|
+
)
|
|
541
|
+
messages = session.receive()
|
|
542
|
+
messages = await _async_iterator_to_list(messages)
|
|
543
|
+
assert isinstance(messages[0], types.LiveServerMessage)
|
|
544
|
+
assert messages[0].server_content.input_transcription.text == 'test_input'
|
|
545
|
+
assert messages[0].server_content.input_transcription.finished == True
|
|
546
|
+
|
|
547
|
+
assert isinstance(messages[1], types.LiveServerMessage)
|
|
548
|
+
assert messages[1].server_content.output_transcription.text == 'test_output'
|
|
549
|
+
assert messages[1].server_content.output_transcription.finished == False
|
|
550
|
+
|
|
551
|
+
|
|
552
|
+
@pytest.mark.parametrize('vertexai', [True, False])
|
|
553
|
+
@pytest.mark.asyncio
|
|
554
|
+
async def test_async_go_away(
|
|
555
|
+
mock_websocket, vertexai
|
|
556
|
+
):
|
|
557
|
+
mock_websocket.recv = AsyncMock(
|
|
558
|
+
side_effect=[
|
|
559
|
+
'{"goAway": {"timeLeft": "10s"}}',
|
|
560
|
+
'{"serverContent": {"turnComplete": true}}',
|
|
561
|
+
]
|
|
562
|
+
)
|
|
563
|
+
expected_result = types.LiveServerMessage(
|
|
564
|
+
go_away=types.LiveServerGoAway(time_left='10s'),
|
|
565
|
+
)
|
|
566
|
+
session = live.AsyncSession(
|
|
567
|
+
api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
|
|
568
|
+
)
|
|
569
|
+
messages = session.receive()
|
|
570
|
+
messages = await _async_iterator_to_list(messages)
|
|
571
|
+
message = messages[0]
|
|
572
|
+
|
|
573
|
+
assert isinstance(message, types.LiveServerMessage)
|
|
574
|
+
assert message == expected_result
|
|
575
|
+
|
|
576
|
+
|
|
577
|
+
@pytest.mark.parametrize('vertexai', [True, False])
|
|
578
|
+
@pytest.mark.asyncio
|
|
579
|
+
async def test_async_session_resumption_update(
|
|
580
|
+
mock_websocket, vertexai
|
|
581
|
+
):
|
|
582
|
+
mock_websocket.recv = AsyncMock(
|
|
583
|
+
side_effect=[
|
|
584
|
+
"""{
|
|
585
|
+
"sessionResumptionUpdate": {
|
|
586
|
+
"newHandle": "test_handle",
|
|
587
|
+
"resumable": "true",
|
|
588
|
+
"lastConsumedClientMessageIndex": "123456789"
|
|
589
|
+
}
|
|
590
|
+
}""",
|
|
591
|
+
'{"serverContent": {"turnComplete": true}}',
|
|
592
|
+
]
|
|
593
|
+
)
|
|
594
|
+
|
|
595
|
+
expected_result = types.LiveServerMessage(
|
|
596
|
+
session_resumption_update=types.LiveServerSessionResumptionUpdate(
|
|
597
|
+
new_handle='test_handle',
|
|
598
|
+
resumable=True,
|
|
599
|
+
last_consumed_client_message_index=123456789
|
|
600
|
+
),
|
|
601
|
+
)
|
|
602
|
+
|
|
603
|
+
session = live.AsyncSession(
|
|
604
|
+
api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
|
|
605
|
+
)
|
|
606
|
+
messages = session.receive()
|
|
607
|
+
messages = await _async_iterator_to_list(messages)
|
|
608
|
+
message = messages[0]
|
|
609
|
+
|
|
610
|
+
assert isinstance(message, types.LiveServerMessage)
|
|
611
|
+
assert message == expected_result
|
|
612
|
+
|
|
613
|
+
|
|
614
|
+
@pytest.mark.parametrize('vertexai', [True, False])
|
|
615
|
+
@pytest.mark.asyncio
|
|
616
|
+
async def test_async_session_start_stream(
|
|
617
|
+
mock_websocket, vertexai
|
|
618
|
+
):
|
|
619
|
+
|
|
620
|
+
session = live.AsyncSession(
|
|
621
|
+
mock_api_client(vertexai=vertexai), mock_websocket
|
|
622
|
+
)
|
|
623
|
+
|
|
624
|
+
async def mock_stream():
|
|
625
|
+
yield b'data1'
|
|
626
|
+
yield b'data2'
|
|
627
|
+
|
|
628
|
+
async for message in session.start_stream(
|
|
629
|
+
stream=mock_stream(), mime_type='audio/pcm'
|
|
630
|
+
):
|
|
631
|
+
assert isinstance(message, types.LiveServerMessage)
|
|
632
|
+
|
|
633
|
+
|
|
634
|
+
@pytest.mark.parametrize('vertexai', [True, False])
|
|
635
|
+
@pytest.mark.asyncio
|
|
636
|
+
async def test_async_session_receive_vad_signal(mock_websocket, vertexai):
|
|
637
|
+
# Simulate the server sending a VAD signal message
|
|
638
|
+
mock_websocket.recv = mock.AsyncMock(
|
|
639
|
+
side_effect=[
|
|
640
|
+
'{"voiceActivityDetectionSignal": {"vadSignalType": "VAD_SIGNAL_TYPE_SOS"}}',
|
|
641
|
+
'{"serverContent": {"turnComplete": true}}', # To close the receiver loop
|
|
642
|
+
]
|
|
643
|
+
)
|
|
644
|
+
|
|
645
|
+
session = live.AsyncSession(
|
|
646
|
+
api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
|
|
647
|
+
)
|
|
648
|
+
|
|
649
|
+
messages = await _async_iterator_to_list(session.receive())
|
|
650
|
+
|
|
651
|
+
# Check if the first message contains the VAD signal
|
|
652
|
+
assert len(messages) > 0
|
|
653
|
+
vad_message = messages[0]
|
|
654
|
+
assert isinstance(vad_message, types.LiveServerMessage)
|
|
655
|
+
assert vad_message.voice_activity_detection_signal is not None
|
|
656
|
+
assert (
|
|
657
|
+
vad_message.voice_activity_detection_signal.vad_signal_type
|
|
658
|
+
== types.VadSignalType.VAD_SIGNAL_TYPE_SOS
|
|
659
|
+
)
|
|
660
|
+
|
|
661
|
+
# Check that the session can close cleanly
|
|
662
|
+
assert messages[-1].server_content.turn_complete is True
|
|
663
|
+
|
|
664
|
+
|
|
665
|
+
@pytest.mark.parametrize('vertexai', [True, False])
|
|
666
|
+
@pytest.mark.asyncio
|
|
667
|
+
async def test_async_session_close( mock_websocket, vertexai):
|
|
668
|
+
session = live.AsyncSession(
|
|
669
|
+
mock_api_client(vertexai=vertexai), mock_websocket
|
|
670
|
+
)
|
|
671
|
+
await session.close()
|
|
672
|
+
mock_websocket.close.assert_called_once()
|
|
673
|
+
|
|
674
|
+
|
|
675
|
+
@pytest.mark.parametrize('vertexai', [True, False])
|
|
676
|
+
@pytest.mark.asyncio
|
|
677
|
+
async def test_bidi_setup_to_api_no_config(vertexai):
|
|
678
|
+
with warnings.catch_warnings():
|
|
679
|
+
# Make sure there are no warnings cause by default values.
|
|
680
|
+
warnings.simplefilter('error')
|
|
681
|
+
result = await get_connect_message(
|
|
682
|
+
mock_api_client(vertexai=vertexai),
|
|
683
|
+
model='test_model'
|
|
684
|
+
)
|
|
685
|
+
expected_result = {'setup': {}}
|
|
686
|
+
if vertexai:
|
|
687
|
+
expected_result['setup']['model'] = 'projects/test_project/locations/us-central1/publishers/google/models/test_model'
|
|
688
|
+
expected_result['setup']['generationConfig'] = {}
|
|
689
|
+
expected_result['setup']['generationConfig']['responseModalities'] = ["AUDIO"]
|
|
690
|
+
else:
|
|
691
|
+
expected_result['setup']['model'] = 'models/test_model'
|
|
692
|
+
assert result == expected_result
|
|
693
|
+
|
|
694
|
+
|
|
695
|
+
@pytest.mark.parametrize('vertexai', [True, False])
|
|
696
|
+
@pytest.mark.asyncio
|
|
697
|
+
async def test_bidi_setup_to_api_speech_config(vertexai):
|
|
698
|
+
|
|
699
|
+
expected_result = {
|
|
700
|
+
'setup': {
|
|
701
|
+
'model': 'models/test_model',
|
|
702
|
+
'generationConfig': {
|
|
703
|
+
'speechConfig': {
|
|
704
|
+
'voice_config': {
|
|
705
|
+
'prebuilt_voice_config': {'voice_name': 'en-default'}
|
|
706
|
+
},
|
|
707
|
+
'language_code': 'en-US',
|
|
708
|
+
},
|
|
709
|
+
'enableAffectiveDialog': True,
|
|
710
|
+
'temperature': 0.7,
|
|
711
|
+
'topP': 0.8,
|
|
712
|
+
'topK': 9.0,
|
|
713
|
+
'maxOutputTokens': 10,
|
|
714
|
+
'mediaResolution': 'MEDIA_RESOLUTION_MEDIUM',
|
|
715
|
+
'seed': 13,
|
|
716
|
+
},
|
|
717
|
+
'proactivity': {'proactive_audio': True},
|
|
718
|
+
'systemInstruction': {
|
|
719
|
+
'parts': [
|
|
720
|
+
{
|
|
721
|
+
'text': 'test instruction',
|
|
722
|
+
},
|
|
723
|
+
],
|
|
724
|
+
'role': 'user',
|
|
725
|
+
},
|
|
726
|
+
}
|
|
727
|
+
}
|
|
728
|
+
if vertexai:
|
|
729
|
+
expected_result['setup']['model'] = (
|
|
730
|
+
'projects/test_project/locations/us-central1/'
|
|
731
|
+
'publishers/google/models/test_model'
|
|
732
|
+
)
|
|
733
|
+
expected_result['setup']['generationConfig']['responseModalities'] = [
|
|
734
|
+
'AUDIO'
|
|
735
|
+
]
|
|
736
|
+
expected_result['setup']['generationConfig']['speechConfig'] = {
|
|
737
|
+
'voiceConfig': {
|
|
738
|
+
'prebuilt_voice_config': {'voice_name': 'en-default'}
|
|
739
|
+
},
|
|
740
|
+
'languageCode': 'en-US',
|
|
741
|
+
}
|
|
742
|
+
else:
|
|
743
|
+
expected_result['setup']['model'] = 'models/test_model'
|
|
744
|
+
|
|
745
|
+
# Config is a dict
|
|
746
|
+
config_dict = {
|
|
747
|
+
'speech_config': {
|
|
748
|
+
'voice_config': {
|
|
749
|
+
'prebuilt_voice_config': {'voice_name': 'en-default'}
|
|
750
|
+
},
|
|
751
|
+
'language_code': 'en-US',
|
|
752
|
+
},
|
|
753
|
+
'enable_affective_dialog': True,
|
|
754
|
+
'proactivity': {'proactive_audio': True},
|
|
755
|
+
'temperature': 0.7,
|
|
756
|
+
'top_p': 0.8,
|
|
757
|
+
'top_k': 9,
|
|
758
|
+
'max_output_tokens': 10,
|
|
759
|
+
'seed': 13,
|
|
760
|
+
'system_instruction': 'test instruction',
|
|
761
|
+
'media_resolution': 'MEDIA_RESOLUTION_MEDIUM',
|
|
762
|
+
}
|
|
763
|
+
result = await get_connect_message(
|
|
764
|
+
mock_api_client(vertexai=vertexai), model='test_model', config=config_dict
|
|
765
|
+
)
|
|
766
|
+
assert types.LiveClientMessage._from_response(
|
|
767
|
+
response=result, kwargs=None
|
|
768
|
+
) == types.LiveClientMessage._from_response(
|
|
769
|
+
response=expected_result, kwargs=None
|
|
770
|
+
)
|
|
771
|
+
# Config is a LiveConnectConfig
|
|
772
|
+
config = types.LiveConnectConfig(
|
|
773
|
+
speech_config=types.SpeechConfig(
|
|
774
|
+
voice_config=types.VoiceConfig(
|
|
775
|
+
prebuilt_voice_config=types.PrebuiltVoiceConfig(
|
|
776
|
+
voice_name='en-default'
|
|
777
|
+
)
|
|
778
|
+
),
|
|
779
|
+
language_code='en-US',
|
|
780
|
+
),
|
|
781
|
+
enable_affective_dialog=True,
|
|
782
|
+
proactivity=types.ProactivityConfig(proactive_audio=True),
|
|
783
|
+
temperature=0.7,
|
|
784
|
+
top_p=0.8,
|
|
785
|
+
top_k=9,
|
|
786
|
+
max_output_tokens=10,
|
|
787
|
+
media_resolution=types.MediaResolution.MEDIA_RESOLUTION_MEDIUM,
|
|
788
|
+
seed=13,
|
|
789
|
+
system_instruction='test instruction',
|
|
790
|
+
)
|
|
791
|
+
result = await get_connect_message(
|
|
792
|
+
mock_api_client(vertexai=vertexai),
|
|
793
|
+
model='test_model', config=config
|
|
794
|
+
)
|
|
795
|
+
assert types.LiveClientMessage._from_response(
|
|
796
|
+
response=result, kwargs=None
|
|
797
|
+
) == types.LiveClientMessage._from_response(
|
|
798
|
+
response=expected_result, kwargs=None
|
|
799
|
+
)
|
|
800
|
+
|
|
801
|
+
|
|
802
|
+
@pytest.mark.parametrize('vertexai', [True, False])
|
|
803
|
+
@pytest.mark.asyncio
|
|
804
|
+
async def test_bidi_setup_error_if_multispeaker_voice_config(vertexai):
|
|
805
|
+
|
|
806
|
+
# Config is a dict
|
|
807
|
+
config_dict = {
|
|
808
|
+
'speech_config': {
|
|
809
|
+
'multi_speaker_voice_config': {
|
|
810
|
+
'speaker_voice_configs': [
|
|
811
|
+
{
|
|
812
|
+
'speaker': 'Alice',
|
|
813
|
+
'voice_config': {
|
|
814
|
+
'prebuilt_voice_config': {'voice_name': 'leda'}
|
|
815
|
+
},
|
|
816
|
+
},
|
|
817
|
+
{
|
|
818
|
+
'speaker': 'Bob',
|
|
819
|
+
'voice_config': {
|
|
820
|
+
'prebuilt_voice_config': {'voice_name': 'kore'}
|
|
821
|
+
},
|
|
822
|
+
},
|
|
823
|
+
],
|
|
824
|
+
},
|
|
825
|
+
},
|
|
826
|
+
'temperature': 0.7,
|
|
827
|
+
'top_p': 0.8,
|
|
828
|
+
'top_k': 9,
|
|
829
|
+
'max_output_tokens': 10,
|
|
830
|
+
'seed': 13,
|
|
831
|
+
'system_instruction': 'test instruction',
|
|
832
|
+
'media_resolution': 'MEDIA_RESOLUTION_MEDIUM',
|
|
833
|
+
}
|
|
834
|
+
with pytest.raises(ValueError, match='.*multi_speaker_voice_config.*'):
|
|
835
|
+
result = await get_connect_message(
|
|
836
|
+
mock_api_client(vertexai=vertexai),
|
|
837
|
+
model='test_model',
|
|
838
|
+
config=config_dict,
|
|
839
|
+
)
|
|
840
|
+
|
|
841
|
+
|
|
842
|
+
@pytest.mark.parametrize('vertexai', [True, False])
|
|
843
|
+
@pytest.mark.asyncio
|
|
844
|
+
async def test_replicated_voice_config(vertexai):
|
|
845
|
+
# Config is a dict
|
|
846
|
+
config_dict = {
|
|
847
|
+
'speech_config': {
|
|
848
|
+
'voice_config': {
|
|
849
|
+
'replicated_voice_config': {
|
|
850
|
+
'mime_type': 'audio/pcm',
|
|
851
|
+
'voice_sample_audio': bytes([0, 0, 0]),
|
|
852
|
+
},
|
|
853
|
+
},
|
|
854
|
+
},
|
|
855
|
+
}
|
|
856
|
+
result = await get_connect_message(
|
|
857
|
+
mock_api_client(vertexai=vertexai),
|
|
858
|
+
model='test_model',
|
|
859
|
+
config=config_dict,
|
|
860
|
+
)
|
|
861
|
+
if vertexai:
|
|
862
|
+
try:
|
|
863
|
+
replicated_voice_config = result['setup']['generationConfig'][
|
|
864
|
+
'speechConfig'
|
|
865
|
+
]['voiceConfig']['replicatedVoiceConfig']
|
|
866
|
+
except KeyError:
|
|
867
|
+
replicated_voice_config = result['setup']['generationConfig'][
|
|
868
|
+
'speechConfig'
|
|
869
|
+
]['voiceConfig']['replicated_voice_config']
|
|
870
|
+
assert replicated_voice_config == {
|
|
871
|
+
'mime_type': 'audio/pcm',
|
|
872
|
+
'voice_sample_audio': 'AAAA',
|
|
873
|
+
}
|
|
874
|
+
else:
|
|
875
|
+
return
|
|
876
|
+
|
|
877
|
+
|
|
878
|
+
@pytest.mark.parametrize('vertexai', [True, False])
|
|
879
|
+
@pytest.mark.asyncio
|
|
880
|
+
async def test_explicit_vad(vertexai):
|
|
881
|
+
# Config is a dict
|
|
882
|
+
config_dict = {'explicit_vad_signal': True}
|
|
883
|
+
with pytest_helper.exception_if_mldev(
|
|
884
|
+
mock_api_client(vertexai=vertexai), ValueError
|
|
885
|
+
):
|
|
886
|
+
result = await get_connect_message(
|
|
887
|
+
mock_api_client(vertexai=vertexai), model='test_model', config=config_dict
|
|
888
|
+
)
|
|
889
|
+
if not vertexai:
|
|
890
|
+
return
|
|
891
|
+
assert result['setup']['explicitVadSignal'] == True
|
|
892
|
+
|
|
893
|
+
|
|
894
|
+
@pytest.mark.parametrize('vertexai', [True, False])
|
|
895
|
+
@pytest.mark.asyncio
|
|
896
|
+
async def test_explicit_vad_config(vertexai):
|
|
897
|
+
api_client = mock_api_client(vertexai=vertexai)
|
|
898
|
+
|
|
899
|
+
# Config is a dict
|
|
900
|
+
config_dict = {'explicit_vad_signal': True}
|
|
901
|
+
with pytest_helper.exception_if_mldev(api_client, ValueError):
|
|
902
|
+
result = await get_connect_message(
|
|
903
|
+
mock_api_client(vertexai=vertexai),
|
|
904
|
+
model='test_model',
|
|
905
|
+
config=config_dict,
|
|
906
|
+
)
|
|
907
|
+
if not vertexai:
|
|
908
|
+
return
|
|
909
|
+
assert result['setup']['explicitVadSignal'] == True
|
|
910
|
+
|
|
911
|
+
|
|
912
|
+
@pytest.mark.parametrize('vertexai', [True, False])
|
|
913
|
+
@pytest.mark.asyncio
|
|
914
|
+
async def test_bidi_setup_to_api_with_system_instruction_as_content_type(
|
|
915
|
+
vertexai,
|
|
916
|
+
):
|
|
917
|
+
config_dict = {
|
|
918
|
+
'system_instruction': {
|
|
919
|
+
'parts': [{'text': 'test instruction'}],
|
|
920
|
+
'role': 'user',
|
|
921
|
+
},
|
|
922
|
+
}
|
|
923
|
+
config = types.LiveConnectConfig(**config_dict)
|
|
924
|
+
expected_result = {
|
|
925
|
+
'setup': {
|
|
926
|
+
'model': 'test_model',
|
|
927
|
+
'systemInstruction': {
|
|
928
|
+
'parts': [{'text': 'test instruction'}],
|
|
929
|
+
'role': 'user',
|
|
930
|
+
},
|
|
931
|
+
}
|
|
932
|
+
}
|
|
933
|
+
if vertexai:
|
|
934
|
+
expected_result['setup'][
|
|
935
|
+
'model'
|
|
936
|
+
] = 'projects/test_project/locations/us-central1/publishers/google/models/test_model'
|
|
937
|
+
expected_result['setup']['generationConfig'] = {}
|
|
938
|
+
expected_result['setup']['generationConfig']['responseModalities'] = [
|
|
939
|
+
'AUDIO'
|
|
940
|
+
]
|
|
941
|
+
else:
|
|
942
|
+
expected_result['setup']['model'] = 'models/test_model'
|
|
943
|
+
|
|
944
|
+
result = await get_connect_message(
|
|
945
|
+
mock_api_client(vertexai=vertexai),
|
|
946
|
+
model='test_model', config=config
|
|
947
|
+
)
|
|
948
|
+
assert result == expected_result
|
|
949
|
+
|
|
950
|
+
|
|
951
|
+
@pytest.mark.parametrize('vertexai', [True, False])
|
|
952
|
+
@pytest.mark.asyncio
|
|
953
|
+
async def test_bidi_setup_to_api_with_config_tools_google_search(vertexai):
|
|
954
|
+
config_dict = {
|
|
955
|
+
'response_modalities': ['TEXT'],
|
|
956
|
+
'system_instruction': 'test instruction',
|
|
957
|
+
'generation_config': {'temperature': 0.7},
|
|
958
|
+
'tools': [{'google_search': {}}],
|
|
959
|
+
}
|
|
960
|
+
|
|
961
|
+
config = types.LiveConnectConfig(**config_dict)
|
|
962
|
+
expected_result = {
|
|
963
|
+
'setup': {
|
|
964
|
+
'generationConfig': {
|
|
965
|
+
'temperature': 0.7,
|
|
966
|
+
'responseModalities': ['TEXT'],
|
|
967
|
+
},
|
|
968
|
+
'systemInstruction': {
|
|
969
|
+
'parts': [{'text': 'test instruction'}],
|
|
970
|
+
'role': 'user',
|
|
971
|
+
},
|
|
972
|
+
'tools': [{'googleSearch': {}}],
|
|
973
|
+
}
|
|
974
|
+
}
|
|
975
|
+
if vertexai:
|
|
976
|
+
expected_result['setup']['model'] = 'projects/test_project/locations/us-central1/publishers/google/models/test_model'
|
|
977
|
+
else:
|
|
978
|
+
expected_result['setup']['model'] = 'models/test_model'
|
|
979
|
+
|
|
980
|
+
result = await get_connect_message(
|
|
981
|
+
mock_api_client(vertexai=vertexai),
|
|
982
|
+
model='test_model', config=config_dict
|
|
983
|
+
)
|
|
984
|
+
|
|
985
|
+
assert result == expected_result
|
|
986
|
+
|
|
987
|
+
# Test config is a LiveConnectConfig
|
|
988
|
+
result = await get_connect_message(
|
|
989
|
+
mock_api_client(vertexai=vertexai),
|
|
990
|
+
model='test_model', config=config
|
|
991
|
+
)
|
|
992
|
+
|
|
993
|
+
assert result == expected_result
|
|
994
|
+
|
|
995
|
+
|
|
996
|
+
@pytest.mark.parametrize('vertexai', [True, False])
|
|
997
|
+
@pytest.mark.asyncio
|
|
998
|
+
async def test_bidi_setup_to_api_with_config_tools_with_no_mcp(vertexai):
|
|
999
|
+
config_dict = {
|
|
1000
|
+
'response_modalities': ['TEXT'],
|
|
1001
|
+
'system_instruction': 'test instruction',
|
|
1002
|
+
'generation_config': {'temperature': 0.7},
|
|
1003
|
+
'tools': [{'google_search': {}}],
|
|
1004
|
+
}
|
|
1005
|
+
|
|
1006
|
+
config = types.LiveConnectConfig(**config_dict)
|
|
1007
|
+
expected_result = {
|
|
1008
|
+
'setup': {
|
|
1009
|
+
'generationConfig': {
|
|
1010
|
+
'temperature': 0.7,
|
|
1011
|
+
'responseModalities': ['TEXT'],
|
|
1012
|
+
},
|
|
1013
|
+
'systemInstruction': {
|
|
1014
|
+
'parts': [{'text': 'test instruction'}],
|
|
1015
|
+
'role': 'user',
|
|
1016
|
+
},
|
|
1017
|
+
'tools': [{'googleSearch': {}}],
|
|
1018
|
+
}
|
|
1019
|
+
}
|
|
1020
|
+
if vertexai:
|
|
1021
|
+
expected_result['setup']['model'] = 'projects/test_project/locations/us-central1/publishers/google/models/test_model'
|
|
1022
|
+
else:
|
|
1023
|
+
expected_result['setup']['model'] = 'models/test_model'
|
|
1024
|
+
|
|
1025
|
+
@patch.object(live, "McpClientSession", new=None)
|
|
1026
|
+
@patch.object(live, "McpTool", new=None)
|
|
1027
|
+
async def get_connect_message_no_mcp(config):
|
|
1028
|
+
return await get_connect_message(
|
|
1029
|
+
mock_api_client(vertexai=vertexai),
|
|
1030
|
+
model='test_model', config=config
|
|
1031
|
+
)
|
|
1032
|
+
|
|
1033
|
+
result = await get_connect_message_no_mcp(config_dict)
|
|
1034
|
+
assert result == expected_result
|
|
1035
|
+
|
|
1036
|
+
result = await get_connect_message_no_mcp(config_dict)
|
|
1037
|
+
assert result == expected_result
|
|
1038
|
+
|
|
1039
|
+
|
|
1040
|
+
@pytest.mark.parametrize('vertexai', [True, False])
|
|
1041
|
+
@pytest.mark.asyncio
|
|
1042
|
+
async def test_bidi_setup_to_api_with_context_window_compression(
|
|
1043
|
+
vertexai
|
|
1044
|
+
):
|
|
1045
|
+
config = types.LiveConnectConfig(
|
|
1046
|
+
generation_config=types.GenerationConfig(temperature=0.7),
|
|
1047
|
+
response_modalities=['TEXT'],
|
|
1048
|
+
system_instruction=types.Content(
|
|
1049
|
+
parts=[types.Part(text='test instruction')], role='user'
|
|
1050
|
+
),
|
|
1051
|
+
context_window_compression=types.ContextWindowCompressionConfig(
|
|
1052
|
+
trigger_tokens=1000,
|
|
1053
|
+
sliding_window=types.SlidingWindow(target_tokens=10),
|
|
1054
|
+
),
|
|
1055
|
+
)
|
|
1056
|
+
expected_result = {
|
|
1057
|
+
'setup': {
|
|
1058
|
+
'generationConfig': {
|
|
1059
|
+
'temperature': 0.7,
|
|
1060
|
+
'responseModalities': ['TEXT'],
|
|
1061
|
+
},
|
|
1062
|
+
'systemInstruction': {
|
|
1063
|
+
'parts': [{'text': 'test instruction'}],
|
|
1064
|
+
'role': 'user',
|
|
1065
|
+
},
|
|
1066
|
+
'contextWindowCompression': {
|
|
1067
|
+
'trigger_tokens': 1000,
|
|
1068
|
+
'sliding_window': {'target_tokens': 10},
|
|
1069
|
+
}
|
|
1070
|
+
}
|
|
1071
|
+
}
|
|
1072
|
+
if vertexai:
|
|
1073
|
+
expected_result['setup']['model'] = 'projects/test_project/locations/us-central1/publishers/google/models/test_model'
|
|
1074
|
+
else:
|
|
1075
|
+
expected_result['setup']['model'] = 'models/test_model'
|
|
1076
|
+
|
|
1077
|
+
result = await get_connect_message(
|
|
1078
|
+
mock_api_client(vertexai=vertexai),
|
|
1079
|
+
model='test_model', config=config
|
|
1080
|
+
)
|
|
1081
|
+
assert result == expected_result
|
|
1082
|
+
|
|
1083
|
+
|
|
1084
|
+
@pytest.mark.parametrize('vertexai', [True, False])
|
|
1085
|
+
@pytest.mark.asyncio
|
|
1086
|
+
async def test_bidi_setup_to_api_with_config_tools_function_declaration(
|
|
1087
|
+
vertexai
|
|
1088
|
+
):
|
|
1089
|
+
config_dict = {
|
|
1090
|
+
'generation_config': {'temperature': 0.7},
|
|
1091
|
+
'tools': [{'function_declarations': function_declarations}],
|
|
1092
|
+
}
|
|
1093
|
+
config = types.LiveConnectConfig(**config_dict)
|
|
1094
|
+
expected_result = {
|
|
1095
|
+
'setup': {
|
|
1096
|
+
'model': 'test_model',
|
|
1097
|
+
'tools': [{
|
|
1098
|
+
'functionDeclarations': [{
|
|
1099
|
+
'parameters': {
|
|
1100
|
+
'type': 'OBJECT',
|
|
1101
|
+
'properties': {
|
|
1102
|
+
'location': {
|
|
1103
|
+
'type': 'STRING',
|
|
1104
|
+
'description': (
|
|
1105
|
+
'The location to get the weather for'
|
|
1106
|
+
),
|
|
1107
|
+
},
|
|
1108
|
+
'unit': {'type': 'STRING', 'enum': ['C', 'F']},
|
|
1109
|
+
},
|
|
1110
|
+
},
|
|
1111
|
+
'name': 'get_current_weather',
|
|
1112
|
+
'description': 'Get the current weather in a city',
|
|
1113
|
+
}],
|
|
1114
|
+
}],
|
|
1115
|
+
}
|
|
1116
|
+
}
|
|
1117
|
+
result = await get_connect_message(
|
|
1118
|
+
mock_api_client(vertexai=vertexai),
|
|
1119
|
+
model='test_model', config=config
|
|
1120
|
+
)
|
|
1121
|
+
|
|
1122
|
+
assert result['setup']['tools'][0]['functionDeclarations'][0][
|
|
1123
|
+
'description'
|
|
1124
|
+
] == (
|
|
1125
|
+
expected_result['setup']['tools'][0]['functionDeclarations'][0][
|
|
1126
|
+
'description'
|
|
1127
|
+
]
|
|
1128
|
+
)
|
|
1129
|
+
|
|
1130
|
+
result = await get_connect_message(
|
|
1131
|
+
mock_api_client(vertexai=vertexai),
|
|
1132
|
+
model='test_model', config=config
|
|
1133
|
+
)
|
|
1134
|
+
|
|
1135
|
+
assert result['setup']['tools'][0]['functionDeclarations'][0][
|
|
1136
|
+
'description'
|
|
1137
|
+
] == (
|
|
1138
|
+
expected_result['setup']['tools'][0]['functionDeclarations'][0][
|
|
1139
|
+
'description'
|
|
1140
|
+
]
|
|
1141
|
+
)
|
|
1142
|
+
|
|
1143
|
+
|
|
1144
|
+
@pytest.mark.parametrize('vertexai', [True, False])
|
|
1145
|
+
@pytest.mark.asyncio
|
|
1146
|
+
async def test_bidi_setup_to_api_with_config_tools_function_directly(
|
|
1147
|
+
vertexai
|
|
1148
|
+
):
|
|
1149
|
+
config_dict = {
|
|
1150
|
+
'generation_config': {'temperature': 0.7},
|
|
1151
|
+
'tools': [get_current_weather],
|
|
1152
|
+
}
|
|
1153
|
+
config = types.LiveConnectConfig(**config_dict)
|
|
1154
|
+
expected_result = {
|
|
1155
|
+
'setup': {
|
|
1156
|
+
'model': 'test_model',
|
|
1157
|
+
'tools': [{
|
|
1158
|
+
'functionDeclarations': [{
|
|
1159
|
+
'parameters': {
|
|
1160
|
+
'type': 'OBJECT',
|
|
1161
|
+
'properties': {
|
|
1162
|
+
'location': {
|
|
1163
|
+
'type': 'STRING',
|
|
1164
|
+
'description': (
|
|
1165
|
+
'The location to get the weather for'
|
|
1166
|
+
),
|
|
1167
|
+
},
|
|
1168
|
+
'unit': {'type': 'STRING', 'enum': ['C', 'F']},
|
|
1169
|
+
},
|
|
1170
|
+
},
|
|
1171
|
+
'name': 'get_current_weather',
|
|
1172
|
+
'description': 'Get the current weather in a city.',
|
|
1173
|
+
}],
|
|
1174
|
+
}],
|
|
1175
|
+
}
|
|
1176
|
+
}
|
|
1177
|
+
result = await get_connect_message(
|
|
1178
|
+
mock_api_client(vertexai=vertexai),
|
|
1179
|
+
model='test_model', config=config
|
|
1180
|
+
)
|
|
1181
|
+
|
|
1182
|
+
assert result['setup']['tools'][0]['functionDeclarations'][0][
|
|
1183
|
+
'description'
|
|
1184
|
+
] == (
|
|
1185
|
+
expected_result['setup']['tools'][0]['functionDeclarations'][0][
|
|
1186
|
+
'description'
|
|
1187
|
+
]
|
|
1188
|
+
)
|
|
1189
|
+
|
|
1190
|
+
result = await get_connect_message(
|
|
1191
|
+
mock_api_client(vertexai=vertexai),
|
|
1192
|
+
model='test_model', config=config
|
|
1193
|
+
)
|
|
1194
|
+
|
|
1195
|
+
assert result['setup']['tools'][0]['functionDeclarations'][0][
|
|
1196
|
+
'description'
|
|
1197
|
+
] == (
|
|
1198
|
+
expected_result['setup']['tools'][0]['functionDeclarations'][0][
|
|
1199
|
+
'description'
|
|
1200
|
+
]
|
|
1201
|
+
)
|
|
1202
|
+
|
|
1203
|
+
|
|
1204
|
+
@pytest.mark.parametrize('vertexai', [True, False])
|
|
1205
|
+
@pytest.mark.asyncio
|
|
1206
|
+
async def test_bidi_setup_to_api_with_tools_function_behavior(vertexai):
|
|
1207
|
+
api_client = mock_api_client(vertexai=vertexai)
|
|
1208
|
+
|
|
1209
|
+
declaration = types.FunctionDeclaration.from_callable(
|
|
1210
|
+
client=api_client, callable=get_current_weather
|
|
1211
|
+
)
|
|
1212
|
+
declaration.behavior = types.Behavior.NON_BLOCKING
|
|
1213
|
+
config_dict = {
|
|
1214
|
+
'generation_config': {'temperature': 0.7},
|
|
1215
|
+
'tools': [{'function_declarations': [declaration]}],
|
|
1216
|
+
}
|
|
1217
|
+
config = types.LiveConnectConfig(**config_dict)
|
|
1218
|
+
|
|
1219
|
+
with pytest_helper.exception_if_vertex(api_client, ValueError):
|
|
1220
|
+
result = await get_connect_message(
|
|
1221
|
+
mock_api_client(vertexai=vertexai), model='test_model', config=config
|
|
1222
|
+
)
|
|
1223
|
+
if vertexai:
|
|
1224
|
+
return
|
|
1225
|
+
|
|
1226
|
+
assert (
|
|
1227
|
+
result['setup']['tools'][0]['functionDeclarations'][0]['behavior']
|
|
1228
|
+
== 'NON_BLOCKING'
|
|
1229
|
+
)
|
|
1230
|
+
|
|
1231
|
+
|
|
1232
|
+
@pytest.mark.parametrize('vertexai', [True, False])
|
|
1233
|
+
@pytest.mark.asyncio
|
|
1234
|
+
async def test_bidi_setup_to_api_with_config_mcp_tools(
|
|
1235
|
+
vertexai,
|
|
1236
|
+
):
|
|
1237
|
+
if mcp_types is None:
|
|
1238
|
+
return
|
|
1239
|
+
|
|
1240
|
+
expected_result_googleai = {
|
|
1241
|
+
'setup': {
|
|
1242
|
+
'model': 'models/test_model',
|
|
1243
|
+
'tools': [{
|
|
1244
|
+
'functionDeclarations': [{
|
|
1245
|
+
'parameters': {
|
|
1246
|
+
'type': 'OBJECT',
|
|
1247
|
+
'properties': {
|
|
1248
|
+
'location': {
|
|
1249
|
+
'type': 'STRING',
|
|
1250
|
+
},
|
|
1251
|
+
},
|
|
1252
|
+
},
|
|
1253
|
+
'name': 'get_weather',
|
|
1254
|
+
'description': 'Get the weather in a city.',
|
|
1255
|
+
}],
|
|
1256
|
+
}],
|
|
1257
|
+
}
|
|
1258
|
+
}
|
|
1259
|
+
expected_result_vertexai = {
|
|
1260
|
+
'setup': {
|
|
1261
|
+
'generationConfig': {
|
|
1262
|
+
'responseModalities': [
|
|
1263
|
+
'AUDIO',
|
|
1264
|
+
],
|
|
1265
|
+
},
|
|
1266
|
+
'model': (
|
|
1267
|
+
'projects/test_project/locations/us-central1/publishers/google/models/test_model'
|
|
1268
|
+
),
|
|
1269
|
+
'tools': [{
|
|
1270
|
+
'functionDeclarations': [{
|
|
1271
|
+
'parameters': {
|
|
1272
|
+
'type': 'OBJECT',
|
|
1273
|
+
'properties': {
|
|
1274
|
+
'location': {
|
|
1275
|
+
'type': 'STRING',
|
|
1276
|
+
},
|
|
1277
|
+
},
|
|
1278
|
+
},
|
|
1279
|
+
'name': 'get_weather',
|
|
1280
|
+
'description': 'Get the weather in a city.',
|
|
1281
|
+
}],
|
|
1282
|
+
}],
|
|
1283
|
+
}
|
|
1284
|
+
}
|
|
1285
|
+
result = await get_connect_message(
|
|
1286
|
+
mock_api_client(vertexai=vertexai),
|
|
1287
|
+
model='test_model',
|
|
1288
|
+
config={
|
|
1289
|
+
'tools': [
|
|
1290
|
+
mcp_types.Tool(
|
|
1291
|
+
name='get_weather',
|
|
1292
|
+
description='Get the weather in a city.',
|
|
1293
|
+
inputSchema={
|
|
1294
|
+
'type': 'object',
|
|
1295
|
+
'properties': {'location': {'type': 'string'}},
|
|
1296
|
+
},
|
|
1297
|
+
)
|
|
1298
|
+
],
|
|
1299
|
+
},
|
|
1300
|
+
)
|
|
1301
|
+
|
|
1302
|
+
assert (
|
|
1303
|
+
result == expected_result_vertexai
|
|
1304
|
+
if vertexai
|
|
1305
|
+
else expected_result_googleai
|
|
1306
|
+
)
|
|
1307
|
+
|
|
1308
|
+
|
|
1309
|
+
@pytest.mark.parametrize('vertexai', [True, False])
|
|
1310
|
+
@pytest.mark.asyncio
|
|
1311
|
+
async def test_bidi_setup_to_api_with_config_mcp_session(
|
|
1312
|
+
vertexai,
|
|
1313
|
+
):
|
|
1314
|
+
if mcp_types is None:
|
|
1315
|
+
return
|
|
1316
|
+
|
|
1317
|
+
class MockMcpClientSession(McpClientSession):
|
|
1318
|
+
|
|
1319
|
+
def __init__(self):
|
|
1320
|
+
self._read_stream = None
|
|
1321
|
+
self._write_stream = None
|
|
1322
|
+
|
|
1323
|
+
async def list_tools(self):
|
|
1324
|
+
return mcp_types.ListToolsResult(
|
|
1325
|
+
tools=[
|
|
1326
|
+
mcp_types.Tool(
|
|
1327
|
+
name='get_weather',
|
|
1328
|
+
description='Get the weather in a city.',
|
|
1329
|
+
inputSchema={
|
|
1330
|
+
'type': 'object',
|
|
1331
|
+
'properties': {'location': {'type': 'string'}},
|
|
1332
|
+
},
|
|
1333
|
+
),
|
|
1334
|
+
]
|
|
1335
|
+
)
|
|
1336
|
+
|
|
1337
|
+
expected_result_googleai = {
|
|
1338
|
+
'setup': {
|
|
1339
|
+
'model': 'models/test_model',
|
|
1340
|
+
'tools': [{
|
|
1341
|
+
'functionDeclarations': [{
|
|
1342
|
+
'parameters': {
|
|
1343
|
+
'type': 'OBJECT',
|
|
1344
|
+
'properties': {
|
|
1345
|
+
'location': {
|
|
1346
|
+
'type': 'STRING',
|
|
1347
|
+
},
|
|
1348
|
+
},
|
|
1349
|
+
},
|
|
1350
|
+
'name': 'get_weather',
|
|
1351
|
+
'description': 'Get the weather in a city.',
|
|
1352
|
+
}],
|
|
1353
|
+
}],
|
|
1354
|
+
}
|
|
1355
|
+
}
|
|
1356
|
+
expected_result_vertexai = {
|
|
1357
|
+
'setup': {
|
|
1358
|
+
'generationConfig': {
|
|
1359
|
+
'responseModalities': [
|
|
1360
|
+
'AUDIO',
|
|
1361
|
+
],
|
|
1362
|
+
},
|
|
1363
|
+
'model': (
|
|
1364
|
+
'projects/test_project/locations/us-central1/publishers/google/models/test_model'
|
|
1365
|
+
),
|
|
1366
|
+
'tools': [{
|
|
1367
|
+
'functionDeclarations': [{
|
|
1368
|
+
'parameters': {
|
|
1369
|
+
'type': 'OBJECT',
|
|
1370
|
+
'properties': {
|
|
1371
|
+
'location': {
|
|
1372
|
+
'type': 'STRING',
|
|
1373
|
+
},
|
|
1374
|
+
},
|
|
1375
|
+
},
|
|
1376
|
+
'name': 'get_weather',
|
|
1377
|
+
'description': 'Get the weather in a city.',
|
|
1378
|
+
}],
|
|
1379
|
+
}],
|
|
1380
|
+
}
|
|
1381
|
+
}
|
|
1382
|
+
result = await get_connect_message(
|
|
1383
|
+
mock_api_client(vertexai=vertexai),
|
|
1384
|
+
model='test_model',
|
|
1385
|
+
config={
|
|
1386
|
+
'tools': [MockMcpClientSession()],
|
|
1387
|
+
},
|
|
1388
|
+
)
|
|
1389
|
+
|
|
1390
|
+
assert (
|
|
1391
|
+
result == expected_result_vertexai
|
|
1392
|
+
if vertexai
|
|
1393
|
+
else expected_result_googleai
|
|
1394
|
+
)
|
|
1395
|
+
|
|
1396
|
+
|
|
1397
|
+
@pytest.mark.parametrize('vertexai', [True, False])
|
|
1398
|
+
@pytest.mark.asyncio
|
|
1399
|
+
async def test_bidi_setup_to_api_with_config_tools_code_execution(
|
|
1400
|
+
vertexai
|
|
1401
|
+
):
|
|
1402
|
+
config_dict = {
|
|
1403
|
+
'tools': [{'code_execution': {}}],
|
|
1404
|
+
}
|
|
1405
|
+
config = types.LiveConnectConfig(**config_dict)
|
|
1406
|
+
expected_result = {
|
|
1407
|
+
'setup': {
|
|
1408
|
+
'model': 'test_model',
|
|
1409
|
+
'tools': [{
|
|
1410
|
+
'codeExecution': {},
|
|
1411
|
+
}],
|
|
1412
|
+
}
|
|
1413
|
+
}
|
|
1414
|
+
result = await get_connect_message(
|
|
1415
|
+
mock_api_client(vertexai=vertexai),
|
|
1416
|
+
model='test_model', config=config
|
|
1417
|
+
)
|
|
1418
|
+
|
|
1419
|
+
assert result['setup']['tools'][0] == expected_result['setup']['tools'][0]
|
|
1420
|
+
|
|
1421
|
+
|
|
1422
|
+
@pytest.mark.parametrize('vertexai', [True, False])
|
|
1423
|
+
@pytest.mark.asyncio
|
|
1424
|
+
async def test_bidi_setup_to_api_with_realtime_input_config(vertexai):
|
|
1425
|
+
config_dict = {
|
|
1426
|
+
'realtime_input_config': {
|
|
1427
|
+
'automatic_activity_detection': {
|
|
1428
|
+
'disabled': True,
|
|
1429
|
+
'start_of_speech_sensitivity': 'START_SENSITIVITY_HIGH',
|
|
1430
|
+
'end_of_speech_sensitivity': 'END_SENSITIVITY_HIGH',
|
|
1431
|
+
'prefix_padding_ms': 20,
|
|
1432
|
+
'silence_duration_ms': 100,
|
|
1433
|
+
},
|
|
1434
|
+
'activity_handling': 'NO_INTERRUPTION',
|
|
1435
|
+
'turn_coverage': 'TURN_INCLUDES_ALL_INPUT',
|
|
1436
|
+
}
|
|
1437
|
+
}
|
|
1438
|
+
|
|
1439
|
+
config = types.LiveConnectConfig(**config_dict)
|
|
1440
|
+
expected_result = {
|
|
1441
|
+
'setup': {
|
|
1442
|
+
'model': 'test_model',
|
|
1443
|
+
'realtimeInputConfig': config_dict['realtime_input_config'],
|
|
1444
|
+
}
|
|
1445
|
+
}
|
|
1446
|
+
|
|
1447
|
+
result = await get_connect_message(
|
|
1448
|
+
mock_api_client(vertexai=vertexai),
|
|
1449
|
+
model='test_model', config=config
|
|
1450
|
+
)
|
|
1451
|
+
|
|
1452
|
+
assert (
|
|
1453
|
+
result['setup']['realtimeInputConfig']
|
|
1454
|
+
== expected_result['setup']['realtimeInputConfig']
|
|
1455
|
+
)
|
|
1456
|
+
|
|
1457
|
+
|
|
1458
|
+
@pytest.mark.parametrize('vertexai', [True, False])
|
|
1459
|
+
@pytest.mark.asyncio
|
|
1460
|
+
async def test_bidi_setup_to_api_with_input_transcription(vertexai):
|
|
1461
|
+
config_dict = {
|
|
1462
|
+
'input_audio_transcription': {},
|
|
1463
|
+
}
|
|
1464
|
+
config = types.LiveConnectConfig(**config_dict)
|
|
1465
|
+
expected_result = {
|
|
1466
|
+
'setup': {
|
|
1467
|
+
'model': 'test_model',
|
|
1468
|
+
'inputAudioTranscription': {},
|
|
1469
|
+
}
|
|
1470
|
+
}
|
|
1471
|
+
|
|
1472
|
+
result = await get_connect_message(
|
|
1473
|
+
mock_api_client(vertexai=vertexai), model='test_model', config=config
|
|
1474
|
+
)
|
|
1475
|
+
|
|
1476
|
+
assert (
|
|
1477
|
+
result['setup']['inputAudioTranscription']
|
|
1478
|
+
== expected_result['setup']['inputAudioTranscription']
|
|
1479
|
+
)
|
|
1480
|
+
|
|
1481
|
+
|
|
1482
|
+
@pytest.mark.parametrize('vertexai', [True, False])
|
|
1483
|
+
@pytest.mark.asyncio
|
|
1484
|
+
async def test_bidi_setup_to_api_with_output_transcription(vertexai):
|
|
1485
|
+
config_dict = {
|
|
1486
|
+
'output_audio_transcription': {},
|
|
1487
|
+
}
|
|
1488
|
+
config = types.LiveConnectConfig(**config_dict)
|
|
1489
|
+
expected_result = {
|
|
1490
|
+
'setup': {
|
|
1491
|
+
'model': 'test_model',
|
|
1492
|
+
'outputAudioTranscription': {},
|
|
1493
|
+
}
|
|
1494
|
+
}
|
|
1495
|
+
|
|
1496
|
+
result = await get_connect_message(
|
|
1497
|
+
mock_api_client(vertexai=vertexai),
|
|
1498
|
+
model='test_model', config=config
|
|
1499
|
+
)
|
|
1500
|
+
|
|
1501
|
+
assert (
|
|
1502
|
+
result['setup']['outputAudioTranscription']
|
|
1503
|
+
== expected_result['setup']['outputAudioTranscription']
|
|
1504
|
+
)
|
|
1505
|
+
|
|
1506
|
+
@pytest.mark.parametrize('vertexai', [True, False])
|
|
1507
|
+
@pytest.mark.asyncio
|
|
1508
|
+
async def test_bidi_setup_to_api_with_media_resolution(vertexai):
|
|
1509
|
+
config_dict = {
|
|
1510
|
+
'media_resolution': 'MEDIA_RESOLUTION_LOW',
|
|
1511
|
+
}
|
|
1512
|
+
config = types.LiveConnectConfig(**config_dict)
|
|
1513
|
+
expected_result = {
|
|
1514
|
+
'setup': {
|
|
1515
|
+
'model': 'test_model',
|
|
1516
|
+
'generationConfig': {'mediaResolution':'MEDIA_RESOLUTION_LOW'},
|
|
1517
|
+
}
|
|
1518
|
+
}
|
|
1519
|
+
|
|
1520
|
+
result = await get_connect_message(
|
|
1521
|
+
mock_api_client(vertexai=vertexai),
|
|
1522
|
+
model='test_model', config=config
|
|
1523
|
+
)
|
|
1524
|
+
|
|
1525
|
+
assert (
|
|
1526
|
+
result['setup']['generationConfig']['mediaResolution']
|
|
1527
|
+
== expected_result['setup']['generationConfig']['mediaResolution']
|
|
1528
|
+
)
|
|
1529
|
+
|
|
1530
|
+
|
|
1531
|
+
@pytest.mark.parametrize('vertexai', [True])
|
|
1532
|
+
@pytest.mark.asyncio
|
|
1533
|
+
async def test_bidi_setup_publishers(
|
|
1534
|
+
vertexai
|
|
1535
|
+
):
|
|
1536
|
+
expected_result = {
|
|
1537
|
+
'setup': {
|
|
1538
|
+
'generationConfig': {
|
|
1539
|
+
'responseModalities': [
|
|
1540
|
+
'AUDIO',
|
|
1541
|
+
],
|
|
1542
|
+
},
|
|
1543
|
+
'model': 'projects/test_project/locations/us-central1/publishers/google/models/test_model',
|
|
1544
|
+
}
|
|
1545
|
+
}
|
|
1546
|
+
result = await get_connect_message(
|
|
1547
|
+
mock_api_client(vertexai=vertexai),
|
|
1548
|
+
model='publishers/google/models/test_model')
|
|
1549
|
+
|
|
1550
|
+
assert result == expected_result
|
|
1551
|
+
|
|
1552
|
+
|
|
1553
|
+
@pytest.mark.parametrize('vertexai', [True, False])
|
|
1554
|
+
@pytest.mark.asyncio
|
|
1555
|
+
async def test_bidi_setup_generation_config_warning(
|
|
1556
|
+
vertexai
|
|
1557
|
+
):
|
|
1558
|
+
with pytest.warns(
|
|
1559
|
+
DeprecationWarning,
|
|
1560
|
+
match='Setting `LiveConnectConfig.generation_config` is deprecated'
|
|
1561
|
+
):
|
|
1562
|
+
result = await get_connect_message(
|
|
1563
|
+
mock_api_client(vertexai=vertexai),
|
|
1564
|
+
model='models/test_model',
|
|
1565
|
+
config={'generation_config': {'temperature': 0.7}})
|
|
1566
|
+
|
|
1567
|
+
assert result['setup']['generationConfig']['temperature'] == 0.7
|
|
1568
|
+
|
|
1569
|
+
|
|
1570
|
+
@pytest.mark.parametrize('vertexai', [True, False])
|
|
1571
|
+
@pytest.mark.asyncio
|
|
1572
|
+
async def test_bidi_setup_to_api_with_session_resumption(vertexai):
|
|
1573
|
+
config_dict = {
|
|
1574
|
+
'session_resumption': {'handle': 'test_handle'},
|
|
1575
|
+
}
|
|
1576
|
+
config = types.LiveConnectConfig(**config_dict)
|
|
1577
|
+
|
|
1578
|
+
result = await get_connect_message(
|
|
1579
|
+
mock_api_client(vertexai=vertexai),
|
|
1580
|
+
model='test_model',
|
|
1581
|
+
config=config
|
|
1582
|
+
)
|
|
1583
|
+
expected_result = {
|
|
1584
|
+
'setup': {
|
|
1585
|
+
'sessionResumption': {
|
|
1586
|
+
'handle': 'test_handle',
|
|
1587
|
+
},
|
|
1588
|
+
}
|
|
1589
|
+
}
|
|
1590
|
+
if vertexai:
|
|
1591
|
+
expected_result['setup']['generationConfig'] = {
|
|
1592
|
+
'responseModalities': [
|
|
1593
|
+
'AUDIO',
|
|
1594
|
+
],
|
|
1595
|
+
}
|
|
1596
|
+
expected_result['setup']['model'] = 'projects/test_project/locations/us-central1/publishers/google/models/test_model'
|
|
1597
|
+
else:
|
|
1598
|
+
expected_result['setup']['model'] = 'models/test_model'
|
|
1599
|
+
assert result == expected_result
|
|
1600
|
+
|
|
1601
|
+
|
|
1602
|
+
@pytest.mark.parametrize('vertexai', [True, False])
|
|
1603
|
+
@pytest.mark.asyncio
|
|
1604
|
+
async def test_bidi_setup_to_api_with_transparent_session_resumption(vertexai):
|
|
1605
|
+
api_client = mock_api_client(vertexai=vertexai)
|
|
1606
|
+
config_dict = {
|
|
1607
|
+
'session_resumption': {'handle': 'test_handle', 'transparent': True},
|
|
1608
|
+
}
|
|
1609
|
+
config = types.LiveConnectConfig(**config_dict)
|
|
1610
|
+
|
|
1611
|
+
with pytest_helper.exception_if_mldev(api_client, ValueError):
|
|
1612
|
+
result = await get_connect_message(
|
|
1613
|
+
mock_api_client(vertexai=vertexai),
|
|
1614
|
+
model='test_model',
|
|
1615
|
+
config=config
|
|
1616
|
+
)
|
|
1617
|
+
|
|
1618
|
+
expected_result = {
|
|
1619
|
+
'setup': {
|
|
1620
|
+
'sessionResumption': {
|
|
1621
|
+
'handle': 'test_handle',
|
|
1622
|
+
'transparent': True,
|
|
1623
|
+
},
|
|
1624
|
+
}
|
|
1625
|
+
}
|
|
1626
|
+
if vertexai:
|
|
1627
|
+
expected_result['setup']['generationConfig'] = {
|
|
1628
|
+
'responseModalities': [
|
|
1629
|
+
'AUDIO',
|
|
1630
|
+
],
|
|
1631
|
+
}
|
|
1632
|
+
expected_result['setup']['model'] = 'projects/test_project/locations/us-central1/publishers/google/models/test_model'
|
|
1633
|
+
else:
|
|
1634
|
+
return
|
|
1635
|
+
|
|
1636
|
+
assert result == expected_result
|
|
1637
|
+
|
|
1638
|
+
|
|
1639
|
+
@pytest.mark.parametrize('vertexai', [True, False])
|
|
1640
|
+
def test_parse_client_message_str( mock_websocket, vertexai):
|
|
1641
|
+
session = live.AsyncSession(
|
|
1642
|
+
api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
|
|
1643
|
+
)
|
|
1644
|
+
result = session._parse_client_message('test')
|
|
1645
|
+
assert 'client_content' in result
|
|
1646
|
+
assert result == {
|
|
1647
|
+
'client_content': {
|
|
1648
|
+
'turn_complete': False,
|
|
1649
|
+
'turns': [{'role': 'user', 'parts': [{'text': 'test'}]}],
|
|
1650
|
+
}
|
|
1651
|
+
}
|
|
1652
|
+
# _parse_client_message returns a TypedDict, so we should be able to
|
|
1653
|
+
# construct a LiveClientMessage from it
|
|
1654
|
+
assert types.LiveClientMessage(**result)
|
|
1655
|
+
|
|
1656
|
+
|
|
1657
|
+
@pytest.mark.parametrize('vertexai', [True, False])
|
|
1658
|
+
@pytest.mark.asyncio
|
|
1659
|
+
async def test_bidi_setup_to_api_with_thinking_config(vertexai):
|
|
1660
|
+
config_dict = {
|
|
1661
|
+
'thinking_config': {
|
|
1662
|
+
'include_thoughts': True,
|
|
1663
|
+
'thinking_budget': 1024,
|
|
1664
|
+
}
|
|
1665
|
+
}
|
|
1666
|
+
|
|
1667
|
+
expected_gen_config = {
|
|
1668
|
+
'thinkingConfig': {
|
|
1669
|
+
'include_thoughts': True,
|
|
1670
|
+
'thinking_budget': 1024,
|
|
1671
|
+
}
|
|
1672
|
+
}
|
|
1673
|
+
|
|
1674
|
+
if vertexai:
|
|
1675
|
+
expected_gen_config['responseModalities'] = ['AUDIO']
|
|
1676
|
+
|
|
1677
|
+
expected_result = {
|
|
1678
|
+
'setup': {
|
|
1679
|
+
'generationConfig': expected_gen_config,
|
|
1680
|
+
}
|
|
1681
|
+
}
|
|
1682
|
+
|
|
1683
|
+
if vertexai:
|
|
1684
|
+
expected_result['setup'][
|
|
1685
|
+
'model'
|
|
1686
|
+
] = 'projects/test_project/locations/us-central1/publishers/google/models/test_model'
|
|
1687
|
+
else:
|
|
1688
|
+
expected_result['setup']['model'] = 'models/test_model'
|
|
1689
|
+
|
|
1690
|
+
result = await get_connect_message(
|
|
1691
|
+
mock_api_client(vertexai=vertexai), model='test_model', config=config_dict
|
|
1692
|
+
)
|
|
1693
|
+
assert result == expected_result
|
|
1694
|
+
|
|
1695
|
+
|
|
1696
|
+
@pytest.mark.parametrize('vertexai', [True, False])
|
|
1697
|
+
def test_parse_client_message_blob( mock_websocket, vertexai):
|
|
1698
|
+
session = live.AsyncSession(
|
|
1699
|
+
api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
|
|
1700
|
+
)
|
|
1701
|
+
result = session._parse_client_message(
|
|
1702
|
+
types.Blob(data=bytes([0, 0, 0]), mime_type='text/plain')
|
|
1703
|
+
)
|
|
1704
|
+
assert 'realtime_input' in result
|
|
1705
|
+
assert result == {
|
|
1706
|
+
'realtime_input': {
|
|
1707
|
+
'media_chunks': [{'mime_type': 'text/plain', 'data': 'AAAA'}],
|
|
1708
|
+
}
|
|
1709
|
+
}
|
|
1710
|
+
|
|
1711
|
+
|
|
1712
|
+
@pytest.mark.parametrize('vertexai', [True, False])
|
|
1713
|
+
def test_parse_client_message_blob_dict(
|
|
1714
|
+
mock_websocket, vertexai
|
|
1715
|
+
):
|
|
1716
|
+
session = live.AsyncSession(
|
|
1717
|
+
api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
|
|
1718
|
+
)
|
|
1719
|
+
|
|
1720
|
+
blob = types.Blob(data=bytes([0, 0, 0]), mime_type='text/plain')
|
|
1721
|
+
blob_dict = blob.model_dump()
|
|
1722
|
+
result = session._parse_client_message(blob_dict)
|
|
1723
|
+
assert 'realtime_input' in result
|
|
1724
|
+
assert result == {
|
|
1725
|
+
'realtime_input': {
|
|
1726
|
+
'media_chunks': [{'mime_type': 'text/plain', 'data': 'AAAA'}],
|
|
1727
|
+
}
|
|
1728
|
+
}
|
|
1729
|
+
|
|
1730
|
+
|
|
1731
|
+
@pytest.mark.parametrize('vertexai', [True, False])
|
|
1732
|
+
def test_parse_client_message_client_content(
|
|
1733
|
+
mock_websocket, vertexai
|
|
1734
|
+
):
|
|
1735
|
+
session = live.AsyncSession(
|
|
1736
|
+
api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
|
|
1737
|
+
)
|
|
1738
|
+
result = session._parse_client_message(
|
|
1739
|
+
types.LiveClientContent(
|
|
1740
|
+
turn_complete=False,
|
|
1741
|
+
turns=[types.Content(parts=[types.Part(text='test')], role='user')],
|
|
1742
|
+
)
|
|
1743
|
+
)
|
|
1744
|
+
assert 'client_content' in result
|
|
1745
|
+
assert result == {
|
|
1746
|
+
'client_content': {
|
|
1747
|
+
'turn_complete': False,
|
|
1748
|
+
'turns': [{'role': 'user', 'parts': [{'text': 'test'}]}],
|
|
1749
|
+
}
|
|
1750
|
+
}
|
|
1751
|
+
|
|
1752
|
+
|
|
1753
|
+
@pytest.mark.parametrize('vertexai', [True, False])
|
|
1754
|
+
def test_parse_client_message_client_content_blob(
|
|
1755
|
+
mock_websocket, vertexai
|
|
1756
|
+
):
|
|
1757
|
+
session = live.AsyncSession(
|
|
1758
|
+
api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
|
|
1759
|
+
)
|
|
1760
|
+
client_content = types.LiveClientContent(
|
|
1761
|
+
turn_complete=False,
|
|
1762
|
+
turns=[
|
|
1763
|
+
types.Content(
|
|
1764
|
+
parts=[
|
|
1765
|
+
types.Part(
|
|
1766
|
+
inline_data=types.Blob(
|
|
1767
|
+
data=bytes([0, 0, 0]), mime_type='text/plain'
|
|
1768
|
+
)
|
|
1769
|
+
)
|
|
1770
|
+
],
|
|
1771
|
+
role='user',
|
|
1772
|
+
)
|
|
1773
|
+
],
|
|
1774
|
+
)
|
|
1775
|
+
result = session._parse_client_message(client_content)
|
|
1776
|
+
assert 'client_content' in result
|
|
1777
|
+
assert (
|
|
1778
|
+
type(
|
|
1779
|
+
result['client_content']['turns'][0]['parts'][0]['inline_data'][
|
|
1780
|
+
'data'
|
|
1781
|
+
]
|
|
1782
|
+
)
|
|
1783
|
+
== str
|
|
1784
|
+
)
|
|
1785
|
+
assert result == {
|
|
1786
|
+
'client_content': {
|
|
1787
|
+
'turn_complete': False,
|
|
1788
|
+
'turns': [{
|
|
1789
|
+
'role': 'user',
|
|
1790
|
+
'parts': [
|
|
1791
|
+
{'inline_data': {'mime_type': 'text/plain', 'data': 'AAAA'}}
|
|
1792
|
+
],
|
|
1793
|
+
}],
|
|
1794
|
+
}
|
|
1795
|
+
}
|
|
1796
|
+
|
|
1797
|
+
|
|
1798
|
+
@pytest.mark.parametrize('vertexai', [True, False])
|
|
1799
|
+
def test_parse_client_message_client_content_dict(
|
|
1800
|
+
mock_websocket, vertexai
|
|
1801
|
+
):
|
|
1802
|
+
session = live.AsyncSession(
|
|
1803
|
+
api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
|
|
1804
|
+
)
|
|
1805
|
+
client_content = types.LiveClientContent(
|
|
1806
|
+
turn_complete=False,
|
|
1807
|
+
turns=[
|
|
1808
|
+
types.Content(
|
|
1809
|
+
parts=[
|
|
1810
|
+
types.Part(
|
|
1811
|
+
inline_data=types.Blob(
|
|
1812
|
+
data=bytes([0, 0, 0]), mime_type='text/plain'
|
|
1813
|
+
)
|
|
1814
|
+
)
|
|
1815
|
+
],
|
|
1816
|
+
role='user',
|
|
1817
|
+
)
|
|
1818
|
+
],
|
|
1819
|
+
)
|
|
1820
|
+
result = session._parse_client_message(
|
|
1821
|
+
client_content.model_dump(mode='json', exclude_none=True)
|
|
1822
|
+
)
|
|
1823
|
+
assert 'client_content' in result
|
|
1824
|
+
assert (
|
|
1825
|
+
type(
|
|
1826
|
+
result['client_content']['turns'][0]['parts'][0]['inline_data'][
|
|
1827
|
+
'data'
|
|
1828
|
+
]
|
|
1829
|
+
)
|
|
1830
|
+
== str
|
|
1831
|
+
)
|
|
1832
|
+
assert result == {
|
|
1833
|
+
'client_content': {
|
|
1834
|
+
'turn_complete': False,
|
|
1835
|
+
'turns': [{
|
|
1836
|
+
'role': 'user',
|
|
1837
|
+
'parts': [
|
|
1838
|
+
{'inline_data': {'mime_type': 'text/plain', 'data': 'AAAA'}}
|
|
1839
|
+
],
|
|
1840
|
+
}],
|
|
1841
|
+
}
|
|
1842
|
+
}
|
|
1843
|
+
|
|
1844
|
+
|
|
1845
|
+
@pytest.mark.parametrize('vertexai', [True, False])
|
|
1846
|
+
def test_parse_client_message_realtime_input(
|
|
1847
|
+
mock_websocket, vertexai
|
|
1848
|
+
):
|
|
1849
|
+
session = live.AsyncSession(
|
|
1850
|
+
api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
|
|
1851
|
+
)
|
|
1852
|
+
input = types.LiveClientRealtimeInput(
|
|
1853
|
+
media_chunks=[types.Blob(data=bytes([0, 0, 0]), mime_type='text/plain')]
|
|
1854
|
+
)
|
|
1855
|
+
result = session._parse_client_message(input)
|
|
1856
|
+
assert 'realtime_input' in result
|
|
1857
|
+
assert result == {
|
|
1858
|
+
'realtime_input': {
|
|
1859
|
+
'media_chunks': [{'mime_type': 'text/plain', 'data': 'AAAA'}],
|
|
1860
|
+
}
|
|
1861
|
+
}
|
|
1862
|
+
|
|
1863
|
+
|
|
1864
|
+
@pytest.mark.parametrize('vertexai', [True, False])
|
|
1865
|
+
def test_parse_client_message_realtime_input_dict(
|
|
1866
|
+
mock_websocket, vertexai
|
|
1867
|
+
):
|
|
1868
|
+
session = live.AsyncSession(
|
|
1869
|
+
api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
|
|
1870
|
+
)
|
|
1871
|
+
input = types.LiveClientRealtimeInput(
|
|
1872
|
+
media_chunks=[types.Blob(data=bytes([0, 0, 0]), mime_type='text/plain')]
|
|
1873
|
+
)
|
|
1874
|
+
result = session._parse_client_message(
|
|
1875
|
+
input.model_dump(mode='json', exclude_none=True)
|
|
1876
|
+
)
|
|
1877
|
+
assert 'realtime_input' in result
|
|
1878
|
+
assert result == {
|
|
1879
|
+
'realtime_input': {
|
|
1880
|
+
'media_chunks': [{'mime_type': 'text/plain', 'data': 'AAAA'}],
|
|
1881
|
+
}
|
|
1882
|
+
}
|
|
1883
|
+
|
|
1884
|
+
|
|
1885
|
+
@pytest.mark.parametrize('vertexai', [True, False])
|
|
1886
|
+
def test_parse_client_message_tool_response(
|
|
1887
|
+
mock_websocket, vertexai
|
|
1888
|
+
):
|
|
1889
|
+
session = live.AsyncSession(
|
|
1890
|
+
api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
|
|
1891
|
+
)
|
|
1892
|
+
input = types.LiveClientToolResponse(
|
|
1893
|
+
function_responses=[
|
|
1894
|
+
types.FunctionResponse(
|
|
1895
|
+
id='test_id',
|
|
1896
|
+
name='test_name',
|
|
1897
|
+
response={'result': 'test_response'},
|
|
1898
|
+
)
|
|
1899
|
+
]
|
|
1900
|
+
)
|
|
1901
|
+
result = session._parse_client_message(input)
|
|
1902
|
+
assert 'tool_response' in result
|
|
1903
|
+
assert result == {
|
|
1904
|
+
'tool_response': {
|
|
1905
|
+
'function_responses': [
|
|
1906
|
+
{
|
|
1907
|
+
'id': 'test_id',
|
|
1908
|
+
'name': 'test_name',
|
|
1909
|
+
'response': {
|
|
1910
|
+
'result': 'test_response',
|
|
1911
|
+
},
|
|
1912
|
+
},
|
|
1913
|
+
],
|
|
1914
|
+
}
|
|
1915
|
+
}
|
|
1916
|
+
|
|
1917
|
+
|
|
1918
|
+
@pytest.mark.parametrize('vertexai', [True, False])
|
|
1919
|
+
def test_parse_client_message_function_response(
|
|
1920
|
+
mock_websocket, vertexai
|
|
1921
|
+
):
|
|
1922
|
+
session = live.AsyncSession(
|
|
1923
|
+
api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
|
|
1924
|
+
)
|
|
1925
|
+
input = types.FunctionResponse(
|
|
1926
|
+
id='test_id',
|
|
1927
|
+
name='test_name',
|
|
1928
|
+
response={
|
|
1929
|
+
'result': 'test_response',
|
|
1930
|
+
'user_name': 'test_user_name',
|
|
1931
|
+
'userEmail': 'test_user_email',
|
|
1932
|
+
},
|
|
1933
|
+
)
|
|
1934
|
+
result = session._parse_client_message(input)
|
|
1935
|
+
assert 'tool_response' in result
|
|
1936
|
+
assert result == {
|
|
1937
|
+
'tool_response': {
|
|
1938
|
+
'function_responses': [
|
|
1939
|
+
{
|
|
1940
|
+
'id': 'test_id',
|
|
1941
|
+
'name': 'test_name',
|
|
1942
|
+
'response': {
|
|
1943
|
+
'result': 'test_response',
|
|
1944
|
+
'user_name': 'test_user_name',
|
|
1945
|
+
'userEmail': 'test_user_email',
|
|
1946
|
+
},
|
|
1947
|
+
},
|
|
1948
|
+
],
|
|
1949
|
+
}
|
|
1950
|
+
}
|
|
1951
|
+
|
|
1952
|
+
|
|
1953
|
+
@pytest.mark.parametrize('vertexai', [True, False])
|
|
1954
|
+
def test_parse_client_message_tool_response_dict_with_only_response(
|
|
1955
|
+
mock_websocket, vertexai
|
|
1956
|
+
):
|
|
1957
|
+
session = live.AsyncSession(
|
|
1958
|
+
api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
|
|
1959
|
+
)
|
|
1960
|
+
input = {
|
|
1961
|
+
'id': 'test_id',
|
|
1962
|
+
'name': 'test_name',
|
|
1963
|
+
'response': {
|
|
1964
|
+
'result': 'test_response',
|
|
1965
|
+
}
|
|
1966
|
+
}
|
|
1967
|
+
result = session._parse_client_message(input)
|
|
1968
|
+
assert 'tool_response' in result
|
|
1969
|
+
assert result == {
|
|
1970
|
+
'tool_response': {
|
|
1971
|
+
'function_responses': [
|
|
1972
|
+
{
|
|
1973
|
+
'id': 'test_id',
|
|
1974
|
+
'name': 'test_name',
|
|
1975
|
+
'response': {
|
|
1976
|
+
'result': 'test_response',
|
|
1977
|
+
},
|
|
1978
|
+
},
|
|
1979
|
+
],
|
|
1980
|
+
}
|
|
1981
|
+
}
|
|
1982
|
+
|
|
1983
|
+
|
|
1984
|
+
@pytest.mark.parametrize('vertexai', [True, False])
|
|
1985
|
+
def test_parse_client_message_realtime_tool_response(
|
|
1986
|
+
mock_websocket, vertexai
|
|
1987
|
+
):
|
|
1988
|
+
session = live.AsyncSession(
|
|
1989
|
+
api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
|
|
1990
|
+
)
|
|
1991
|
+
input = types.LiveClientToolResponse(
|
|
1992
|
+
function_responses=[
|
|
1993
|
+
types.FunctionResponse(
|
|
1994
|
+
id='test_id',
|
|
1995
|
+
name='test_name',
|
|
1996
|
+
response={'result': 'test_response'},
|
|
1997
|
+
)
|
|
1998
|
+
]
|
|
1999
|
+
)
|
|
2000
|
+
|
|
2001
|
+
result = session._parse_client_message(
|
|
2002
|
+
input.model_dump(mode='json', exclude_none=True)
|
|
2003
|
+
)
|
|
2004
|
+
assert 'tool_response' in result
|
|
2005
|
+
assert result == {
|
|
2006
|
+
'tool_response': {
|
|
2007
|
+
'function_responses': [
|
|
2008
|
+
{
|
|
2009
|
+
'id': 'test_id',
|
|
2010
|
+
'name': 'test_name',
|
|
2011
|
+
'response': {
|
|
2012
|
+
'result': 'test_response',
|
|
2013
|
+
},
|
|
2014
|
+
},
|
|
2015
|
+
],
|
|
2016
|
+
}
|
|
2017
|
+
}
|
|
2018
|
+
|
|
2019
|
+
|
|
2020
|
+
@pytest.mark.asyncio
|
|
2021
|
+
async def test_connect_with_provided_credentials(mock_websocket):
|
|
2022
|
+
# custom oauth2 credentials
|
|
2023
|
+
credentials = Credentials(token='provided_fake_token')
|
|
2024
|
+
# mock api client
|
|
2025
|
+
client = mock_api_client(vertexai=True, credentials=credentials)
|
|
2026
|
+
capture = {}
|
|
2027
|
+
|
|
2028
|
+
@contextlib.asynccontextmanager
|
|
2029
|
+
async def mock_connect(uri, additional_headers=None, **kwargs):
|
|
2030
|
+
capture['headers'] = additional_headers
|
|
2031
|
+
yield mock_websocket
|
|
2032
|
+
|
|
2033
|
+
@patch.object(live, 'ws_connect', new=mock_connect)
|
|
2034
|
+
async def _test_connect():
|
|
2035
|
+
live_module = live.AsyncLive(client)
|
|
2036
|
+
async with live_module.connect(model='test-model'):
|
|
2037
|
+
pass
|
|
2038
|
+
|
|
2039
|
+
assert 'Authorization' in capture['headers']
|
|
2040
|
+
assert capture['headers']['Authorization'] == 'Bearer provided_fake_token'
|
|
2041
|
+
|
|
2042
|
+
await _test_connect()
|
|
2043
|
+
|
|
2044
|
+
|
|
2045
|
+
@pytest.mark.asyncio
|
|
2046
|
+
async def test_connect_with_default_credentials(mock_websocket):
|
|
2047
|
+
# mock api client
|
|
2048
|
+
client = mock_api_client(vertexai=True, credentials=None)
|
|
2049
|
+
# mock google auth cred
|
|
2050
|
+
mock_google_auth_default = Mock(return_value=(None, None))
|
|
2051
|
+
mock_creds = Mock(token='default_test_token')
|
|
2052
|
+
mock_google_auth_default.return_value = (mock_creds, None)
|
|
2053
|
+
capture = {}
|
|
2054
|
+
|
|
2055
|
+
@contextlib.asynccontextmanager
|
|
2056
|
+
async def mock_connect(uri, additional_headers=None, **kwargs):
|
|
2057
|
+
capture['headers'] = additional_headers
|
|
2058
|
+
yield mock_websocket
|
|
2059
|
+
|
|
2060
|
+
@patch('google.auth.default', new=mock_google_auth_default)
|
|
2061
|
+
@patch.object(live, 'ws_connect', new=mock_connect)
|
|
2062
|
+
async def _test_connect():
|
|
2063
|
+
live_module = live.AsyncLive(client)
|
|
2064
|
+
async with live_module.connect(model='test-model'):
|
|
2065
|
+
pass
|
|
2066
|
+
|
|
2067
|
+
assert 'Authorization' in capture['headers']
|
|
2068
|
+
assert capture['headers']['Authorization'] == 'Bearer default_test_token'
|
|
2069
|
+
|
|
2070
|
+
await _test_connect()
|
|
2071
|
+
|
|
2072
|
+
|
|
2073
|
+
@pytest.mark.asyncio
|
|
2074
|
+
async def test_connect_with_custom_base_url(mock_websocket):
|
|
2075
|
+
# mock api client
|
|
2076
|
+
client = gl_client.BaseApiClient(
|
|
2077
|
+
vertexai=True,
|
|
2078
|
+
http_options={
|
|
2079
|
+
'base_url': 'https://custom-base-url.com',
|
|
2080
|
+
'headers': {'Authorization': 'Bearer custom_test_token'},
|
|
2081
|
+
}
|
|
2082
|
+
)
|
|
2083
|
+
# No ADC credentials.
|
|
2084
|
+
capture = {}
|
|
2085
|
+
|
|
2086
|
+
@contextlib.asynccontextmanager
|
|
2087
|
+
async def mock_connect(uri, additional_headers=None, **kwargs):
|
|
2088
|
+
capture['uri'] = uri
|
|
2089
|
+
capture['headers'] = additional_headers
|
|
2090
|
+
yield mock_websocket
|
|
2091
|
+
|
|
2092
|
+
@patch.object(live, 'ws_connect', new=mock_connect)
|
|
2093
|
+
async def _test_connect():
|
|
2094
|
+
live_module = live.AsyncLive(client)
|
|
2095
|
+
async with live_module.connect(model='test-model'):
|
|
2096
|
+
pass
|
|
2097
|
+
|
|
2098
|
+
assert 'Authorization' in capture['headers']
|
|
2099
|
+
assert capture['headers']['Authorization'] == 'Bearer custom_test_token'
|
|
2100
|
+
assert capture['uri'] == 'https://custom-base-url.com'
|
|
2101
|
+
|
|
2102
|
+
await _test_connect()
|
|
2103
|
+
|
|
2104
|
+
|
|
2105
|
+
@pytest.mark.parametrize('vertexai', [False])
|
|
2106
|
+
@pytest.mark.asyncio
|
|
2107
|
+
async def test_bidi_setup_to_api_with_auth_tokens(mock_websocket, vertexai):
|
|
2108
|
+
api_client_mock = mock_api_client(vertexai=vertexai)
|
|
2109
|
+
api_client_mock.api_key = 'auth_tokens/TEST_AUTH_TOKEN'
|
|
2110
|
+
result = await get_connect_message(api_client_mock, model='test_model')
|
|
2111
|
+
|
|
2112
|
+
mock_ws = AsyncMock()
|
|
2113
|
+
mock_ws.send = AsyncMock()
|
|
2114
|
+
mock_ws.recv = AsyncMock(
|
|
2115
|
+
return_value=(
|
|
2116
|
+
b'{\n "setupComplete": {"sessionId": "test_session_id"}\n}\n'
|
|
2117
|
+
)
|
|
2118
|
+
)
|
|
2119
|
+
capture = {}
|
|
2120
|
+
|
|
2121
|
+
@contextlib.asynccontextmanager
|
|
2122
|
+
async def mock_connect(uri, additional_headers=None, **kwargs):
|
|
2123
|
+
capture['uri'] = uri
|
|
2124
|
+
capture['headers'] = additional_headers
|
|
2125
|
+
yield mock_ws
|
|
2126
|
+
|
|
2127
|
+
with patch.object(live, 'ws_connect', new=mock_connect):
|
|
2128
|
+
live_module = live.AsyncLive(api_client_mock)
|
|
2129
|
+
async with live_module.connect(
|
|
2130
|
+
model='test_model',
|
|
2131
|
+
):
|
|
2132
|
+
pass
|
|
2133
|
+
|
|
2134
|
+
assert (
|
|
2135
|
+
'Authorization' in capture['headers']
|
|
2136
|
+
), 'Authorization key is missing from headers'
|
|
2137
|
+
assert (
|
|
2138
|
+
capture['headers']['Authorization'] == 'Token auth_tokens/TEST_AUTH_TOKEN'
|
|
2139
|
+
)
|
|
2140
|
+
assert 'BidiGenerateContentConstrained' in capture['uri']
|
|
2141
|
+
|
|
2142
|
+
|
|
2143
|
+
@pytest.mark.parametrize('vertexai', [False])
|
|
2144
|
+
@pytest.mark.asyncio
|
|
2145
|
+
async def test_bidi_setup_to_api_with_api_key(mock_websocket, vertexai):
|
|
2146
|
+
api_client_mock = mock_api_client(vertexai=vertexai)
|
|
2147
|
+
api_client_mock._http_options = types.HttpOptions.model_validate(
|
|
2148
|
+
{'headers': {'x-goog-api-key': 'TEST_API_KEY'}}
|
|
2149
|
+
)
|
|
2150
|
+
result = await get_connect_message(api_client_mock, model='test_model')
|
|
2151
|
+
|
|
2152
|
+
mock_ws = AsyncMock()
|
|
2153
|
+
mock_ws.send = AsyncMock()
|
|
2154
|
+
mock_ws.recv = AsyncMock(
|
|
2155
|
+
return_value=(
|
|
2156
|
+
b'{\n "setupComplete": {"sessionId": "test_session_id"}\n}\n'
|
|
2157
|
+
)
|
|
2158
|
+
)
|
|
2159
|
+
capture = {}
|
|
2160
|
+
|
|
2161
|
+
@contextlib.asynccontextmanager
|
|
2162
|
+
async def mock_connect(uri, additional_headers=None, **kwargs):
|
|
2163
|
+
capture['uri'] = uri
|
|
2164
|
+
capture['headers'] = additional_headers
|
|
2165
|
+
yield mock_ws
|
|
2166
|
+
|
|
2167
|
+
with patch.object(live, 'ws_connect', new=mock_connect):
|
|
2168
|
+
live_module = live.AsyncLive(api_client_mock)
|
|
2169
|
+
async with live_module.connect(
|
|
2170
|
+
model='test_model',
|
|
2171
|
+
):
|
|
2172
|
+
pass
|
|
2173
|
+
|
|
2174
|
+
assert 'x-goog-api-key' in capture['headers'], "x-goog-api-key is missing from headers"
|
|
2175
|
+
assert capture['headers']['x-goog-api-key'] == 'TEST_API_KEY'
|
|
2176
|
+
assert 'BidiGenerateContent' in capture['uri']
|
|
2177
|
+
|