awslabs.openapi-mcp-server 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- awslabs/__init__.py +16 -0
- awslabs/openapi_mcp_server/__init__.py +69 -0
- awslabs/openapi_mcp_server/api/__init__.py +18 -0
- awslabs/openapi_mcp_server/api/config.py +200 -0
- awslabs/openapi_mcp_server/auth/__init__.py +27 -0
- awslabs/openapi_mcp_server/auth/api_key_auth.py +185 -0
- awslabs/openapi_mcp_server/auth/auth_cache.py +190 -0
- awslabs/openapi_mcp_server/auth/auth_errors.py +206 -0
- awslabs/openapi_mcp_server/auth/auth_factory.py +146 -0
- awslabs/openapi_mcp_server/auth/auth_protocol.py +63 -0
- awslabs/openapi_mcp_server/auth/auth_provider.py +160 -0
- awslabs/openapi_mcp_server/auth/base_auth.py +218 -0
- awslabs/openapi_mcp_server/auth/basic_auth.py +171 -0
- awslabs/openapi_mcp_server/auth/bearer_auth.py +108 -0
- awslabs/openapi_mcp_server/auth/cognito_auth.py +538 -0
- awslabs/openapi_mcp_server/auth/register.py +100 -0
- awslabs/openapi_mcp_server/patch/__init__.py +17 -0
- awslabs/openapi_mcp_server/prompts/__init__.py +18 -0
- awslabs/openapi_mcp_server/prompts/generators/__init__.py +22 -0
- awslabs/openapi_mcp_server/prompts/generators/operation_prompts.py +642 -0
- awslabs/openapi_mcp_server/prompts/generators/workflow_prompts.py +257 -0
- awslabs/openapi_mcp_server/prompts/models.py +70 -0
- awslabs/openapi_mcp_server/prompts/prompt_manager.py +150 -0
- awslabs/openapi_mcp_server/server.py +511 -0
- awslabs/openapi_mcp_server/utils/__init__.py +18 -0
- awslabs/openapi_mcp_server/utils/cache_provider.py +249 -0
- awslabs/openapi_mcp_server/utils/config.py +35 -0
- awslabs/openapi_mcp_server/utils/error_handler.py +349 -0
- awslabs/openapi_mcp_server/utils/http_client.py +263 -0
- awslabs/openapi_mcp_server/utils/metrics_provider.py +503 -0
- awslabs/openapi_mcp_server/utils/openapi.py +217 -0
- awslabs/openapi_mcp_server/utils/openapi_validator.py +253 -0
- awslabs_openapi_mcp_server-0.1.1.dist-info/METADATA +418 -0
- awslabs_openapi_mcp_server-0.1.1.dist-info/RECORD +38 -0
- awslabs_openapi_mcp_server-0.1.1.dist-info/WHEEL +4 -0
- awslabs_openapi_mcp_server-0.1.1.dist-info/entry_points.txt +2 -0
- awslabs_openapi_mcp_server-0.1.1.dist-info/licenses/LICENSE +175 -0
- awslabs_openapi_mcp_server-0.1.1.dist-info/licenses/NOTICE +2 -0
|
@@ -0,0 +1,190 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
"""Authentication caching utilities.
|
|
15
|
+
|
|
16
|
+
This module provides caching mechanisms for authentication tokens and other data.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
import time
|
|
20
|
+
from typing import Any, Callable, Dict, Optional, TypeVar, cast
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
# Type variable for cached function return types
|
|
24
|
+
T = TypeVar('T')
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class TokenCache:
|
|
28
|
+
"""Cache for authentication tokens and related data.
|
|
29
|
+
|
|
30
|
+
This class provides a simple time-based cache for authentication tokens
|
|
31
|
+
and other authentication-related data.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
def __init__(self, max_size: int = 100, ttl: int = 300):
|
|
35
|
+
"""Initialize the token cache.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
max_size: Maximum number of items to store in the cache
|
|
39
|
+
ttl: Time-to-live in seconds for cached items
|
|
40
|
+
|
|
41
|
+
"""
|
|
42
|
+
self._cache: Dict[str, Dict[str, Any]] = {}
|
|
43
|
+
self._max_size = max_size
|
|
44
|
+
self._ttl = ttl
|
|
45
|
+
|
|
46
|
+
def get(self, key: str) -> Optional[Any]:
|
|
47
|
+
"""Get a value from the cache.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
key: Cache key
|
|
51
|
+
|
|
52
|
+
Returns:
|
|
53
|
+
Any: Cached value or None if not found or expired
|
|
54
|
+
|
|
55
|
+
"""
|
|
56
|
+
if key not in self._cache:
|
|
57
|
+
return None
|
|
58
|
+
|
|
59
|
+
item = self._cache[key]
|
|
60
|
+
if time.time() > item['expires_at']:
|
|
61
|
+
# Item has expired
|
|
62
|
+
del self._cache[key]
|
|
63
|
+
return None
|
|
64
|
+
|
|
65
|
+
return item['value']
|
|
66
|
+
|
|
67
|
+
def set(self, key: str, value: Any, ttl: Optional[int] = None) -> None:
|
|
68
|
+
"""Set a value in the cache.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
key: Cache key
|
|
72
|
+
value: Value to cache
|
|
73
|
+
ttl: Time-to-live in seconds (overrides default)
|
|
74
|
+
|
|
75
|
+
"""
|
|
76
|
+
# Ensure we don't exceed max size
|
|
77
|
+
if len(self._cache) >= self._max_size and key not in self._cache:
|
|
78
|
+
# Remove oldest item (simple LRU implementation)
|
|
79
|
+
oldest_key = min(self._cache.items(), key=lambda x: x[1]['expires_at'])[0]
|
|
80
|
+
del self._cache[oldest_key]
|
|
81
|
+
|
|
82
|
+
# Calculate expiration time
|
|
83
|
+
expires_at = time.time() + (ttl if ttl is not None else self._ttl)
|
|
84
|
+
|
|
85
|
+
# Store the item
|
|
86
|
+
self._cache[key] = {
|
|
87
|
+
'value': value,
|
|
88
|
+
'expires_at': expires_at,
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
def delete(self, key: str) -> bool:
|
|
92
|
+
"""Delete a value from the cache.
|
|
93
|
+
|
|
94
|
+
Args:
|
|
95
|
+
key: Cache key
|
|
96
|
+
|
|
97
|
+
Returns:
|
|
98
|
+
bool: True if the key was found and deleted, False otherwise
|
|
99
|
+
|
|
100
|
+
"""
|
|
101
|
+
if key in self._cache:
|
|
102
|
+
del self._cache[key]
|
|
103
|
+
return True
|
|
104
|
+
return False
|
|
105
|
+
|
|
106
|
+
def clear(self) -> None:
|
|
107
|
+
"""Clear the entire cache."""
|
|
108
|
+
self._cache.clear()
|
|
109
|
+
|
|
110
|
+
def cleanup(self) -> int:
|
|
111
|
+
"""Remove expired items from the cache.
|
|
112
|
+
|
|
113
|
+
Returns:
|
|
114
|
+
int: Number of items removed
|
|
115
|
+
|
|
116
|
+
"""
|
|
117
|
+
now = time.time()
|
|
118
|
+
expired_keys = [k for k, v in self._cache.items() if now > v['expires_at']]
|
|
119
|
+
for key in expired_keys:
|
|
120
|
+
del self._cache[key]
|
|
121
|
+
return len(expired_keys)
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
# Global token cache instance
|
|
125
|
+
_TOKEN_CACHE = TokenCache()
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def get_token_cache() -> TokenCache:
|
|
129
|
+
"""Get the global token cache instance.
|
|
130
|
+
|
|
131
|
+
Returns:
|
|
132
|
+
TokenCache: Global token cache instance
|
|
133
|
+
|
|
134
|
+
"""
|
|
135
|
+
return _TOKEN_CACHE
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def cached_auth_data(ttl: int = 300) -> Callable[[Callable[..., T]], Callable[..., T]]:
|
|
139
|
+
"""Cache authentication data.
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
ttl: Time-to-live in seconds for cached items
|
|
143
|
+
|
|
144
|
+
Returns:
|
|
145
|
+
Callable: Decorator function
|
|
146
|
+
|
|
147
|
+
"""
|
|
148
|
+
|
|
149
|
+
def decorator(func: Callable[..., T]) -> Callable[..., T]:
|
|
150
|
+
"""Decorate function with caching.
|
|
151
|
+
|
|
152
|
+
Args:
|
|
153
|
+
func: Function to decorate
|
|
154
|
+
|
|
155
|
+
Returns:
|
|
156
|
+
Callable: Wrapped function
|
|
157
|
+
|
|
158
|
+
"""
|
|
159
|
+
cache = get_token_cache()
|
|
160
|
+
|
|
161
|
+
def wrapper(*args: Any, **kwargs: Any) -> T:
|
|
162
|
+
"""Wrap function with caching logic.
|
|
163
|
+
|
|
164
|
+
Args:
|
|
165
|
+
*args: Positional arguments
|
|
166
|
+
**kwargs: Keyword arguments
|
|
167
|
+
|
|
168
|
+
Returns:
|
|
169
|
+
T: Function result
|
|
170
|
+
|
|
171
|
+
"""
|
|
172
|
+
# Create a cache key from the function name and arguments
|
|
173
|
+
key_parts = [func.__name__]
|
|
174
|
+
key_parts.extend(str(arg) for arg in args)
|
|
175
|
+
key_parts.extend(f'{k}={v}' for k, v in sorted(kwargs.items()))
|
|
176
|
+
cache_key = ':'.join(key_parts)
|
|
177
|
+
|
|
178
|
+
# Check if we have a cached result
|
|
179
|
+
cached_result = cache.get(cache_key)
|
|
180
|
+
if cached_result is not None:
|
|
181
|
+
return cast(T, cached_result)
|
|
182
|
+
|
|
183
|
+
# Call the function and cache the result
|
|
184
|
+
result = func(*args, **kwargs)
|
|
185
|
+
cache.set(cache_key, result, ttl)
|
|
186
|
+
return result
|
|
187
|
+
|
|
188
|
+
return wrapper
|
|
189
|
+
|
|
190
|
+
return decorator
|
|
@@ -0,0 +1,206 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
"""Authentication error handling.
|
|
15
|
+
|
|
16
|
+
This module provides centralized error handling for authentication providers.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
from enum import Enum
|
|
20
|
+
from typing import Dict, Optional, Type
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class AuthErrorType(Enum):
|
|
24
|
+
"""Authentication error types."""
|
|
25
|
+
|
|
26
|
+
MISSING_CREDENTIALS = 'missing_credentials'
|
|
27
|
+
INVALID_CREDENTIALS = 'invalid_credentials'
|
|
28
|
+
EXPIRED_TOKEN = 'expired_token' # nosec B105
|
|
29
|
+
INSUFFICIENT_PERMISSIONS = 'insufficient_permissions'
|
|
30
|
+
CONFIGURATION_ERROR = 'configuration_error'
|
|
31
|
+
NETWORK_ERROR = 'network_error'
|
|
32
|
+
UNKNOWN_ERROR = 'unknown_error'
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class AuthError(Exception):
|
|
36
|
+
"""Base class for authentication errors."""
|
|
37
|
+
|
|
38
|
+
def __init__(
|
|
39
|
+
self,
|
|
40
|
+
message: str,
|
|
41
|
+
error_type: AuthErrorType = AuthErrorType.UNKNOWN_ERROR,
|
|
42
|
+
details: Optional[Dict] = None,
|
|
43
|
+
):
|
|
44
|
+
"""Initialize the error.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
message: Error message
|
|
48
|
+
error_type: Type of authentication error
|
|
49
|
+
details: Additional error details
|
|
50
|
+
|
|
51
|
+
"""
|
|
52
|
+
self.message = message
|
|
53
|
+
self.error_type = error_type
|
|
54
|
+
self.details = details or {}
|
|
55
|
+
super().__init__(message)
|
|
56
|
+
|
|
57
|
+
def __str__(self) -> str:
|
|
58
|
+
"""Get string representation of the error.
|
|
59
|
+
|
|
60
|
+
Returns:
|
|
61
|
+
str: Error message with type
|
|
62
|
+
|
|
63
|
+
"""
|
|
64
|
+
return f'{self.error_type.value}: {self.message}'
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
class MissingCredentialsError(AuthError):
|
|
68
|
+
"""Error raised when required credentials are missing."""
|
|
69
|
+
|
|
70
|
+
def __init__(self, message: str, details: Optional[Dict] = None):
|
|
71
|
+
"""Initialize the error.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
message: Error message
|
|
75
|
+
details: Additional error details
|
|
76
|
+
|
|
77
|
+
"""
|
|
78
|
+
super().__init__(
|
|
79
|
+
message=message, error_type=AuthErrorType.MISSING_CREDENTIALS, details=details
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
class InvalidCredentialsError(AuthError):
|
|
84
|
+
"""Error raised when credentials are invalid."""
|
|
85
|
+
|
|
86
|
+
def __init__(self, message: str, details: Optional[Dict] = None):
|
|
87
|
+
"""Initialize the error.
|
|
88
|
+
|
|
89
|
+
Args:
|
|
90
|
+
message: Error message
|
|
91
|
+
details: Additional error details
|
|
92
|
+
|
|
93
|
+
"""
|
|
94
|
+
super().__init__(
|
|
95
|
+
message=message, error_type=AuthErrorType.INVALID_CREDENTIALS, details=details
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
class ExpiredTokenError(AuthError):
|
|
100
|
+
"""Error raised when a token has expired."""
|
|
101
|
+
|
|
102
|
+
def __init__(self, message: str, details: Optional[Dict] = None):
|
|
103
|
+
"""Initialize the error.
|
|
104
|
+
|
|
105
|
+
Args:
|
|
106
|
+
message: Error message
|
|
107
|
+
details: Additional error details
|
|
108
|
+
|
|
109
|
+
"""
|
|
110
|
+
super().__init__(message=message, error_type=AuthErrorType.EXPIRED_TOKEN, details=details)
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
class InsufficientPermissionsError(AuthError):
|
|
114
|
+
"""Error raised when permissions are insufficient."""
|
|
115
|
+
|
|
116
|
+
def __init__(self, message: str, details: Optional[Dict] = None):
|
|
117
|
+
"""Initialize the error.
|
|
118
|
+
|
|
119
|
+
Args:
|
|
120
|
+
message: Error message
|
|
121
|
+
details: Additional error details
|
|
122
|
+
|
|
123
|
+
"""
|
|
124
|
+
super().__init__(
|
|
125
|
+
message=message, error_type=AuthErrorType.INSUFFICIENT_PERMISSIONS, details=details
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
class ConfigurationError(AuthError):
|
|
130
|
+
"""Error raised when there is a configuration issue."""
|
|
131
|
+
|
|
132
|
+
def __init__(self, message: str, details: Optional[Dict] = None):
|
|
133
|
+
"""Initialize the error.
|
|
134
|
+
|
|
135
|
+
Args:
|
|
136
|
+
message: Error message
|
|
137
|
+
details: Additional error details
|
|
138
|
+
|
|
139
|
+
"""
|
|
140
|
+
super().__init__(
|
|
141
|
+
message=message, error_type=AuthErrorType.CONFIGURATION_ERROR, details=details
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
class NetworkError(AuthError):
|
|
146
|
+
"""Error raised when there is a network issue."""
|
|
147
|
+
|
|
148
|
+
def __init__(self, message: str, details: Optional[Dict] = None):
|
|
149
|
+
"""Initialize the error.
|
|
150
|
+
|
|
151
|
+
Args:
|
|
152
|
+
message: Error message
|
|
153
|
+
details: Additional error details
|
|
154
|
+
|
|
155
|
+
"""
|
|
156
|
+
super().__init__(message=message, error_type=AuthErrorType.NETWORK_ERROR, details=details)
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
# Map of error types to error classes
|
|
160
|
+
ERROR_CLASSES: Dict[AuthErrorType, Type[AuthError]] = {
|
|
161
|
+
AuthErrorType.MISSING_CREDENTIALS: MissingCredentialsError,
|
|
162
|
+
AuthErrorType.INVALID_CREDENTIALS: InvalidCredentialsError,
|
|
163
|
+
AuthErrorType.EXPIRED_TOKEN: ExpiredTokenError,
|
|
164
|
+
AuthErrorType.INSUFFICIENT_PERMISSIONS: InsufficientPermissionsError,
|
|
165
|
+
AuthErrorType.CONFIGURATION_ERROR: ConfigurationError,
|
|
166
|
+
AuthErrorType.NETWORK_ERROR: NetworkError,
|
|
167
|
+
AuthErrorType.UNKNOWN_ERROR: AuthError,
|
|
168
|
+
}
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
def create_auth_error(
|
|
172
|
+
error_type: AuthErrorType, message: str, details: Optional[Dict] = None
|
|
173
|
+
) -> AuthError:
|
|
174
|
+
"""Create an authentication error of the specified type.
|
|
175
|
+
|
|
176
|
+
Args:
|
|
177
|
+
error_type: Type of authentication error
|
|
178
|
+
message: Error message
|
|
179
|
+
details: Additional error details
|
|
180
|
+
|
|
181
|
+
Returns:
|
|
182
|
+
AuthError: An instance of the appropriate error class
|
|
183
|
+
|
|
184
|
+
"""
|
|
185
|
+
error_class = ERROR_CLASSES.get(error_type, AuthError)
|
|
186
|
+
if error_class == AuthError:
|
|
187
|
+
# For the base class, we need to pass the error_type explicitly
|
|
188
|
+
return AuthError(message=message, error_type=error_type, details=details)
|
|
189
|
+
else:
|
|
190
|
+
# For subclasses, the error_type is already set in the constructor
|
|
191
|
+
return error_class(message=message, details=details)
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
def format_error_message(provider_name: str, error_type: AuthErrorType, message: str) -> str:
|
|
195
|
+
"""Format an error message for consistent output.
|
|
196
|
+
|
|
197
|
+
Args:
|
|
198
|
+
provider_name: Name of the authentication provider
|
|
199
|
+
error_type: Type of authentication error
|
|
200
|
+
message: Error message
|
|
201
|
+
|
|
202
|
+
Returns:
|
|
203
|
+
str: Formatted error message
|
|
204
|
+
|
|
205
|
+
"""
|
|
206
|
+
return f'[{provider_name.upper()}] {error_type.value}: {message}'
|
|
@@ -0,0 +1,146 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
"""Authentication provider factory."""
|
|
15
|
+
|
|
16
|
+
from awslabs.openapi_mcp_server import logger
|
|
17
|
+
from awslabs.openapi_mcp_server.api.config import Config
|
|
18
|
+
from awslabs.openapi_mcp_server.auth.auth_protocol import AuthProviderProtocol
|
|
19
|
+
from awslabs.openapi_mcp_server.auth.auth_provider import NullAuthProvider
|
|
20
|
+
from typing import Any, Dict, Type
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
# Registry of authentication providers
|
|
24
|
+
_AUTH_PROVIDERS: Dict[str, Type[Any]] = {'none': NullAuthProvider}
|
|
25
|
+
|
|
26
|
+
# Cache for provider instances
|
|
27
|
+
_PROVIDER_CACHE: Dict[int, AuthProviderProtocol] = {}
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def register_auth_provider(auth_type: str, provider_class: Type[Any]) -> None:
|
|
31
|
+
"""Register an authentication provider.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
auth_type: The authentication type identifier
|
|
35
|
+
provider_class: The provider class to register
|
|
36
|
+
|
|
37
|
+
Raises:
|
|
38
|
+
ValueError: If auth_type is already registered
|
|
39
|
+
|
|
40
|
+
"""
|
|
41
|
+
auth_type = auth_type.lower()
|
|
42
|
+
if auth_type in _AUTH_PROVIDERS:
|
|
43
|
+
raise ValueError(f"Authentication provider for type '{auth_type}' already registered")
|
|
44
|
+
|
|
45
|
+
_AUTH_PROVIDERS[auth_type] = provider_class
|
|
46
|
+
logger.debug(
|
|
47
|
+
f"Registered authentication provider for type '{auth_type}': {provider_class.__name__}"
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def _get_provider_instance(
|
|
52
|
+
auth_type: str, config_hash: int, config: Config
|
|
53
|
+
) -> AuthProviderProtocol:
|
|
54
|
+
"""Get a cached provider instance or create a new one.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
auth_type: The authentication type
|
|
58
|
+
config_hash: Hash of the configuration to differentiate instances
|
|
59
|
+
config: The configuration object
|
|
60
|
+
|
|
61
|
+
Returns:
|
|
62
|
+
AuthProviderProtocol: The authentication provider instance
|
|
63
|
+
|
|
64
|
+
"""
|
|
65
|
+
# Check if we have a cached instance
|
|
66
|
+
if config_hash in _PROVIDER_CACHE:
|
|
67
|
+
logger.debug(f'Using cached authentication provider for {auth_type}')
|
|
68
|
+
return _PROVIDER_CACHE[config_hash]
|
|
69
|
+
|
|
70
|
+
# Create a new instance
|
|
71
|
+
provider_class = _AUTH_PROVIDERS[auth_type]
|
|
72
|
+
provider = provider_class(config)
|
|
73
|
+
|
|
74
|
+
# Cache the instance
|
|
75
|
+
_PROVIDER_CACHE[config_hash] = provider
|
|
76
|
+
|
|
77
|
+
logger.debug(f'Created new authentication provider: {provider.provider_name}')
|
|
78
|
+
return provider
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def get_auth_provider(config: Config) -> AuthProviderProtocol:
|
|
82
|
+
"""Get an authentication provider based on configuration.
|
|
83
|
+
|
|
84
|
+
Args:
|
|
85
|
+
config: The application configuration
|
|
86
|
+
|
|
87
|
+
Returns:
|
|
88
|
+
AuthProviderProtocol: An authentication provider instance
|
|
89
|
+
|
|
90
|
+
Notes:
|
|
91
|
+
If the specified auth_type is not registered, falls back to NullAuthProvider
|
|
92
|
+
Uses caching to avoid creating duplicate provider instances
|
|
93
|
+
|
|
94
|
+
"""
|
|
95
|
+
auth_type = config.auth_type.lower()
|
|
96
|
+
|
|
97
|
+
if auth_type not in _AUTH_PROVIDERS:
|
|
98
|
+
logger.warning(f"Unknown authentication type '{auth_type}'. Falling back to 'none'.")
|
|
99
|
+
auth_type = 'none'
|
|
100
|
+
|
|
101
|
+
# Create a hash of the relevant config parts for caching
|
|
102
|
+
config_hash = hash(
|
|
103
|
+
(
|
|
104
|
+
auth_type,
|
|
105
|
+
getattr(config, 'auth_token', None),
|
|
106
|
+
getattr(config, 'auth_username', None),
|
|
107
|
+
getattr(config, 'auth_password', None),
|
|
108
|
+
getattr(config, 'auth_api_key', None),
|
|
109
|
+
getattr(config, 'auth_api_key_name', None),
|
|
110
|
+
getattr(config, 'auth_api_key_in', None),
|
|
111
|
+
)
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
# Get or create provider instance
|
|
115
|
+
provider = _get_provider_instance(auth_type, config_hash, config)
|
|
116
|
+
|
|
117
|
+
logger.info(f'Created authentication provider: {provider.provider_name}')
|
|
118
|
+
|
|
119
|
+
if not provider.is_configured() and auth_type != 'none':
|
|
120
|
+
logger.warning(
|
|
121
|
+
f"Authentication provider '{provider.provider_name}' is not properly configured"
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
return provider
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def is_auth_type_available(auth_type: str) -> bool:
|
|
128
|
+
"""Check if an authentication type is available.
|
|
129
|
+
|
|
130
|
+
Args:
|
|
131
|
+
auth_type: The authentication type to check
|
|
132
|
+
|
|
133
|
+
Returns:
|
|
134
|
+
bool: True if available, False otherwise
|
|
135
|
+
|
|
136
|
+
"""
|
|
137
|
+
return auth_type.lower() in _AUTH_PROVIDERS
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def clear_provider_cache() -> None:
|
|
141
|
+
"""Clear the provider instance cache.
|
|
142
|
+
|
|
143
|
+
This is useful for testing or when configuration changes.
|
|
144
|
+
"""
|
|
145
|
+
_PROVIDER_CACHE.clear()
|
|
146
|
+
logger.debug('Authentication provider cache cleared')
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
"""Authentication provider protocols and type definitions."""
|
|
15
|
+
|
|
16
|
+
import httpx
|
|
17
|
+
from awslabs.openapi_mcp_server.api.config import Config
|
|
18
|
+
from typing import Dict, Optional, Protocol, TypeVar, runtime_checkable
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@runtime_checkable
|
|
22
|
+
class AuthProviderProtocol(Protocol):
|
|
23
|
+
"""Protocol defining the interface for authentication providers.
|
|
24
|
+
|
|
25
|
+
This protocol allows for better type checking and removes the need for casting.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
@property
|
|
29
|
+
def provider_name(self) -> str:
|
|
30
|
+
"""Get the name of the authentication provider."""
|
|
31
|
+
...
|
|
32
|
+
|
|
33
|
+
def is_configured(self) -> bool:
|
|
34
|
+
"""Check if the authentication provider is properly configured."""
|
|
35
|
+
...
|
|
36
|
+
|
|
37
|
+
def get_auth_headers(self) -> Dict[str, str]:
|
|
38
|
+
"""Get authentication headers for HTTP requests."""
|
|
39
|
+
...
|
|
40
|
+
|
|
41
|
+
def get_auth_params(self) -> Dict[str, str]:
|
|
42
|
+
"""Get authentication query parameters for HTTP requests."""
|
|
43
|
+
...
|
|
44
|
+
|
|
45
|
+
def get_auth_cookies(self) -> Dict[str, str]:
|
|
46
|
+
"""Get authentication cookies for HTTP requests."""
|
|
47
|
+
...
|
|
48
|
+
|
|
49
|
+
def get_httpx_auth(self) -> Optional[httpx.Auth]:
|
|
50
|
+
"""Get authentication object for HTTPX."""
|
|
51
|
+
...
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
# Type variable for auth provider classes that can be instantiated with a Config
|
|
55
|
+
T = TypeVar('T', bound=AuthProviderProtocol)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class AuthProviderFactory(Protocol):
|
|
59
|
+
"""Protocol for auth provider factory functions."""
|
|
60
|
+
|
|
61
|
+
def __call__(self, config: Config) -> AuthProviderProtocol:
|
|
62
|
+
"""Create an authentication provider instance."""
|
|
63
|
+
...
|