atlas-chat 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- atlas/__init__.py +40 -0
- atlas/application/__init__.py +7 -0
- atlas/application/chat/__init__.py +7 -0
- atlas/application/chat/agent/__init__.py +10 -0
- atlas/application/chat/agent/act_loop.py +179 -0
- atlas/application/chat/agent/factory.py +142 -0
- atlas/application/chat/agent/protocols.py +46 -0
- atlas/application/chat/agent/react_loop.py +338 -0
- atlas/application/chat/agent/think_act_loop.py +171 -0
- atlas/application/chat/approval_manager.py +151 -0
- atlas/application/chat/elicitation_manager.py +191 -0
- atlas/application/chat/events/__init__.py +1 -0
- atlas/application/chat/events/agent_event_relay.py +112 -0
- atlas/application/chat/modes/__init__.py +1 -0
- atlas/application/chat/modes/agent.py +125 -0
- atlas/application/chat/modes/plain.py +74 -0
- atlas/application/chat/modes/rag.py +81 -0
- atlas/application/chat/modes/tools.py +179 -0
- atlas/application/chat/orchestrator.py +213 -0
- atlas/application/chat/policies/__init__.py +1 -0
- atlas/application/chat/policies/tool_authorization.py +99 -0
- atlas/application/chat/preprocessors/__init__.py +1 -0
- atlas/application/chat/preprocessors/message_builder.py +92 -0
- atlas/application/chat/preprocessors/prompt_override_service.py +104 -0
- atlas/application/chat/service.py +454 -0
- atlas/application/chat/utilities/__init__.py +6 -0
- atlas/application/chat/utilities/error_handler.py +367 -0
- atlas/application/chat/utilities/event_notifier.py +546 -0
- atlas/application/chat/utilities/file_processor.py +613 -0
- atlas/application/chat/utilities/tool_executor.py +789 -0
- atlas/atlas_chat_cli.py +347 -0
- atlas/atlas_client.py +238 -0
- atlas/core/__init__.py +0 -0
- atlas/core/auth.py +205 -0
- atlas/core/authorization_manager.py +27 -0
- atlas/core/capabilities.py +123 -0
- atlas/core/compliance.py +215 -0
- atlas/core/domain_whitelist.py +147 -0
- atlas/core/domain_whitelist_middleware.py +82 -0
- atlas/core/http_client.py +28 -0
- atlas/core/log_sanitizer.py +102 -0
- atlas/core/metrics_logger.py +59 -0
- atlas/core/middleware.py +131 -0
- atlas/core/otel_config.py +242 -0
- atlas/core/prompt_risk.py +200 -0
- atlas/core/rate_limit.py +0 -0
- atlas/core/rate_limit_middleware.py +64 -0
- atlas/core/security_headers_middleware.py +51 -0
- atlas/domain/__init__.py +37 -0
- atlas/domain/chat/__init__.py +1 -0
- atlas/domain/chat/dtos.py +85 -0
- atlas/domain/errors.py +96 -0
- atlas/domain/messages/__init__.py +12 -0
- atlas/domain/messages/models.py +160 -0
- atlas/domain/rag_mcp_service.py +664 -0
- atlas/domain/sessions/__init__.py +7 -0
- atlas/domain/sessions/models.py +36 -0
- atlas/domain/unified_rag_service.py +371 -0
- atlas/infrastructure/__init__.py +10 -0
- atlas/infrastructure/app_factory.py +135 -0
- atlas/infrastructure/events/__init__.py +1 -0
- atlas/infrastructure/events/cli_event_publisher.py +140 -0
- atlas/infrastructure/events/websocket_publisher.py +140 -0
- atlas/infrastructure/sessions/in_memory_repository.py +56 -0
- atlas/infrastructure/transport/__init__.py +7 -0
- atlas/infrastructure/transport/websocket_connection_adapter.py +33 -0
- atlas/init_cli.py +226 -0
- atlas/interfaces/__init__.py +15 -0
- atlas/interfaces/events.py +134 -0
- atlas/interfaces/llm.py +54 -0
- atlas/interfaces/rag.py +40 -0
- atlas/interfaces/sessions.py +75 -0
- atlas/interfaces/tools.py +57 -0
- atlas/interfaces/transport.py +24 -0
- atlas/main.py +564 -0
- atlas/mcp/api_key_demo/README.md +76 -0
- atlas/mcp/api_key_demo/main.py +172 -0
- atlas/mcp/api_key_demo/run.sh +56 -0
- atlas/mcp/basictable/main.py +147 -0
- atlas/mcp/calculator/main.py +149 -0
- atlas/mcp/code-executor/execution_engine.py +98 -0
- atlas/mcp/code-executor/execution_environment.py +95 -0
- atlas/mcp/code-executor/main.py +528 -0
- atlas/mcp/code-executor/result_processing.py +276 -0
- atlas/mcp/code-executor/script_generation.py +195 -0
- atlas/mcp/code-executor/security_checker.py +140 -0
- atlas/mcp/corporate_cars/main.py +437 -0
- atlas/mcp/csv_reporter/main.py +545 -0
- atlas/mcp/duckduckgo/main.py +182 -0
- atlas/mcp/elicitation_demo/README.md +171 -0
- atlas/mcp/elicitation_demo/main.py +262 -0
- atlas/mcp/env-demo/README.md +158 -0
- atlas/mcp/env-demo/main.py +199 -0
- atlas/mcp/file_size_test/main.py +284 -0
- atlas/mcp/filesystem/main.py +348 -0
- atlas/mcp/image_demo/main.py +113 -0
- atlas/mcp/image_demo/requirements.txt +4 -0
- atlas/mcp/logging_demo/README.md +72 -0
- atlas/mcp/logging_demo/main.py +103 -0
- atlas/mcp/many_tools_demo/main.py +50 -0
- atlas/mcp/order_database/__init__.py +0 -0
- atlas/mcp/order_database/main.py +369 -0
- atlas/mcp/order_database/signal_data.csv +1001 -0
- atlas/mcp/pdfbasic/main.py +394 -0
- atlas/mcp/pptx_generator/main.py +760 -0
- atlas/mcp/pptx_generator/requirements.txt +13 -0
- atlas/mcp/pptx_generator/run_test.sh +1 -0
- atlas/mcp/pptx_generator/test_pptx_generator_security.py +169 -0
- atlas/mcp/progress_demo/main.py +167 -0
- atlas/mcp/progress_updates_demo/QUICKSTART.md +273 -0
- atlas/mcp/progress_updates_demo/README.md +120 -0
- atlas/mcp/progress_updates_demo/main.py +497 -0
- atlas/mcp/prompts/main.py +222 -0
- atlas/mcp/public_demo/main.py +189 -0
- atlas/mcp/sampling_demo/README.md +169 -0
- atlas/mcp/sampling_demo/main.py +234 -0
- atlas/mcp/thinking/main.py +77 -0
- atlas/mcp/tool_planner/main.py +240 -0
- atlas/mcp/ui-demo/badmesh.png +0 -0
- atlas/mcp/ui-demo/main.py +383 -0
- atlas/mcp/ui-demo/templates/button_demo.html +32 -0
- atlas/mcp/ui-demo/templates/data_visualization.html +32 -0
- atlas/mcp/ui-demo/templates/form_demo.html +28 -0
- atlas/mcp/username-override-demo/README.md +320 -0
- atlas/mcp/username-override-demo/main.py +308 -0
- atlas/modules/__init__.py +0 -0
- atlas/modules/config/__init__.py +34 -0
- atlas/modules/config/cli.py +231 -0
- atlas/modules/config/config_manager.py +1096 -0
- atlas/modules/file_storage/__init__.py +22 -0
- atlas/modules/file_storage/cli.py +330 -0
- atlas/modules/file_storage/content_extractor.py +290 -0
- atlas/modules/file_storage/manager.py +295 -0
- atlas/modules/file_storage/mock_s3_client.py +402 -0
- atlas/modules/file_storage/s3_client.py +417 -0
- atlas/modules/llm/__init__.py +19 -0
- atlas/modules/llm/caller.py +287 -0
- atlas/modules/llm/litellm_caller.py +675 -0
- atlas/modules/llm/models.py +19 -0
- atlas/modules/mcp_tools/__init__.py +17 -0
- atlas/modules/mcp_tools/client.py +2123 -0
- atlas/modules/mcp_tools/token_storage.py +556 -0
- atlas/modules/prompts/prompt_provider.py +130 -0
- atlas/modules/rag/__init__.py +24 -0
- atlas/modules/rag/atlas_rag_client.py +336 -0
- atlas/modules/rag/client.py +129 -0
- atlas/routes/admin_routes.py +865 -0
- atlas/routes/config_routes.py +484 -0
- atlas/routes/feedback_routes.py +361 -0
- atlas/routes/files_routes.py +274 -0
- atlas/routes/health_routes.py +40 -0
- atlas/routes/mcp_auth_routes.py +223 -0
- atlas/server_cli.py +164 -0
- atlas/tests/conftest.py +20 -0
- atlas/tests/integration/test_mcp_auth_integration.py +152 -0
- atlas/tests/manual_test_sampling.py +87 -0
- atlas/tests/modules/mcp_tools/test_client_auth.py +226 -0
- atlas/tests/modules/mcp_tools/test_client_env.py +191 -0
- atlas/tests/test_admin_mcp_server_management_routes.py +141 -0
- atlas/tests/test_agent_roa.py +135 -0
- atlas/tests/test_app_factory_smoke.py +47 -0
- atlas/tests/test_approval_manager.py +439 -0
- atlas/tests/test_atlas_client.py +188 -0
- atlas/tests/test_atlas_rag_client.py +447 -0
- atlas/tests/test_atlas_rag_integration.py +224 -0
- atlas/tests/test_attach_file_flow.py +287 -0
- atlas/tests/test_auth_utils.py +165 -0
- atlas/tests/test_backend_public_url.py +185 -0
- atlas/tests/test_banner_logging.py +287 -0
- atlas/tests/test_capability_tokens_and_injection.py +203 -0
- atlas/tests/test_compliance_level.py +54 -0
- atlas/tests/test_compliance_manager.py +253 -0
- atlas/tests/test_config_manager.py +617 -0
- atlas/tests/test_config_manager_paths.py +12 -0
- atlas/tests/test_core_auth.py +18 -0
- atlas/tests/test_core_utils.py +190 -0
- atlas/tests/test_docker_env_sync.py +202 -0
- atlas/tests/test_domain_errors.py +329 -0
- atlas/tests/test_domain_whitelist.py +359 -0
- atlas/tests/test_elicitation_manager.py +408 -0
- atlas/tests/test_elicitation_routing.py +296 -0
- atlas/tests/test_env_demo_server.py +88 -0
- atlas/tests/test_error_classification.py +113 -0
- atlas/tests/test_error_flow_integration.py +116 -0
- atlas/tests/test_feedback_routes.py +333 -0
- atlas/tests/test_file_content_extraction.py +1134 -0
- atlas/tests/test_file_extraction_routes.py +158 -0
- atlas/tests/test_file_library.py +107 -0
- atlas/tests/test_file_manager_unit.py +18 -0
- atlas/tests/test_health_route.py +49 -0
- atlas/tests/test_http_client_stub.py +8 -0
- atlas/tests/test_imports_smoke.py +30 -0
- atlas/tests/test_interfaces_llm_response.py +9 -0
- atlas/tests/test_issue_access_denied_fix.py +136 -0
- atlas/tests/test_llm_env_expansion.py +836 -0
- atlas/tests/test_log_level_sensitive_data.py +285 -0
- atlas/tests/test_mcp_auth_routes.py +341 -0
- atlas/tests/test_mcp_client_auth.py +331 -0
- atlas/tests/test_mcp_data_injection.py +270 -0
- atlas/tests/test_mcp_get_authorized_servers.py +95 -0
- atlas/tests/test_mcp_hot_reload.py +512 -0
- atlas/tests/test_mcp_image_content.py +424 -0
- atlas/tests/test_mcp_logging.py +172 -0
- atlas/tests/test_mcp_progress_updates.py +313 -0
- atlas/tests/test_mcp_prompt_override_system_prompt.py +102 -0
- atlas/tests/test_mcp_prompts_server.py +39 -0
- atlas/tests/test_mcp_tool_result_parsing.py +296 -0
- atlas/tests/test_metrics_logger.py +56 -0
- atlas/tests/test_middleware_auth.py +379 -0
- atlas/tests/test_prompt_risk_and_acl.py +141 -0
- atlas/tests/test_rag_mcp_aggregator.py +204 -0
- atlas/tests/test_rag_mcp_service.py +224 -0
- atlas/tests/test_rate_limit_middleware.py +45 -0
- atlas/tests/test_routes_config_smoke.py +60 -0
- atlas/tests/test_routes_files_download_token.py +41 -0
- atlas/tests/test_routes_files_health.py +18 -0
- atlas/tests/test_runtime_imports.py +53 -0
- atlas/tests/test_sampling_integration.py +482 -0
- atlas/tests/test_security_admin_routes.py +61 -0
- atlas/tests/test_security_capability_tokens.py +65 -0
- atlas/tests/test_security_file_stats_scope.py +21 -0
- atlas/tests/test_security_header_injection.py +191 -0
- atlas/tests/test_security_headers_and_filename.py +63 -0
- atlas/tests/test_shared_session_repository.py +101 -0
- atlas/tests/test_system_prompt_loading.py +181 -0
- atlas/tests/test_token_storage.py +505 -0
- atlas/tests/test_tool_approval_config.py +93 -0
- atlas/tests/test_tool_approval_utils.py +356 -0
- atlas/tests/test_tool_authorization_group_filtering.py +223 -0
- atlas/tests/test_tool_details_in_config.py +108 -0
- atlas/tests/test_tool_planner.py +300 -0
- atlas/tests/test_unified_rag_service.py +398 -0
- atlas/tests/test_username_override_in_approval.py +258 -0
- atlas/tests/test_websocket_auth_header.py +168 -0
- atlas/version.py +6 -0
- atlas_chat-0.1.0.data/data/.env.example +253 -0
- atlas_chat-0.1.0.data/data/config/defaults/compliance-levels.json +44 -0
- atlas_chat-0.1.0.data/data/config/defaults/domain-whitelist.json +123 -0
- atlas_chat-0.1.0.data/data/config/defaults/file-extractors.json +74 -0
- atlas_chat-0.1.0.data/data/config/defaults/help-config.json +198 -0
- atlas_chat-0.1.0.data/data/config/defaults/llmconfig-buggy.yml +11 -0
- atlas_chat-0.1.0.data/data/config/defaults/llmconfig.yml +19 -0
- atlas_chat-0.1.0.data/data/config/defaults/mcp.json +138 -0
- atlas_chat-0.1.0.data/data/config/defaults/rag-sources.json +17 -0
- atlas_chat-0.1.0.data/data/config/defaults/splash-config.json +16 -0
- atlas_chat-0.1.0.dist-info/METADATA +236 -0
- atlas_chat-0.1.0.dist-info/RECORD +250 -0
- atlas_chat-0.1.0.dist-info/WHEEL +5 -0
- atlas_chat-0.1.0.dist-info/entry_points.txt +4 -0
- atlas_chat-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,331 @@
|
|
|
1
|
+
"""Unit tests for MCP client authentication methods.
|
|
2
|
+
|
|
3
|
+
Tests the per-user authentication functionality in MCPToolManager:
|
|
4
|
+
- _requires_user_auth: Check if server requires user authentication
|
|
5
|
+
- _get_user_client: Get or create user-specific client with token
|
|
6
|
+
- Cache validation and invalidation
|
|
7
|
+
|
|
8
|
+
Updated: 2025-01-23
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from unittest.mock import MagicMock, patch
|
|
12
|
+
|
|
13
|
+
import pytest
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class TestRequiresUserAuth:
|
|
17
|
+
"""Test _requires_user_auth method."""
|
|
18
|
+
|
|
19
|
+
def test_requires_user_auth_for_jwt(self):
|
|
20
|
+
"""JWT auth_type should require user auth."""
|
|
21
|
+
from atlas.modules.mcp_tools.client import MCPToolManager
|
|
22
|
+
|
|
23
|
+
manager = MCPToolManager.__new__(MCPToolManager)
|
|
24
|
+
manager.servers_config = {"test-server": {"auth_type": "jwt"}}
|
|
25
|
+
|
|
26
|
+
assert manager._requires_user_auth("test-server") is True
|
|
27
|
+
|
|
28
|
+
def test_requires_user_auth_for_oauth(self):
|
|
29
|
+
"""OAuth auth_type should require user auth."""
|
|
30
|
+
from atlas.modules.mcp_tools.client import MCPToolManager
|
|
31
|
+
|
|
32
|
+
manager = MCPToolManager.__new__(MCPToolManager)
|
|
33
|
+
manager.servers_config = {"test-server": {"auth_type": "oauth"}}
|
|
34
|
+
|
|
35
|
+
assert manager._requires_user_auth("test-server") is True
|
|
36
|
+
|
|
37
|
+
def test_requires_user_auth_for_bearer(self):
|
|
38
|
+
"""Bearer auth_type should require user auth."""
|
|
39
|
+
from atlas.modules.mcp_tools.client import MCPToolManager
|
|
40
|
+
|
|
41
|
+
manager = MCPToolManager.__new__(MCPToolManager)
|
|
42
|
+
manager.servers_config = {"test-server": {"auth_type": "bearer"}}
|
|
43
|
+
|
|
44
|
+
assert manager._requires_user_auth("test-server") is True
|
|
45
|
+
|
|
46
|
+
def test_requires_user_auth_for_api_key(self):
|
|
47
|
+
"""API key auth_type should require user auth."""
|
|
48
|
+
from atlas.modules.mcp_tools.client import MCPToolManager
|
|
49
|
+
|
|
50
|
+
manager = MCPToolManager.__new__(MCPToolManager)
|
|
51
|
+
manager.servers_config = {"test-server": {"auth_type": "api_key"}}
|
|
52
|
+
|
|
53
|
+
assert manager._requires_user_auth("test-server") is True
|
|
54
|
+
|
|
55
|
+
def test_no_user_auth_for_none(self):
|
|
56
|
+
"""None auth_type should not require user auth."""
|
|
57
|
+
from atlas.modules.mcp_tools.client import MCPToolManager
|
|
58
|
+
|
|
59
|
+
manager = MCPToolManager.__new__(MCPToolManager)
|
|
60
|
+
manager.servers_config = {"test-server": {"auth_type": "none"}}
|
|
61
|
+
|
|
62
|
+
assert manager._requires_user_auth("test-server") is False
|
|
63
|
+
|
|
64
|
+
def test_no_user_auth_when_missing(self):
|
|
65
|
+
"""Missing auth_type should not require user auth (defaults to none)."""
|
|
66
|
+
from atlas.modules.mcp_tools.client import MCPToolManager
|
|
67
|
+
|
|
68
|
+
manager = MCPToolManager.__new__(MCPToolManager)
|
|
69
|
+
manager.servers_config = {"test-server": {}}
|
|
70
|
+
|
|
71
|
+
assert manager._requires_user_auth("test-server") is False
|
|
72
|
+
|
|
73
|
+
def test_no_user_auth_for_unknown_server(self):
|
|
74
|
+
"""Unknown server should not require user auth."""
|
|
75
|
+
from atlas.modules.mcp_tools.client import MCPToolManager
|
|
76
|
+
|
|
77
|
+
manager = MCPToolManager.__new__(MCPToolManager)
|
|
78
|
+
manager.servers_config = {}
|
|
79
|
+
|
|
80
|
+
assert manager._requires_user_auth("unknown-server") is False
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
class TestGetUserClient:
|
|
84
|
+
"""Test _get_user_client method."""
|
|
85
|
+
|
|
86
|
+
@pytest.fixture
|
|
87
|
+
def manager(self):
|
|
88
|
+
"""Create a mock MCPToolManager for testing."""
|
|
89
|
+
import asyncio
|
|
90
|
+
|
|
91
|
+
from atlas.modules.mcp_tools.client import MCPToolManager
|
|
92
|
+
|
|
93
|
+
manager = MCPToolManager.__new__(MCPToolManager)
|
|
94
|
+
manager.servers_config = {
|
|
95
|
+
"test-server": {
|
|
96
|
+
"auth_type": "api_key",
|
|
97
|
+
"url": "http://localhost:8080"
|
|
98
|
+
}
|
|
99
|
+
}
|
|
100
|
+
manager._user_clients = {}
|
|
101
|
+
manager._user_clients_lock = asyncio.Lock()
|
|
102
|
+
manager._create_log_handler = MagicMock(return_value=None)
|
|
103
|
+
manager._create_elicitation_handler = MagicMock(return_value=None)
|
|
104
|
+
manager._create_sampling_handler = MagicMock(return_value=None)
|
|
105
|
+
return manager
|
|
106
|
+
|
|
107
|
+
@pytest.mark.asyncio
|
|
108
|
+
async def test_returns_none_without_token(self, manager):
|
|
109
|
+
"""Should return None when user has no token stored."""
|
|
110
|
+
with patch("atlas.modules.mcp_tools.token_storage.get_token_storage") as mock_storage:
|
|
111
|
+
mock_token_storage = MagicMock()
|
|
112
|
+
mock_token_storage.get_valid_token.return_value = None
|
|
113
|
+
mock_storage.return_value = mock_token_storage
|
|
114
|
+
|
|
115
|
+
result = await manager._get_user_client("test-server", "user@example.com")
|
|
116
|
+
|
|
117
|
+
assert result is None
|
|
118
|
+
|
|
119
|
+
@pytest.mark.asyncio
|
|
120
|
+
async def test_creates_client_with_api_key_header(self, manager):
|
|
121
|
+
"""Should create client with custom header for API key auth type."""
|
|
122
|
+
with patch("atlas.modules.mcp_tools.token_storage.get_token_storage") as mock_storage, \
|
|
123
|
+
patch("atlas.modules.mcp_tools.client.Client") as mock_client_class, \
|
|
124
|
+
patch("atlas.modules.mcp_tools.client.StreamableHttpTransport") as mock_transport_class:
|
|
125
|
+
|
|
126
|
+
# Mock token storage
|
|
127
|
+
mock_token = MagicMock()
|
|
128
|
+
mock_token.token_value = "test-api-key-123"
|
|
129
|
+
mock_token_storage = MagicMock()
|
|
130
|
+
mock_token_storage.get_valid_token.return_value = mock_token
|
|
131
|
+
mock_storage.return_value = mock_token_storage
|
|
132
|
+
|
|
133
|
+
# Mock transport and Client constructor
|
|
134
|
+
mock_transport = MagicMock()
|
|
135
|
+
mock_transport_class.return_value = mock_transport
|
|
136
|
+
mock_client = MagicMock()
|
|
137
|
+
mock_client_class.return_value = mock_client
|
|
138
|
+
|
|
139
|
+
result = await manager._get_user_client("test-server", "user@example.com")
|
|
140
|
+
|
|
141
|
+
assert result is mock_client
|
|
142
|
+
# Verify StreamableHttpTransport was created with custom header
|
|
143
|
+
mock_transport_class.assert_called_once()
|
|
144
|
+
transport_call_kwargs = mock_transport_class.call_args
|
|
145
|
+
assert transport_call_kwargs[1]["headers"] == {"X-API-Key": "test-api-key-123"}
|
|
146
|
+
# Verify Client was created with transport
|
|
147
|
+
mock_client_class.assert_called_once()
|
|
148
|
+
client_call_kwargs = mock_client_class.call_args
|
|
149
|
+
assert client_call_kwargs[1]["transport"] is mock_transport
|
|
150
|
+
|
|
151
|
+
@pytest.mark.asyncio
|
|
152
|
+
async def test_creates_client_with_bearer_token(self):
|
|
153
|
+
"""Should create client with auth parameter for bearer auth type."""
|
|
154
|
+
import asyncio
|
|
155
|
+
|
|
156
|
+
from atlas.modules.mcp_tools.client import MCPToolManager
|
|
157
|
+
|
|
158
|
+
manager = MCPToolManager.__new__(MCPToolManager)
|
|
159
|
+
manager.servers_config = {
|
|
160
|
+
"bearer-server": {
|
|
161
|
+
"auth_type": "bearer",
|
|
162
|
+
"url": "http://localhost:8080"
|
|
163
|
+
}
|
|
164
|
+
}
|
|
165
|
+
manager._user_clients = {}
|
|
166
|
+
manager._user_clients_lock = asyncio.Lock()
|
|
167
|
+
manager._create_log_handler = MagicMock(return_value=None)
|
|
168
|
+
manager._create_elicitation_handler = MagicMock(return_value=None)
|
|
169
|
+
manager._create_sampling_handler = MagicMock(return_value=None)
|
|
170
|
+
|
|
171
|
+
with patch("atlas.modules.mcp_tools.token_storage.get_token_storage") as mock_storage, \
|
|
172
|
+
patch("atlas.modules.mcp_tools.client.Client") as mock_client_class:
|
|
173
|
+
|
|
174
|
+
# Mock token storage
|
|
175
|
+
mock_token = MagicMock()
|
|
176
|
+
mock_token.token_value = "bearer-token-123"
|
|
177
|
+
mock_token_storage = MagicMock()
|
|
178
|
+
mock_token_storage.get_valid_token.return_value = mock_token
|
|
179
|
+
mock_storage.return_value = mock_token_storage
|
|
180
|
+
|
|
181
|
+
# Mock Client constructor
|
|
182
|
+
mock_client = MagicMock()
|
|
183
|
+
mock_client_class.return_value = mock_client
|
|
184
|
+
|
|
185
|
+
result = await manager._get_user_client("bearer-server", "user@example.com")
|
|
186
|
+
|
|
187
|
+
assert result is mock_client
|
|
188
|
+
mock_client_class.assert_called_once()
|
|
189
|
+
# Verify token was passed as auth parameter (not via transport)
|
|
190
|
+
call_args = mock_client_class.call_args
|
|
191
|
+
assert call_args[0][0] == "http://localhost:8080" # URL as first positional arg
|
|
192
|
+
assert call_args[1]["auth"] == "bearer-token-123"
|
|
193
|
+
|
|
194
|
+
@pytest.mark.asyncio
|
|
195
|
+
async def test_uses_custom_auth_header_name(self):
|
|
196
|
+
"""Should use custom auth_header from config for API key auth."""
|
|
197
|
+
import asyncio
|
|
198
|
+
|
|
199
|
+
from atlas.modules.mcp_tools.client import MCPToolManager
|
|
200
|
+
|
|
201
|
+
manager = MCPToolManager.__new__(MCPToolManager)
|
|
202
|
+
manager.servers_config = {
|
|
203
|
+
"custom-header-server": {
|
|
204
|
+
"auth_type": "api_key",
|
|
205
|
+
"auth_header": "X-Custom-Auth",
|
|
206
|
+
"url": "http://localhost:8080"
|
|
207
|
+
}
|
|
208
|
+
}
|
|
209
|
+
manager._user_clients = {}
|
|
210
|
+
manager._user_clients_lock = asyncio.Lock()
|
|
211
|
+
manager._create_log_handler = MagicMock(return_value=None)
|
|
212
|
+
manager._create_elicitation_handler = MagicMock(return_value=None)
|
|
213
|
+
manager._create_sampling_handler = MagicMock(return_value=None)
|
|
214
|
+
|
|
215
|
+
with patch("atlas.modules.mcp_tools.token_storage.get_token_storage") as mock_storage, \
|
|
216
|
+
patch("atlas.modules.mcp_tools.client.Client") as mock_client_class, \
|
|
217
|
+
patch("atlas.modules.mcp_tools.client.StreamableHttpTransport") as mock_transport_class:
|
|
218
|
+
|
|
219
|
+
mock_token = MagicMock()
|
|
220
|
+
mock_token.token_value = "custom-key-456"
|
|
221
|
+
mock_token_storage = MagicMock()
|
|
222
|
+
mock_token_storage.get_valid_token.return_value = mock_token
|
|
223
|
+
mock_storage.return_value = mock_token_storage
|
|
224
|
+
|
|
225
|
+
mock_transport = MagicMock()
|
|
226
|
+
mock_transport_class.return_value = mock_transport
|
|
227
|
+
mock_client = MagicMock()
|
|
228
|
+
mock_client_class.return_value = mock_client
|
|
229
|
+
|
|
230
|
+
result = await manager._get_user_client("custom-header-server", "user@example.com")
|
|
231
|
+
|
|
232
|
+
assert result is mock_client
|
|
233
|
+
# Verify custom header name was used
|
|
234
|
+
transport_call_kwargs = mock_transport_class.call_args
|
|
235
|
+
assert transport_call_kwargs[1]["headers"] == {"X-Custom-Auth": "custom-key-456"}
|
|
236
|
+
|
|
237
|
+
@pytest.mark.asyncio
|
|
238
|
+
async def test_caches_client(self, manager):
|
|
239
|
+
"""Should cache client for subsequent calls."""
|
|
240
|
+
with patch("atlas.modules.mcp_tools.token_storage.get_token_storage") as mock_storage, \
|
|
241
|
+
patch("atlas.modules.mcp_tools.client.Client") as mock_client_class, \
|
|
242
|
+
patch("atlas.modules.mcp_tools.client.StreamableHttpTransport") as mock_transport_class:
|
|
243
|
+
|
|
244
|
+
mock_token = MagicMock()
|
|
245
|
+
mock_token.token_value = "test-api-key"
|
|
246
|
+
mock_token_storage = MagicMock()
|
|
247
|
+
mock_token_storage.get_valid_token.return_value = mock_token
|
|
248
|
+
mock_storage.return_value = mock_token_storage
|
|
249
|
+
|
|
250
|
+
mock_transport = MagicMock()
|
|
251
|
+
mock_transport_class.return_value = mock_transport
|
|
252
|
+
mock_client = MagicMock()
|
|
253
|
+
mock_client_class.return_value = mock_client
|
|
254
|
+
|
|
255
|
+
# First call creates client
|
|
256
|
+
result1 = await manager._get_user_client("test-server", "user@example.com")
|
|
257
|
+
# Second call should use cache
|
|
258
|
+
result2 = await manager._get_user_client("test-server", "user@example.com")
|
|
259
|
+
|
|
260
|
+
assert result1 is result2
|
|
261
|
+
# Client constructor only called once
|
|
262
|
+
assert mock_client_class.call_count == 1
|
|
263
|
+
|
|
264
|
+
@pytest.mark.asyncio
|
|
265
|
+
async def test_invalidates_cache_on_expired_token(self, manager):
|
|
266
|
+
"""Should invalidate cached client when token expires."""
|
|
267
|
+
with patch("atlas.modules.mcp_tools.token_storage.get_token_storage") as mock_storage, \
|
|
268
|
+
patch("atlas.modules.mcp_tools.client.Client") as mock_client_class, \
|
|
269
|
+
patch("atlas.modules.mcp_tools.client.StreamableHttpTransport") as mock_transport_class:
|
|
270
|
+
|
|
271
|
+
mock_token = MagicMock()
|
|
272
|
+
mock_token.token_value = "test-api-key"
|
|
273
|
+
mock_token_storage = MagicMock()
|
|
274
|
+
mock_storage.return_value = mock_token_storage
|
|
275
|
+
|
|
276
|
+
mock_transport = MagicMock()
|
|
277
|
+
mock_transport_class.return_value = mock_transport
|
|
278
|
+
mock_client = MagicMock()
|
|
279
|
+
mock_client_class.return_value = mock_client
|
|
280
|
+
|
|
281
|
+
# First call - token valid
|
|
282
|
+
mock_token_storage.get_valid_token.return_value = mock_token
|
|
283
|
+
result1 = await manager._get_user_client("test-server", "user@example.com")
|
|
284
|
+
assert result1 is mock_client
|
|
285
|
+
|
|
286
|
+
# Second call - token expired (returns None)
|
|
287
|
+
mock_token_storage.get_valid_token.return_value = None
|
|
288
|
+
result2 = await manager._get_user_client("test-server", "user@example.com")
|
|
289
|
+
|
|
290
|
+
# Should return None and cache should be invalidated
|
|
291
|
+
assert result2 is None
|
|
292
|
+
cache_key = ("user@example.com", "test-server")
|
|
293
|
+
assert cache_key not in manager._user_clients
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
class TestInvalidateUserClient:
|
|
297
|
+
"""Test _invalidate_user_client method."""
|
|
298
|
+
|
|
299
|
+
@pytest.mark.asyncio
|
|
300
|
+
async def test_removes_cached_client(self):
|
|
301
|
+
"""Should remove client from cache."""
|
|
302
|
+
import asyncio
|
|
303
|
+
|
|
304
|
+
from atlas.modules.mcp_tools.client import MCPToolManager
|
|
305
|
+
|
|
306
|
+
manager = MCPToolManager.__new__(MCPToolManager)
|
|
307
|
+
manager._user_clients = {
|
|
308
|
+
("user@example.com", "test-server"): MagicMock(),
|
|
309
|
+
("other@example.com", "test-server"): MagicMock(),
|
|
310
|
+
}
|
|
311
|
+
manager._user_clients_lock = asyncio.Lock()
|
|
312
|
+
|
|
313
|
+
await manager._invalidate_user_client("user@example.com", "test-server")
|
|
314
|
+
|
|
315
|
+
assert ("user@example.com", "test-server") not in manager._user_clients
|
|
316
|
+
# Other user's client should remain
|
|
317
|
+
assert ("other@example.com", "test-server") in manager._user_clients
|
|
318
|
+
|
|
319
|
+
@pytest.mark.asyncio
|
|
320
|
+
async def test_handles_missing_cache_entry(self):
|
|
321
|
+
"""Should not error when cache entry doesn't exist."""
|
|
322
|
+
import asyncio
|
|
323
|
+
|
|
324
|
+
from atlas.modules.mcp_tools.client import MCPToolManager
|
|
325
|
+
|
|
326
|
+
manager = MCPToolManager.__new__(MCPToolManager)
|
|
327
|
+
manager._user_clients = {}
|
|
328
|
+
manager._user_clients_lock = asyncio.Lock()
|
|
329
|
+
|
|
330
|
+
# Should not raise
|
|
331
|
+
await manager._invalidate_user_client("user@example.com", "test-server")
|
|
@@ -0,0 +1,270 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Tests for _mcp_data injection into MCP tool arguments.
|
|
3
|
+
|
|
4
|
+
Validates that tools declaring an _mcp_data parameter in their schema
|
|
5
|
+
receive structured metadata about all available MCP tools, following
|
|
6
|
+
the same pattern as the username injection feature.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from unittest.mock import MagicMock
|
|
10
|
+
|
|
11
|
+
from atlas.application.chat.utilities.tool_executor import (
|
|
12
|
+
build_mcp_data,
|
|
13
|
+
inject_context_into_args,
|
|
14
|
+
tool_accepts_mcp_data,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class FakeTool:
|
|
19
|
+
"""Minimal tool object matching the MCP tool interface."""
|
|
20
|
+
|
|
21
|
+
def __init__(self, name, description="", inputSchema=None):
|
|
22
|
+
self.name = name
|
|
23
|
+
self.description = description
|
|
24
|
+
self.inputSchema = inputSchema or {"type": "object", "properties": {}}
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def _make_tool_manager(available_tools=None, schema_override=None):
|
|
28
|
+
"""Create a mock tool manager with the given available_tools dict."""
|
|
29
|
+
manager = MagicMock()
|
|
30
|
+
manager.available_tools = available_tools or {}
|
|
31
|
+
|
|
32
|
+
def get_tools_schema(tool_names):
|
|
33
|
+
if schema_override is not None:
|
|
34
|
+
return schema_override
|
|
35
|
+
schemas = []
|
|
36
|
+
for server_name, server_data in manager.available_tools.items():
|
|
37
|
+
for tool in server_data.get("tools", []):
|
|
38
|
+
fq_name = f"{server_name}_{tool.name}"
|
|
39
|
+
if fq_name in tool_names:
|
|
40
|
+
schemas.append({
|
|
41
|
+
"type": "function",
|
|
42
|
+
"function": {
|
|
43
|
+
"name": fq_name,
|
|
44
|
+
"description": tool.description or "",
|
|
45
|
+
"parameters": tool.inputSchema or {},
|
|
46
|
+
},
|
|
47
|
+
})
|
|
48
|
+
return schemas
|
|
49
|
+
|
|
50
|
+
manager.get_tools_schema = MagicMock(side_effect=get_tools_schema)
|
|
51
|
+
return manager
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
# -- tool_accepts_mcp_data tests --
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class TestToolAcceptsMcpData:
|
|
58
|
+
"""Tests for tool_accepts_mcp_data detection function."""
|
|
59
|
+
|
|
60
|
+
def test_returns_true_when_schema_has_mcp_data(self):
|
|
61
|
+
tool = FakeTool(
|
|
62
|
+
"planner",
|
|
63
|
+
inputSchema={
|
|
64
|
+
"type": "object",
|
|
65
|
+
"properties": {
|
|
66
|
+
"task": {"type": "string"},
|
|
67
|
+
"_mcp_data": {"type": "object"},
|
|
68
|
+
},
|
|
69
|
+
},
|
|
70
|
+
)
|
|
71
|
+
manager = _make_tool_manager({"demo": {"tools": [tool], "config": {}}})
|
|
72
|
+
assert tool_accepts_mcp_data("demo_planner", manager) is True
|
|
73
|
+
|
|
74
|
+
def test_returns_false_when_schema_lacks_mcp_data(self):
|
|
75
|
+
tool = FakeTool(
|
|
76
|
+
"search",
|
|
77
|
+
inputSchema={
|
|
78
|
+
"type": "object",
|
|
79
|
+
"properties": {"query": {"type": "string"}},
|
|
80
|
+
},
|
|
81
|
+
)
|
|
82
|
+
manager = _make_tool_manager({"demo": {"tools": [tool], "config": {}}})
|
|
83
|
+
assert tool_accepts_mcp_data("demo_search", manager) is False
|
|
84
|
+
|
|
85
|
+
def test_returns_false_with_no_tool_name(self):
|
|
86
|
+
assert tool_accepts_mcp_data("", MagicMock()) is False
|
|
87
|
+
|
|
88
|
+
def test_returns_false_with_no_tool_manager(self):
|
|
89
|
+
assert tool_accepts_mcp_data("some_tool", None) is False
|
|
90
|
+
|
|
91
|
+
def test_returns_false_when_schema_lookup_fails(self):
|
|
92
|
+
manager = MagicMock()
|
|
93
|
+
manager.get_tools_schema = MagicMock(side_effect=RuntimeError("fail"))
|
|
94
|
+
assert tool_accepts_mcp_data("any_tool", manager) is False
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
# -- build_mcp_data tests --
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
class TestBuildMcpData:
|
|
101
|
+
"""Tests for build_mcp_data output structure."""
|
|
102
|
+
|
|
103
|
+
def test_returns_empty_when_no_tools(self):
|
|
104
|
+
manager = _make_tool_manager({})
|
|
105
|
+
result = build_mcp_data(manager)
|
|
106
|
+
assert result == {"available_servers": []}
|
|
107
|
+
|
|
108
|
+
def test_returns_empty_when_no_manager(self):
|
|
109
|
+
result = build_mcp_data(None)
|
|
110
|
+
assert result == {"available_servers": []}
|
|
111
|
+
|
|
112
|
+
def test_skips_canvas_server(self):
|
|
113
|
+
tool = FakeTool("canvas")
|
|
114
|
+
manager = _make_tool_manager({
|
|
115
|
+
"canvas": {"tools": [tool], "config": {}},
|
|
116
|
+
})
|
|
117
|
+
result = build_mcp_data(manager)
|
|
118
|
+
assert result["available_servers"] == []
|
|
119
|
+
|
|
120
|
+
def test_includes_server_and_tool_metadata(self):
|
|
121
|
+
tool_a = FakeTool(
|
|
122
|
+
"search",
|
|
123
|
+
description="Search documents",
|
|
124
|
+
inputSchema={
|
|
125
|
+
"type": "object",
|
|
126
|
+
"properties": {"query": {"type": "string"}},
|
|
127
|
+
},
|
|
128
|
+
)
|
|
129
|
+
tool_b = FakeTool("list", description="List items")
|
|
130
|
+
manager = _make_tool_manager({
|
|
131
|
+
"myserver": {
|
|
132
|
+
"tools": [tool_a, tool_b],
|
|
133
|
+
"config": {"description": "My Server"},
|
|
134
|
+
},
|
|
135
|
+
})
|
|
136
|
+
|
|
137
|
+
result = build_mcp_data(manager)
|
|
138
|
+
assert len(result["available_servers"]) == 1
|
|
139
|
+
server = result["available_servers"][0]
|
|
140
|
+
assert server["server_name"] == "myserver"
|
|
141
|
+
assert server["description"] == "My Server"
|
|
142
|
+
assert len(server["tools"]) == 2
|
|
143
|
+
|
|
144
|
+
tool_entry = server["tools"][0]
|
|
145
|
+
assert tool_entry["name"] == "myserver_search"
|
|
146
|
+
assert tool_entry["description"] == "Search documents"
|
|
147
|
+
assert "properties" in tool_entry["parameters"]
|
|
148
|
+
|
|
149
|
+
def test_multiple_servers(self):
|
|
150
|
+
tool1 = FakeTool("t1")
|
|
151
|
+
tool2 = FakeTool("t2")
|
|
152
|
+
manager = _make_tool_manager({
|
|
153
|
+
"server_a": {"tools": [tool1], "config": {}},
|
|
154
|
+
"server_b": {"tools": [tool2], "config": {}},
|
|
155
|
+
})
|
|
156
|
+
result = build_mcp_data(manager)
|
|
157
|
+
names = [s["server_name"] for s in result["available_servers"]]
|
|
158
|
+
assert "server_a" in names
|
|
159
|
+
assert "server_b" in names
|
|
160
|
+
|
|
161
|
+
def test_handles_missing_description(self):
|
|
162
|
+
tool = FakeTool("t", description=None, inputSchema=None)
|
|
163
|
+
manager = _make_tool_manager({
|
|
164
|
+
"s": {"tools": [tool], "config": {}},
|
|
165
|
+
})
|
|
166
|
+
result = build_mcp_data(manager)
|
|
167
|
+
server = result["available_servers"][0]
|
|
168
|
+
assert server["description"] == ""
|
|
169
|
+
assert server["tools"][0]["description"] == ""
|
|
170
|
+
assert server["tools"][0]["parameters"] == {"type": "object", "properties": {}}
|
|
171
|
+
|
|
172
|
+
def test_uses_short_description_fallback(self):
|
|
173
|
+
tool = FakeTool("t")
|
|
174
|
+
manager = _make_tool_manager({
|
|
175
|
+
"s": {"tools": [tool], "config": {"short_description": "Short desc"}},
|
|
176
|
+
})
|
|
177
|
+
result = build_mcp_data(manager)
|
|
178
|
+
assert result["available_servers"][0]["description"] == "Short desc"
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
# -- inject_context_into_args with _mcp_data tests --
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
class TestInjectMcpData:
|
|
185
|
+
"""Tests for _mcp_data injection in inject_context_into_args."""
|
|
186
|
+
|
|
187
|
+
def test_injects_mcp_data_when_tool_accepts_it(self):
|
|
188
|
+
tool = FakeTool(
|
|
189
|
+
"planner",
|
|
190
|
+
inputSchema={
|
|
191
|
+
"type": "object",
|
|
192
|
+
"properties": {
|
|
193
|
+
"task": {"type": "string"},
|
|
194
|
+
"_mcp_data": {"type": "object"},
|
|
195
|
+
},
|
|
196
|
+
},
|
|
197
|
+
)
|
|
198
|
+
manager = _make_tool_manager({
|
|
199
|
+
"demo": {"tools": [tool], "config": {"description": "Demo"}},
|
|
200
|
+
})
|
|
201
|
+
|
|
202
|
+
result = inject_context_into_args(
|
|
203
|
+
{"task": "do something"},
|
|
204
|
+
{"user_email": "user@test.com"},
|
|
205
|
+
"demo_planner",
|
|
206
|
+
manager,
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
assert "_mcp_data" in result
|
|
210
|
+
assert "available_servers" in result["_mcp_data"]
|
|
211
|
+
assert len(result["_mcp_data"]["available_servers"]) == 1
|
|
212
|
+
|
|
213
|
+
def test_does_not_inject_mcp_data_when_tool_lacks_param(self):
|
|
214
|
+
tool = FakeTool(
|
|
215
|
+
"search",
|
|
216
|
+
inputSchema={
|
|
217
|
+
"type": "object",
|
|
218
|
+
"properties": {"query": {"type": "string"}},
|
|
219
|
+
},
|
|
220
|
+
)
|
|
221
|
+
manager = _make_tool_manager({
|
|
222
|
+
"demo": {"tools": [tool], "config": {}},
|
|
223
|
+
})
|
|
224
|
+
|
|
225
|
+
result = inject_context_into_args(
|
|
226
|
+
{"query": "hello"},
|
|
227
|
+
{"user_email": "user@test.com"},
|
|
228
|
+
"demo_search",
|
|
229
|
+
manager,
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
assert "_mcp_data" not in result
|
|
233
|
+
|
|
234
|
+
def test_mcp_data_reinjected_after_edit(self):
|
|
235
|
+
"""Simulates the re-injection path after user edits tool arguments."""
|
|
236
|
+
tool = FakeTool(
|
|
237
|
+
"planner",
|
|
238
|
+
inputSchema={
|
|
239
|
+
"type": "object",
|
|
240
|
+
"properties": {
|
|
241
|
+
"task": {"type": "string"},
|
|
242
|
+
"_mcp_data": {"type": "object"},
|
|
243
|
+
},
|
|
244
|
+
},
|
|
245
|
+
)
|
|
246
|
+
manager = _make_tool_manager({
|
|
247
|
+
"demo": {"tools": [tool], "config": {}},
|
|
248
|
+
})
|
|
249
|
+
|
|
250
|
+
# Simulate user editing args (removing _mcp_data)
|
|
251
|
+
edited_args = {"task": "edited task"}
|
|
252
|
+
result = inject_context_into_args(
|
|
253
|
+
edited_args,
|
|
254
|
+
{"user_email": "user@test.com"},
|
|
255
|
+
"demo_planner",
|
|
256
|
+
manager,
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
# _mcp_data should be re-injected
|
|
260
|
+
assert "_mcp_data" in result
|
|
261
|
+
assert result["task"] == "edited task"
|
|
262
|
+
|
|
263
|
+
def test_mcp_data_not_injected_without_tool_manager(self):
|
|
264
|
+
result = inject_context_into_args(
|
|
265
|
+
{"task": "test"},
|
|
266
|
+
{"user_email": "user@test.com"},
|
|
267
|
+
"some_tool",
|
|
268
|
+
None,
|
|
269
|
+
)
|
|
270
|
+
assert "_mcp_data" not in result
|
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
"""Test get_authorized_servers with async auth function."""
|
|
2
|
+
|
|
3
|
+
import pytest
|
|
4
|
+
|
|
5
|
+
from atlas.modules.mcp_tools.client import MCPToolManager
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@pytest.mark.asyncio
|
|
9
|
+
async def test_get_authorized_servers_with_async_auth():
|
|
10
|
+
"""Test that get_authorized_servers properly handles async auth_check_func."""
|
|
11
|
+
|
|
12
|
+
# Create a mock MCPToolManager with test server config
|
|
13
|
+
mcp_manager = MCPToolManager(None)
|
|
14
|
+
mcp_manager.servers_config = {
|
|
15
|
+
"server1": {
|
|
16
|
+
"enabled": True,
|
|
17
|
+
"groups": ["admin", "users"]
|
|
18
|
+
},
|
|
19
|
+
"server2": {
|
|
20
|
+
"enabled": True,
|
|
21
|
+
"groups": ["admin"]
|
|
22
|
+
},
|
|
23
|
+
"server3": {
|
|
24
|
+
"enabled": True,
|
|
25
|
+
"groups": [] # No groups required
|
|
26
|
+
},
|
|
27
|
+
"server4": {
|
|
28
|
+
"enabled": False,
|
|
29
|
+
"groups": ["admin"]
|
|
30
|
+
}
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
# Mock async auth function
|
|
34
|
+
async def mock_auth_check(user_email: str, group: str) -> bool:
|
|
35
|
+
"""Mock auth check that returns True for admin group."""
|
|
36
|
+
return group == "admin"
|
|
37
|
+
|
|
38
|
+
# Test with user who has admin access
|
|
39
|
+
authorized = await mcp_manager.get_authorized_servers("admin@test.com", mock_auth_check)
|
|
40
|
+
|
|
41
|
+
# Should include server1 (has admin), server2 (has admin), server3 (no groups required)
|
|
42
|
+
# Should NOT include server4 (disabled)
|
|
43
|
+
assert set(authorized) == {"server1", "server2", "server3"}
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
@pytest.mark.asyncio
|
|
47
|
+
async def test_get_authorized_servers_with_multiple_groups():
|
|
48
|
+
"""Test authorization with multiple group checks."""
|
|
49
|
+
|
|
50
|
+
mcp_manager = MCPToolManager(None)
|
|
51
|
+
mcp_manager.servers_config = {
|
|
52
|
+
"server1": {
|
|
53
|
+
"enabled": True,
|
|
54
|
+
"groups": ["users", "developers"]
|
|
55
|
+
},
|
|
56
|
+
"server2": {
|
|
57
|
+
"enabled": True,
|
|
58
|
+
"groups": ["admin"]
|
|
59
|
+
}
|
|
60
|
+
}
|
|
61
|
+
|
|
62
|
+
# User is in 'users' group but not 'admin'
|
|
63
|
+
async def mock_auth_check(user_email: str, group: str) -> bool:
|
|
64
|
+
return group in ["users", "developers"]
|
|
65
|
+
|
|
66
|
+
authorized = await mcp_manager.get_authorized_servers("user@test.com", mock_auth_check)
|
|
67
|
+
|
|
68
|
+
# Should include server1 (user is in 'users' group)
|
|
69
|
+
# Should NOT include server2 (user not in 'admin' group)
|
|
70
|
+
assert authorized == ["server1"]
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
@pytest.mark.asyncio
|
|
74
|
+
async def test_get_authorized_servers_no_access():
|
|
75
|
+
"""Test when user has no access to any servers."""
|
|
76
|
+
|
|
77
|
+
mcp_manager = MCPToolManager(None)
|
|
78
|
+
mcp_manager.servers_config = {
|
|
79
|
+
"server1": {
|
|
80
|
+
"enabled": True,
|
|
81
|
+
"groups": ["admin"]
|
|
82
|
+
},
|
|
83
|
+
"server2": {
|
|
84
|
+
"enabled": True,
|
|
85
|
+
"groups": ["superusers"]
|
|
86
|
+
}
|
|
87
|
+
}
|
|
88
|
+
|
|
89
|
+
# User has no group memberships
|
|
90
|
+
async def mock_auth_check(user_email: str, group: str) -> bool:
|
|
91
|
+
return False
|
|
92
|
+
|
|
93
|
+
authorized = await mcp_manager.get_authorized_servers("user@test.com", mock_auth_check)
|
|
94
|
+
|
|
95
|
+
assert authorized == []
|