awslabs.openapi-mcp-server 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- awslabs/__init__.py +16 -0
- awslabs/openapi_mcp_server/__init__.py +69 -0
- awslabs/openapi_mcp_server/api/__init__.py +18 -0
- awslabs/openapi_mcp_server/api/config.py +200 -0
- awslabs/openapi_mcp_server/auth/__init__.py +27 -0
- awslabs/openapi_mcp_server/auth/api_key_auth.py +185 -0
- awslabs/openapi_mcp_server/auth/auth_cache.py +190 -0
- awslabs/openapi_mcp_server/auth/auth_errors.py +206 -0
- awslabs/openapi_mcp_server/auth/auth_factory.py +146 -0
- awslabs/openapi_mcp_server/auth/auth_protocol.py +63 -0
- awslabs/openapi_mcp_server/auth/auth_provider.py +160 -0
- awslabs/openapi_mcp_server/auth/base_auth.py +218 -0
- awslabs/openapi_mcp_server/auth/basic_auth.py +171 -0
- awslabs/openapi_mcp_server/auth/bearer_auth.py +108 -0
- awslabs/openapi_mcp_server/auth/cognito_auth.py +538 -0
- awslabs/openapi_mcp_server/auth/register.py +100 -0
- awslabs/openapi_mcp_server/patch/__init__.py +17 -0
- awslabs/openapi_mcp_server/prompts/__init__.py +18 -0
- awslabs/openapi_mcp_server/prompts/generators/__init__.py +22 -0
- awslabs/openapi_mcp_server/prompts/generators/operation_prompts.py +642 -0
- awslabs/openapi_mcp_server/prompts/generators/workflow_prompts.py +257 -0
- awslabs/openapi_mcp_server/prompts/models.py +70 -0
- awslabs/openapi_mcp_server/prompts/prompt_manager.py +150 -0
- awslabs/openapi_mcp_server/server.py +511 -0
- awslabs/openapi_mcp_server/utils/__init__.py +18 -0
- awslabs/openapi_mcp_server/utils/cache_provider.py +249 -0
- awslabs/openapi_mcp_server/utils/config.py +35 -0
- awslabs/openapi_mcp_server/utils/error_handler.py +349 -0
- awslabs/openapi_mcp_server/utils/http_client.py +263 -0
- awslabs/openapi_mcp_server/utils/metrics_provider.py +503 -0
- awslabs/openapi_mcp_server/utils/openapi.py +217 -0
- awslabs/openapi_mcp_server/utils/openapi_validator.py +253 -0
- awslabs_openapi_mcp_server-0.1.1.dist-info/METADATA +418 -0
- awslabs_openapi_mcp_server-0.1.1.dist-info/RECORD +38 -0
- awslabs_openapi_mcp_server-0.1.1.dist-info/WHEEL +4 -0
- awslabs_openapi_mcp_server-0.1.1.dist-info/entry_points.txt +2 -0
- awslabs_openapi_mcp_server-0.1.1.dist-info/licenses/LICENSE +175 -0
- awslabs_openapi_mcp_server-0.1.1.dist-info/licenses/NOTICE +2 -0
|
@@ -0,0 +1,249 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
"""Cache provider for the OpenAPI MCP Server.
|
|
15
|
+
|
|
16
|
+
This module provides a pluggable caching system with different backends.
|
|
17
|
+
The default is a simple in-memory implementation, but it can be switched
|
|
18
|
+
to use external caching systems via environment variables.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
import time
|
|
22
|
+
from abc import ABC, abstractmethod
|
|
23
|
+
from awslabs.openapi_mcp_server import logger
|
|
24
|
+
from awslabs.openapi_mcp_server.utils.config import CACHE_MAXSIZE, CACHE_TTL, USE_CACHETOOLS
|
|
25
|
+
from typing import Any, Callable, Dict, Generic, Optional, TypeVar
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
# Type variable for generic cache implementation
|
|
29
|
+
T = TypeVar('T')
|
|
30
|
+
|
|
31
|
+
# Try to import cachetools if enabled
|
|
32
|
+
CACHETOOLS_AVAILABLE = False
|
|
33
|
+
cachetools = None
|
|
34
|
+
if USE_CACHETOOLS:
|
|
35
|
+
try:
|
|
36
|
+
import cachetools
|
|
37
|
+
|
|
38
|
+
CACHETOOLS_AVAILABLE = True
|
|
39
|
+
logger.info('cachetools caching enabled')
|
|
40
|
+
except ImportError:
|
|
41
|
+
logger.warning(
|
|
42
|
+
'cachetools requested but not installed. Install with: pip install cachetools'
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class CacheProvider(Generic[T], ABC):
|
|
47
|
+
"""Abstract base class for cache providers."""
|
|
48
|
+
|
|
49
|
+
@abstractmethod
|
|
50
|
+
def get(self, key: str) -> Optional[T]:
|
|
51
|
+
"""Get a value from the cache."""
|
|
52
|
+
pass
|
|
53
|
+
|
|
54
|
+
@abstractmethod
|
|
55
|
+
def set(self, key: str, value: T) -> None:
|
|
56
|
+
"""Set a value in the cache."""
|
|
57
|
+
pass
|
|
58
|
+
|
|
59
|
+
@abstractmethod
|
|
60
|
+
def invalidate(self, key: str) -> bool:
|
|
61
|
+
"""Invalidate a cache entry."""
|
|
62
|
+
pass
|
|
63
|
+
|
|
64
|
+
@abstractmethod
|
|
65
|
+
def clear(self) -> None:
|
|
66
|
+
"""Clear all entries from the cache."""
|
|
67
|
+
pass
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class InMemoryCacheProvider(CacheProvider[T]):
|
|
71
|
+
"""Simple in-memory cache provider with TTL support."""
|
|
72
|
+
|
|
73
|
+
def __init__(self, ttl_seconds: Optional[int] = None):
|
|
74
|
+
"""Initialize the cache provider.
|
|
75
|
+
|
|
76
|
+
Args:
|
|
77
|
+
ttl_seconds: Time-to-live in seconds for cache entries (defaults to config value)
|
|
78
|
+
|
|
79
|
+
"""
|
|
80
|
+
self._cache: Dict[str, tuple[T, float]] = {}
|
|
81
|
+
self._ttl_seconds = ttl_seconds if ttl_seconds is not None else CACHE_TTL
|
|
82
|
+
logger.debug(f'Created in-memory cache provider with TTL of {self._ttl_seconds} seconds')
|
|
83
|
+
|
|
84
|
+
def get(self, key: str) -> Optional[T]:
|
|
85
|
+
"""Get a value from the cache."""
|
|
86
|
+
if key not in self._cache:
|
|
87
|
+
return None
|
|
88
|
+
|
|
89
|
+
value, expiry = self._cache[key]
|
|
90
|
+
if time.time() > expiry:
|
|
91
|
+
# Entry has expired
|
|
92
|
+
del self._cache[key]
|
|
93
|
+
logger.debug(f'Cache entry expired: {key}')
|
|
94
|
+
return None
|
|
95
|
+
|
|
96
|
+
logger.debug(f'Cache hit: {key}')
|
|
97
|
+
return value
|
|
98
|
+
|
|
99
|
+
def set(self, key: str, value: T) -> None:
|
|
100
|
+
"""Set a value in the cache."""
|
|
101
|
+
expiry = time.time() + self._ttl_seconds
|
|
102
|
+
self._cache[key] = (value, expiry)
|
|
103
|
+
logger.debug(f'Cache set: {key} (expires in {self._ttl_seconds} seconds)')
|
|
104
|
+
|
|
105
|
+
def invalidate(self, key: str) -> bool:
|
|
106
|
+
"""Invalidate a cache entry."""
|
|
107
|
+
if key in self._cache:
|
|
108
|
+
del self._cache[key]
|
|
109
|
+
logger.debug(f'Cache invalidated: {key}')
|
|
110
|
+
return True
|
|
111
|
+
return False
|
|
112
|
+
|
|
113
|
+
def clear(self) -> None:
|
|
114
|
+
"""Clear all entries from the cache."""
|
|
115
|
+
count = len(self._cache)
|
|
116
|
+
self._cache.clear()
|
|
117
|
+
logger.debug(f'Cache cleared ({count} entries removed)')
|
|
118
|
+
|
|
119
|
+
def cleanup(self) -> int:
|
|
120
|
+
"""Remove all expired entries from the cache.
|
|
121
|
+
|
|
122
|
+
Returns:
|
|
123
|
+
int: Number of entries removed
|
|
124
|
+
|
|
125
|
+
"""
|
|
126
|
+
now = time.time()
|
|
127
|
+
expired_keys = [key for key, (_, expiry) in self._cache.items() if now > expiry]
|
|
128
|
+
|
|
129
|
+
for key in expired_keys:
|
|
130
|
+
del self._cache[key]
|
|
131
|
+
|
|
132
|
+
if expired_keys:
|
|
133
|
+
logger.debug(f'Cache cleanup: removed {len(expired_keys)} expired entries')
|
|
134
|
+
|
|
135
|
+
return len(expired_keys)
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
class CachetoolsProvider(CacheProvider[T]):
|
|
139
|
+
"""Cache provider using the cachetools library."""
|
|
140
|
+
|
|
141
|
+
def __init__(self, ttl_seconds: Optional[int] = None, maxsize: Optional[int] = None):
|
|
142
|
+
"""Initialize the cache provider.
|
|
143
|
+
|
|
144
|
+
Args:
|
|
145
|
+
ttl_seconds: Time-to-live in seconds for cache entries (defaults to config value)
|
|
146
|
+
maxsize: Maximum number of entries in the cache (defaults to config value)
|
|
147
|
+
|
|
148
|
+
"""
|
|
149
|
+
if not CACHETOOLS_AVAILABLE or cachetools is None:
|
|
150
|
+
raise ImportError('cachetools not available')
|
|
151
|
+
|
|
152
|
+
# Use configuration values if not explicitly provided
|
|
153
|
+
ttl_seconds = ttl_seconds if ttl_seconds is not None else CACHE_TTL
|
|
154
|
+
maxsize = maxsize if maxsize is not None else CACHE_MAXSIZE
|
|
155
|
+
|
|
156
|
+
self._cache = cachetools.TTLCache(maxsize=maxsize, ttl=ttl_seconds)
|
|
157
|
+
logger.debug(
|
|
158
|
+
f'Created cachetools cache provider with TTL of {ttl_seconds} seconds and maxsize of {maxsize}'
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
def get(self, key: str) -> Optional[T]:
|
|
162
|
+
"""Get a value from the cache."""
|
|
163
|
+
try:
|
|
164
|
+
value = self._cache[key]
|
|
165
|
+
logger.debug(f'Cache hit: {key}')
|
|
166
|
+
return value
|
|
167
|
+
except KeyError:
|
|
168
|
+
logger.debug(f'Cache miss: {key}')
|
|
169
|
+
return None
|
|
170
|
+
|
|
171
|
+
def set(self, key: str, value: T) -> None:
|
|
172
|
+
"""Set a value in the cache."""
|
|
173
|
+
self._cache[key] = value
|
|
174
|
+
logger.debug(f'Cache set: {key}')
|
|
175
|
+
|
|
176
|
+
def invalidate(self, key: str) -> bool:
|
|
177
|
+
"""Invalidate a cache entry."""
|
|
178
|
+
try:
|
|
179
|
+
del self._cache[key]
|
|
180
|
+
logger.debug(f'Cache invalidated: {key}')
|
|
181
|
+
return True
|
|
182
|
+
except KeyError:
|
|
183
|
+
return False
|
|
184
|
+
|
|
185
|
+
def clear(self) -> None:
|
|
186
|
+
"""Clear all entries from the cache."""
|
|
187
|
+
count = len(self._cache)
|
|
188
|
+
self._cache.clear()
|
|
189
|
+
logger.debug(f'Cache cleared ({count} entries removed)')
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
# Create the appropriate cache provider based on configuration
|
|
193
|
+
def create_cache_provider(ttl_seconds: Optional[int] = None) -> CacheProvider:
|
|
194
|
+
"""Create a cache provider based on configuration.
|
|
195
|
+
|
|
196
|
+
Args:
|
|
197
|
+
ttl_seconds: Time-to-live in seconds for cache entries (defaults to config value)
|
|
198
|
+
|
|
199
|
+
Returns:
|
|
200
|
+
CacheProvider: The cache provider
|
|
201
|
+
|
|
202
|
+
"""
|
|
203
|
+
# Use configuration value if not explicitly provided
|
|
204
|
+
ttl_seconds = ttl_seconds if ttl_seconds is not None else CACHE_TTL
|
|
205
|
+
|
|
206
|
+
if USE_CACHETOOLS and CACHETOOLS_AVAILABLE:
|
|
207
|
+
try:
|
|
208
|
+
return CachetoolsProvider(ttl_seconds=ttl_seconds, maxsize=CACHE_MAXSIZE)
|
|
209
|
+
except Exception as e:
|
|
210
|
+
logger.error(f'Failed to create cachetools cache provider: {e}')
|
|
211
|
+
logger.info('Falling back to in-memory cache provider')
|
|
212
|
+
|
|
213
|
+
# Default to in-memory provider
|
|
214
|
+
return InMemoryCacheProvider(ttl_seconds=ttl_seconds)
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
def cached(ttl_seconds: Optional[int] = None) -> Callable:
|
|
218
|
+
"""Cache function results.
|
|
219
|
+
|
|
220
|
+
Args:
|
|
221
|
+
ttl_seconds: Time-to-live in seconds for cache entries (defaults to config value)
|
|
222
|
+
|
|
223
|
+
Returns:
|
|
224
|
+
Callable: Decorated function with caching
|
|
225
|
+
|
|
226
|
+
"""
|
|
227
|
+
cache = create_cache_provider(ttl_seconds=ttl_seconds)
|
|
228
|
+
|
|
229
|
+
def decorator(func: Callable) -> Callable:
|
|
230
|
+
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
|
231
|
+
# Create a cache key from the function name and arguments
|
|
232
|
+
key_parts = [func.__name__]
|
|
233
|
+
key_parts.extend(str(arg) for arg in args)
|
|
234
|
+
key_parts.extend(f'{k}={v}' for k, v in sorted(kwargs.items()))
|
|
235
|
+
cache_key = ':'.join(key_parts)
|
|
236
|
+
|
|
237
|
+
# Try to get from cache first
|
|
238
|
+
cached_result = cache.get(cache_key)
|
|
239
|
+
if cached_result is not None:
|
|
240
|
+
return cached_result
|
|
241
|
+
|
|
242
|
+
# Call the function and cache the result
|
|
243
|
+
result = func(*args, **kwargs)
|
|
244
|
+
cache.set(cache_key, result)
|
|
245
|
+
return result
|
|
246
|
+
|
|
247
|
+
return wrapper
|
|
248
|
+
|
|
249
|
+
return decorator
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
"""Configuration utilities for the OpenAPI MCP Server."""
|
|
15
|
+
|
|
16
|
+
import os
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
# Metrics configuration
|
|
20
|
+
METRICS_MAX_HISTORY = int(os.environ.get('METRICS_MAX_HISTORY', '100'))
|
|
21
|
+
USE_PROMETHEUS = os.environ.get('ENABLE_PROMETHEUS', 'false').lower() == 'true'
|
|
22
|
+
PROMETHEUS_PORT = int(os.environ.get('PROMETHEUS_PORT', '9090'))
|
|
23
|
+
|
|
24
|
+
# Operation prompts configuration
|
|
25
|
+
ENABLE_OPERATION_PROMPTS = os.environ.get('ENABLE_OPERATION_PROMPTS', 'true').lower() == 'true'
|
|
26
|
+
|
|
27
|
+
# HTTP client configuration
|
|
28
|
+
HTTP_MAX_CONNECTIONS = int(os.environ.get('HTTP_MAX_CONNECTIONS', '100'))
|
|
29
|
+
HTTP_MAX_KEEPALIVE = int(os.environ.get('HTTP_MAX_KEEPALIVE', '20'))
|
|
30
|
+
USE_TENACITY = os.environ.get('USE_TENACITY', 'true').lower() == 'true'
|
|
31
|
+
|
|
32
|
+
# Cache configuration
|
|
33
|
+
CACHE_MAXSIZE = int(os.environ.get('CACHE_MAXSIZE', '1000'))
|
|
34
|
+
CACHE_TTL = int(os.environ.get('CACHE_TTL', '3600')) # 1 hour default
|
|
35
|
+
USE_CACHETOOLS = os.environ.get('USE_CACHETOOLS', 'true').lower() == 'true'
|
|
@@ -0,0 +1,349 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
"""Utilities for error handling in the OpenAPI MCP Server."""
|
|
15
|
+
|
|
16
|
+
import httpx
|
|
17
|
+
import json
|
|
18
|
+
from awslabs.openapi_mcp_server import logger
|
|
19
|
+
from typing import Any, Dict, Optional, Type
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class APIError(Exception):
|
|
23
|
+
"""Base exception class for API errors."""
|
|
24
|
+
|
|
25
|
+
def __init__(
|
|
26
|
+
self,
|
|
27
|
+
status_code: int,
|
|
28
|
+
message: str,
|
|
29
|
+
details: Any = None,
|
|
30
|
+
original_error: Optional[Exception] = None,
|
|
31
|
+
):
|
|
32
|
+
"""Initialize the API error.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
status_code: HTTP status code
|
|
36
|
+
message: Error message
|
|
37
|
+
details: Additional error details
|
|
38
|
+
original_error: Original exception that caused this error
|
|
39
|
+
|
|
40
|
+
"""
|
|
41
|
+
self.status_code = status_code
|
|
42
|
+
self.message = message
|
|
43
|
+
self.details = {} if details is None else details
|
|
44
|
+
self.original_error = original_error
|
|
45
|
+
super().__init__(message)
|
|
46
|
+
|
|
47
|
+
def __str__(self) -> str:
|
|
48
|
+
"""Return a string representation of the error."""
|
|
49
|
+
return f'{self.status_code}: {self.message}'
|
|
50
|
+
|
|
51
|
+
def __repr__(self) -> str:
|
|
52
|
+
"""Return a representation of the error."""
|
|
53
|
+
return f'{self.__class__.__name__}({self.status_code}, {repr(self.message)}, {repr(self.details)})'
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class AuthenticationError(APIError):
|
|
57
|
+
"""Exception raised for authentication errors (401)."""
|
|
58
|
+
|
|
59
|
+
pass
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class AuthorizationError(APIError):
|
|
63
|
+
"""Exception raised for authorization errors (403)."""
|
|
64
|
+
|
|
65
|
+
pass
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class ResourceNotFoundError(APIError):
|
|
69
|
+
"""Exception raised for resource not found errors (404)."""
|
|
70
|
+
|
|
71
|
+
pass
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class ValidationError(APIError):
|
|
75
|
+
"""Exception raised for validation errors (422)."""
|
|
76
|
+
|
|
77
|
+
pass
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
class RateLimitError(APIError):
|
|
81
|
+
"""Exception raised for rate limit errors (429)."""
|
|
82
|
+
|
|
83
|
+
pass
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class ServerError(APIError):
|
|
87
|
+
"""Exception raised for server errors (5xx)."""
|
|
88
|
+
|
|
89
|
+
pass
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
class ConnectionError(APIError):
|
|
93
|
+
"""Exception raised for connection errors."""
|
|
94
|
+
|
|
95
|
+
pass
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
class NetworkError(APIError):
|
|
99
|
+
"""Exception raised for network errors."""
|
|
100
|
+
|
|
101
|
+
pass
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
# Map status codes to error classes
|
|
105
|
+
ERROR_CLASSES: Dict[Any, Type[APIError]] = {
|
|
106
|
+
400: ValidationError,
|
|
107
|
+
401: AuthenticationError,
|
|
108
|
+
403: AuthorizationError,
|
|
109
|
+
404: ResourceNotFoundError,
|
|
110
|
+
422: ValidationError,
|
|
111
|
+
429: RateLimitError,
|
|
112
|
+
500: ServerError,
|
|
113
|
+
502: ServerError,
|
|
114
|
+
503: ServerError,
|
|
115
|
+
504: ServerError,
|
|
116
|
+
# Request error types
|
|
117
|
+
httpx.ConnectTimeout: ConnectionError,
|
|
118
|
+
httpx.ReadTimeout: ConnectionError,
|
|
119
|
+
httpx.ConnectError: NetworkError,
|
|
120
|
+
httpx.RequestError: NetworkError,
|
|
121
|
+
}
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def extract_error_details(response: httpx.Response) -> Dict[str, Any]:
|
|
125
|
+
"""Extract error details from an HTTP response.
|
|
126
|
+
|
|
127
|
+
Args:
|
|
128
|
+
response: The HTTP response
|
|
129
|
+
|
|
130
|
+
Returns:
|
|
131
|
+
A dictionary of error details
|
|
132
|
+
|
|
133
|
+
"""
|
|
134
|
+
details = {}
|
|
135
|
+
|
|
136
|
+
# Try to parse JSON response
|
|
137
|
+
try:
|
|
138
|
+
if response.headers.get('content-type', '').startswith('application/json'):
|
|
139
|
+
details = response.json()
|
|
140
|
+
except Exception as e:
|
|
141
|
+
logger.debug(f'Failed to parse JSON response: {e}')
|
|
142
|
+
|
|
143
|
+
# If we couldn't parse JSON, use the text
|
|
144
|
+
if not details and response.text:
|
|
145
|
+
details = {'message': response.text}
|
|
146
|
+
|
|
147
|
+
return details
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
def format_error_message(status_code: int, reason: str, details: Dict[str, Any]) -> str:
|
|
151
|
+
"""Format an error message from status code, reason, and details.
|
|
152
|
+
|
|
153
|
+
Args:
|
|
154
|
+
status_code: HTTP status code
|
|
155
|
+
reason: HTTP reason phrase
|
|
156
|
+
details: Additional error details
|
|
157
|
+
|
|
158
|
+
Returns:
|
|
159
|
+
A formatted error message
|
|
160
|
+
|
|
161
|
+
"""
|
|
162
|
+
# Start with the status code and reason
|
|
163
|
+
message = f'{status_code} {reason}'
|
|
164
|
+
|
|
165
|
+
# Add details if available
|
|
166
|
+
if details:
|
|
167
|
+
# Try to extract a message from the details
|
|
168
|
+
if 'message' in details:
|
|
169
|
+
message += f': {details["message"]}'
|
|
170
|
+
elif 'error' in details:
|
|
171
|
+
if isinstance(details['error'], str):
|
|
172
|
+
message += f': {details["error"]}'
|
|
173
|
+
elif isinstance(details['error'], dict) and 'message' in details['error']:
|
|
174
|
+
message += f': {details["error"]["message"]}'
|
|
175
|
+
|
|
176
|
+
# Add troubleshooting tips based on status code
|
|
177
|
+
if status_code == 401:
|
|
178
|
+
message += '\n\nTROUBLESHOOTING: Authentication error. Please check your credentials or ensure your token is valid. You may need to refresh your authentication tokens.'
|
|
179
|
+
elif status_code == 403:
|
|
180
|
+
message += "\n\nTROUBLESHOOTING: Authorization error. You don't have permission to access this resource. Please check your IAM permissions or API key scope."
|
|
181
|
+
|
|
182
|
+
return message
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
def handle_http_error(error: httpx.HTTPStatusError) -> APIError:
|
|
186
|
+
"""Convert an HTTPX error to an appropriate APIError subclass."""
|
|
187
|
+
status_code = error.response.status_code
|
|
188
|
+
details = extract_error_details(error.response)
|
|
189
|
+
message = format_error_message(status_code, error.response.reason_phrase, details)
|
|
190
|
+
|
|
191
|
+
# Enhanced logging for auth errors
|
|
192
|
+
if status_code == 401:
|
|
193
|
+
# Extract and log authorization header (masked) for debugging
|
|
194
|
+
request = error.request
|
|
195
|
+
if request and hasattr(request, 'headers') and 'Authorization' in request.headers:
|
|
196
|
+
auth_header = request.headers['Authorization']
|
|
197
|
+
# Safely mask the token
|
|
198
|
+
if auth_header.startswith('Bearer '):
|
|
199
|
+
token = auth_header[7:]
|
|
200
|
+
# Try to decode JWT token for debugging (without validation)
|
|
201
|
+
try:
|
|
202
|
+
# Split the token into parts
|
|
203
|
+
parts = token.split('.')
|
|
204
|
+
if len(parts) == 3:
|
|
205
|
+
# Decode the payload (middle part)
|
|
206
|
+
# Add padding if needed
|
|
207
|
+
payload = parts[1]
|
|
208
|
+
padding = len(payload) % 4
|
|
209
|
+
if padding:
|
|
210
|
+
payload += '=' * (4 - padding)
|
|
211
|
+
|
|
212
|
+
# Decode base64
|
|
213
|
+
import base64
|
|
214
|
+
|
|
215
|
+
decoded = base64.b64decode(payload)
|
|
216
|
+
payload_data = json.loads(decoded)
|
|
217
|
+
|
|
218
|
+
# Check for expiration
|
|
219
|
+
if 'exp' in payload_data:
|
|
220
|
+
import time
|
|
221
|
+
|
|
222
|
+
exp_time = payload_data['exp']
|
|
223
|
+
now = int(time.time())
|
|
224
|
+
remaining = exp_time - now
|
|
225
|
+
|
|
226
|
+
if remaining < 0:
|
|
227
|
+
logger.warning(f'Token expired {abs(remaining)} seconds ago')
|
|
228
|
+
else:
|
|
229
|
+
logger.debug(
|
|
230
|
+
f'Token expiration: {exp_time}, Current time: {now}, Remaining: {remaining}s'
|
|
231
|
+
)
|
|
232
|
+
except Exception as e:
|
|
233
|
+
logger.warning(f'Could not decode token for debugging: {e}')
|
|
234
|
+
|
|
235
|
+
# Use the appropriate error class based on status code
|
|
236
|
+
error_class = ERROR_CLASSES.get(status_code, APIError)
|
|
237
|
+
return error_class(status_code, message, details=details, original_error=error)
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
def handle_request_error(error: httpx.RequestError) -> APIError:
|
|
241
|
+
"""Convert an HTTPX request error to an appropriate APIError subclass."""
|
|
242
|
+
# Map different request error types to different messages
|
|
243
|
+
error_class = ConnectionError
|
|
244
|
+
for error_type in [
|
|
245
|
+
httpx.ConnectTimeout,
|
|
246
|
+
httpx.ReadTimeout,
|
|
247
|
+
httpx.ConnectError,
|
|
248
|
+
httpx.RequestError,
|
|
249
|
+
]:
|
|
250
|
+
if isinstance(error, error_type):
|
|
251
|
+
error_class = ERROR_CLASSES.get(error_type, ConnectionError)
|
|
252
|
+
break
|
|
253
|
+
|
|
254
|
+
# Get more specific error message based on error type
|
|
255
|
+
if isinstance(error, httpx.ConnectTimeout):
|
|
256
|
+
message = 'Connection timed out: The server took too long to respond'
|
|
257
|
+
elif isinstance(error, httpx.ReadTimeout):
|
|
258
|
+
message = 'Read timed out: The server took too long to send a response'
|
|
259
|
+
elif isinstance(error, httpx.ConnectError):
|
|
260
|
+
message = f'Connection error: Could not connect to the server: {error}'
|
|
261
|
+
else:
|
|
262
|
+
message = f'Request error: {error}'
|
|
263
|
+
|
|
264
|
+
# Create the error
|
|
265
|
+
return error_class(500, message, original_error=error)
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
async def safe_request(
|
|
269
|
+
client: httpx.AsyncClient, method: str, url: str, **kwargs
|
|
270
|
+
) -> httpx.Response:
|
|
271
|
+
"""Execute an HTTP request with comprehensive error handling.
|
|
272
|
+
|
|
273
|
+
Args:
|
|
274
|
+
client: The HTTPX client to use for the request
|
|
275
|
+
method: The HTTP method to use
|
|
276
|
+
url: The URL to request
|
|
277
|
+
**kwargs: Additional arguments to pass to the client's request method
|
|
278
|
+
|
|
279
|
+
Returns:
|
|
280
|
+
The HTTP response
|
|
281
|
+
|
|
282
|
+
Raises:
|
|
283
|
+
APIError: If an error occurs during the request
|
|
284
|
+
|
|
285
|
+
"""
|
|
286
|
+
try:
|
|
287
|
+
# Log request details at DEBUG level
|
|
288
|
+
request_details = {
|
|
289
|
+
'method': method,
|
|
290
|
+
'url': url,
|
|
291
|
+
}
|
|
292
|
+
|
|
293
|
+
# Log headers (safely) if present in kwargs
|
|
294
|
+
if 'headers' in kwargs and kwargs['headers']:
|
|
295
|
+
sanitized_headers = {}
|
|
296
|
+
for header, value in kwargs['headers'].items():
|
|
297
|
+
if header.lower() == 'authorization' and value:
|
|
298
|
+
# Mask authorization header for security
|
|
299
|
+
if value.startswith('Bearer ') and len(value) > 15:
|
|
300
|
+
sanitized_headers[header] = (
|
|
301
|
+
'Bearer ' + value[7:15] + '...' + value[-8:]
|
|
302
|
+
if len(value) > 30
|
|
303
|
+
else 'Bearer ****'
|
|
304
|
+
)
|
|
305
|
+
else:
|
|
306
|
+
sanitized_headers[header] = '[MASKED]'
|
|
307
|
+
else:
|
|
308
|
+
sanitized_headers[header] = str(value)
|
|
309
|
+
request_details['headers'] = sanitized_headers
|
|
310
|
+
|
|
311
|
+
# Log query params if present
|
|
312
|
+
if 'params' in kwargs and kwargs['params']:
|
|
313
|
+
# Convert all values to strings to avoid type issues
|
|
314
|
+
params_dict = {}
|
|
315
|
+
for k, v in kwargs['params'].items():
|
|
316
|
+
params_dict[k] = str(v) if v is not None else None
|
|
317
|
+
request_details['params'] = params_dict
|
|
318
|
+
|
|
319
|
+
logger.debug(f'Making HTTP request: {request_details}')
|
|
320
|
+
|
|
321
|
+
# Make the request
|
|
322
|
+
response = await client.request(method=method, url=url, **kwargs)
|
|
323
|
+
|
|
324
|
+
# Log response details at DEBUG level
|
|
325
|
+
logger.debug(f'Response: {response.status_code} {response.reason_phrase}')
|
|
326
|
+
|
|
327
|
+
# Raise an exception for 4xx/5xx responses
|
|
328
|
+
response.raise_for_status()
|
|
329
|
+
|
|
330
|
+
return response
|
|
331
|
+
|
|
332
|
+
except httpx.HTTPStatusError as e:
|
|
333
|
+
# Handle HTTP errors (4xx, 5xx)
|
|
334
|
+
logger.error(
|
|
335
|
+
f'HTTP error when accessing {url}: {e.response.status_code} {e.response.reason_phrase}'
|
|
336
|
+
)
|
|
337
|
+
raise handle_http_error(e)
|
|
338
|
+
|
|
339
|
+
except httpx.RequestError as e:
|
|
340
|
+
# Handle request errors (connection, timeout, etc.)
|
|
341
|
+
logger.error(f'Request error when accessing {url}: {e}')
|
|
342
|
+
|
|
343
|
+
# Create a more specific error
|
|
344
|
+
raise handle_request_error(e)
|
|
345
|
+
|
|
346
|
+
except Exception as e:
|
|
347
|
+
# Handle unexpected errors
|
|
348
|
+
logger.error(f'Unexpected error when accessing {url}: {e}')
|
|
349
|
+
raise APIError(500, f'Unexpected error: {e}', original_error=e)
|