google-genai 1.54.0__py3-none-any.whl → 1.55.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google/genai/__init__.py +1 -0
- google/genai/_interactions/__init__.py +117 -0
- google/genai/_interactions/_base_client.py +2019 -0
- google/genai/_interactions/_client.py +511 -0
- google/genai/_interactions/_compat.py +234 -0
- google/genai/_interactions/_constants.py +29 -0
- google/genai/_interactions/_exceptions.py +122 -0
- google/genai/_interactions/_files.py +139 -0
- google/genai/_interactions/_models.py +873 -0
- google/genai/_interactions/_qs.py +165 -0
- google/genai/_interactions/_resource.py +58 -0
- google/genai/_interactions/_response.py +847 -0
- google/genai/_interactions/_streaming.py +354 -0
- google/genai/_interactions/_types.py +276 -0
- google/genai/_interactions/_utils/__init__.py +79 -0
- google/genai/_interactions/_utils/_compat.py +61 -0
- google/genai/_interactions/_utils/_datetime_parse.py +151 -0
- google/genai/_interactions/_utils/_logs.py +40 -0
- google/genai/_interactions/_utils/_proxy.py +80 -0
- google/genai/_interactions/_utils/_reflection.py +57 -0
- google/genai/_interactions/_utils/_resources_proxy.py +39 -0
- google/genai/_interactions/_utils/_streams.py +27 -0
- google/genai/_interactions/_utils/_sync.py +73 -0
- google/genai/_interactions/_utils/_transform.py +472 -0
- google/genai/_interactions/_utils/_typing.py +172 -0
- google/genai/_interactions/_utils/_utils.py +437 -0
- google/genai/_interactions/_version.py +18 -0
- google/genai/_interactions/resources/__init__.py +34 -0
- google/genai/_interactions/resources/interactions.py +1350 -0
- google/genai/_interactions/types/__init__.py +107 -0
- google/genai/_interactions/types/allowed_tools.py +33 -0
- google/genai/_interactions/types/allowed_tools_param.py +35 -0
- google/genai/_interactions/types/annotation.py +42 -0
- google/genai/_interactions/types/annotation_param.py +42 -0
- google/genai/_interactions/types/audio_content.py +38 -0
- google/genai/_interactions/types/audio_content_param.py +45 -0
- google/genai/_interactions/types/audio_mime_type.py +25 -0
- google/genai/_interactions/types/audio_mime_type_param.py +27 -0
- google/genai/_interactions/types/code_execution_call_arguments.py +33 -0
- google/genai/_interactions/types/code_execution_call_arguments_param.py +32 -0
- google/genai/_interactions/types/code_execution_call_content.py +37 -0
- google/genai/_interactions/types/code_execution_call_content_param.py +37 -0
- google/genai/_interactions/types/code_execution_result_content.py +42 -0
- google/genai/_interactions/types/code_execution_result_content_param.py +41 -0
- google/genai/_interactions/types/content_delta.py +358 -0
- google/genai/_interactions/types/content_start.py +79 -0
- google/genai/_interactions/types/content_stop.py +35 -0
- google/genai/_interactions/types/deep_research_agent_config.py +33 -0
- google/genai/_interactions/types/deep_research_agent_config_param.py +32 -0
- google/genai/_interactions/types/document_content.py +36 -0
- google/genai/_interactions/types/document_content_param.py +43 -0
- google/genai/_interactions/types/dynamic_agent_config.py +44 -0
- google/genai/_interactions/types/dynamic_agent_config_param.py +33 -0
- google/genai/_interactions/types/error_event.py +46 -0
- google/genai/_interactions/types/file_search_result_content.py +46 -0
- google/genai/_interactions/types/file_search_result_content_param.py +46 -0
- google/genai/_interactions/types/function.py +38 -0
- google/genai/_interactions/types/function_call_content.py +39 -0
- google/genai/_interactions/types/function_call_content_param.py +39 -0
- google/genai/_interactions/types/function_param.py +37 -0
- google/genai/_interactions/types/function_result_content.py +52 -0
- google/genai/_interactions/types/function_result_content_param.py +54 -0
- google/genai/_interactions/types/generation_config.py +57 -0
- google/genai/_interactions/types/generation_config_param.py +59 -0
- google/genai/_interactions/types/google_search_call_arguments.py +29 -0
- google/genai/_interactions/types/google_search_call_arguments_param.py +31 -0
- google/genai/_interactions/types/google_search_call_content.py +37 -0
- google/genai/_interactions/types/google_search_call_content_param.py +37 -0
- google/genai/_interactions/types/google_search_result.py +35 -0
- google/genai/_interactions/types/google_search_result_content.py +43 -0
- google/genai/_interactions/types/google_search_result_content_param.py +44 -0
- google/genai/_interactions/types/google_search_result_param.py +35 -0
- google/genai/_interactions/types/image_content.py +41 -0
- google/genai/_interactions/types/image_content_param.py +48 -0
- google/genai/_interactions/types/image_mime_type.py +23 -0
- google/genai/_interactions/types/image_mime_type_param.py +25 -0
- google/genai/_interactions/types/interaction.py +165 -0
- google/genai/_interactions/types/interaction_create_params.py +212 -0
- google/genai/_interactions/types/interaction_event.py +37 -0
- google/genai/_interactions/types/interaction_get_params.py +46 -0
- google/genai/_interactions/types/interaction_sse_event.py +32 -0
- google/genai/_interactions/types/interaction_status_update.py +37 -0
- google/genai/_interactions/types/mcp_server_tool_call_content.py +42 -0
- google/genai/_interactions/types/mcp_server_tool_call_content_param.py +42 -0
- google/genai/_interactions/types/mcp_server_tool_result_content.py +52 -0
- google/genai/_interactions/types/mcp_server_tool_result_content_param.py +54 -0
- google/genai/_interactions/types/model.py +36 -0
- google/genai/_interactions/types/model_param.py +38 -0
- google/genai/_interactions/types/speech_config.py +35 -0
- google/genai/_interactions/types/speech_config_param.py +35 -0
- google/genai/_interactions/types/text_content.py +37 -0
- google/genai/_interactions/types/text_content_param.py +38 -0
- google/genai/_interactions/types/thinking_level.py +22 -0
- google/genai/_interactions/types/thought_content.py +41 -0
- google/genai/_interactions/types/thought_content_param.py +47 -0
- google/genai/_interactions/types/tool.py +100 -0
- google/genai/_interactions/types/tool_choice.py +26 -0
- google/genai/_interactions/types/tool_choice_config.py +28 -0
- google/genai/_interactions/types/tool_choice_config_param.py +29 -0
- google/genai/_interactions/types/tool_choice_param.py +28 -0
- google/genai/_interactions/types/tool_choice_type.py +22 -0
- google/genai/_interactions/types/tool_param.py +97 -0
- google/genai/_interactions/types/turn.py +76 -0
- google/genai/_interactions/types/turn_param.py +73 -0
- google/genai/_interactions/types/url_context_call_arguments.py +29 -0
- google/genai/_interactions/types/url_context_call_arguments_param.py +31 -0
- google/genai/_interactions/types/url_context_call_content.py +37 -0
- google/genai/_interactions/types/url_context_call_content_param.py +37 -0
- google/genai/_interactions/types/url_context_result.py +33 -0
- google/genai/_interactions/types/url_context_result_content.py +43 -0
- google/genai/_interactions/types/url_context_result_content_param.py +44 -0
- google/genai/_interactions/types/url_context_result_param.py +32 -0
- google/genai/_interactions/types/usage.py +106 -0
- google/genai/_interactions/types/usage_param.py +106 -0
- google/genai/_interactions/types/video_content.py +41 -0
- google/genai/_interactions/types/video_content_param.py +48 -0
- google/genai/_interactions/types/video_mime_type.py +36 -0
- google/genai/_interactions/types/video_mime_type_param.py +38 -0
- google/genai/_live_converters.py +31 -0
- google/genai/_tokens_converters.py +5 -0
- google/genai/batches.py +7 -0
- google/genai/client.py +223 -0
- google/genai/interactions.py +17 -0
- google/genai/live.py +4 -3
- google/genai/models.py +12 -0
- google/genai/tests/__init__.py +21 -0
- google/genai/tests/afc/__init__.py +21 -0
- google/genai/tests/afc/test_convert_if_exist_pydantic_model.py +309 -0
- google/genai/tests/afc/test_convert_number_values_for_function_call_args.py +63 -0
- google/genai/tests/afc/test_find_afc_incompatible_tool_indexes.py +240 -0
- google/genai/tests/afc/test_generate_content_stream_afc.py +530 -0
- google/genai/tests/afc/test_generate_content_stream_afc_thoughts.py +77 -0
- google/genai/tests/afc/test_get_function_map.py +176 -0
- google/genai/tests/afc/test_get_function_response_parts.py +277 -0
- google/genai/tests/afc/test_get_max_remote_calls_for_afc.py +130 -0
- google/genai/tests/afc/test_invoke_function_from_dict_args.py +241 -0
- google/genai/tests/afc/test_raise_error_for_afc_incompatible_config.py +159 -0
- google/genai/tests/afc/test_should_append_afc_history.py +53 -0
- google/genai/tests/afc/test_should_disable_afc.py +214 -0
- google/genai/tests/batches/__init__.py +17 -0
- google/genai/tests/batches/test_cancel.py +77 -0
- google/genai/tests/batches/test_create.py +78 -0
- google/genai/tests/batches/test_create_with_bigquery.py +113 -0
- google/genai/tests/batches/test_create_with_file.py +82 -0
- google/genai/tests/batches/test_create_with_gcs.py +125 -0
- google/genai/tests/batches/test_create_with_inlined_requests.py +255 -0
- google/genai/tests/batches/test_delete.py +86 -0
- google/genai/tests/batches/test_embedding.py +157 -0
- google/genai/tests/batches/test_get.py +78 -0
- google/genai/tests/batches/test_list.py +79 -0
- google/genai/tests/caches/__init__.py +17 -0
- google/genai/tests/caches/constants.py +29 -0
- google/genai/tests/caches/test_create.py +210 -0
- google/genai/tests/caches/test_create_custom_url.py +105 -0
- google/genai/tests/caches/test_delete.py +54 -0
- google/genai/tests/caches/test_delete_custom_url.py +52 -0
- google/genai/tests/caches/test_get.py +94 -0
- google/genai/tests/caches/test_get_custom_url.py +52 -0
- google/genai/tests/caches/test_list.py +68 -0
- google/genai/tests/caches/test_update.py +70 -0
- google/genai/tests/caches/test_update_custom_url.py +58 -0
- google/genai/tests/chats/__init__.py +1 -0
- google/genai/tests/chats/test_get_history.py +597 -0
- google/genai/tests/chats/test_send_message.py +844 -0
- google/genai/tests/chats/test_validate_response.py +90 -0
- google/genai/tests/client/__init__.py +17 -0
- google/genai/tests/client/test_async_stream.py +427 -0
- google/genai/tests/client/test_client_close.py +197 -0
- google/genai/tests/client/test_client_initialization.py +1687 -0
- google/genai/tests/client/test_client_requests.py +355 -0
- google/genai/tests/client/test_custom_client.py +77 -0
- google/genai/tests/client/test_http_options.py +178 -0
- google/genai/tests/client/test_replay_client_equality.py +168 -0
- google/genai/tests/client/test_retries.py +846 -0
- google/genai/tests/client/test_upload_errors.py +136 -0
- google/genai/tests/common/__init__.py +17 -0
- google/genai/tests/common/test_common.py +954 -0
- google/genai/tests/conftest.py +162 -0
- google/genai/tests/documents/__init__.py +17 -0
- google/genai/tests/documents/test_delete.py +51 -0
- google/genai/tests/documents/test_get.py +85 -0
- google/genai/tests/documents/test_list.py +72 -0
- google/genai/tests/errors/__init__.py +1 -0
- google/genai/tests/errors/test_api_error.py +417 -0
- google/genai/tests/file_search_stores/__init__.py +17 -0
- google/genai/tests/file_search_stores/test_create.py +66 -0
- google/genai/tests/file_search_stores/test_delete.py +64 -0
- google/genai/tests/file_search_stores/test_get.py +94 -0
- google/genai/tests/file_search_stores/test_import_file.py +112 -0
- google/genai/tests/file_search_stores/test_list.py +57 -0
- google/genai/tests/file_search_stores/test_upload_to_file_search_store.py +141 -0
- google/genai/tests/files/__init__.py +17 -0
- google/genai/tests/files/test_delete.py +46 -0
- google/genai/tests/files/test_download.py +85 -0
- google/genai/tests/files/test_get.py +46 -0
- google/genai/tests/files/test_list.py +72 -0
- google/genai/tests/files/test_upload.py +255 -0
- google/genai/tests/imports/test_no_optional_imports.py +28 -0
- google/genai/tests/interactions/__init__.py +0 -0
- google/genai/tests/interactions/test_integration.py +80 -0
- google/genai/tests/live/__init__.py +16 -0
- google/genai/tests/live/test_live.py +2177 -0
- google/genai/tests/live/test_live_music.py +362 -0
- google/genai/tests/live/test_live_response.py +163 -0
- google/genai/tests/live/test_send_client_content.py +147 -0
- google/genai/tests/live/test_send_realtime_input.py +268 -0
- google/genai/tests/live/test_send_tool_response.py +222 -0
- google/genai/tests/local_tokenizer/__init__.py +17 -0
- google/genai/tests/local_tokenizer/test_local_tokenizer.py +343 -0
- google/genai/tests/local_tokenizer/test_local_tokenizer_loader.py +235 -0
- google/genai/tests/mcp/__init__.py +17 -0
- google/genai/tests/mcp/test_has_mcp_tool_usage.py +89 -0
- google/genai/tests/mcp/test_mcp_to_gemini_tools.py +191 -0
- google/genai/tests/mcp/test_parse_config_for_mcp_sessions.py +201 -0
- google/genai/tests/mcp/test_parse_config_for_mcp_usage.py +130 -0
- google/genai/tests/mcp/test_set_mcp_usage_header.py +72 -0
- google/genai/tests/models/__init__.py +17 -0
- google/genai/tests/models/constants.py +8 -0
- google/genai/tests/models/test_compute_tokens.py +120 -0
- google/genai/tests/models/test_count_tokens.py +159 -0
- google/genai/tests/models/test_delete.py +107 -0
- google/genai/tests/models/test_edit_image.py +264 -0
- google/genai/tests/models/test_embed_content.py +94 -0
- google/genai/tests/models/test_function_call_streaming.py +442 -0
- google/genai/tests/models/test_generate_content.py +2502 -0
- google/genai/tests/models/test_generate_content_cached_content.py +132 -0
- google/genai/tests/models/test_generate_content_config_zero_value.py +103 -0
- google/genai/tests/models/test_generate_content_from_apikey.py +44 -0
- google/genai/tests/models/test_generate_content_http_options.py +40 -0
- google/genai/tests/models/test_generate_content_image_generation.py +143 -0
- google/genai/tests/models/test_generate_content_mcp.py +343 -0
- google/genai/tests/models/test_generate_content_media_resolution.py +97 -0
- google/genai/tests/models/test_generate_content_model.py +139 -0
- google/genai/tests/models/test_generate_content_part.py +821 -0
- google/genai/tests/models/test_generate_content_thought.py +76 -0
- google/genai/tests/models/test_generate_content_tools.py +1761 -0
- google/genai/tests/models/test_generate_images.py +191 -0
- google/genai/tests/models/test_generate_videos.py +759 -0
- google/genai/tests/models/test_get.py +104 -0
- google/genai/tests/models/test_list.py +233 -0
- google/genai/tests/models/test_recontext_image.py +189 -0
- google/genai/tests/models/test_segment_image.py +148 -0
- google/genai/tests/models/test_update.py +95 -0
- google/genai/tests/models/test_upscale_image.py +157 -0
- google/genai/tests/operations/__init__.py +17 -0
- google/genai/tests/operations/test_get.py +38 -0
- google/genai/tests/public_samples/__init__.py +17 -0
- google/genai/tests/public_samples/test_gemini_text_only.py +34 -0
- google/genai/tests/pytest_helper.py +229 -0
- google/genai/tests/shared/__init__.py +16 -0
- google/genai/tests/shared/batches/__init__.py +14 -0
- google/genai/tests/shared/batches/test_create_delete.py +57 -0
- google/genai/tests/shared/batches/test_create_get_cancel.py +56 -0
- google/genai/tests/shared/batches/test_list.py +40 -0
- google/genai/tests/shared/caches/__init__.py +14 -0
- google/genai/tests/shared/caches/test_create_get_delete.py +67 -0
- google/genai/tests/shared/caches/test_create_update_get.py +71 -0
- google/genai/tests/shared/caches/test_list.py +40 -0
- google/genai/tests/shared/chats/__init__.py +14 -0
- google/genai/tests/shared/chats/test_send_message.py +48 -0
- google/genai/tests/shared/chats/test_send_message_stream.py +50 -0
- google/genai/tests/shared/files/__init__.py +14 -0
- google/genai/tests/shared/files/test_list.py +41 -0
- google/genai/tests/shared/files/test_upload_get_delete.py +54 -0
- google/genai/tests/shared/models/__init__.py +14 -0
- google/genai/tests/shared/models/test_compute_tokens.py +41 -0
- google/genai/tests/shared/models/test_count_tokens.py +40 -0
- google/genai/tests/shared/models/test_edit_image.py +67 -0
- google/genai/tests/shared/models/test_embed.py +40 -0
- google/genai/tests/shared/models/test_generate_content.py +39 -0
- google/genai/tests/shared/models/test_generate_content_stream.py +54 -0
- google/genai/tests/shared/models/test_generate_images.py +40 -0
- google/genai/tests/shared/models/test_generate_videos.py +38 -0
- google/genai/tests/shared/models/test_list.py +37 -0
- google/genai/tests/shared/models/test_recontext_image.py +55 -0
- google/genai/tests/shared/models/test_segment_image.py +52 -0
- google/genai/tests/shared/models/test_upscale_image.py +52 -0
- google/genai/tests/shared/tunings/__init__.py +16 -0
- google/genai/tests/shared/tunings/test_create.py +46 -0
- google/genai/tests/shared/tunings/test_create_get_cancel.py +56 -0
- google/genai/tests/shared/tunings/test_list.py +39 -0
- google/genai/tests/tokens/__init__.py +16 -0
- google/genai/tests/tokens/test_create.py +154 -0
- google/genai/tests/transformers/__init__.py +17 -0
- google/genai/tests/transformers/test_blobs.py +71 -0
- google/genai/tests/transformers/test_bytes.py +15 -0
- google/genai/tests/transformers/test_duck_type.py +96 -0
- google/genai/tests/transformers/test_function_responses.py +72 -0
- google/genai/tests/transformers/test_schema.py +653 -0
- google/genai/tests/transformers/test_t_batch.py +286 -0
- google/genai/tests/transformers/test_t_content.py +160 -0
- google/genai/tests/transformers/test_t_contents.py +398 -0
- google/genai/tests/transformers/test_t_part.py +85 -0
- google/genai/tests/transformers/test_t_parts.py +87 -0
- google/genai/tests/transformers/test_t_tool.py +157 -0
- google/genai/tests/transformers/test_t_tools.py +195 -0
- google/genai/tests/tunings/__init__.py +16 -0
- google/genai/tests/tunings/test_cancel.py +39 -0
- google/genai/tests/tunings/test_end_to_end.py +106 -0
- google/genai/tests/tunings/test_get.py +67 -0
- google/genai/tests/tunings/test_list.py +75 -0
- google/genai/tests/tunings/test_tune.py +268 -0
- google/genai/tests/types/__init__.py +16 -0
- google/genai/tests/types/test_bytes_internal.py +271 -0
- google/genai/tests/types/test_bytes_type.py +152 -0
- google/genai/tests/types/test_future.py +101 -0
- google/genai/tests/types/test_optional_types.py +36 -0
- google/genai/tests/types/test_part_type.py +616 -0
- google/genai/tests/types/test_schema_from_json_schema.py +417 -0
- google/genai/tests/types/test_schema_json_schema.py +468 -0
- google/genai/tests/types/test_types.py +2903 -0
- google/genai/types.py +72 -0
- google/genai/version.py +1 -1
- {google_genai-1.54.0.dist-info → google_genai-1.55.0.dist-info}/METADATA +3 -1
- google_genai-1.55.0.dist-info/RECORD +345 -0
- google_genai-1.54.0.dist-info/RECORD +0 -41
- {google_genai-1.54.0.dist-info → google_genai-1.55.0.dist-info}/WHEEL +0 -0
- {google_genai-1.54.0.dist-info → google_genai-1.55.0.dist-info}/licenses/LICENSE +0 -0
- {google_genai-1.54.0.dist-info → google_genai-1.55.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,362 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
#
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
"""Tests for live_music.py."""
|
|
18
|
+
import contextlib
|
|
19
|
+
import json
|
|
20
|
+
from typing import AsyncIterator
|
|
21
|
+
from unittest import mock
|
|
22
|
+
from unittest.mock import AsyncMock
|
|
23
|
+
from unittest.mock import Mock
|
|
24
|
+
from unittest.mock import patch
|
|
25
|
+
import warnings
|
|
26
|
+
|
|
27
|
+
from google.oauth2.credentials import Credentials
|
|
28
|
+
import pytest
|
|
29
|
+
from websockets import client
|
|
30
|
+
|
|
31
|
+
from ... import _api_client as api_client
|
|
32
|
+
from ... import _common
|
|
33
|
+
from ... import Client
|
|
34
|
+
from ... import client as gl_client
|
|
35
|
+
from ... import live
|
|
36
|
+
from ... import live_music
|
|
37
|
+
from ... import types
|
|
38
|
+
from .. import pytest_helper
|
|
39
|
+
try:
|
|
40
|
+
import aiohttp
|
|
41
|
+
AIOHTTP_NOT_INSTALLED = False
|
|
42
|
+
except ImportError:
|
|
43
|
+
AIOHTTP_NOT_INSTALLED = True
|
|
44
|
+
aiohttp = mock.MagicMock()
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
requires_aiohttp = pytest.mark.skipif(
|
|
48
|
+
AIOHTTP_NOT_INSTALLED, reason="aiohttp is not installed, skipping test."
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def mock_api_client(vertexai=False, credentials=None):
|
|
53
|
+
api_client = mock.MagicMock(spec=gl_client.BaseApiClient)
|
|
54
|
+
if not vertexai:
|
|
55
|
+
api_client.api_key = 'TEST_API_KEY'
|
|
56
|
+
api_client.location = None
|
|
57
|
+
api_client.project = None
|
|
58
|
+
else:
|
|
59
|
+
api_client.api_key = None
|
|
60
|
+
api_client.location = 'us-central1'
|
|
61
|
+
api_client.project = 'test_project'
|
|
62
|
+
|
|
63
|
+
api_client._host = lambda: 'test_host'
|
|
64
|
+
api_client._credentials = credentials
|
|
65
|
+
api_client._http_options = types.HttpOptions.model_validate(
|
|
66
|
+
{'headers': {}}
|
|
67
|
+
) # Ensure headers exist
|
|
68
|
+
api_client.vertexai = vertexai
|
|
69
|
+
api_client._api_client = api_client
|
|
70
|
+
return api_client
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
@pytest.fixture
|
|
74
|
+
def mock_websocket():
|
|
75
|
+
websocket = AsyncMock(spec=client.ClientConnection)
|
|
76
|
+
websocket.send = AsyncMock()
|
|
77
|
+
websocket.recv = AsyncMock(
|
|
78
|
+
return_value=b"""{
|
|
79
|
+
"serverContent": {
|
|
80
|
+
"audioChunks": [
|
|
81
|
+
{
|
|
82
|
+
"data": "Z2VsYmFuYW5h",
|
|
83
|
+
"mimeType": "audio/l16;rate=48000;channels=2",
|
|
84
|
+
"sourceMetadata": {
|
|
85
|
+
"clientContent": {
|
|
86
|
+
"weightedPrompts": [
|
|
87
|
+
{
|
|
88
|
+
"text": "Jazz",
|
|
89
|
+
"weight": 1
|
|
90
|
+
}
|
|
91
|
+
]
|
|
92
|
+
},
|
|
93
|
+
"musicGenerationConfig": {
|
|
94
|
+
"seed": -957124937,
|
|
95
|
+
"bpm": 140,
|
|
96
|
+
"scale": "A_FLAT_MAJOR_F_MINOR"
|
|
97
|
+
}
|
|
98
|
+
}
|
|
99
|
+
}
|
|
100
|
+
]
|
|
101
|
+
}
|
|
102
|
+
}"""
|
|
103
|
+
) # Default response
|
|
104
|
+
websocket.close = AsyncMock()
|
|
105
|
+
return websocket
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
async def get_connect_message(api_client, model):
|
|
109
|
+
mock_ws = AsyncMock()
|
|
110
|
+
mock_ws.send = AsyncMock()
|
|
111
|
+
mock_ws.recv = AsyncMock(return_value=b'some response')
|
|
112
|
+
|
|
113
|
+
mock_google_auth_default = Mock(return_value=(None, None))
|
|
114
|
+
mock_creds = Mock(token='test_token')
|
|
115
|
+
mock_google_auth_default.return_value = (mock_creds, None)
|
|
116
|
+
|
|
117
|
+
@contextlib.asynccontextmanager
|
|
118
|
+
async def mock_connect(uri, additional_headers=None):
|
|
119
|
+
yield mock_ws
|
|
120
|
+
|
|
121
|
+
@patch('google.auth.default', new=mock_google_auth_default)
|
|
122
|
+
@patch.object(live_music, 'connect', new=mock_connect)
|
|
123
|
+
async def _test_connect():
|
|
124
|
+
live_module = live.AsyncLive(api_client)
|
|
125
|
+
async with live_module.music.connect(
|
|
126
|
+
model=model,
|
|
127
|
+
):
|
|
128
|
+
pass
|
|
129
|
+
|
|
130
|
+
mock_ws.send.assert_called_once()
|
|
131
|
+
return json.loads(mock_ws.send.call_args[0][0])
|
|
132
|
+
|
|
133
|
+
return await _test_connect()
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def test_mldev_from_env(monkeypatch):
|
|
137
|
+
api_key = 'google_api_key'
|
|
138
|
+
monkeypatch.setenv('GOOGLE_API_KEY', api_key)
|
|
139
|
+
|
|
140
|
+
client = Client()
|
|
141
|
+
|
|
142
|
+
assert not client.aio.live.music._api_client.vertexai
|
|
143
|
+
assert client.aio.live.music._api_client.api_key == api_key
|
|
144
|
+
assert isinstance(client.aio.live._api_client, api_client.BaseApiClient)
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
@requires_aiohttp
|
|
148
|
+
def test_vertex_from_env(monkeypatch):
|
|
149
|
+
project_id = 'fake_project_id'
|
|
150
|
+
location = 'fake-location'
|
|
151
|
+
monkeypatch.setenv('GOOGLE_GENAI_USE_VERTEXAI', 'true')
|
|
152
|
+
monkeypatch.setenv('GOOGLE_CLOUD_PROJECT', project_id)
|
|
153
|
+
monkeypatch.setenv('GOOGLE_CLOUD_LOCATION', location)
|
|
154
|
+
|
|
155
|
+
client = Client()
|
|
156
|
+
|
|
157
|
+
assert client.aio.live.music._api_client.vertexai
|
|
158
|
+
assert client.aio.live.music._api_client.project == project_id
|
|
159
|
+
assert isinstance(client.aio.live._api_client, api_client.BaseApiClient)
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def test_websocket_base_url():
|
|
163
|
+
base_url = 'https://test.com'
|
|
164
|
+
api_client = gl_client.BaseApiClient(
|
|
165
|
+
api_key='google_api_key',
|
|
166
|
+
http_options={'base_url': base_url},
|
|
167
|
+
)
|
|
168
|
+
assert api_client._websocket_base_url() == 'wss://test.com'
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
@pytest.mark.parametrize('vertexai', [True, False])
|
|
172
|
+
@pytest.mark.asyncio
|
|
173
|
+
async def test_async_session_send_weighted_prompts(
|
|
174
|
+
mock_websocket, vertexai
|
|
175
|
+
):
|
|
176
|
+
session = live_music.AsyncMusicSession(
|
|
177
|
+
api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
|
|
178
|
+
)
|
|
179
|
+
if vertexai:
|
|
180
|
+
with pytest.raises(NotImplementedError):
|
|
181
|
+
await session.set_weighted_prompts(prompts=[types.WeightedPrompt(text='Jazz', weight=1)])
|
|
182
|
+
return
|
|
183
|
+
await session.set_weighted_prompts(prompts=[types.WeightedPrompt(text='Jazz', weight=1)])
|
|
184
|
+
mock_websocket.send.assert_called_once()
|
|
185
|
+
sent_data = json.loads(mock_websocket.send.call_args[0][0])
|
|
186
|
+
assert 'clientContent' in sent_data
|
|
187
|
+
assert sent_data['clientContent']['weightedPrompts'][0]['text'] == 'Jazz'
|
|
188
|
+
assert sent_data['clientContent']['weightedPrompts'][0]['weight'] == 1
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
@pytest.mark.parametrize('vertexai', [True, False])
|
|
192
|
+
@pytest.mark.asyncio
|
|
193
|
+
async def test_async_session_send_config(
|
|
194
|
+
mock_websocket, vertexai
|
|
195
|
+
):
|
|
196
|
+
session = live_music.AsyncMusicSession(
|
|
197
|
+
api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
|
|
198
|
+
)
|
|
199
|
+
if vertexai:
|
|
200
|
+
with pytest.raises(NotImplementedError):
|
|
201
|
+
await session.set_music_generation_config(
|
|
202
|
+
config=types.LiveMusicGenerationConfig(
|
|
203
|
+
bpm=140,
|
|
204
|
+
music_generation_mode=types.MusicGenerationMode.VOCALIZATION,
|
|
205
|
+
)
|
|
206
|
+
)
|
|
207
|
+
return
|
|
208
|
+
await session.set_music_generation_config(
|
|
209
|
+
config=types.LiveMusicGenerationConfig(
|
|
210
|
+
bpm=140,
|
|
211
|
+
music_generation_mode=types.MusicGenerationMode.VOCALIZATION,
|
|
212
|
+
)
|
|
213
|
+
)
|
|
214
|
+
mock_websocket.send.assert_called_once()
|
|
215
|
+
sent_data = json.loads(mock_websocket.send.call_args[0][0])
|
|
216
|
+
assert 'musicGenerationConfig' in sent_data
|
|
217
|
+
assert sent_data['musicGenerationConfig']['bpm'] == 140
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
@pytest.mark.parametrize('vertexai', [True, False])
|
|
221
|
+
@pytest.mark.asyncio
|
|
222
|
+
async def test_async_session_control_signal_play(
|
|
223
|
+
mock_websocket, vertexai
|
|
224
|
+
):
|
|
225
|
+
session = live_music.AsyncMusicSession(
|
|
226
|
+
api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
|
|
227
|
+
)
|
|
228
|
+
if vertexai:
|
|
229
|
+
with pytest.raises(NotImplementedError):
|
|
230
|
+
await session.play()
|
|
231
|
+
return
|
|
232
|
+
await session.play()
|
|
233
|
+
mock_websocket.send.assert_called_once()
|
|
234
|
+
sent_data = json.loads(mock_websocket.send.call_args[0][0])
|
|
235
|
+
assert 'playbackControl' in sent_data
|
|
236
|
+
assert 'PLAY' in sent_data['playbackControl']
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
@pytest.mark.parametrize('vertexai', [True, False])
|
|
240
|
+
@pytest.mark.asyncio
|
|
241
|
+
async def test_async_session_control_signal_pause(
|
|
242
|
+
mock_websocket, vertexai
|
|
243
|
+
):
|
|
244
|
+
session = live_music.AsyncMusicSession(
|
|
245
|
+
api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
|
|
246
|
+
)
|
|
247
|
+
if vertexai:
|
|
248
|
+
with pytest.raises(NotImplementedError):
|
|
249
|
+
await session.pause()
|
|
250
|
+
return
|
|
251
|
+
await session.pause()
|
|
252
|
+
mock_websocket.send.assert_called_once()
|
|
253
|
+
sent_data = json.loads(mock_websocket.send.call_args[0][0])
|
|
254
|
+
assert 'playbackControl' in sent_data
|
|
255
|
+
assert 'PAUSE' in sent_data['playbackControl']
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
@pytest.mark.parametrize('vertexai', [True, False])
|
|
259
|
+
@pytest.mark.asyncio
|
|
260
|
+
async def test_async_session_control_signal_stop(
|
|
261
|
+
mock_websocket, vertexai
|
|
262
|
+
):
|
|
263
|
+
session = live_music.AsyncMusicSession(
|
|
264
|
+
api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
|
|
265
|
+
)
|
|
266
|
+
if vertexai:
|
|
267
|
+
with pytest.raises(NotImplementedError):
|
|
268
|
+
await session.stop()
|
|
269
|
+
return
|
|
270
|
+
await session.stop()
|
|
271
|
+
mock_websocket.send.assert_called_once()
|
|
272
|
+
sent_data = json.loads(mock_websocket.send.call_args[0][0])
|
|
273
|
+
assert 'playbackControl' in sent_data
|
|
274
|
+
assert 'STOP' in sent_data['playbackControl']
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
@pytest.mark.parametrize('vertexai', [True, False])
|
|
278
|
+
@pytest.mark.asyncio
|
|
279
|
+
async def test_async_session_control_signal_reset_context(
|
|
280
|
+
mock_websocket, vertexai
|
|
281
|
+
):
|
|
282
|
+
session = live_music.AsyncMusicSession(
|
|
283
|
+
api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
|
|
284
|
+
)
|
|
285
|
+
if vertexai:
|
|
286
|
+
with pytest.raises(NotImplementedError):
|
|
287
|
+
await session.reset_context()
|
|
288
|
+
return
|
|
289
|
+
await session.reset_context()
|
|
290
|
+
mock_websocket.send.assert_called_once()
|
|
291
|
+
sent_data = json.loads(mock_websocket.send.call_args[0][0])
|
|
292
|
+
assert 'playbackControl' in sent_data
|
|
293
|
+
assert 'RESET_CONTEXT' in sent_data['playbackControl']
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
@pytest.mark.parametrize('vertexai', [True, False])
|
|
297
|
+
@pytest.mark.asyncio
|
|
298
|
+
async def test_async_session_receive( mock_websocket, vertexai):
|
|
299
|
+
session = live_music.AsyncMusicSession(
|
|
300
|
+
api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
|
|
301
|
+
)
|
|
302
|
+
if vertexai:
|
|
303
|
+
with pytest.raises(NotImplementedError):
|
|
304
|
+
async for _ in session.receive():
|
|
305
|
+
pass
|
|
306
|
+
return
|
|
307
|
+
async for response in session.receive():
|
|
308
|
+
assert isinstance(response, types.LiveMusicServerMessage)
|
|
309
|
+
audio_chunk = response.server_content.audio_chunks[0]
|
|
310
|
+
# Data contains decoded b64 audio
|
|
311
|
+
assert audio_chunk.data == b'gelbanana'
|
|
312
|
+
assert audio_chunk.mime_type == 'audio/l16;rate=48000;channels=2'
|
|
313
|
+
assert audio_chunk.source_metadata.client_content.weighted_prompts[0].text == 'Jazz'
|
|
314
|
+
assert audio_chunk.source_metadata.client_content.weighted_prompts[0].weight == 1
|
|
315
|
+
assert audio_chunk.source_metadata.music_generation_config.bpm == 140
|
|
316
|
+
break
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
@pytest.mark.parametrize('vertexai', [True, False])
|
|
320
|
+
@pytest.mark.asyncio
|
|
321
|
+
async def test_async_session_receive_error(
|
|
322
|
+
mock_websocket, vertexai
|
|
323
|
+
):
|
|
324
|
+
mock_websocket.recv = AsyncMock(return_value='invalid json')
|
|
325
|
+
session = live_music.AsyncMusicSession(
|
|
326
|
+
api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
|
|
327
|
+
)
|
|
328
|
+
with pytest.raises(ValueError):
|
|
329
|
+
await session.receive().__anext__()
|
|
330
|
+
|
|
331
|
+
|
|
332
|
+
@pytest.mark.parametrize('vertexai', [True, False])
|
|
333
|
+
@pytest.mark.asyncio
|
|
334
|
+
async def test_async_session_close( mock_websocket, vertexai):
|
|
335
|
+
session = live_music.AsyncMusicSession(
|
|
336
|
+
mock_api_client(vertexai=vertexai), mock_websocket
|
|
337
|
+
)
|
|
338
|
+
await session.close()
|
|
339
|
+
mock_websocket.close.assert_called_once()
|
|
340
|
+
|
|
341
|
+
|
|
342
|
+
@pytest.mark.parametrize('vertexai', [True, False])
|
|
343
|
+
@pytest.mark.asyncio
|
|
344
|
+
async def test_setup_to_api(vertexai):
|
|
345
|
+
if vertexai:
|
|
346
|
+
with pytest.raises(NotImplementedError):
|
|
347
|
+
await get_connect_message(
|
|
348
|
+
mock_api_client(vertexai=vertexai),
|
|
349
|
+
model='test_model'
|
|
350
|
+
)
|
|
351
|
+
return
|
|
352
|
+
result = await get_connect_message(
|
|
353
|
+
mock_api_client(vertexai=vertexai),
|
|
354
|
+
model='test_model'
|
|
355
|
+
)
|
|
356
|
+
expected_result = {'setup': {}}
|
|
357
|
+
if vertexai:
|
|
358
|
+
# Vertex is not supported yet
|
|
359
|
+
assert False
|
|
360
|
+
else:
|
|
361
|
+
expected_result['setup']['model'] = 'models/test_model'
|
|
362
|
+
assert result == expected_result
|
|
@@ -0,0 +1,163 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
#
|
|
15
|
+
|
|
16
|
+
"""Tests for live response handling."""
|
|
17
|
+
import json
|
|
18
|
+
from typing import cast
|
|
19
|
+
from unittest import mock
|
|
20
|
+
from unittest.mock import AsyncMock
|
|
21
|
+
|
|
22
|
+
import pytest
|
|
23
|
+
|
|
24
|
+
from ... import _api_client as api_client
|
|
25
|
+
from ... import _common
|
|
26
|
+
from ... import Client
|
|
27
|
+
from ... import client as gl_client
|
|
28
|
+
from ... import live
|
|
29
|
+
from ... import types
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def mock_api_client(vertexai=False):
|
|
33
|
+
"""Creates a mock BaseApiClient."""
|
|
34
|
+
mock_client = mock.MagicMock(spec=gl_client.BaseApiClient)
|
|
35
|
+
if not vertexai:
|
|
36
|
+
mock_client.api_key = 'TEST_API_KEY'
|
|
37
|
+
mock_client.location = None
|
|
38
|
+
mock_client.project = None
|
|
39
|
+
else:
|
|
40
|
+
mock_client.api_key = None
|
|
41
|
+
mock_client.location = 'us-central1'
|
|
42
|
+
mock_client.project = 'test_project'
|
|
43
|
+
|
|
44
|
+
mock_client._host = lambda: 'test_host'
|
|
45
|
+
mock_client._http_options = types.HttpOptions.model_validate(
|
|
46
|
+
{'headers': {}}
|
|
47
|
+
)
|
|
48
|
+
mock_client.vertexai = vertexai
|
|
49
|
+
return mock_client
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
@pytest.fixture
|
|
53
|
+
def mock_websocket():
|
|
54
|
+
"""Provides a mock websocket connection."""
|
|
55
|
+
# Use live.ClientConnection if that's the specific type hint in AsyncSession
|
|
56
|
+
websocket = AsyncMock(spec=live.ClientConnection)
|
|
57
|
+
websocket.send = AsyncMock()
|
|
58
|
+
# Set default recv value, will be overridden in the test
|
|
59
|
+
websocket.recv = AsyncMock(return_value='{}')
|
|
60
|
+
websocket.close = AsyncMock()
|
|
61
|
+
return websocket
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
@pytest.mark.parametrize('vertexai', [True, False])
|
|
65
|
+
@pytest.mark.asyncio
|
|
66
|
+
async def test_receive_server_content(mock_websocket, vertexai):
|
|
67
|
+
|
|
68
|
+
raw_response_json = json.dumps({
|
|
69
|
+
"usageMetadata": {
|
|
70
|
+
"promptTokenCount": 15,
|
|
71
|
+
"responseTokenCount": 25,
|
|
72
|
+
"candidatesTokenCount": 50,
|
|
73
|
+
"totalTokenCount": 200,
|
|
74
|
+
"responseTokensDetails": [
|
|
75
|
+
{
|
|
76
|
+
"tokenCount": 20,
|
|
77
|
+
"modality": "TEXT",
|
|
78
|
+
}
|
|
79
|
+
],
|
|
80
|
+
"candidatesTokensDetails": [
|
|
81
|
+
{
|
|
82
|
+
"tokenCount": 10,
|
|
83
|
+
"modality": "TEXT",
|
|
84
|
+
}
|
|
85
|
+
],
|
|
86
|
+
},
|
|
87
|
+
"serverContent": {
|
|
88
|
+
"modelTurn": {
|
|
89
|
+
"parts": [{"text": "This is a simple response."}]
|
|
90
|
+
},
|
|
91
|
+
"turnComplete": True,
|
|
92
|
+
"groundingMetadata": {
|
|
93
|
+
"web_search_queries": ["test query"],
|
|
94
|
+
"groundingChunks": [{
|
|
95
|
+
"web": {
|
|
96
|
+
"domain": "google.com",
|
|
97
|
+
"title": "Search results",
|
|
98
|
+
}
|
|
99
|
+
}]
|
|
100
|
+
}
|
|
101
|
+
}
|
|
102
|
+
})
|
|
103
|
+
mock_websocket.recv.return_value = raw_response_json
|
|
104
|
+
|
|
105
|
+
session = live.AsyncSession(
|
|
106
|
+
api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
|
|
107
|
+
)
|
|
108
|
+
result = await session._receive()
|
|
109
|
+
|
|
110
|
+
# Assert the results
|
|
111
|
+
assert isinstance(result, types.LiveServerMessage)
|
|
112
|
+
|
|
113
|
+
assert (
|
|
114
|
+
result.server_content.model_turn.parts[0].text
|
|
115
|
+
== "This is a simple response."
|
|
116
|
+
)
|
|
117
|
+
assert result.server_content.turn_complete
|
|
118
|
+
assert result.server_content.grounding_metadata.web_search_queries == ["test query"]
|
|
119
|
+
assert result.server_content.grounding_metadata.grounding_chunks[0].web.domain == "google.com"
|
|
120
|
+
assert result.server_content.grounding_metadata.grounding_chunks[0].web.title == "Search results"
|
|
121
|
+
# Verify usageMetadata was parsed
|
|
122
|
+
assert isinstance(result.usage_metadata, types.UsageMetadata)
|
|
123
|
+
assert result.usage_metadata.prompt_token_count == 15
|
|
124
|
+
assert result.usage_metadata.total_token_count == 200
|
|
125
|
+
if not vertexai:
|
|
126
|
+
assert result.usage_metadata.response_token_count == 25
|
|
127
|
+
assert result.usage_metadata.response_tokens_details[0].token_count == 20
|
|
128
|
+
else:
|
|
129
|
+
# VertexAI maps candidatesTokenCount to responseTokenCount and maps
|
|
130
|
+
# candidatesTokensDetails to responseTokensDetails.
|
|
131
|
+
assert result.usage_metadata.response_token_count == 50
|
|
132
|
+
assert result.usage_metadata.response_tokens_details[0].token_count == 10
|
|
133
|
+
|
|
134
|
+
@pytest.mark.parametrize('vertexai', [True, False])
|
|
135
|
+
@pytest.mark.asyncio
|
|
136
|
+
async def test_receive_server_content_with_turn_reason(mock_websocket, vertexai):
|
|
137
|
+
"""Tests parsing of LiveServerContent with turn_complete_reason and waiting_for_input."""
|
|
138
|
+
|
|
139
|
+
raw_response_json = json.dumps({
|
|
140
|
+
"serverContent": {
|
|
141
|
+
"modelTurn": {
|
|
142
|
+
"parts": [{"text": "Please provide more details."}]
|
|
143
|
+
},
|
|
144
|
+
"turnComplete": True,
|
|
145
|
+
"turnCompleteReason": "NEED_MORE_INPUT",
|
|
146
|
+
"waitingForInput": True
|
|
147
|
+
}
|
|
148
|
+
})
|
|
149
|
+
mock_websocket.recv.return_value = raw_response_json
|
|
150
|
+
|
|
151
|
+
session = live.AsyncSession(
|
|
152
|
+
api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
|
|
153
|
+
)
|
|
154
|
+
result = await session._receive()
|
|
155
|
+
|
|
156
|
+
# Assert the results
|
|
157
|
+
assert isinstance(result, types.LiveServerMessage)
|
|
158
|
+
assert result.server_content is not None
|
|
159
|
+
|
|
160
|
+
assert result.server_content.model_turn.parts[0].text == "Please provide more details."
|
|
161
|
+
assert result.server_content.turn_complete is True
|
|
162
|
+
assert result.server_content.turn_complete_reason == types.TurnCompleteReason.NEED_MORE_INPUT
|
|
163
|
+
assert result.server_content.waiting_for_input is True
|
|
@@ -0,0 +1,147 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
#
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
"""Tests for live.py."""
|
|
18
|
+
import base64
|
|
19
|
+
import json
|
|
20
|
+
from unittest import mock
|
|
21
|
+
|
|
22
|
+
import pytest
|
|
23
|
+
from websockets import client
|
|
24
|
+
|
|
25
|
+
from .. import pytest_helper
|
|
26
|
+
from ... import client as gl_client
|
|
27
|
+
from ... import live
|
|
28
|
+
from ... import types
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def mock_api_client(vertexai=False):
|
|
32
|
+
api_client = mock.MagicMock(spec=gl_client.BaseApiClient)
|
|
33
|
+
api_client.api_key = 'TEST_API_KEY'
|
|
34
|
+
api_client._host = lambda: 'test_host'
|
|
35
|
+
api_client._http_options = {'headers': {}} # Ensure headers exist
|
|
36
|
+
api_client.vertexai = vertexai
|
|
37
|
+
return api_client
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@pytest.fixture
|
|
41
|
+
def mock_websocket():
|
|
42
|
+
websocket = mock.AsyncMock(spec=client.ClientConnection)
|
|
43
|
+
websocket.send = mock.AsyncMock()
|
|
44
|
+
websocket.recv = mock.AsyncMock(
|
|
45
|
+
return_value='{"serverContent": {"turnComplete": true}}'
|
|
46
|
+
) # Default response
|
|
47
|
+
websocket.close = mock.AsyncMock()
|
|
48
|
+
return websocket
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
@pytest.mark.parametrize('vertexai', [True, False])
|
|
52
|
+
@pytest.mark.asyncio
|
|
53
|
+
async def test_send_content_dict(mock_websocket, vertexai):
|
|
54
|
+
session = live.AsyncSession(
|
|
55
|
+
api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
|
|
56
|
+
)
|
|
57
|
+
content = [{'parts': [{'text': 'test'}]}]
|
|
58
|
+
|
|
59
|
+
await session.send_client_content(turns=content)
|
|
60
|
+
mock_websocket.send.assert_called_once()
|
|
61
|
+
sent_data = json.loads(mock_websocket.send.call_args[0][0])
|
|
62
|
+
assert 'client_content' in sent_data
|
|
63
|
+
|
|
64
|
+
assert sent_data['client_content']['turns'][0]['parts'][0]['text'] == 'test'
|
|
65
|
+
|
|
66
|
+
@pytest.mark.parametrize('vertexai', [True, False])
|
|
67
|
+
@pytest.mark.asyncio
|
|
68
|
+
async def test_send_content_dict_list(mock_websocket, vertexai):
|
|
69
|
+
session = live.AsyncSession(
|
|
70
|
+
api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
|
|
71
|
+
)
|
|
72
|
+
content = [{'parts': [{'text': 'test'}]}]
|
|
73
|
+
|
|
74
|
+
await session.send_client_content(turns=content)
|
|
75
|
+
mock_websocket.send.assert_called_once()
|
|
76
|
+
sent_data = json.loads(mock_websocket.send.call_args[0][0])
|
|
77
|
+
assert 'client_content' in sent_data
|
|
78
|
+
|
|
79
|
+
assert sent_data['client_content']['turns'][0]['parts'][0]['text'] == 'test'
|
|
80
|
+
|
|
81
|
+
@pytest.mark.parametrize('vertexai', [True, False])
|
|
82
|
+
@pytest.mark.asyncio
|
|
83
|
+
async def test_send_content_content(mock_websocket, vertexai):
|
|
84
|
+
session = live.AsyncSession(
|
|
85
|
+
api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
|
|
86
|
+
)
|
|
87
|
+
content = types.Content.model_validate({'parts': [{'text': 'test'}]})
|
|
88
|
+
|
|
89
|
+
await session.send_client_content(turns=content)
|
|
90
|
+
mock_websocket.send.assert_called_once()
|
|
91
|
+
sent_data = json.loads(mock_websocket.send.call_args[0][0])
|
|
92
|
+
assert 'client_content' in sent_data
|
|
93
|
+
|
|
94
|
+
assert sent_data['client_content']['turns'][0]['parts'][0]['text'] == 'test'
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
@pytest.mark.parametrize('vertexai', [True, False])
|
|
98
|
+
@pytest.mark.asyncio
|
|
99
|
+
async def test_send_content_with_blob(mock_websocket, vertexai):
|
|
100
|
+
session = live.AsyncSession(
|
|
101
|
+
api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
|
|
102
|
+
)
|
|
103
|
+
content = types.Content.model_validate(
|
|
104
|
+
{'parts': [{'inline_data': {'data': b'test'}}]}
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
await session.send_client_content(turns=content)
|
|
108
|
+
mock_websocket.send.assert_called_once()
|
|
109
|
+
sent_data = json.loads(mock_websocket.send.call_args[0][0])
|
|
110
|
+
assert 'client_content' in sent_data
|
|
111
|
+
|
|
112
|
+
assert pytest_helper.get_value_ignore_key_case(
|
|
113
|
+
sent_data['client_content']['turns'][0]['parts'][0], 'inline_data') == {
|
|
114
|
+
'data': base64.b64encode(b'test').decode()
|
|
115
|
+
}
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
@pytest.mark.parametrize('vertexai', [True, False])
|
|
119
|
+
@pytest.mark.asyncio
|
|
120
|
+
async def test_send_client_content_turn_complete_false(
|
|
121
|
+
mock_websocket, vertexai
|
|
122
|
+
):
|
|
123
|
+
session = live.AsyncSession(
|
|
124
|
+
api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
await session.send_client_content(turn_complete=False)
|
|
128
|
+
mock_websocket.send.assert_called_once()
|
|
129
|
+
sent_data = json.loads(mock_websocket.send.call_args[0][0])
|
|
130
|
+
assert 'client_content' in sent_data
|
|
131
|
+
assert sent_data['client_content']['turnComplete'] == False
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
@pytest.mark.parametrize('vertexai', [True, False])
|
|
135
|
+
@pytest.mark.asyncio
|
|
136
|
+
async def test_send_client_content_empty(
|
|
137
|
+
mock_websocket, vertexai
|
|
138
|
+
):
|
|
139
|
+
session = live.AsyncSession(
|
|
140
|
+
api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
await session.send_client_content()
|
|
144
|
+
mock_websocket.send.assert_called_once()
|
|
145
|
+
sent_data = json.loads(mock_websocket.send.call_args[0][0])
|
|
146
|
+
assert 'client_content' in sent_data
|
|
147
|
+
assert sent_data['client_content']['turnComplete'] == True
|