google-genai 1.56.0__py3-none-any.whl → 1.58.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google/genai/_api_client.py +49 -26
- google/genai/_interactions/__init__.py +3 -0
- google/genai/_interactions/_base_client.py +1 -1
- google/genai/_interactions/_client.py +57 -3
- google/genai/_interactions/_client_adapter.py +48 -0
- google/genai/_interactions/types/__init__.py +6 -0
- google/genai/_interactions/types/audio_content.py +2 -0
- google/genai/_interactions/types/audio_content_param.py +2 -0
- google/genai/_interactions/types/content.py +65 -0
- google/genai/_interactions/types/content_delta.py +10 -2
- google/genai/_interactions/types/content_param.py +63 -0
- google/genai/_interactions/types/content_start.py +5 -46
- google/genai/_interactions/types/content_stop.py +1 -2
- google/genai/_interactions/types/document_content.py +2 -0
- google/genai/_interactions/types/document_content_param.py +2 -0
- google/genai/_interactions/types/error_event.py +1 -2
- google/genai/_interactions/types/file_search_call_content.py +32 -0
- google/genai/_interactions/types/file_search_call_content_param.py +31 -0
- google/genai/_interactions/types/generation_config.py +4 -0
- google/genai/_interactions/types/generation_config_param.py +4 -0
- google/genai/_interactions/types/image_config.py +31 -0
- google/genai/_interactions/types/image_config_param.py +30 -0
- google/genai/_interactions/types/image_content.py +2 -0
- google/genai/_interactions/types/image_content_param.py +2 -0
- google/genai/_interactions/types/interaction.py +6 -52
- google/genai/_interactions/types/interaction_create_params.py +4 -22
- google/genai/_interactions/types/interaction_event.py +1 -2
- google/genai/_interactions/types/interaction_sse_event.py +5 -3
- google/genai/_interactions/types/interaction_status_update.py +1 -2
- google/genai/_interactions/types/model.py +1 -0
- google/genai/_interactions/types/model_param.py +1 -0
- google/genai/_interactions/types/turn.py +3 -44
- google/genai/_interactions/types/turn_param.py +4 -40
- google/genai/_interactions/types/usage.py +1 -1
- google/genai/_interactions/types/usage_param.py +1 -1
- google/genai/_interactions/types/video_content.py +2 -0
- google/genai/_interactions/types/video_content_param.py +2 -0
- google/genai/_live_converters.py +118 -34
- google/genai/_local_tokenizer_loader.py +1 -0
- google/genai/_tokens_converters.py +14 -14
- google/genai/_transformers.py +15 -21
- google/genai/batches.py +27 -22
- google/genai/caches.py +42 -42
- google/genai/chats.py +0 -2
- google/genai/client.py +61 -55
- google/genai/files.py +224 -0
- google/genai/live.py +1 -1
- google/genai/models.py +56 -44
- google/genai/tests/__init__.py +21 -0
- google/genai/tests/afc/__init__.py +21 -0
- google/genai/tests/afc/test_convert_if_exist_pydantic_model.py +309 -0
- google/genai/tests/afc/test_convert_number_values_for_function_call_args.py +63 -0
- google/genai/tests/afc/test_find_afc_incompatible_tool_indexes.py +240 -0
- google/genai/tests/afc/test_generate_content_stream_afc.py +530 -0
- google/genai/tests/afc/test_generate_content_stream_afc_thoughts.py +77 -0
- google/genai/tests/afc/test_get_function_map.py +176 -0
- google/genai/tests/afc/test_get_function_response_parts.py +277 -0
- google/genai/tests/afc/test_get_max_remote_calls_for_afc.py +130 -0
- google/genai/tests/afc/test_invoke_function_from_dict_args.py +241 -0
- google/genai/tests/afc/test_raise_error_for_afc_incompatible_config.py +159 -0
- google/genai/tests/afc/test_should_append_afc_history.py +53 -0
- google/genai/tests/afc/test_should_disable_afc.py +214 -0
- google/genai/tests/batches/__init__.py +17 -0
- google/genai/tests/batches/test_cancel.py +77 -0
- google/genai/tests/batches/test_create.py +78 -0
- google/genai/tests/batches/test_create_with_bigquery.py +113 -0
- google/genai/tests/batches/test_create_with_file.py +82 -0
- google/genai/tests/batches/test_create_with_gcs.py +125 -0
- google/genai/tests/batches/test_create_with_inlined_requests.py +255 -0
- google/genai/tests/batches/test_delete.py +86 -0
- google/genai/tests/batches/test_embedding.py +157 -0
- google/genai/tests/batches/test_get.py +78 -0
- google/genai/tests/batches/test_list.py +79 -0
- google/genai/tests/caches/__init__.py +17 -0
- google/genai/tests/caches/constants.py +29 -0
- google/genai/tests/caches/test_create.py +210 -0
- google/genai/tests/caches/test_create_custom_url.py +105 -0
- google/genai/tests/caches/test_delete.py +54 -0
- google/genai/tests/caches/test_delete_custom_url.py +52 -0
- google/genai/tests/caches/test_get.py +94 -0
- google/genai/tests/caches/test_get_custom_url.py +52 -0
- google/genai/tests/caches/test_list.py +68 -0
- google/genai/tests/caches/test_update.py +70 -0
- google/genai/tests/caches/test_update_custom_url.py +58 -0
- google/genai/tests/chats/__init__.py +1 -0
- google/genai/tests/chats/test_get_history.py +598 -0
- google/genai/tests/chats/test_send_message.py +844 -0
- google/genai/tests/chats/test_validate_response.py +90 -0
- google/genai/tests/client/__init__.py +17 -0
- google/genai/tests/client/test_async_stream.py +427 -0
- google/genai/tests/client/test_client_close.py +197 -0
- google/genai/tests/client/test_client_initialization.py +1687 -0
- google/genai/tests/client/test_client_requests.py +221 -0
- google/genai/tests/client/test_custom_client.py +104 -0
- google/genai/tests/client/test_http_options.py +178 -0
- google/genai/tests/client/test_replay_client_equality.py +168 -0
- google/genai/tests/client/test_retries.py +846 -0
- google/genai/tests/client/test_upload_errors.py +136 -0
- google/genai/tests/common/__init__.py +17 -0
- google/genai/tests/common/test_common.py +954 -0
- google/genai/tests/conftest.py +162 -0
- google/genai/tests/documents/__init__.py +17 -0
- google/genai/tests/documents/test_delete.py +51 -0
- google/genai/tests/documents/test_get.py +85 -0
- google/genai/tests/documents/test_list.py +72 -0
- google/genai/tests/errors/__init__.py +1 -0
- google/genai/tests/errors/test_api_error.py +417 -0
- google/genai/tests/file_search_stores/__init__.py +17 -0
- google/genai/tests/file_search_stores/test_create.py +66 -0
- google/genai/tests/file_search_stores/test_delete.py +64 -0
- google/genai/tests/file_search_stores/test_get.py +94 -0
- google/genai/tests/file_search_stores/test_import_file.py +112 -0
- google/genai/tests/file_search_stores/test_list.py +57 -0
- google/genai/tests/file_search_stores/test_upload_to_file_search_store.py +141 -0
- google/genai/tests/files/__init__.py +17 -0
- google/genai/tests/files/test_delete.py +46 -0
- google/genai/tests/files/test_download.py +85 -0
- google/genai/tests/files/test_get.py +46 -0
- google/genai/tests/files/test_list.py +72 -0
- google/genai/tests/files/test_register.py +272 -0
- google/genai/tests/files/test_register_table.py +70 -0
- google/genai/tests/files/test_upload.py +255 -0
- google/genai/tests/imports/test_no_optional_imports.py +28 -0
- google/genai/tests/interactions/test_auth.py +476 -0
- google/genai/tests/interactions/test_integration.py +84 -0
- google/genai/tests/interactions/test_paths.py +105 -0
- google/genai/tests/live/__init__.py +16 -0
- google/genai/tests/live/test_live.py +2143 -0
- google/genai/tests/live/test_live_music.py +362 -0
- google/genai/tests/live/test_live_response.py +163 -0
- google/genai/tests/live/test_send_client_content.py +147 -0
- google/genai/tests/live/test_send_realtime_input.py +268 -0
- google/genai/tests/live/test_send_tool_response.py +222 -0
- google/genai/tests/local_tokenizer/__init__.py +17 -0
- google/genai/tests/local_tokenizer/test_local_tokenizer.py +343 -0
- google/genai/tests/local_tokenizer/test_local_tokenizer_loader.py +235 -0
- google/genai/tests/mcp/__init__.py +17 -0
- google/genai/tests/mcp/test_has_mcp_tool_usage.py +89 -0
- google/genai/tests/mcp/test_mcp_to_gemini_tools.py +191 -0
- google/genai/tests/mcp/test_parse_config_for_mcp_sessions.py +201 -0
- google/genai/tests/mcp/test_parse_config_for_mcp_usage.py +130 -0
- google/genai/tests/mcp/test_set_mcp_usage_header.py +72 -0
- google/genai/tests/models/__init__.py +17 -0
- google/genai/tests/models/constants.py +8 -0
- google/genai/tests/models/test_compute_tokens.py +120 -0
- google/genai/tests/models/test_count_tokens.py +159 -0
- google/genai/tests/models/test_delete.py +107 -0
- google/genai/tests/models/test_edit_image.py +264 -0
- google/genai/tests/models/test_embed_content.py +94 -0
- google/genai/tests/models/test_function_call_streaming.py +442 -0
- google/genai/tests/models/test_generate_content.py +2501 -0
- google/genai/tests/models/test_generate_content_cached_content.py +132 -0
- google/genai/tests/models/test_generate_content_config_zero_value.py +103 -0
- google/genai/tests/models/test_generate_content_from_apikey.py +44 -0
- google/genai/tests/models/test_generate_content_http_options.py +40 -0
- google/genai/tests/models/test_generate_content_image_generation.py +143 -0
- google/genai/tests/models/test_generate_content_mcp.py +343 -0
- google/genai/tests/models/test_generate_content_media_resolution.py +97 -0
- google/genai/tests/models/test_generate_content_model.py +139 -0
- google/genai/tests/models/test_generate_content_part.py +821 -0
- google/genai/tests/models/test_generate_content_thought.py +76 -0
- google/genai/tests/models/test_generate_content_tools.py +1761 -0
- google/genai/tests/models/test_generate_images.py +191 -0
- google/genai/tests/models/test_generate_videos.py +759 -0
- google/genai/tests/models/test_get.py +104 -0
- google/genai/tests/models/test_list.py +233 -0
- google/genai/tests/models/test_recontext_image.py +189 -0
- google/genai/tests/models/test_segment_image.py +148 -0
- google/genai/tests/models/test_update.py +95 -0
- google/genai/tests/models/test_upscale_image.py +157 -0
- google/genai/tests/operations/__init__.py +17 -0
- google/genai/tests/operations/test_get.py +38 -0
- google/genai/tests/public_samples/__init__.py +17 -0
- google/genai/tests/public_samples/test_gemini_text_only.py +34 -0
- google/genai/tests/pytest_helper.py +246 -0
- google/genai/tests/shared/__init__.py +16 -0
- google/genai/tests/shared/batches/__init__.py +14 -0
- google/genai/tests/shared/batches/test_create_delete.py +57 -0
- google/genai/tests/shared/batches/test_create_get_cancel.py +56 -0
- google/genai/tests/shared/batches/test_list.py +40 -0
- google/genai/tests/shared/caches/__init__.py +14 -0
- google/genai/tests/shared/caches/test_create_get_delete.py +67 -0
- google/genai/tests/shared/caches/test_create_update_get.py +71 -0
- google/genai/tests/shared/caches/test_list.py +40 -0
- google/genai/tests/shared/chats/__init__.py +14 -0
- google/genai/tests/shared/chats/test_send_message.py +48 -0
- google/genai/tests/shared/chats/test_send_message_stream.py +50 -0
- google/genai/tests/shared/files/__init__.py +14 -0
- google/genai/tests/shared/files/test_list.py +41 -0
- google/genai/tests/shared/files/test_upload_get_delete.py +54 -0
- google/genai/tests/shared/models/__init__.py +14 -0
- google/genai/tests/shared/models/test_compute_tokens.py +41 -0
- google/genai/tests/shared/models/test_count_tokens.py +40 -0
- google/genai/tests/shared/models/test_edit_image.py +67 -0
- google/genai/tests/shared/models/test_embed.py +40 -0
- google/genai/tests/shared/models/test_generate_content.py +39 -0
- google/genai/tests/shared/models/test_generate_content_stream.py +54 -0
- google/genai/tests/shared/models/test_generate_images.py +40 -0
- google/genai/tests/shared/models/test_generate_videos.py +38 -0
- google/genai/tests/shared/models/test_list.py +37 -0
- google/genai/tests/shared/models/test_recontext_image.py +55 -0
- google/genai/tests/shared/models/test_segment_image.py +52 -0
- google/genai/tests/shared/models/test_upscale_image.py +52 -0
- google/genai/tests/shared/tunings/__init__.py +16 -0
- google/genai/tests/shared/tunings/test_create.py +46 -0
- google/genai/tests/shared/tunings/test_create_get_cancel.py +56 -0
- google/genai/tests/shared/tunings/test_list.py +39 -0
- google/genai/tests/tokens/__init__.py +16 -0
- google/genai/tests/tokens/test_create.py +154 -0
- google/genai/tests/transformers/__init__.py +17 -0
- google/genai/tests/transformers/test_blobs.py +84 -0
- google/genai/tests/transformers/test_bytes.py +15 -0
- google/genai/tests/transformers/test_duck_type.py +96 -0
- google/genai/tests/transformers/test_function_responses.py +72 -0
- google/genai/tests/transformers/test_schema.py +653 -0
- google/genai/tests/transformers/test_t_batch.py +286 -0
- google/genai/tests/transformers/test_t_content.py +160 -0
- google/genai/tests/transformers/test_t_contents.py +398 -0
- google/genai/tests/transformers/test_t_part.py +85 -0
- google/genai/tests/transformers/test_t_parts.py +87 -0
- google/genai/tests/transformers/test_t_tool.py +157 -0
- google/genai/tests/transformers/test_t_tools.py +195 -0
- google/genai/tests/tunings/__init__.py +16 -0
- google/genai/tests/tunings/test_cancel.py +39 -0
- google/genai/tests/tunings/test_end_to_end.py +106 -0
- google/genai/tests/tunings/test_get.py +67 -0
- google/genai/tests/tunings/test_list.py +75 -0
- google/genai/tests/tunings/test_tune.py +268 -0
- google/genai/tests/types/__init__.py +16 -0
- google/genai/tests/types/test_bytes_internal.py +271 -0
- google/genai/tests/types/test_bytes_type.py +152 -0
- google/genai/tests/types/test_future.py +101 -0
- google/genai/tests/types/test_optional_types.py +36 -0
- google/genai/tests/types/test_part_type.py +616 -0
- google/genai/tests/types/test_schema_from_json_schema.py +417 -0
- google/genai/tests/types/test_schema_json_schema.py +468 -0
- google/genai/tests/types/test_types.py +2903 -0
- google/genai/types.py +631 -488
- google/genai/version.py +1 -1
- {google_genai-1.56.0.dist-info → google_genai-1.58.0.dist-info}/METADATA +6 -11
- google_genai-1.58.0.dist-info/RECORD +358 -0
- google_genai-1.56.0.dist-info/RECORD +0 -162
- /google/genai/{_interactions/py.typed → tests/interactions/__init__.py} +0 -0
- {google_genai-1.56.0.dist-info → google_genai-1.58.0.dist-info}/WHEEL +0 -0
- {google_genai-1.56.0.dist-info → google_genai-1.58.0.dist-info}/licenses/LICENSE +0 -0
- {google_genai-1.56.0.dist-info → google_genai-1.58.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,1687 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
#
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
"""Tests for client initialization."""
|
|
18
|
+
|
|
19
|
+
import asyncio
|
|
20
|
+
import concurrent.futures
|
|
21
|
+
import logging
|
|
22
|
+
import os
|
|
23
|
+
import ssl
|
|
24
|
+
from unittest import mock
|
|
25
|
+
|
|
26
|
+
import certifi
|
|
27
|
+
import google.auth
|
|
28
|
+
from google.auth import credentials
|
|
29
|
+
import httpx
|
|
30
|
+
import pytest
|
|
31
|
+
|
|
32
|
+
from ... import _api_client as api_client
|
|
33
|
+
from ... import _base_url as base_url
|
|
34
|
+
from ... import _replay_api_client as replay_api_client
|
|
35
|
+
from ... import Client
|
|
36
|
+
from ... import types
|
|
37
|
+
try:
|
|
38
|
+
import aiohttp
|
|
39
|
+
AIOHTTP_NOT_INSTALLED = False
|
|
40
|
+
except ImportError:
|
|
41
|
+
AIOHTTP_NOT_INSTALLED = True
|
|
42
|
+
aiohttp = mock.MagicMock()
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
requires_aiohttp = pytest.mark.skipif(
|
|
46
|
+
AIOHTTP_NOT_INSTALLED, reason="aiohttp is not installed, skipping test."
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
@pytest.fixture(autouse=True)
|
|
51
|
+
def reset_has_aiohttp():
|
|
52
|
+
yield
|
|
53
|
+
api_client.has_aiohttp = False
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def test_ml_dev_from_gemini_env_only(monkeypatch):
|
|
57
|
+
api_key = "gemini_api_key"
|
|
58
|
+
monkeypatch.setenv("GEMINI_API_KEY", api_key)
|
|
59
|
+
monkeypatch.delenv("GOOGLE_API_KEY", raising=False)
|
|
60
|
+
|
|
61
|
+
client = Client()
|
|
62
|
+
|
|
63
|
+
assert not client.models._api_client.vertexai
|
|
64
|
+
assert client.models._api_client.api_key == api_key
|
|
65
|
+
assert isinstance(client.models._api_client, api_client.BaseApiClient)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def test_ml_dev_from_gemini_env_with_google_env_empty(monkeypatch):
|
|
69
|
+
api_key = "gemini_api_key"
|
|
70
|
+
monkeypatch.setenv("GEMINI_API_KEY", api_key)
|
|
71
|
+
monkeypatch.setenv("GOOGLE_API_KEY", "")
|
|
72
|
+
|
|
73
|
+
client = Client()
|
|
74
|
+
|
|
75
|
+
assert not client.models._api_client.vertexai
|
|
76
|
+
assert client.models._api_client.api_key == api_key
|
|
77
|
+
assert isinstance(client.models._api_client, api_client.BaseApiClient)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def test_ml_dev_from_google_env_only(monkeypatch):
|
|
81
|
+
api_key = "google_api_key"
|
|
82
|
+
monkeypatch.setenv("GOOGLE_API_KEY", api_key)
|
|
83
|
+
monkeypatch.delenv("GEMINI_API_KEY", raising=False)
|
|
84
|
+
|
|
85
|
+
client = Client()
|
|
86
|
+
|
|
87
|
+
assert not client.models._api_client.vertexai
|
|
88
|
+
assert client.models._api_client.api_key == api_key
|
|
89
|
+
assert isinstance(client.models._api_client, api_client.BaseApiClient)
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def test_ml_dev_both_env_key_set(monkeypatch, caplog):
|
|
93
|
+
caplog.set_level(logging.DEBUG, logger="google_genai._api_client")
|
|
94
|
+
google_api_key = "google_api_key"
|
|
95
|
+
gemini_api_key = "gemini_api_key"
|
|
96
|
+
monkeypatch.setenv("GOOGLE_API_KEY", google_api_key)
|
|
97
|
+
monkeypatch.setenv("GEMINI_API_KEY", gemini_api_key)
|
|
98
|
+
|
|
99
|
+
client = Client()
|
|
100
|
+
|
|
101
|
+
assert not client.models._api_client.vertexai
|
|
102
|
+
assert client.models._api_client.api_key == google_api_key
|
|
103
|
+
assert isinstance(client.models._api_client, api_client.BaseApiClient)
|
|
104
|
+
assert (
|
|
105
|
+
"Both GOOGLE_API_KEY and GEMINI_API_KEY are set. Using GOOGLE_API_KEY."
|
|
106
|
+
in caplog.text
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def test_api_key_with_new_line(monkeypatch, caplog):
|
|
111
|
+
caplog.set_level(logging.DEBUG, logger="google_genai._api_client")
|
|
112
|
+
api_key = "gemini_api_key\r\n"
|
|
113
|
+
monkeypatch.setenv("GOOGLE_API_KEY", api_key)
|
|
114
|
+
|
|
115
|
+
client = Client()
|
|
116
|
+
|
|
117
|
+
assert client.models._api_client.api_key == "gemini_api_key"
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def test_ml_dev_from_constructor():
|
|
121
|
+
api_key = "google_api_key"
|
|
122
|
+
|
|
123
|
+
client = Client(api_key=api_key)
|
|
124
|
+
|
|
125
|
+
assert not client.models._api_client.vertexai
|
|
126
|
+
assert client.models._api_client.api_key == api_key
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def test_constructor_with_http_options():
|
|
130
|
+
mldev_http_options = {
|
|
131
|
+
"api_version": "v1main",
|
|
132
|
+
"base_url": "https://placeholder-fake-url.com/",
|
|
133
|
+
"headers": {"X-Custom-Header": "custom_value_mldev"},
|
|
134
|
+
"timeout": 10000,
|
|
135
|
+
}
|
|
136
|
+
vertexai_http_options = {
|
|
137
|
+
"api_version": "v1",
|
|
138
|
+
"base_url": (
|
|
139
|
+
"https://{self.location}-aiplatform.googleapis.com/{{api_version}}/"
|
|
140
|
+
),
|
|
141
|
+
"headers": {"X-Custom-Header": "custom_value_vertexai"},
|
|
142
|
+
"timeout": 11000,
|
|
143
|
+
}
|
|
144
|
+
|
|
145
|
+
mldev_client = Client(
|
|
146
|
+
api_key="google_api_key", http_options=mldev_http_options
|
|
147
|
+
)
|
|
148
|
+
assert not mldev_client.models._api_client.vertexai
|
|
149
|
+
assert (
|
|
150
|
+
mldev_client.models._api_client.get_read_only_http_options()["base_url"]
|
|
151
|
+
== "https://placeholder-fake-url.com/"
|
|
152
|
+
)
|
|
153
|
+
assert (
|
|
154
|
+
mldev_client.models._api_client.get_read_only_http_options()[
|
|
155
|
+
"api_version"
|
|
156
|
+
]
|
|
157
|
+
== "v1main"
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
assert (
|
|
161
|
+
mldev_client.models._api_client.get_read_only_http_options()["headers"][
|
|
162
|
+
"X-Custom-Header"
|
|
163
|
+
]
|
|
164
|
+
== "custom_value_mldev"
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
assert (
|
|
168
|
+
mldev_client.models._api_client.get_read_only_http_options()["timeout"]
|
|
169
|
+
== 10000
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
vertexai_client = Client(
|
|
173
|
+
vertexai=True,
|
|
174
|
+
project="fake_project_id",
|
|
175
|
+
location="fake-location",
|
|
176
|
+
http_options=vertexai_http_options,
|
|
177
|
+
)
|
|
178
|
+
assert vertexai_client.models._api_client.vertexai
|
|
179
|
+
assert (
|
|
180
|
+
vertexai_client.models._api_client.get_read_only_http_options()[
|
|
181
|
+
"base_url"
|
|
182
|
+
]
|
|
183
|
+
== "https://{self.location}-aiplatform.googleapis.com/{{api_version}}/"
|
|
184
|
+
)
|
|
185
|
+
assert (
|
|
186
|
+
vertexai_client.models._api_client.get_read_only_http_options()[
|
|
187
|
+
"api_version"
|
|
188
|
+
]
|
|
189
|
+
== "v1"
|
|
190
|
+
)
|
|
191
|
+
assert (
|
|
192
|
+
vertexai_client.models._api_client.get_read_only_http_options()[
|
|
193
|
+
"headers"
|
|
194
|
+
]["X-Custom-Header"]
|
|
195
|
+
== "custom_value_vertexai"
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
assert (
|
|
199
|
+
vertexai_client.models._api_client.get_read_only_http_options()["timeout"]
|
|
200
|
+
== 11000
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
def test_constructor_with_invalid_http_options_key():
|
|
205
|
+
mldev_http_options = {
|
|
206
|
+
"invalid_version_key": "v1",
|
|
207
|
+
"base_url": "https://placeholder-fake-url.com/",
|
|
208
|
+
"headers": {"X-Custom-Header": "custom_value"},
|
|
209
|
+
}
|
|
210
|
+
vertexai_http_options = {
|
|
211
|
+
"api_version": "v1",
|
|
212
|
+
"base_url": (
|
|
213
|
+
"https://{self.location}-aiplatform.googleapis.com/{{api_version}}/"
|
|
214
|
+
),
|
|
215
|
+
"invalid_header_key": {"X-Custom-Header": "custom_value"},
|
|
216
|
+
}
|
|
217
|
+
|
|
218
|
+
# Expect value error when HTTPOptions is provided as a dict and contains
|
|
219
|
+
# an invalid key.
|
|
220
|
+
try:
|
|
221
|
+
_ = Client(api_key="google_api_key", http_options=mldev_http_options)
|
|
222
|
+
except Exception as e:
|
|
223
|
+
assert isinstance(e, ValueError)
|
|
224
|
+
assert "invalid_version_key" in str(e)
|
|
225
|
+
|
|
226
|
+
# Expect value error when HTTPOptions is provided as a dict and contains
|
|
227
|
+
# an invalid key.
|
|
228
|
+
try:
|
|
229
|
+
_ = Client(
|
|
230
|
+
vertexai=True,
|
|
231
|
+
project="fake_project_id",
|
|
232
|
+
location="fake-location",
|
|
233
|
+
http_options=vertexai_http_options,
|
|
234
|
+
)
|
|
235
|
+
except Exception as e:
|
|
236
|
+
assert isinstance(e, ValueError)
|
|
237
|
+
assert "invalid_header_key" in str(e)
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
def test_constructor_with_http_options_as_pydantic_type():
|
|
241
|
+
mldev_http_options = types.HttpOptions(
|
|
242
|
+
api_version="v1",
|
|
243
|
+
base_url="https://placeholder-fake-url.com/",
|
|
244
|
+
headers={"X-Custom-Header": "custom_value"},
|
|
245
|
+
)
|
|
246
|
+
vertexai_http_options = types.HttpOptions(
|
|
247
|
+
api_version="v1",
|
|
248
|
+
base_url=(
|
|
249
|
+
"https://{self.location}-aiplatform.googleapis.com/{{api_version}}/"
|
|
250
|
+
),
|
|
251
|
+
headers={"X-Custom-Header": "custom_value"},
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
# Test http_options for mldev client.
|
|
255
|
+
mldev_client = Client(
|
|
256
|
+
api_key="google_api_key", http_options=mldev_http_options
|
|
257
|
+
)
|
|
258
|
+
assert not mldev_client.models._api_client.vertexai
|
|
259
|
+
assert (
|
|
260
|
+
mldev_client.models._api_client.get_read_only_http_options()["base_url"]
|
|
261
|
+
== mldev_http_options.base_url
|
|
262
|
+
)
|
|
263
|
+
assert (
|
|
264
|
+
mldev_client.models._api_client.get_read_only_http_options()[
|
|
265
|
+
"api_version"
|
|
266
|
+
]
|
|
267
|
+
== mldev_http_options.api_version
|
|
268
|
+
)
|
|
269
|
+
|
|
270
|
+
assert (
|
|
271
|
+
mldev_client.models._api_client.get_read_only_http_options()["headers"][
|
|
272
|
+
"X-Custom-Header"
|
|
273
|
+
]
|
|
274
|
+
== mldev_http_options.headers["X-Custom-Header"]
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
# Test http_options for vertexai client.
|
|
278
|
+
vertexai_client = Client(
|
|
279
|
+
vertexai=True,
|
|
280
|
+
project="fake_project_id",
|
|
281
|
+
location="fake-location",
|
|
282
|
+
http_options=vertexai_http_options,
|
|
283
|
+
)
|
|
284
|
+
assert vertexai_client.models._api_client.vertexai
|
|
285
|
+
assert (
|
|
286
|
+
vertexai_client.models._api_client.get_read_only_http_options()[
|
|
287
|
+
"base_url"
|
|
288
|
+
]
|
|
289
|
+
== vertexai_http_options.base_url
|
|
290
|
+
)
|
|
291
|
+
assert (
|
|
292
|
+
vertexai_client.models._api_client.get_read_only_http_options()[
|
|
293
|
+
"api_version"
|
|
294
|
+
]
|
|
295
|
+
== vertexai_http_options.api_version
|
|
296
|
+
)
|
|
297
|
+
assert (
|
|
298
|
+
vertexai_client.models._api_client.get_read_only_http_options()[
|
|
299
|
+
"headers"
|
|
300
|
+
]["X-Custom-Header"]
|
|
301
|
+
== vertexai_http_options.headers["X-Custom-Header"]
|
|
302
|
+
)
|
|
303
|
+
|
|
304
|
+
|
|
305
|
+
def test_vertexai_from_env_1(monkeypatch):
|
|
306
|
+
project_id = "fake_project_id"
|
|
307
|
+
location = "fake-location"
|
|
308
|
+
monkeypatch.setenv("GOOGLE_GENAI_USE_VERTEXAI", "1")
|
|
309
|
+
monkeypatch.setenv("GOOGLE_CLOUD_PROJECT", project_id)
|
|
310
|
+
monkeypatch.setenv("GOOGLE_CLOUD_LOCATION", location)
|
|
311
|
+
|
|
312
|
+
client = Client()
|
|
313
|
+
|
|
314
|
+
assert client.models._api_client.vertexai
|
|
315
|
+
assert client.models._api_client.project == project_id
|
|
316
|
+
assert client.models._api_client.location == location
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
def test_vertexai_from_env_true(monkeypatch):
|
|
320
|
+
project_id = "fake_project_id"
|
|
321
|
+
location = "fake-location"
|
|
322
|
+
monkeypatch.setenv("GOOGLE_GENAI_USE_VERTEXAI", "true")
|
|
323
|
+
monkeypatch.setenv("GOOGLE_CLOUD_PROJECT", project_id)
|
|
324
|
+
monkeypatch.setenv("GOOGLE_CLOUD_LOCATION", location)
|
|
325
|
+
|
|
326
|
+
client = Client()
|
|
327
|
+
|
|
328
|
+
assert client.models._api_client.vertexai
|
|
329
|
+
assert client.models._api_client.project == project_id
|
|
330
|
+
assert client.models._api_client.location == location
|
|
331
|
+
|
|
332
|
+
|
|
333
|
+
def test_vertexai_from_constructor():
|
|
334
|
+
project_id = "fake_project_id"
|
|
335
|
+
location = "fake-location"
|
|
336
|
+
|
|
337
|
+
client = Client(
|
|
338
|
+
vertexai=True,
|
|
339
|
+
project=project_id,
|
|
340
|
+
location=location,
|
|
341
|
+
)
|
|
342
|
+
|
|
343
|
+
assert client.models._api_client.vertexai
|
|
344
|
+
assert client.models._api_client.project == project_id
|
|
345
|
+
assert client.models._api_client.location == location
|
|
346
|
+
assert isinstance(client.models._api_client, api_client.BaseApiClient)
|
|
347
|
+
|
|
348
|
+
|
|
349
|
+
def test_invalid_vertexai_constructor_empty(monkeypatch):
|
|
350
|
+
with pytest.raises(ValueError):
|
|
351
|
+
monkeypatch.setenv("GOOGLE_CLOUD_PROJECT", "")
|
|
352
|
+
monkeypatch.setenv("GOOGLE_CLOUD_LOCATION", "")
|
|
353
|
+
monkeypatch.setenv("GOOGLE_API_KEY", "")
|
|
354
|
+
monkeypatch.setenv("GEMINI_API_KEY", "")
|
|
355
|
+
|
|
356
|
+
def mock_auth_default(scopes=None):
|
|
357
|
+
return None, None
|
|
358
|
+
|
|
359
|
+
monkeypatch.setattr(google.auth, "default", mock_auth_default)
|
|
360
|
+
Client(vertexai=True)
|
|
361
|
+
|
|
362
|
+
|
|
363
|
+
def test_vertexai_constructor_empty_base_url_override(monkeypatch):
|
|
364
|
+
monkeypatch.setenv("GOOGLE_CLOUD_PROJECT", "")
|
|
365
|
+
monkeypatch.setenv("GOOGLE_CLOUD_LOCATION", "")
|
|
366
|
+
monkeypatch.setenv("GOOGLE_API_KEY", "")
|
|
367
|
+
monkeypatch.setenv("GEMINI_API_KEY", "")
|
|
368
|
+
|
|
369
|
+
def mock_auth_default(scopes=None):
|
|
370
|
+
return None, None
|
|
371
|
+
|
|
372
|
+
monkeypatch.setattr(google.auth, "default", mock_auth_default)
|
|
373
|
+
# Including a base_url override skips the check for having proj/location or
|
|
374
|
+
# api_key set.
|
|
375
|
+
client = Client(vertexai=True, http_options={"base_url": "https://override.com/"})
|
|
376
|
+
assert client.models._api_client.location is None
|
|
377
|
+
|
|
378
|
+
|
|
379
|
+
def test_invalid_mldev_constructor_empty(monkeypatch):
|
|
380
|
+
with pytest.raises(ValueError):
|
|
381
|
+
monkeypatch.setenv("GOOGLE_API_KEY", "")
|
|
382
|
+
monkeypatch.setenv("GEMINI_API_KEY", "")
|
|
383
|
+
Client()
|
|
384
|
+
|
|
385
|
+
|
|
386
|
+
def test_invalid_vertexai_constructor1():
|
|
387
|
+
project_id = "fake_project_id"
|
|
388
|
+
location = "fake-location"
|
|
389
|
+
api_key = "fake-api_key"
|
|
390
|
+
try:
|
|
391
|
+
Client(
|
|
392
|
+
vertexai=True,
|
|
393
|
+
project=project_id,
|
|
394
|
+
location=location,
|
|
395
|
+
api_key=api_key,
|
|
396
|
+
)
|
|
397
|
+
except Exception as e:
|
|
398
|
+
assert isinstance(e, ValueError)
|
|
399
|
+
|
|
400
|
+
|
|
401
|
+
def test_invalid_vertexai_constructor2():
|
|
402
|
+
creds = credentials.AnonymousCredentials()
|
|
403
|
+
api_key = "fake-api_key"
|
|
404
|
+
with pytest.raises(ValueError):
|
|
405
|
+
Client(
|
|
406
|
+
vertexai=True,
|
|
407
|
+
credentials=creds,
|
|
408
|
+
api_key=api_key,
|
|
409
|
+
)
|
|
410
|
+
|
|
411
|
+
|
|
412
|
+
def test_vertexai_default_location_to_global(monkeypatch):
|
|
413
|
+
|
|
414
|
+
with monkeypatch.context() as m:
|
|
415
|
+
m.delenv("GOOGLE_CLOUD_LOCATION", raising=False)
|
|
416
|
+
project_id = "fake_project_id"
|
|
417
|
+
client = Client(vertexai=True, project=project_id)
|
|
418
|
+
assert client.models._api_client.location == "global"
|
|
419
|
+
|
|
420
|
+
|
|
421
|
+
def test_vertexai_default_location_to_global_with_credentials(monkeypatch):
|
|
422
|
+
# Test case 1: When credentials are provided with project but no location
|
|
423
|
+
creds = credentials.AnonymousCredentials()
|
|
424
|
+
project_id = "fake_project_id"
|
|
425
|
+
|
|
426
|
+
with monkeypatch.context() as m:
|
|
427
|
+
m.delenv("GOOGLE_CLOUD_LOCATION", raising=False)
|
|
428
|
+
m.setenv("GOOGLE_API_KEY", "")
|
|
429
|
+
client = Client(vertexai=True, credentials=creds, project=project_id)
|
|
430
|
+
assert client.models._api_client.location == "global"
|
|
431
|
+
assert client.models._api_client.project == project_id
|
|
432
|
+
|
|
433
|
+
|
|
434
|
+
def test_vertexai_default_location_to_global_with_explicit_project_and_env_apikey(
|
|
435
|
+
monkeypatch,
|
|
436
|
+
):
|
|
437
|
+
# Test case 2: When explicit project is provided and env api_key exists
|
|
438
|
+
project_id = "explicit_project_id"
|
|
439
|
+
api_key = "env_api_key"
|
|
440
|
+
|
|
441
|
+
with monkeypatch.context() as m:
|
|
442
|
+
m.delenv("GOOGLE_CLOUD_LOCATION", raising=False)
|
|
443
|
+
m.delenv("GOOGLE_CLOUD_PROJECT", raising=False)
|
|
444
|
+
m.setenv("GOOGLE_API_KEY", api_key)
|
|
445
|
+
client = Client(vertexai=True, project=project_id)
|
|
446
|
+
# Explicit project takes precedence over implicit api_key
|
|
447
|
+
assert client.models._api_client.location == "global"
|
|
448
|
+
assert client.models._api_client.project == project_id
|
|
449
|
+
assert not client.models._api_client.api_key
|
|
450
|
+
|
|
451
|
+
|
|
452
|
+
def test_vertexai_default_location_to_global_with_env_project_and_env_apikey(
|
|
453
|
+
monkeypatch,
|
|
454
|
+
):
|
|
455
|
+
# Test case 3: When env project and env api_key both exist
|
|
456
|
+
project_id = "env_project_id"
|
|
457
|
+
api_key = "env_api_key"
|
|
458
|
+
|
|
459
|
+
with monkeypatch.context() as m:
|
|
460
|
+
m.delenv("GOOGLE_CLOUD_LOCATION", raising=False)
|
|
461
|
+
m.setenv("GOOGLE_CLOUD_PROJECT", project_id)
|
|
462
|
+
m.setenv("GOOGLE_API_KEY", api_key)
|
|
463
|
+
client = Client(vertexai=True)
|
|
464
|
+
# Implicit project takes precedence over implicit api_key
|
|
465
|
+
assert client.models._api_client.location == "global"
|
|
466
|
+
assert client.models._api_client.project == project_id
|
|
467
|
+
assert not client.models._api_client.api_key
|
|
468
|
+
|
|
469
|
+
|
|
470
|
+
def test_vertexai_no_default_location_when_location_explicitly_set(monkeypatch):
|
|
471
|
+
# Verify that location is NOT defaulted to global when explicitly set
|
|
472
|
+
project_id = "fake_project_id"
|
|
473
|
+
location = "us-central1"
|
|
474
|
+
|
|
475
|
+
with monkeypatch.context() as m:
|
|
476
|
+
m.delenv("GOOGLE_CLOUD_LOCATION", raising=False)
|
|
477
|
+
client = Client(vertexai=True, project=project_id, location=location)
|
|
478
|
+
assert client.models._api_client.location == location
|
|
479
|
+
assert client.models._api_client.project == project_id
|
|
480
|
+
|
|
481
|
+
|
|
482
|
+
def test_vertexai_no_default_location_when_env_location_set(monkeypatch):
|
|
483
|
+
# Verify that location is NOT defaulted to global when set via environment
|
|
484
|
+
project_id = "fake_project_id"
|
|
485
|
+
location = "us-west1"
|
|
486
|
+
|
|
487
|
+
with monkeypatch.context() as m:
|
|
488
|
+
m.setenv("GOOGLE_CLOUD_LOCATION", location)
|
|
489
|
+
m.setenv("GOOGLE_CLOUD_PROJECT", project_id)
|
|
490
|
+
client = Client(vertexai=True)
|
|
491
|
+
assert client.models._api_client.location == location
|
|
492
|
+
assert client.models._api_client.project == project_id
|
|
493
|
+
|
|
494
|
+
|
|
495
|
+
def test_vertexai_no_default_location_with_apikey_only(monkeypatch):
|
|
496
|
+
# Verify that location is NOT set when using API key mode (no project)
|
|
497
|
+
api_key = "vertexai_api_key"
|
|
498
|
+
|
|
499
|
+
with monkeypatch.context() as m:
|
|
500
|
+
m.delenv("GOOGLE_CLOUD_LOCATION", raising=False)
|
|
501
|
+
m.delenv("GOOGLE_CLOUD_PROJECT", raising=False)
|
|
502
|
+
m.setenv("GOOGLE_API_KEY", "")
|
|
503
|
+
client = Client(vertexai=True, api_key=api_key)
|
|
504
|
+
assert not client.models._api_client.location
|
|
505
|
+
assert not client.models._api_client.project
|
|
506
|
+
assert client.models._api_client.api_key == api_key
|
|
507
|
+
|
|
508
|
+
|
|
509
|
+
def test_vertexai_explicit_credentials(monkeypatch):
|
|
510
|
+
creds = credentials.AnonymousCredentials()
|
|
511
|
+
monkeypatch.setenv("GOOGLE_CLOUD_PROJECT", "fake_project_id")
|
|
512
|
+
monkeypatch.setenv("GOOGLE_CLOUD_LOCATION", "fake-location")
|
|
513
|
+
monkeypatch.setenv("GOOGLE_API_KEY", "env_api_key")
|
|
514
|
+
|
|
515
|
+
client = Client(
|
|
516
|
+
vertexai=True,
|
|
517
|
+
credentials=creds
|
|
518
|
+
)
|
|
519
|
+
|
|
520
|
+
assert client.models._api_client.vertexai
|
|
521
|
+
assert client.models._api_client.project
|
|
522
|
+
assert client.models._api_client.location
|
|
523
|
+
assert not client.models._api_client.api_key
|
|
524
|
+
assert client.models._api_client._credentials is creds
|
|
525
|
+
assert isinstance(client.models._api_client, api_client.BaseApiClient)
|
|
526
|
+
|
|
527
|
+
|
|
528
|
+
def test_vertexai_explicit_arg_precedence1(monkeypatch):
|
|
529
|
+
project_id = "constructor_project_id"
|
|
530
|
+
location = "constructor-location"
|
|
531
|
+
|
|
532
|
+
monkeypatch.setenv("GOOGLE_CLOUD_PROJECT", "env_project_id")
|
|
533
|
+
monkeypatch.setenv("GOOGLE_CLOUD_LOCATION", "env_location")
|
|
534
|
+
monkeypatch.setenv("GOOGLE_API_KEY", "")
|
|
535
|
+
|
|
536
|
+
client = Client(
|
|
537
|
+
vertexai=True,
|
|
538
|
+
project=project_id,
|
|
539
|
+
location=location,
|
|
540
|
+
)
|
|
541
|
+
|
|
542
|
+
assert client.models._api_client.vertexai
|
|
543
|
+
assert client.models._api_client.project == project_id
|
|
544
|
+
assert client.models._api_client.location == location
|
|
545
|
+
assert not client.models._api_client.api_key
|
|
546
|
+
assert isinstance(client.models._api_client, api_client.BaseApiClient)
|
|
547
|
+
|
|
548
|
+
|
|
549
|
+
def test_vertexai_explicit_arg_precedence2(monkeypatch):
|
|
550
|
+
api_key = "constructor_apikey"
|
|
551
|
+
|
|
552
|
+
monkeypatch.setenv("GOOGLE_CLOUD_PROJECT", "")
|
|
553
|
+
monkeypatch.setenv("GOOGLE_CLOUD_LOCATION", "")
|
|
554
|
+
monkeypatch.setenv("GOOGLE_API_KEY", "env_api_key")
|
|
555
|
+
|
|
556
|
+
client = Client(
|
|
557
|
+
vertexai=True,
|
|
558
|
+
api_key=api_key,
|
|
559
|
+
)
|
|
560
|
+
|
|
561
|
+
assert client.models._api_client.vertexai
|
|
562
|
+
assert not client.models._api_client.project
|
|
563
|
+
assert not client.models._api_client.location
|
|
564
|
+
assert client.models._api_client.api_key == api_key
|
|
565
|
+
assert isinstance(client.models._api_client, api_client.BaseApiClient)
|
|
566
|
+
|
|
567
|
+
|
|
568
|
+
def test_invalid_mldev_constructor():
|
|
569
|
+
project_id = "fake_project_id"
|
|
570
|
+
location = "fake-location"
|
|
571
|
+
api_key = "fake-api_key"
|
|
572
|
+
try:
|
|
573
|
+
Client(
|
|
574
|
+
project=project_id,
|
|
575
|
+
location=location,
|
|
576
|
+
api_key=api_key,
|
|
577
|
+
)
|
|
578
|
+
except Exception as e:
|
|
579
|
+
assert isinstance(e, ValueError)
|
|
580
|
+
|
|
581
|
+
|
|
582
|
+
def test_mldev_explicit_arg_precedence(monkeypatch, caplog):
|
|
583
|
+
caplog.set_level(logging.DEBUG, logger="google_genai._api_client")
|
|
584
|
+
api_key = "constructor_api_key"
|
|
585
|
+
|
|
586
|
+
monkeypatch.setenv("GOOGLE_API_KEY", "google_env_api_key")
|
|
587
|
+
monkeypatch.setenv("GEMINI_API_KEY", "gemini_env_api_key")
|
|
588
|
+
|
|
589
|
+
client = Client(api_key=api_key)
|
|
590
|
+
|
|
591
|
+
assert not client.models._api_client.vertexai
|
|
592
|
+
assert client.models._api_client.api_key == api_key
|
|
593
|
+
assert isinstance(client.models._api_client, api_client.BaseApiClient)
|
|
594
|
+
assert (
|
|
595
|
+
"Both GOOGLE_API_KEY and GEMINI_API_KEY are set. Using GOOGLE_API_KEY."
|
|
596
|
+
in caplog.text
|
|
597
|
+
)
|
|
598
|
+
|
|
599
|
+
|
|
600
|
+
def test_replay_client_ml_dev_from_env(monkeypatch, use_vertex: bool):
|
|
601
|
+
api_key = "google_api_key"
|
|
602
|
+
monkeypatch.setenv("GOOGLE_API_KEY", api_key)
|
|
603
|
+
monkeypatch.setenv("GOOGLE_GENAI_CLIENT_MODE", "replay")
|
|
604
|
+
api_type = "vertex" if use_vertex else "mldev"
|
|
605
|
+
monkeypatch.setenv("GOOGLE_GENAI_REPLAY_ID", "test_replay_id." + api_type)
|
|
606
|
+
monkeypatch.setenv("GOOGLE_GENAI_REPLAYS_DIRECTORY", "test_replay_data")
|
|
607
|
+
|
|
608
|
+
client = Client()
|
|
609
|
+
|
|
610
|
+
assert not client.models._api_client.vertexai
|
|
611
|
+
assert client.models._api_client.api_key == api_key
|
|
612
|
+
assert isinstance(
|
|
613
|
+
client.models._api_client, replay_api_client.ReplayApiClient
|
|
614
|
+
)
|
|
615
|
+
|
|
616
|
+
|
|
617
|
+
def test_replay_client_vertexai_from_env(monkeypatch, use_vertex: bool):
|
|
618
|
+
project_id = "fake_project_id"
|
|
619
|
+
location = "fake-location"
|
|
620
|
+
monkeypatch.setenv("GOOGLE_GENAI_USE_VERTEXAI", "1")
|
|
621
|
+
monkeypatch.setenv("GOOGLE_CLOUD_PROJECT", project_id)
|
|
622
|
+
monkeypatch.setenv("GOOGLE_CLOUD_LOCATION", location)
|
|
623
|
+
monkeypatch.setenv("GOOGLE_GENAI_CLIENT_MODE", "replay")
|
|
624
|
+
api_type = "vertex" if use_vertex else "mldev"
|
|
625
|
+
monkeypatch.setenv("GOOGLE_GENAI_REPLAY_ID", "test_replay_id." + api_type)
|
|
626
|
+
monkeypatch.setenv("GOOGLE_GENAI_REPLAYS_DIRECTORY", "test_replay_data")
|
|
627
|
+
|
|
628
|
+
client = Client()
|
|
629
|
+
|
|
630
|
+
assert client.models._api_client.vertexai
|
|
631
|
+
assert client.models._api_client.project == project_id
|
|
632
|
+
assert client.models._api_client.location == location
|
|
633
|
+
assert isinstance(
|
|
634
|
+
client.models._api_client, replay_api_client.ReplayApiClient
|
|
635
|
+
)
|
|
636
|
+
|
|
637
|
+
|
|
638
|
+
def test_change_client_mode_from_env(monkeypatch, use_vertex: bool):
|
|
639
|
+
api_key = "google_api_key"
|
|
640
|
+
monkeypatch.setenv("GOOGLE_API_KEY", api_key)
|
|
641
|
+
monkeypatch.setenv("GOOGLE_GENAI_CLIENT_MODE", "replay")
|
|
642
|
+
|
|
643
|
+
client1 = Client()
|
|
644
|
+
assert isinstance(
|
|
645
|
+
client1.models._api_client, replay_api_client.ReplayApiClient
|
|
646
|
+
)
|
|
647
|
+
monkeypatch.setenv("GOOGLE_GENAI_CLIENT_MODE", "")
|
|
648
|
+
|
|
649
|
+
client2 = Client()
|
|
650
|
+
assert isinstance(client2.models._api_client, api_client.BaseApiClient)
|
|
651
|
+
|
|
652
|
+
|
|
653
|
+
def test_vertexai_apikey_from_constructor(monkeypatch):
|
|
654
|
+
# Vertex AI Express mode uses API key on Vertex AI.
|
|
655
|
+
api_key = "vertexai_api_key"
|
|
656
|
+
|
|
657
|
+
# Due to proj/location taking precedence, need to clear proj/location env
|
|
658
|
+
# variables.
|
|
659
|
+
monkeypatch.setenv("GOOGLE_CLOUD_LOCATION", "")
|
|
660
|
+
monkeypatch.setenv("GOOGLE_CLOUD_PROJECT", "")
|
|
661
|
+
|
|
662
|
+
client = Client(api_key=api_key, vertexai=True)
|
|
663
|
+
|
|
664
|
+
assert client.models._api_client.vertexai
|
|
665
|
+
assert not client.models._api_client.project
|
|
666
|
+
assert not client.models._api_client.location
|
|
667
|
+
assert client.models._api_client.api_key == api_key
|
|
668
|
+
assert "aiplatform" in client._api_client._http_options.base_url
|
|
669
|
+
assert isinstance(client.models._api_client, api_client.BaseApiClient)
|
|
670
|
+
|
|
671
|
+
|
|
672
|
+
def test_vertexai_apikey_from_env_google_api_key_only(monkeypatch):
|
|
673
|
+
# Vertex AI Express mode uses API key on Vertex AI.
|
|
674
|
+
api_key = "vertexai_api_key"
|
|
675
|
+
monkeypatch.setenv("GOOGLE_API_KEY", api_key)
|
|
676
|
+
monkeypatch.delenv("GEMINI_API_KEY", raising=False)
|
|
677
|
+
|
|
678
|
+
# Due to proj/location taking precedence, need to clear proj/location env
|
|
679
|
+
# variables.
|
|
680
|
+
monkeypatch.setenv("GOOGLE_CLOUD_LOCATION", "")
|
|
681
|
+
monkeypatch.setenv("GOOGLE_CLOUD_PROJECT", "")
|
|
682
|
+
|
|
683
|
+
client = Client(vertexai=True)
|
|
684
|
+
|
|
685
|
+
assert client.models._api_client.vertexai
|
|
686
|
+
assert client.models._api_client.api_key == api_key
|
|
687
|
+
assert not client.models._api_client.project
|
|
688
|
+
assert not client.models._api_client.location
|
|
689
|
+
assert "aiplatform" in client._api_client._http_options.base_url
|
|
690
|
+
assert isinstance(client.models._api_client, api_client.BaseApiClient)
|
|
691
|
+
|
|
692
|
+
|
|
693
|
+
def test_vertexai_apikey_from_env_gemini_api_key_only(monkeypatch):
|
|
694
|
+
# Vertex AI Express mode uses API key on Vertex AI.
|
|
695
|
+
api_key = "vertexai_api_key"
|
|
696
|
+
monkeypatch.setenv("GEMINI_API_KEY", api_key)
|
|
697
|
+
monkeypatch.delenv("GOOGLE_API_KEY", raising=False)
|
|
698
|
+
|
|
699
|
+
# Due to proj/location taking precedence, need to clear proj/location env
|
|
700
|
+
# variables.
|
|
701
|
+
monkeypatch.setenv("GOOGLE_CLOUD_LOCATION", "")
|
|
702
|
+
monkeypatch.setenv("GOOGLE_CLOUD_PROJECT", "")
|
|
703
|
+
|
|
704
|
+
client = Client(vertexai=True)
|
|
705
|
+
|
|
706
|
+
assert client.models._api_client.vertexai
|
|
707
|
+
assert client.models._api_client.api_key == api_key
|
|
708
|
+
assert not client.models._api_client.project
|
|
709
|
+
assert not client.models._api_client.location
|
|
710
|
+
assert "aiplatform" in client._api_client._http_options.base_url
|
|
711
|
+
assert isinstance(client.models._api_client, api_client.BaseApiClient)
|
|
712
|
+
|
|
713
|
+
|
|
714
|
+
def test_vertexai_apikey_from_env_gemini_api_key_with_google_api_key_empty(
|
|
715
|
+
monkeypatch,
|
|
716
|
+
):
|
|
717
|
+
# Vertex AI Express mode uses API key on Vertex AI.
|
|
718
|
+
api_key = "vertexai_api_key"
|
|
719
|
+
monkeypatch.setenv("GEMINI_API_KEY", api_key)
|
|
720
|
+
monkeypatch.setenv("GOOGLE_API_KEY", "")
|
|
721
|
+
|
|
722
|
+
# Due to proj/location taking precedence, need to clear proj/location env
|
|
723
|
+
# variables.
|
|
724
|
+
monkeypatch.setenv("GOOGLE_CLOUD_LOCATION", "")
|
|
725
|
+
monkeypatch.setenv("GOOGLE_CLOUD_PROJECT", "")
|
|
726
|
+
|
|
727
|
+
client = Client(vertexai=True)
|
|
728
|
+
|
|
729
|
+
assert client.models._api_client.vertexai
|
|
730
|
+
assert client.models._api_client.api_key == api_key
|
|
731
|
+
assert not client.models._api_client.project
|
|
732
|
+
assert not client.models._api_client.location
|
|
733
|
+
assert "aiplatform" in client._api_client._http_options.base_url
|
|
734
|
+
assert isinstance(client.models._api_client, api_client.BaseApiClient)
|
|
735
|
+
|
|
736
|
+
|
|
737
|
+
def test_vertexai_apikey_from_env_both_api_keys(monkeypatch, caplog):
|
|
738
|
+
caplog.set_level(logging.DEBUG, logger="google_genai._api_client")
|
|
739
|
+
# Vertex AI Express mode uses API key on Vertex AI.
|
|
740
|
+
google_api_key = "google_api_key"
|
|
741
|
+
gemini_api_key = "vertexai_api_key"
|
|
742
|
+
monkeypatch.setenv("GEMINI_API_KEY", gemini_api_key)
|
|
743
|
+
monkeypatch.setenv("GOOGLE_API_KEY", google_api_key)
|
|
744
|
+
|
|
745
|
+
# Due to proj/location taking precedence, need to clear proj/location env
|
|
746
|
+
# variables.
|
|
747
|
+
monkeypatch.setenv("GOOGLE_CLOUD_LOCATION", "")
|
|
748
|
+
monkeypatch.setenv("GOOGLE_CLOUD_PROJECT", "")
|
|
749
|
+
|
|
750
|
+
client = Client(vertexai=True)
|
|
751
|
+
|
|
752
|
+
assert client.models._api_client.vertexai
|
|
753
|
+
assert client.models._api_client.api_key == google_api_key
|
|
754
|
+
assert not client.models._api_client.project
|
|
755
|
+
assert not client.models._api_client.location
|
|
756
|
+
assert "aiplatform" in client._api_client._http_options.base_url
|
|
757
|
+
assert isinstance(client.models._api_client, api_client.BaseApiClient)
|
|
758
|
+
assert (
|
|
759
|
+
"Both GOOGLE_API_KEY and GEMINI_API_KEY are set. Using GOOGLE_API_KEY."
|
|
760
|
+
in caplog.text
|
|
761
|
+
)
|
|
762
|
+
|
|
763
|
+
|
|
764
|
+
def test_vertexai_apikey_invalid_constructor1():
|
|
765
|
+
# Vertex AI Express mode uses API key on Vertex AI.
|
|
766
|
+
api_key = "vertexai_api_key"
|
|
767
|
+
project_id = "fake_project_id"
|
|
768
|
+
location = "fake-location"
|
|
769
|
+
|
|
770
|
+
with pytest.raises(ValueError):
|
|
771
|
+
Client(
|
|
772
|
+
api_key=api_key,
|
|
773
|
+
project=project_id,
|
|
774
|
+
location=location,
|
|
775
|
+
vertexai=True,
|
|
776
|
+
)
|
|
777
|
+
|
|
778
|
+
|
|
779
|
+
def test_vertexai_apikey_combo1(monkeypatch):
|
|
780
|
+
# Vertex AI Express mode uses API key on Vertex AI.
|
|
781
|
+
api_key = "vertexai_api_key"
|
|
782
|
+
project_id = "fake_project_id"
|
|
783
|
+
location = "fake-location"
|
|
784
|
+
monkeypatch.setenv("GOOGLE_CLOUD_PROJECT", project_id)
|
|
785
|
+
monkeypatch.setenv("GOOGLE_CLOUD_LOCATION", location)
|
|
786
|
+
monkeypatch.setenv("GOOGLE_API_KEY", "")
|
|
787
|
+
|
|
788
|
+
# Explicit api_key takes precedence over implicit project/location.
|
|
789
|
+
client = Client(vertexai=True, api_key=api_key)
|
|
790
|
+
|
|
791
|
+
assert client.models._api_client.vertexai
|
|
792
|
+
assert client.models._api_client.api_key == api_key
|
|
793
|
+
assert not client.models._api_client.project
|
|
794
|
+
assert not client.models._api_client.location
|
|
795
|
+
assert "aiplatform" in client._api_client._http_options.base_url
|
|
796
|
+
assert isinstance(client.models._api_client, api_client.BaseApiClient)
|
|
797
|
+
|
|
798
|
+
|
|
799
|
+
def test_vertexai_apikey_combo2(monkeypatch):
|
|
800
|
+
# Vertex AI Express mode uses API key on Vertex AI.
|
|
801
|
+
api_key = "vertexai_api_key"
|
|
802
|
+
project_id = "fake_project_id"
|
|
803
|
+
location = "fake-location"
|
|
804
|
+
monkeypatch.setenv("GOOGLE_CLOUD_PROJECT", "")
|
|
805
|
+
monkeypatch.setenv("GOOGLE_CLOUD_LOCATION", "")
|
|
806
|
+
monkeypatch.setenv("GOOGLE_API_KEY", api_key)
|
|
807
|
+
|
|
808
|
+
# Explicit project/location takes precedence over implicit api_key.
|
|
809
|
+
client = Client(vertexai=True, project=project_id, location=location)
|
|
810
|
+
|
|
811
|
+
assert client.models._api_client.vertexai
|
|
812
|
+
assert not client.models._api_client.api_key
|
|
813
|
+
assert client.models._api_client.project == project_id
|
|
814
|
+
assert client.models._api_client.location == location
|
|
815
|
+
assert "aiplatform" in client._api_client._http_options.base_url
|
|
816
|
+
assert isinstance(client.models._api_client, api_client.BaseApiClient)
|
|
817
|
+
|
|
818
|
+
|
|
819
|
+
def test_vertexai_apikey_combo3(monkeypatch):
|
|
820
|
+
# Vertex AI Express mode uses API key on Vertex AI.
|
|
821
|
+
project_id = "fake_project_id"
|
|
822
|
+
location = "fake-location"
|
|
823
|
+
api_key = "vertexai_api_key"
|
|
824
|
+
monkeypatch.setenv("GOOGLE_CLOUD_PROJECT", project_id)
|
|
825
|
+
monkeypatch.setenv("GOOGLE_CLOUD_LOCATION", location)
|
|
826
|
+
monkeypatch.setenv("GOOGLE_API_KEY", api_key)
|
|
827
|
+
|
|
828
|
+
# Implicit project/location takes precedence over implicit api_key.
|
|
829
|
+
client = Client(vertexai=True)
|
|
830
|
+
|
|
831
|
+
assert client.models._api_client.vertexai
|
|
832
|
+
assert not client.models._api_client.api_key
|
|
833
|
+
assert client.models._api_client.project == project_id
|
|
834
|
+
assert client.models._api_client.location == location
|
|
835
|
+
assert "aiplatform" in client._api_client._http_options.base_url
|
|
836
|
+
assert isinstance(client.models._api_client, api_client.BaseApiClient)
|
|
837
|
+
|
|
838
|
+
|
|
839
|
+
def test_vertexai_global_endpoint(monkeypatch):
|
|
840
|
+
# Vertex AI uses global endpoint when location is global.
|
|
841
|
+
project_id = "fake_project_id"
|
|
842
|
+
location = "global"
|
|
843
|
+
monkeypatch.setenv("GOOGLE_CLOUD_PROJECT", project_id)
|
|
844
|
+
monkeypatch.setenv("GOOGLE_CLOUD_LOCATION", location)
|
|
845
|
+
|
|
846
|
+
client = Client(vertexai=True, location=location)
|
|
847
|
+
|
|
848
|
+
assert client.models._api_client.vertexai
|
|
849
|
+
assert client.models._api_client.project == project_id
|
|
850
|
+
assert client.models._api_client.location == location
|
|
851
|
+
assert client.models._api_client._http_options.base_url == (
|
|
852
|
+
"https://aiplatform.googleapis.com/"
|
|
853
|
+
)
|
|
854
|
+
assert isinstance(client.models._api_client, api_client.BaseApiClient)
|
|
855
|
+
|
|
856
|
+
|
|
857
|
+
def test_client_logs_to_logger_instance(monkeypatch, caplog):
|
|
858
|
+
caplog.set_level(logging.DEBUG, logger="google_genai._api_client")
|
|
859
|
+
|
|
860
|
+
project_id = "fake_project_id"
|
|
861
|
+
location = "fake-location"
|
|
862
|
+
api_key = "vertexai_api_key"
|
|
863
|
+
monkeypatch.setenv("GOOGLE_CLOUD_PROJECT", project_id)
|
|
864
|
+
monkeypatch.setenv("GOOGLE_CLOUD_LOCATION", location)
|
|
865
|
+
|
|
866
|
+
_ = Client(vertexai=True, api_key=api_key)
|
|
867
|
+
|
|
868
|
+
assert "INFO" in caplog.text
|
|
869
|
+
assert (
|
|
870
|
+
"The user provided Vertex AI API key will take precedence" in caplog.text
|
|
871
|
+
)
|
|
872
|
+
|
|
873
|
+
|
|
874
|
+
def test_client_ssl_context_implicit_initialization():
|
|
875
|
+
client_args, async_client_args = (
|
|
876
|
+
api_client.BaseApiClient._ensure_httpx_ssl_ctx(types.HttpOptions())
|
|
877
|
+
)
|
|
878
|
+
|
|
879
|
+
assert client_args["verify"]
|
|
880
|
+
assert isinstance(client_args["verify"], ssl.SSLContext)
|
|
881
|
+
try:
|
|
882
|
+
import aiohttp # pylint: disable=g-import-not-at-top
|
|
883
|
+
|
|
884
|
+
async_client_args = api_client.BaseApiClient._ensure_aiohttp_ssl_ctx(
|
|
885
|
+
types.HttpOptions()
|
|
886
|
+
)
|
|
887
|
+
assert async_client_args["ssl"]
|
|
888
|
+
assert isinstance(async_client_args["ssl"], ssl.SSLContext)
|
|
889
|
+
except ImportError:
|
|
890
|
+
assert async_client_args["verify"]
|
|
891
|
+
assert isinstance(async_client_args["verify"], ssl.SSLContext)
|
|
892
|
+
|
|
893
|
+
|
|
894
|
+
def test_client_ssl_context_explicit_initialization_same_args():
|
|
895
|
+
ctx = ssl.create_default_context(
|
|
896
|
+
cafile=os.environ.get("SSL_CERT_FILE", certifi.where()),
|
|
897
|
+
capath=os.environ.get("SSL_CERT_DIR"),
|
|
898
|
+
)
|
|
899
|
+
|
|
900
|
+
options = types.HttpOptions(
|
|
901
|
+
client_args={"verify": ctx}, async_client_args={"verify": ctx}
|
|
902
|
+
)
|
|
903
|
+
client_args, async_client_args = (
|
|
904
|
+
api_client.BaseApiClient._ensure_httpx_ssl_ctx(options)
|
|
905
|
+
)
|
|
906
|
+
|
|
907
|
+
assert client_args["verify"] == ctx
|
|
908
|
+
try:
|
|
909
|
+
import aiohttp # pylint: disable=g-import-not-at-top
|
|
910
|
+
|
|
911
|
+
async_client_args = api_client.BaseApiClient._ensure_aiohttp_ssl_ctx(
|
|
912
|
+
options
|
|
913
|
+
)
|
|
914
|
+
assert async_client_args["ssl"]
|
|
915
|
+
assert isinstance(async_client_args["ssl"], ssl.SSLContext)
|
|
916
|
+
except ImportError:
|
|
917
|
+
assert async_client_args["verify"]
|
|
918
|
+
assert isinstance(async_client_args["verify"], ssl.SSLContext)
|
|
919
|
+
|
|
920
|
+
|
|
921
|
+
def test_client_ssl_context_explicit_initialization_separate_args():
|
|
922
|
+
ctx = ssl.create_default_context(
|
|
923
|
+
cafile=os.environ.get("SSL_CERT_FILE", certifi.where()),
|
|
924
|
+
capath=os.environ.get("SSL_CERT_DIR"),
|
|
925
|
+
)
|
|
926
|
+
|
|
927
|
+
async_ctx = ssl.create_default_context(
|
|
928
|
+
cafile=os.environ.get("SSL_CERT_FILE", certifi.where()),
|
|
929
|
+
capath=os.environ.get("SSL_CERT_DIR"),
|
|
930
|
+
)
|
|
931
|
+
|
|
932
|
+
options = types.HttpOptions(
|
|
933
|
+
client_args={"verify": ctx}, async_client_args={"verify": async_ctx}
|
|
934
|
+
)
|
|
935
|
+
client_args, async_client_args = (
|
|
936
|
+
api_client.BaseApiClient._ensure_httpx_ssl_ctx(options)
|
|
937
|
+
)
|
|
938
|
+
|
|
939
|
+
assert client_args["verify"] == ctx
|
|
940
|
+
try:
|
|
941
|
+
import aiohttp # pylint: disable=g-import-not-at-top
|
|
942
|
+
|
|
943
|
+
async_client_args = api_client.BaseApiClient._ensure_aiohttp_ssl_ctx(
|
|
944
|
+
options
|
|
945
|
+
)
|
|
946
|
+
assert async_client_args["ssl"]
|
|
947
|
+
assert isinstance(async_client_args["ssl"], ssl.SSLContext)
|
|
948
|
+
except ImportError:
|
|
949
|
+
assert async_client_args["verify"]
|
|
950
|
+
assert isinstance(async_client_args["verify"], ssl.SSLContext)
|
|
951
|
+
|
|
952
|
+
|
|
953
|
+
def test_client_ssl_context_explicit_initialization_sync_args():
|
|
954
|
+
ctx = ssl.create_default_context(
|
|
955
|
+
cafile=os.environ.get("SSL_CERT_FILE", certifi.where()),
|
|
956
|
+
capath=os.environ.get("SSL_CERT_DIR"),
|
|
957
|
+
)
|
|
958
|
+
|
|
959
|
+
options = types.HttpOptions(client_args={"verify": ctx})
|
|
960
|
+
client_args, async_client_args = (
|
|
961
|
+
api_client.BaseApiClient._ensure_httpx_ssl_ctx(options)
|
|
962
|
+
)
|
|
963
|
+
|
|
964
|
+
assert client_args["verify"] == ctx
|
|
965
|
+
try:
|
|
966
|
+
import aiohttp # pylint: disable=g-import-not-at-top
|
|
967
|
+
|
|
968
|
+
async_client_args = api_client.BaseApiClient._ensure_aiohttp_ssl_ctx(
|
|
969
|
+
options
|
|
970
|
+
)
|
|
971
|
+
assert async_client_args["ssl"]
|
|
972
|
+
assert isinstance(async_client_args["ssl"], ssl.SSLContext)
|
|
973
|
+
except ImportError:
|
|
974
|
+
assert async_client_args["verify"]
|
|
975
|
+
assert isinstance(async_client_args["verify"], ssl.SSLContext)
|
|
976
|
+
|
|
977
|
+
|
|
978
|
+
def test_client_ssl_context_explicit_initialization_async_args():
|
|
979
|
+
ctx = ssl.create_default_context(
|
|
980
|
+
cafile=os.environ.get("SSL_CERT_FILE", certifi.where()),
|
|
981
|
+
capath=os.environ.get("SSL_CERT_DIR"),
|
|
982
|
+
)
|
|
983
|
+
|
|
984
|
+
options = types.HttpOptions(async_client_args={"verify": ctx})
|
|
985
|
+
client_args, async_client_args = (
|
|
986
|
+
api_client.BaseApiClient._ensure_httpx_ssl_ctx(options)
|
|
987
|
+
)
|
|
988
|
+
|
|
989
|
+
assert client_args["verify"] == ctx
|
|
990
|
+
try:
|
|
991
|
+
import aiohttp # pylint: disable=g-import-not-at-top
|
|
992
|
+
|
|
993
|
+
async_client_args = api_client.BaseApiClient._ensure_aiohttp_ssl_ctx(
|
|
994
|
+
options
|
|
995
|
+
)
|
|
996
|
+
assert async_client_args["ssl"]
|
|
997
|
+
assert isinstance(async_client_args["ssl"], ssl.SSLContext)
|
|
998
|
+
except ImportError:
|
|
999
|
+
assert async_client_args["verify"]
|
|
1000
|
+
assert isinstance(async_client_args["verify"], ssl.SSLContext)
|
|
1001
|
+
|
|
1002
|
+
|
|
1003
|
+
def test_constructor_with_base_url_from_http_options():
|
|
1004
|
+
mldev_http_options = {
|
|
1005
|
+
"base_url": "https://placeholder-fake-url.com/",
|
|
1006
|
+
}
|
|
1007
|
+
vertexai_http_options = {
|
|
1008
|
+
"base_url": (
|
|
1009
|
+
"https://{self.location}-aiplatform.googleapis.com/{{api_version}}/"
|
|
1010
|
+
),
|
|
1011
|
+
}
|
|
1012
|
+
|
|
1013
|
+
mldev_client = Client(
|
|
1014
|
+
api_key="google_api_key", http_options=mldev_http_options
|
|
1015
|
+
)
|
|
1016
|
+
assert not mldev_client.models._api_client.vertexai
|
|
1017
|
+
assert (
|
|
1018
|
+
mldev_client.models._api_client.get_read_only_http_options()["base_url"]
|
|
1019
|
+
== "https://placeholder-fake-url.com/"
|
|
1020
|
+
)
|
|
1021
|
+
|
|
1022
|
+
vertexai_client = Client(
|
|
1023
|
+
vertexai=True,
|
|
1024
|
+
project="fake_project_id",
|
|
1025
|
+
location="fake-location",
|
|
1026
|
+
http_options=vertexai_http_options,
|
|
1027
|
+
)
|
|
1028
|
+
assert vertexai_client.models._api_client.vertexai
|
|
1029
|
+
assert (
|
|
1030
|
+
vertexai_client.models._api_client.get_read_only_http_options()[
|
|
1031
|
+
"base_url"
|
|
1032
|
+
]
|
|
1033
|
+
== "https://{self.location}-aiplatform.googleapis.com/{{api_version}}/"
|
|
1034
|
+
)
|
|
1035
|
+
|
|
1036
|
+
|
|
1037
|
+
def test_constructor_with_base_url_from_set_default_base_urls():
|
|
1038
|
+
base_url.set_default_base_urls(
|
|
1039
|
+
gemini_url="https://gemini-base-url.com/",
|
|
1040
|
+
vertex_url="https://vertex-base-url.com/",
|
|
1041
|
+
)
|
|
1042
|
+
mldev_client = Client(api_key="google_api_key")
|
|
1043
|
+
assert not mldev_client.models._api_client.vertexai
|
|
1044
|
+
assert (
|
|
1045
|
+
mldev_client.models._api_client.get_read_only_http_options()["base_url"]
|
|
1046
|
+
== "https://gemini-base-url.com/"
|
|
1047
|
+
)
|
|
1048
|
+
|
|
1049
|
+
vertexai_client = Client(
|
|
1050
|
+
vertexai=True,
|
|
1051
|
+
project="fake_project_id",
|
|
1052
|
+
location="fake-location",
|
|
1053
|
+
)
|
|
1054
|
+
assert vertexai_client.models._api_client.vertexai
|
|
1055
|
+
assert (
|
|
1056
|
+
vertexai_client.models._api_client.get_read_only_http_options()[
|
|
1057
|
+
"base_url"
|
|
1058
|
+
]
|
|
1059
|
+
== "https://vertex-base-url.com/"
|
|
1060
|
+
)
|
|
1061
|
+
base_url.set_default_base_urls(gemini_url=None, vertex_url=None)
|
|
1062
|
+
|
|
1063
|
+
|
|
1064
|
+
def test_constructor_with_constructor_base_url_overrides_set_default_base_urls():
|
|
1065
|
+
mldev_http_options = {
|
|
1066
|
+
"base_url": "https://gemini-constructor-base-url.com/",
|
|
1067
|
+
}
|
|
1068
|
+
vertexai_http_options = {
|
|
1069
|
+
"base_url": "https://vertex-constructor-base-url.com/",
|
|
1070
|
+
}
|
|
1071
|
+
|
|
1072
|
+
base_url.set_default_base_urls(
|
|
1073
|
+
gemini_url="https://gemini-base-url.com/",
|
|
1074
|
+
vertex_url="https://vertex-base-url.com/",
|
|
1075
|
+
)
|
|
1076
|
+
mldev_client = Client(
|
|
1077
|
+
api_key="google_api_key", http_options=mldev_http_options
|
|
1078
|
+
)
|
|
1079
|
+
assert not mldev_client.models._api_client.vertexai
|
|
1080
|
+
assert (
|
|
1081
|
+
mldev_client.models._api_client.get_read_only_http_options()["base_url"]
|
|
1082
|
+
== "https://gemini-constructor-base-url.com/"
|
|
1083
|
+
)
|
|
1084
|
+
|
|
1085
|
+
vertexai_client = Client(
|
|
1086
|
+
vertexai=True,
|
|
1087
|
+
project="fake_project_id",
|
|
1088
|
+
location="fake-location",
|
|
1089
|
+
http_options=vertexai_http_options,
|
|
1090
|
+
)
|
|
1091
|
+
assert vertexai_client.models._api_client.vertexai
|
|
1092
|
+
assert (
|
|
1093
|
+
vertexai_client.models._api_client.get_read_only_http_options()[
|
|
1094
|
+
"base_url"
|
|
1095
|
+
]
|
|
1096
|
+
== "https://vertex-constructor-base-url.com/"
|
|
1097
|
+
)
|
|
1098
|
+
base_url.set_default_base_urls(gemini_url=None, vertex_url=None)
|
|
1099
|
+
|
|
1100
|
+
|
|
1101
|
+
def test_constructor_with_constructor_base_url_overrides_environment_variables(
|
|
1102
|
+
monkeypatch,
|
|
1103
|
+
):
|
|
1104
|
+
monkeypatch.setenv(
|
|
1105
|
+
"GOOGLE_GEMINI_BASE_URL", "https://gemini-env-base-url.com/"
|
|
1106
|
+
)
|
|
1107
|
+
monkeypatch.setenv(
|
|
1108
|
+
"GOOGLE_VERTEX_BASE_URL", "https://vertex-env-base-url.com/"
|
|
1109
|
+
)
|
|
1110
|
+
|
|
1111
|
+
mldev_http_options = {
|
|
1112
|
+
"base_url": "https://gemini-constructor-base-url.com/",
|
|
1113
|
+
}
|
|
1114
|
+
vertexai_http_options = {
|
|
1115
|
+
"base_url": "https://vertex-constructor-base-url.com/",
|
|
1116
|
+
}
|
|
1117
|
+
|
|
1118
|
+
mldev_client = Client(
|
|
1119
|
+
api_key="google_api_key", http_options=mldev_http_options
|
|
1120
|
+
)
|
|
1121
|
+
assert not mldev_client.models._api_client.vertexai
|
|
1122
|
+
assert (
|
|
1123
|
+
mldev_client.models._api_client.get_read_only_http_options()["base_url"]
|
|
1124
|
+
== "https://gemini-constructor-base-url.com/"
|
|
1125
|
+
)
|
|
1126
|
+
|
|
1127
|
+
vertexai_client = Client(
|
|
1128
|
+
vertexai=True,
|
|
1129
|
+
project="fake_project_id",
|
|
1130
|
+
location="fake-location",
|
|
1131
|
+
http_options=vertexai_http_options,
|
|
1132
|
+
)
|
|
1133
|
+
assert vertexai_client.models._api_client.vertexai
|
|
1134
|
+
assert (
|
|
1135
|
+
vertexai_client.models._api_client.get_read_only_http_options()[
|
|
1136
|
+
"base_url"
|
|
1137
|
+
]
|
|
1138
|
+
== "https://vertex-constructor-base-url.com/"
|
|
1139
|
+
)
|
|
1140
|
+
base_url.set_default_base_urls(gemini_url=None, vertex_url=None)
|
|
1141
|
+
|
|
1142
|
+
|
|
1143
|
+
def test_constructor_with_base_url_from_set_default_base_urls_overrides_environment_variables(
|
|
1144
|
+
monkeypatch,
|
|
1145
|
+
):
|
|
1146
|
+
monkeypatch.setenv(
|
|
1147
|
+
"GOOGLE_GEMINI_BASE_URL", "https://gemini-env-base-url.com/"
|
|
1148
|
+
)
|
|
1149
|
+
monkeypatch.setenv(
|
|
1150
|
+
"GOOGLE_VERTEX_BASE_URL", "https://vertex-env-base-url.com/"
|
|
1151
|
+
)
|
|
1152
|
+
|
|
1153
|
+
base_url.set_default_base_urls(
|
|
1154
|
+
gemini_url="https://gemini-base-url.com/",
|
|
1155
|
+
vertex_url="https://vertex-base-url.com/",
|
|
1156
|
+
)
|
|
1157
|
+
mldev_client = Client(api_key="google_api_key")
|
|
1158
|
+
assert not mldev_client.models._api_client.vertexai
|
|
1159
|
+
assert (
|
|
1160
|
+
mldev_client.models._api_client.get_read_only_http_options()["base_url"]
|
|
1161
|
+
== "https://gemini-base-url.com/"
|
|
1162
|
+
)
|
|
1163
|
+
|
|
1164
|
+
vertexai_client = Client(
|
|
1165
|
+
vertexai=True,
|
|
1166
|
+
project="fake_project_id",
|
|
1167
|
+
location="fake-location",
|
|
1168
|
+
)
|
|
1169
|
+
assert vertexai_client.models._api_client.vertexai
|
|
1170
|
+
assert (
|
|
1171
|
+
vertexai_client.models._api_client.get_read_only_http_options()[
|
|
1172
|
+
"base_url"
|
|
1173
|
+
]
|
|
1174
|
+
== "https://vertex-base-url.com/"
|
|
1175
|
+
)
|
|
1176
|
+
base_url.set_default_base_urls(gemini_url=None, vertex_url=None)
|
|
1177
|
+
|
|
1178
|
+
|
|
1179
|
+
def test_constructor_with_base_url_from_environment_variables(monkeypatch):
|
|
1180
|
+
monkeypatch.setenv("GOOGLE_GEMINI_BASE_URL", "https://gemini-base-url.com/")
|
|
1181
|
+
monkeypatch.setenv("GOOGLE_VERTEX_BASE_URL", "https://vertex-base-url.com/")
|
|
1182
|
+
|
|
1183
|
+
mldev_client = Client(api_key="google_api_key")
|
|
1184
|
+
assert not mldev_client.models._api_client.vertexai
|
|
1185
|
+
assert (
|
|
1186
|
+
mldev_client.models._api_client.get_read_only_http_options()["base_url"]
|
|
1187
|
+
== "https://gemini-base-url.com/"
|
|
1188
|
+
)
|
|
1189
|
+
|
|
1190
|
+
vertexai_client = Client(
|
|
1191
|
+
vertexai=True,
|
|
1192
|
+
project="fake_project_id",
|
|
1193
|
+
location="fake-location",
|
|
1194
|
+
)
|
|
1195
|
+
assert vertexai_client.models._api_client.vertexai
|
|
1196
|
+
assert (
|
|
1197
|
+
vertexai_client.models._api_client.get_read_only_http_options()[
|
|
1198
|
+
"base_url"
|
|
1199
|
+
]
|
|
1200
|
+
== "https://vertex-base-url.com/"
|
|
1201
|
+
)
|
|
1202
|
+
|
|
1203
|
+
|
|
1204
|
+
def test_async_transport_absence_allows_aiohttp_to_be_used():
|
|
1205
|
+
client = Client(
|
|
1206
|
+
vertexai=True,
|
|
1207
|
+
project="fake_project_id",
|
|
1208
|
+
location="fake-location",
|
|
1209
|
+
)
|
|
1210
|
+
|
|
1211
|
+
api_client.has_aiohttp = False
|
|
1212
|
+
assert not client._api_client._use_aiohttp()
|
|
1213
|
+
|
|
1214
|
+
api_client.has_aiohttp = True
|
|
1215
|
+
assert client._api_client._use_aiohttp()
|
|
1216
|
+
|
|
1217
|
+
|
|
1218
|
+
def test_async_async_client_args_without_transport_allows_aiohttp_to_be_used():
|
|
1219
|
+
client = Client(
|
|
1220
|
+
vertexai=True,
|
|
1221
|
+
project="fake_project_id",
|
|
1222
|
+
location="fake-location",
|
|
1223
|
+
http_options=types.HttpOptions(async_client_args={}),
|
|
1224
|
+
)
|
|
1225
|
+
|
|
1226
|
+
api_client.has_aiohttp = False
|
|
1227
|
+
assert not client._api_client._use_aiohttp()
|
|
1228
|
+
|
|
1229
|
+
api_client.has_aiohttp = True
|
|
1230
|
+
assert client._api_client._use_aiohttp()
|
|
1231
|
+
|
|
1232
|
+
|
|
1233
|
+
def test_async_transport_forces_httpx_regardless_of_aiohttp_availability():
|
|
1234
|
+
|
|
1235
|
+
client = Client(
|
|
1236
|
+
vertexai=True,
|
|
1237
|
+
project="fake_project_id",
|
|
1238
|
+
location="fake-location",
|
|
1239
|
+
http_options=types.HttpOptions(
|
|
1240
|
+
async_client_args={"transport": httpx.AsyncBaseTransport()}
|
|
1241
|
+
),
|
|
1242
|
+
)
|
|
1243
|
+
|
|
1244
|
+
api_client.has_aiohttp = False
|
|
1245
|
+
assert not client._api_client._use_aiohttp()
|
|
1246
|
+
|
|
1247
|
+
api_client.has_aiohttp = True
|
|
1248
|
+
assert not client._api_client._use_aiohttp()
|
|
1249
|
+
|
|
1250
|
+
|
|
1251
|
+
@pytest.mark.asyncio
|
|
1252
|
+
async def test_get_async_auth_lock_basic_functionality():
|
|
1253
|
+
"""Tests that _get_async_auth_lock returns an asyncio.Lock."""
|
|
1254
|
+
client = Client(
|
|
1255
|
+
vertexai=True, project="fake_project_id", location="fake-location"
|
|
1256
|
+
)
|
|
1257
|
+
|
|
1258
|
+
lock = await client._api_client._get_async_auth_lock()
|
|
1259
|
+
assert isinstance(lock, asyncio.Lock)
|
|
1260
|
+
assert client._api_client._async_auth_lock is lock
|
|
1261
|
+
|
|
1262
|
+
|
|
1263
|
+
@pytest.mark.asyncio
|
|
1264
|
+
async def test_get_async_auth_lock_returns_same_instance():
|
|
1265
|
+
"""Tests that multiple calls return the same lock instance."""
|
|
1266
|
+
client = Client(
|
|
1267
|
+
vertexai=True, project="fake_project_id", location="fake-location"
|
|
1268
|
+
)
|
|
1269
|
+
lock1 = await client._api_client._get_async_auth_lock()
|
|
1270
|
+
lock2 = await client._api_client._get_async_auth_lock()
|
|
1271
|
+
lock3 = await client._api_client._get_async_auth_lock()
|
|
1272
|
+
assert lock1 is lock2
|
|
1273
|
+
assert lock2 is lock3
|
|
1274
|
+
assert isinstance(lock1, asyncio.Lock)
|
|
1275
|
+
|
|
1276
|
+
|
|
1277
|
+
def test_threaded_generate_content_locking(monkeypatch):
|
|
1278
|
+
"""Tests that synchronous API calls are thread-safe."""
|
|
1279
|
+
monkeypatch.delenv("GOOGLE_GENAI_CLIENT_MODE", raising=False)
|
|
1280
|
+
# Mock credentials
|
|
1281
|
+
mock_creds = mock.Mock(spec=credentials.Credentials)
|
|
1282
|
+
mock_creds.token = "initial-token"
|
|
1283
|
+
mock_creds.expired = False
|
|
1284
|
+
mock_creds.quota_project_id = None
|
|
1285
|
+
|
|
1286
|
+
# Mock google.auth.default
|
|
1287
|
+
mock_auth_default = mock.Mock(return_value=(mock_creds, "test-project"))
|
|
1288
|
+
monkeypatch.setattr(google.auth, "default", mock_auth_default)
|
|
1289
|
+
|
|
1290
|
+
# Mock Credentials.refresh
|
|
1291
|
+
def refresh_side_effect(request):
|
|
1292
|
+
mock_creds.token = "refreshed-token"
|
|
1293
|
+
mock_creds.expired = False
|
|
1294
|
+
|
|
1295
|
+
mock_refresh = mock.Mock(side_effect=refresh_side_effect)
|
|
1296
|
+
mock_creds.refresh = mock_refresh
|
|
1297
|
+
|
|
1298
|
+
# Mock the actual request to avoid network calls
|
|
1299
|
+
mock_httpx_response = httpx.Response(
|
|
1300
|
+
status_code=200,
|
|
1301
|
+
headers={},
|
|
1302
|
+
text='{"candidates": [{"content": {"parts": [{"text": "response"}]}}]}',
|
|
1303
|
+
)
|
|
1304
|
+
mock_request = mock.Mock(return_value=mock_httpx_response)
|
|
1305
|
+
monkeypatch.setattr(api_client.SyncHttpxClient, "request", mock_request)
|
|
1306
|
+
|
|
1307
|
+
client = Client(
|
|
1308
|
+
vertexai=True, project="fake_project_id", location="fake-location"
|
|
1309
|
+
)
|
|
1310
|
+
# Reset credentials to test initialization to ensure the sync lock is tested.
|
|
1311
|
+
client._api_client._credentials = None
|
|
1312
|
+
|
|
1313
|
+
# 1. Test initial credential loading in multiple threads
|
|
1314
|
+
with concurrent.futures.ThreadPoolExecutor(max_workers=3) as executor:
|
|
1315
|
+
futures = [
|
|
1316
|
+
executor.submit(
|
|
1317
|
+
client.models.generate_content, model="gemini-pro", contents=str(i)
|
|
1318
|
+
)
|
|
1319
|
+
for i in range(10)
|
|
1320
|
+
]
|
|
1321
|
+
for future in concurrent.futures.as_completed(futures):
|
|
1322
|
+
assert future.result().text == "response"
|
|
1323
|
+
|
|
1324
|
+
mock_auth_default.assert_called_once()
|
|
1325
|
+
mock_refresh.assert_not_called()
|
|
1326
|
+
assert mock_request.call_count == 10
|
|
1327
|
+
|
|
1328
|
+
# 2. Test credential refreshing in multiple threads
|
|
1329
|
+
mock_creds.expired = True
|
|
1330
|
+
with concurrent.futures.ThreadPoolExecutor(max_workers=3) as executor:
|
|
1331
|
+
futures = [
|
|
1332
|
+
executor.submit(
|
|
1333
|
+
client.models.generate_content, model="gemini-pro", contents=str(i)
|
|
1334
|
+
)
|
|
1335
|
+
for i in range(10)
|
|
1336
|
+
]
|
|
1337
|
+
for future in concurrent.futures.as_completed(futures):
|
|
1338
|
+
assert future.result().text == "response"
|
|
1339
|
+
|
|
1340
|
+
mock_auth_default.assert_called_once()
|
|
1341
|
+
mock_refresh.assert_called_once()
|
|
1342
|
+
assert mock_request.call_count == 20
|
|
1343
|
+
|
|
1344
|
+
|
|
1345
|
+
@pytest.mark.asyncio
|
|
1346
|
+
async def test_async_access_token_locking(monkeypatch):
|
|
1347
|
+
"""Tests that _async_access_token uses locks to prevent race conditions."""
|
|
1348
|
+
# Mock credentials
|
|
1349
|
+
mock_creds = mock.Mock(spec=credentials.Credentials)
|
|
1350
|
+
mock_creds.token = "initial-token"
|
|
1351
|
+
mock_creds.expired = False
|
|
1352
|
+
mock_creds.quota_project_id = None
|
|
1353
|
+
|
|
1354
|
+
# Mock google.auth.default
|
|
1355
|
+
mock_auth_default = mock.Mock(return_value=(mock_creds, "test-project"))
|
|
1356
|
+
monkeypatch.setattr(google.auth, "default", mock_auth_default)
|
|
1357
|
+
|
|
1358
|
+
# Mock Credentials.refresh
|
|
1359
|
+
def refresh_side_effect(request):
|
|
1360
|
+
mock_creds.token = "refreshed-token"
|
|
1361
|
+
mock_creds.expired = False
|
|
1362
|
+
|
|
1363
|
+
mock_refresh = mock.Mock(side_effect=refresh_side_effect)
|
|
1364
|
+
mock_creds.refresh = mock_refresh
|
|
1365
|
+
|
|
1366
|
+
client = Client(
|
|
1367
|
+
vertexai=True, project="fake_project_id", location="fake-location"
|
|
1368
|
+
)
|
|
1369
|
+
# Reset credentials to test initialization to ensure the async lock is tested.
|
|
1370
|
+
client._api_client._credentials = None
|
|
1371
|
+
|
|
1372
|
+
# 1. Test initial credential loading
|
|
1373
|
+
# Running them concurrently should result in only one call to load_auth.
|
|
1374
|
+
tokens = await asyncio.gather(
|
|
1375
|
+
client._api_client._async_access_token(),
|
|
1376
|
+
client._api_client._async_access_token(),
|
|
1377
|
+
client._api_client._async_access_token(),
|
|
1378
|
+
)
|
|
1379
|
+
|
|
1380
|
+
assert tokens == ["initial-token", "initial-token", "initial-token"]
|
|
1381
|
+
mock_auth_default.assert_called_once()
|
|
1382
|
+
mock_refresh.assert_not_called()
|
|
1383
|
+
|
|
1384
|
+
# 2. Test credential refreshing
|
|
1385
|
+
# Now the token is "expired", so the next call should refresh it.
|
|
1386
|
+
mock_creds.expired = True
|
|
1387
|
+
|
|
1388
|
+
# Running them concurrently should result in only one call to refresh.
|
|
1389
|
+
tokens = await asyncio.gather(
|
|
1390
|
+
client._api_client._async_access_token(),
|
|
1391
|
+
client._api_client._async_access_token(),
|
|
1392
|
+
client._api_client._async_access_token(),
|
|
1393
|
+
)
|
|
1394
|
+
|
|
1395
|
+
assert tokens == ["refreshed-token", "refreshed-token", "refreshed-token"]
|
|
1396
|
+
# google.auth.default should still have been called only once in total.
|
|
1397
|
+
mock_auth_default.assert_called_once()
|
|
1398
|
+
mock_refresh.assert_called_once()
|
|
1399
|
+
|
|
1400
|
+
|
|
1401
|
+
@pytest.mark.asyncio
|
|
1402
|
+
async def test_get_async_auth_lock_concurrent_access():
|
|
1403
|
+
"""Tests that concurrent access to _get_async_auth_lock is thread-safe."""
|
|
1404
|
+
client = Client(
|
|
1405
|
+
vertexai=True, project="fake_project_id", location="fake-location"
|
|
1406
|
+
)
|
|
1407
|
+
|
|
1408
|
+
# Run multiple concurrent calls
|
|
1409
|
+
async def get_lock_task(task_id: int):
|
|
1410
|
+
lock = await client._api_client._get_async_auth_lock()
|
|
1411
|
+
return task_id, id(lock)
|
|
1412
|
+
|
|
1413
|
+
tasks = [get_lock_task(i) for i in range(20)]
|
|
1414
|
+
results = await asyncio.gather(*tasks)
|
|
1415
|
+
|
|
1416
|
+
# All tasks should get the same lock instance
|
|
1417
|
+
lock_ids = [result[1] for result in results]
|
|
1418
|
+
assert all(
|
|
1419
|
+
lock_id == lock_ids[0] for lock_id in lock_ids
|
|
1420
|
+
), "All tasks should get the same lock instance"
|
|
1421
|
+
|
|
1422
|
+
# All tasks should complete
|
|
1423
|
+
task_ids = [result[0] for result in results]
|
|
1424
|
+
assert sorted(task_ids) == list(range(20)), "All tasks should complete"
|
|
1425
|
+
|
|
1426
|
+
|
|
1427
|
+
@pytest.mark.asyncio
|
|
1428
|
+
async def test_get_async_auth_lock_doesnt_block_other_operations():
|
|
1429
|
+
"""Tests that _get_async_auth_lock doesn't interfere with other async operations."""
|
|
1430
|
+
client = Client(
|
|
1431
|
+
vertexai=True, project="fake_project_id", location="fake-location"
|
|
1432
|
+
)
|
|
1433
|
+
|
|
1434
|
+
# Track completion of other async operations
|
|
1435
|
+
completed_operations = []
|
|
1436
|
+
|
|
1437
|
+
async def mock_async_operation(op_id: int):
|
|
1438
|
+
await asyncio.sleep(0.01) # Small delay to simulate async work
|
|
1439
|
+
completed_operations.append(op_id)
|
|
1440
|
+
return f"operation_{op_id}"
|
|
1441
|
+
|
|
1442
|
+
# Start auth lock requests and other operations simultaneously
|
|
1443
|
+
start_time = asyncio.get_event_loop().time()
|
|
1444
|
+
|
|
1445
|
+
auth_tasks = [client._api_client._get_async_auth_lock() for _ in range(10)]
|
|
1446
|
+
work_tasks = [mock_async_operation(i) for i in range(15)]
|
|
1447
|
+
|
|
1448
|
+
auth_results, work_results = await asyncio.gather(
|
|
1449
|
+
asyncio.gather(*auth_tasks), asyncio.gather(*work_tasks)
|
|
1450
|
+
)
|
|
1451
|
+
|
|
1452
|
+
end_time = asyncio.get_event_loop().time()
|
|
1453
|
+
total_time = end_time - start_time
|
|
1454
|
+
|
|
1455
|
+
# Verify all operations completed
|
|
1456
|
+
assert len(auth_results) == 10, "All auth lock requests should complete"
|
|
1457
|
+
assert len(work_results) == 15, "All work tasks should complete"
|
|
1458
|
+
assert len(completed_operations) == 15, "All async operations should complete"
|
|
1459
|
+
|
|
1460
|
+
# All auth requests should return the same lock
|
|
1461
|
+
lock_ids = [id(lock) for lock in auth_results]
|
|
1462
|
+
assert all(lock_id == lock_ids[0] for lock_id in lock_ids)
|
|
1463
|
+
|
|
1464
|
+
# Should complete quickly since operations run concurrently
|
|
1465
|
+
assert total_time < 0.1, (
|
|
1466
|
+
f"Operations took too long ({total_time:.3f}s), suggesting blocking"
|
|
1467
|
+
" occurred"
|
|
1468
|
+
)
|
|
1469
|
+
|
|
1470
|
+
|
|
1471
|
+
@pytest.mark.asyncio
|
|
1472
|
+
async def test_get_async_auth_lock_creation_lock_lifecycle():
|
|
1473
|
+
"""Tests the creation lock lifecycle and cleanup."""
|
|
1474
|
+
client = Client(
|
|
1475
|
+
vertexai=True, project="fake_project_id", location="fake-location"
|
|
1476
|
+
)
|
|
1477
|
+
|
|
1478
|
+
# Initially, both locks should be None
|
|
1479
|
+
assert client._api_client._async_auth_lock is None
|
|
1480
|
+
assert client._api_client._async_auth_lock_creation_lock is None
|
|
1481
|
+
|
|
1482
|
+
# After first call, both should exist
|
|
1483
|
+
lock1 = await client._api_client._get_async_auth_lock()
|
|
1484
|
+
assert client._api_client._async_auth_lock is not None
|
|
1485
|
+
assert client._api_client._async_auth_lock_creation_lock is not None
|
|
1486
|
+
assert isinstance(lock1, asyncio.Lock)
|
|
1487
|
+
|
|
1488
|
+
# Creation lock should be different from the auth lock
|
|
1489
|
+
creation_lock = client._api_client._async_auth_lock_creation_lock
|
|
1490
|
+
assert creation_lock is not lock1
|
|
1491
|
+
assert isinstance(creation_lock, asyncio.Lock)
|
|
1492
|
+
|
|
1493
|
+
# Subsequent calls should reuse both locks
|
|
1494
|
+
lock2 = await client._api_client._get_async_auth_lock()
|
|
1495
|
+
assert lock2 is lock1
|
|
1496
|
+
assert client._api_client._async_auth_lock_creation_lock is creation_lock
|
|
1497
|
+
|
|
1498
|
+
|
|
1499
|
+
@pytest.mark.asyncio
|
|
1500
|
+
async def test_get_async_auth_lock_under_load():
|
|
1501
|
+
"""Tests _get_async_auth_lock under heavy concurrent load."""
|
|
1502
|
+
client = Client(
|
|
1503
|
+
vertexai=True, project="fake_project_id", location="fake-location"
|
|
1504
|
+
)
|
|
1505
|
+
|
|
1506
|
+
num_concurrent_calls = 100
|
|
1507
|
+
|
|
1508
|
+
async def get_lock_with_timing(call_id: int):
|
|
1509
|
+
start = asyncio.get_event_loop().time()
|
|
1510
|
+
lock = await client._api_client._get_async_auth_lock()
|
|
1511
|
+
end = asyncio.get_event_loop().time()
|
|
1512
|
+
return call_id, id(lock), end - start
|
|
1513
|
+
|
|
1514
|
+
# Run many concurrent calls
|
|
1515
|
+
start_time = asyncio.get_event_loop().time()
|
|
1516
|
+
tasks = [get_lock_with_timing(i) for i in range(num_concurrent_calls)]
|
|
1517
|
+
results = await asyncio.gather(*tasks)
|
|
1518
|
+
total_time = asyncio.get_event_loop().time() - start_time
|
|
1519
|
+
|
|
1520
|
+
# Verify all calls succeeded and got the same lock
|
|
1521
|
+
call_ids = [r[0] for r in results]
|
|
1522
|
+
lock_ids = [r[1] for r in results]
|
|
1523
|
+
call_times = [r[2] for r in results]
|
|
1524
|
+
|
|
1525
|
+
assert len(results) == num_concurrent_calls
|
|
1526
|
+
assert sorted(call_ids) == list(range(num_concurrent_calls))
|
|
1527
|
+
assert all(
|
|
1528
|
+
lock_id == lock_ids[0] for lock_id in lock_ids
|
|
1529
|
+
), "All calls should get same lock"
|
|
1530
|
+
|
|
1531
|
+
# Performance checks
|
|
1532
|
+
max_call_time = max(call_times)
|
|
1533
|
+
assert total_time < 1.0, f"Total time ({total_time:.3f}s) suggests blocking"
|
|
1534
|
+
assert (
|
|
1535
|
+
max_call_time < 0.1
|
|
1536
|
+
), f"Max individual call time ({max_call_time:.3f}s) too high"
|
|
1537
|
+
|
|
1538
|
+
|
|
1539
|
+
@pytest.mark.asyncio
|
|
1540
|
+
async def test_get_async_auth_lock_interleaved_with_auth_operations():
|
|
1541
|
+
"""Tests _get_async_auth_lock working correctly with actual auth operations."""
|
|
1542
|
+
client = Client(
|
|
1543
|
+
vertexai=True, project="fake_project_id", location="fake-location"
|
|
1544
|
+
)
|
|
1545
|
+
|
|
1546
|
+
# Mock credentials for this test
|
|
1547
|
+
mock_creds = mock.Mock(spec=credentials.Credentials)
|
|
1548
|
+
mock_creds.token = "test-token"
|
|
1549
|
+
mock_creds.expired = False
|
|
1550
|
+
mock_creds.quota_project_id = None
|
|
1551
|
+
client._api_client._credentials = mock_creds
|
|
1552
|
+
|
|
1553
|
+
# Mix lock requests with simulated auth operations
|
|
1554
|
+
async def auth_operation(op_id: int):
|
|
1555
|
+
# This simulates what _async_access_token does
|
|
1556
|
+
lock = await client._api_client._get_async_auth_lock()
|
|
1557
|
+
async with lock:
|
|
1558
|
+
await asyncio.sleep(0.001) # Simulate auth work
|
|
1559
|
+
return f"auth_op_{op_id}"
|
|
1560
|
+
|
|
1561
|
+
async def lock_request(req_id: int):
|
|
1562
|
+
lock = await client._api_client._get_async_auth_lock()
|
|
1563
|
+
return req_id, id(lock)
|
|
1564
|
+
|
|
1565
|
+
# Interleave different types of operations
|
|
1566
|
+
auth_tasks = [auth_operation(i) for i in range(10)]
|
|
1567
|
+
lock_tasks = [lock_request(i) for i in range(10)]
|
|
1568
|
+
|
|
1569
|
+
auth_results, lock_results = await asyncio.gather(
|
|
1570
|
+
asyncio.gather(*auth_tasks), asyncio.gather(*lock_tasks)
|
|
1571
|
+
)
|
|
1572
|
+
|
|
1573
|
+
# Verify all operations completed
|
|
1574
|
+
assert len(auth_results) == 10
|
|
1575
|
+
assert len(lock_results) == 10
|
|
1576
|
+
|
|
1577
|
+
# All lock requests should return the same lock ID
|
|
1578
|
+
lock_ids = [result[1] for result in lock_results]
|
|
1579
|
+
assert all(lock_id == lock_ids[0] for lock_id in lock_ids)
|
|
1580
|
+
|
|
1581
|
+
# Auth operations should complete successfully
|
|
1582
|
+
assert all(result.startswith("auth_op_") for result in auth_results)
|
|
1583
|
+
|
|
1584
|
+
|
|
1585
|
+
@pytest.mark.asyncio
|
|
1586
|
+
async def test_get_async_auth_lock_with_event_loop_switch():
|
|
1587
|
+
"""Tests that _get_async_auth_lock works correctly with event loop context."""
|
|
1588
|
+
|
|
1589
|
+
async def create_client_and_get_lock():
|
|
1590
|
+
client = Client(
|
|
1591
|
+
vertexai=True, project="fake_project_id", location="fake-location"
|
|
1592
|
+
)
|
|
1593
|
+
lock = await client._api_client._get_async_auth_lock()
|
|
1594
|
+
return client, lock
|
|
1595
|
+
|
|
1596
|
+
# Create client and get lock in current event loop
|
|
1597
|
+
client, lock1 = await create_client_and_get_lock()
|
|
1598
|
+
|
|
1599
|
+
# Get lock again in same event loop
|
|
1600
|
+
lock2 = await client._api_client._get_async_auth_lock()
|
|
1601
|
+
|
|
1602
|
+
assert lock1 is lock2
|
|
1603
|
+
assert isinstance(lock1, asyncio.Lock)
|
|
1604
|
+
|
|
1605
|
+
# Verify the locks work correctly
|
|
1606
|
+
async def test_lock_functionality():
|
|
1607
|
+
async with lock1:
|
|
1608
|
+
await asyncio.sleep(0.001)
|
|
1609
|
+
return "success"
|
|
1610
|
+
|
|
1611
|
+
result = await test_lock_functionality()
|
|
1612
|
+
assert result == "success"
|
|
1613
|
+
|
|
1614
|
+
|
|
1615
|
+
@pytest.mark.asyncio
|
|
1616
|
+
async def test_get_async_auth_lock_double_checked_locking():
|
|
1617
|
+
"""Tests the double-checked locking pattern implementation."""
|
|
1618
|
+
client = Client(
|
|
1619
|
+
vertexai=True, project="fake_project_id", location="fake-location"
|
|
1620
|
+
)
|
|
1621
|
+
|
|
1622
|
+
original_lock_init = asyncio.Lock.__init__
|
|
1623
|
+
lock_creation_count = [0]
|
|
1624
|
+
|
|
1625
|
+
def counting_lock_init(self):
|
|
1626
|
+
lock_creation_count[0] += 1
|
|
1627
|
+
return original_lock_init(self)
|
|
1628
|
+
|
|
1629
|
+
# Patch asyncio.Lock to count creations
|
|
1630
|
+
asyncio.Lock.__init__ = counting_lock_init
|
|
1631
|
+
|
|
1632
|
+
try:
|
|
1633
|
+
# Run many concurrent requests
|
|
1634
|
+
tasks = [client._api_client._get_async_auth_lock() for _ in range(50)]
|
|
1635
|
+
locks = await asyncio.gather(*tasks)
|
|
1636
|
+
|
|
1637
|
+
# All should be the same instance
|
|
1638
|
+
assert all(lock is locks[0] for lock in locks)
|
|
1639
|
+
|
|
1640
|
+
# Should only create 2 locks: creation_lock + auth_lock
|
|
1641
|
+
# (Could be slightly more due to asyncio internals, but should be minimal)
|
|
1642
|
+
assert (
|
|
1643
|
+
lock_creation_count[0] <= 5
|
|
1644
|
+
), f"Created {lock_creation_count[0]} locks, expected ~2"
|
|
1645
|
+
|
|
1646
|
+
finally:
|
|
1647
|
+
asyncio.Lock.__init__ = original_lock_init
|
|
1648
|
+
|
|
1649
|
+
|
|
1650
|
+
@pytest.mark.asyncio
|
|
1651
|
+
async def test_get_async_auth_lock_memory_efficiency():
|
|
1652
|
+
"""Tests that _get_async_auth_lock doesn't leak memory under repeated use."""
|
|
1653
|
+
client = Client(
|
|
1654
|
+
vertexai=True, project="fake_project_id", location="fake-location"
|
|
1655
|
+
)
|
|
1656
|
+
initial_lock = await client._api_client._get_async_auth_lock()
|
|
1657
|
+
initial_creation_lock = client._api_client._async_auth_lock_creation_lock
|
|
1658
|
+
|
|
1659
|
+
# Run many operations
|
|
1660
|
+
for _ in range(100):
|
|
1661
|
+
lock = await client._api_client._get_async_auth_lock()
|
|
1662
|
+
assert lock is initial_lock
|
|
1663
|
+
assert (
|
|
1664
|
+
client._api_client._async_auth_lock_creation_lock
|
|
1665
|
+
is initial_creation_lock
|
|
1666
|
+
)
|
|
1667
|
+
# Verify no new objects were created
|
|
1668
|
+
final_lock = await client._api_client._get_async_auth_lock()
|
|
1669
|
+
final_creation_lock = client._api_client._async_auth_lock_creation_lock
|
|
1670
|
+
|
|
1671
|
+
assert final_lock is initial_lock
|
|
1672
|
+
assert final_creation_lock is initial_creation_lock
|
|
1673
|
+
|
|
1674
|
+
|
|
1675
|
+
@requires_aiohttp
|
|
1676
|
+
@pytest.mark.asyncio
|
|
1677
|
+
async def test_get_aiohttp_session():
|
|
1678
|
+
"""Tests that _get_async_auth_lock works correctly with aiohttp session lock."""
|
|
1679
|
+
|
|
1680
|
+
client = Client(
|
|
1681
|
+
vertexai=True, project="fake_project_id", location="fake-location"
|
|
1682
|
+
)
|
|
1683
|
+
api_client.has_aiohttp = True
|
|
1684
|
+
initial_session = await client._api_client._get_aiohttp_session()
|
|
1685
|
+
assert initial_session is not None
|
|
1686
|
+
session = await client._api_client._get_aiohttp_session()
|
|
1687
|
+
assert session is initial_session
|